diff --git a/src/network/graph/from.rs b/src/network/graph/from.rs index 47f1364..4802590 100644 --- a/src/network/graph/from.rs +++ b/src/network/graph/from.rs @@ -1,4 +1,7 @@ -use crate::network::residual::{NeuraAxisDefault, NeuraSplitInputs}; +use crate::network::{ + residual::{NeuraAxisDefault, NeuraSplitInputs}, + sequential::NeuraSequentialLast, +}; use super::*; @@ -10,9 +13,9 @@ pub trait FromSequential { ) -> Self; } -impl FromSequential<(), Data> for NeuraGraph { +impl FromSequential for NeuraGraph { fn from_sequential_rec( - _seq: &(), + _seq: &NeuraSequentialLast, nodes: Vec>, input_shape: NeuraShape, ) -> Self { diff --git a/src/network/residual/node.rs b/src/network/residual/node.rs index 6b04e9c..cfc99aa 100644 --- a/src/network/residual/node.rs +++ b/src/network/residual/node.rs @@ -115,7 +115,6 @@ impl< { #[inline(always)] fn output_shape(&self) -> NeuraShape { - todo!("output_shape for NeuraResidualNode is not yet ready"); self.child_network.output_shape() } diff --git a/src/network/sequential/construct.rs b/src/network/sequential/construct.rs index f8a50f5..4f8153a 100644 --- a/src/network/sequential/construct.rs +++ b/src/network/sequential/construct.rs @@ -2,17 +2,17 @@ use crate::err::NeuraRecursiveErr; use super::*; -impl NeuraPartialLayer for NeuraSequential { - type Constructed = NeuraSequential; - type Err = Layer::Err; +// impl NeuraPartialLayer for NeuraSequential { +// type Constructed = NeuraSequential; +// type Err = Layer::Err; - fn construct(self, input_shape: NeuraShape) -> Result { - Ok(NeuraSequential { - layer: self.layer.construct(input_shape)?, - child_network: Box::new(()), - }) - } -} +// fn construct(self, input_shape: NeuraShape) -> Result { +// Ok(NeuraSequential { +// layer: self.layer.construct(input_shape)?, +// child_network: Box::new(()), +// }) +// } +// } impl NeuraPartialLayer for NeuraSequential diff --git a/src/network/sequential/layer_impl.rs b/src/network/sequential/layer_impl.rs index 1d5aa20..283b55c 100644 --- a/src/network/sequential/layer_impl.rs +++ b/src/network/sequential/layer_impl.rs @@ -1,19 +1,11 @@ use super::*; use crate::layer::{NeuraLayer, NeuraLayerBase}; -// impl NeuraLayerBase for NeuraSequential { -// #[inline(always)] -// fn output_shape(&self) -> NeuraShape { -// self.layer.output_shape() -// } -// } - impl NeuraLayerBase for NeuraSequential { #[inline(always)] fn output_shape(&self) -> NeuraShape { - todo!("Have output_shape return Option"); self.child_network.output_shape() } diff --git a/src/network/sequential/lock.rs b/src/network/sequential/lock.rs index 45e2591..211b970 100644 --- a/src/network/sequential/lock.rs +++ b/src/network/sequential/lock.rs @@ -8,11 +8,11 @@ pub trait NeuraSequentialLock { fn lock(self) -> Self::Locked; } -impl NeuraSequentialLock for () { - type Locked = (); +impl NeuraSequentialLock for NeuraSequentialLast { + type Locked = NeuraSequentialLast; fn lock(self) -> Self::Locked { - () + self } } diff --git a/src/network/sequential/mod.rs b/src/network/sequential/mod.rs index a4dca4e..bd01449 100644 --- a/src/network/sequential/mod.rs +++ b/src/network/sequential/mod.rs @@ -42,7 +42,7 @@ pub use tail::*; /// instance. /// /// The operations on the tail end are more complex, and require recursively traversing the `NeuraSequential` structure, -/// until an instance of `NeuraSequential` is found. +/// until an instance of `NeuraSequential` is found. /// If your network feeds into a type that does not implement `NeuraSequentialTail`, then you will not be able to use those operations. #[derive(Clone, Debug)] pub struct NeuraSequential { @@ -76,11 +76,11 @@ impl NeuraSequential { } } -impl From for NeuraSequential { +impl From for NeuraSequential { fn from(layer: Layer) -> Self { Self { layer, - child_network: Box::new(()), + child_network: Box::new(NeuraSequentialLast::default()), } } } @@ -154,7 +154,7 @@ where #[macro_export] macro_rules! neura_sequential { [] => { - () + $crate::network::sequential::NeuraSequentialLast::default() }; [ .. $network:expr $(,)? ] => { diff --git a/src/network/sequential/tail.rs b/src/network/sequential/tail.rs index dd6e9e2..efc4c6d 100644 --- a/src/network/sequential/tail.rs +++ b/src/network/sequential/tail.rs @@ -1,5 +1,127 @@ use super::*; +/// Last element of a NeuraSequential network +#[derive(Clone, Debug, PartialEq, Copy)] +pub struct NeuraSequentialLast { + shape: Option, +} + +impl NeuraPartialLayer for NeuraSequentialLast { + type Constructed = NeuraSequentialLast; + + type Err = (); + + fn construct(mut self, input_shape: NeuraShape) -> Result { + self.shape = Some(input_shape); + Ok(self) + } +} + +impl NeuraLayerBase for NeuraSequentialLast { + type Gradient = (); + + #[inline(always)] + fn output_shape(&self) -> NeuraShape { + self.shape + .expect("Called NeuraSequentialLast::output_shape() without building it") + } + + #[inline(always)] + fn default_gradient(&self) -> Self::Gradient { + () + } +} + +impl NeuraLayer for NeuraSequentialLast { + type Output = Input; + type IntermediaryRepr = (); + + #[inline(always)] + fn eval_training(&self, input: &Input) -> (Self::Output, Self::IntermediaryRepr) { + (input.clone(), ()) + } + + #[inline(always)] + fn backprop_layer( + &self, + _input: &Input, + _intermediary: &Self::IntermediaryRepr, + epsilon: &Self::Output, + ) -> Input { + epsilon.clone() + } +} + +impl NeuraNetworkBase for NeuraSequentialLast { + type Layer = (); + + #[inline(always)] + fn get_layer(&self) -> &Self::Layer { + &() + } +} + +impl NeuraNetworkRec for NeuraSequentialLast { + type NextNode = (); + + #[inline(always)] + fn get_next(&self) -> &Self::NextNode { + &() + } + + #[inline(always)] + fn merge_gradient( + &self, + rec_gradient: ::Gradient, + _layer_gradient: ::Gradient, + ) -> Self::Gradient + where + Self::Layer: NeuraLayerBase, + { + rec_gradient + } +} + +impl NeuraNetwork for NeuraSequentialLast { + type LayerInput = Input; + type NodeOutput = Input; + + fn map_input<'a>(&'_ self, input: &'a Input) -> Cow<'a, Self::LayerInput> { + Cow::Borrowed(input) + } + + fn map_output<'a>( + &'_ self, + _input: &'_ Input, + layer_output: &'a Input, + ) -> Cow<'a, Self::NodeOutput> { + Cow::Borrowed(layer_output) + } + + fn map_gradient_in<'a>( + &'_ self, + _input: &'_ Input, + gradient_in: &'a Self::NodeOutput, + ) -> Cow<'a, Input> { + Cow::Borrowed(gradient_in) + } + + fn map_gradient_out<'a>( + &'_ self, + _input: &'_ Input, + _gradient_in: &'_ Self::NodeOutput, + gradient_out: &'a Self::LayerInput, + ) -> Cow<'a, Input> { + Cow::Borrowed(gradient_out) + } +} + +impl Default for NeuraSequentialLast { + fn default() -> Self { + Self { shape: None } + } +} + /// Operations on the tail end of a sequential network pub trait NeuraSequentialTail { type TailTrimmed; @@ -10,13 +132,13 @@ pub trait NeuraSequentialTail { } // Trimming the last layer returns an empty network -impl NeuraSequentialTail for NeuraSequential { - type TailTrimmed = (); +impl NeuraSequentialTail for NeuraSequential { + type TailTrimmed = NeuraSequentialLast; // GAT :3 - type TailPushed = NeuraSequential>; + type TailPushed = NeuraSequential>; fn trim_tail(self) -> Self::TailTrimmed { - () + NeuraSequentialLast::default() } fn push_tail(self, layer: T) -> Self::TailPushed { @@ -24,7 +146,7 @@ impl NeuraSequentialTail for NeuraSequential { layer: self.layer, child_network: Box::new(NeuraSequential { layer, - child_network: Box::new(()), + child_network: Box::new(NeuraSequentialLast::default()), }), } }