diff --git a/src/gradient_solver/backprop.rs b/src/gradient_solver/backprop.rs index 7c9ede8..921e2b9 100644 --- a/src/gradient_solver/backprop.rs +++ b/src/gradient_solver/backprop.rs @@ -17,20 +17,23 @@ impl NeuraBackprop { impl< Input, Target, - Trainable: NeuraOldTrainableNetworkBase, + Trainable: NeuraTrainableLayerBase + NeuraLayer + NeuraNetworkRec, Loss: NeuraLoss + Clone, > NeuraGradientSolver for NeuraBackprop where >::Output: ToPrimitive, - Trainable: for<'a> NeuraOldTrainableNetwork, &'a Target)>, + // Trainable: NeuraOldTrainableNetworkBase::Gradient>, + // Trainable: for<'a> NeuraOldTrainableNetwork, &'a Target)>, + for<'a> (&'a NeuraBackprop, &'a Target): BackpropRecurse::Gradient> { fn get_gradient( &self, trainable: &Trainable, input: &Input, target: &Target, - ) -> Trainable::Gradient { - let (_, gradient) = trainable.traverse(input, &(self, target)); + ) -> ::Gradient { + let (_, gradient) = (self, target).recurse(trainable, input); + // let (_, gradient) = trainable.traverse(input, &(self, target)); gradient } @@ -119,14 +122,24 @@ where Network::NextNode: NeuraTrainableLayerEval, { fn recurse(&self, network: &Network, input: &Input) -> (Input, Network::Gradient) { + let layer = network.get_layer(); + // Get layer output let layer_input = network.map_input(input); - let (layer_output, layer_intermediary) = - network.get_layer().eval_training(layer_input.as_ref()); + let (layer_output, layer_intermediary) = layer.eval_training(layer_input.as_ref()); let output = network.map_output(input, &layer_output); + // Recurse let (epsilon_in, gradient_rec) = self.recurse(network.get_next(), output.as_ref()); - todo!() + // Get layer outgoing gradient vector + let layer_epsilon_in = network.map_gradient_in(input, &epsilon_in); + let layer_epsilon_out = layer.backprop_layer(&layer_input, &layer_intermediary, &layer_epsilon_in); + let epsilon_out = network.map_gradient_out(input, &epsilon_in, &layer_epsilon_out); + + // Get layer parameter gradient + let gradient = layer.get_gradient(&layer_input, &layer_intermediary, &layer_epsilon_in); + + (epsilon_out.into_owned(), network.merge_gradient(gradient_rec, gradient)) } } diff --git a/src/gradient_solver/forward_forward.rs b/src/gradient_solver/forward_forward.rs index 1ada58e..777ac26 100644 --- a/src/gradient_solver/forward_forward.rs +++ b/src/gradient_solver/forward_forward.rs @@ -30,17 +30,18 @@ impl< F, Act: Clone + NeuraDerivable, Input, - Trainable: NeuraOldTrainableNetwork, Output = DVector>, + Trainable: NeuraTrainableLayerBase, > NeuraGradientSolver for NeuraForwardForward where F: ToPrimitive, + Trainable: NeuraOldTrainableNetwork, Output = DVector, Gradient = ::Gradient> { fn get_gradient( &self, trainable: &Trainable, input: &Input, target: &bool, - ) -> Trainable::Gradient { + ) -> ::Gradient { let target = *target; trainable.traverse( diff --git a/src/gradient_solver/mod.rs b/src/gradient_solver/mod.rs index e6291a2..126bd2f 100644 --- a/src/gradient_solver/mod.rs +++ b/src/gradient_solver/mod.rs @@ -37,7 +37,7 @@ pub trait NeuraGradientSolverTransient Self::Output; } -pub trait NeuraGradientSolver> { +pub trait NeuraGradientSolver { fn get_gradient( &self, trainable: &Trainable, diff --git a/src/network/sequential/mod.rs b/src/network/sequential/mod.rs index 0301f65..8dd5013 100644 --- a/src/network/sequential/mod.rs +++ b/src/network/sequential/mod.rs @@ -1,4 +1,4 @@ -use super::{NeuraOldTrainableNetwork, NeuraOldTrainableNetworkBase}; +use super::*; use crate::{ gradient_solver::NeuraGradientSolverTransient, layer::{ @@ -177,6 +177,30 @@ impl From for NeuraSequential { } } +impl NeuraNetworkBase for NeuraSequential { + type Layer = Layer; + + fn get_layer(&self) -> &Self::Layer { + &self.layer + } +} + +impl NeuraNetworkRec for NeuraSequential { + type NextNode = ChildNetwork; + + fn get_next(&self) -> &Self::NextNode { + &self.child_network + } + + fn merge_gradient( + &self, + rec_gradient: ::Gradient, + layer_gradient: ::Gradient + ) -> Self::Gradient { + (rec_gradient, Box::new(layer_gradient)) + } +} + /// An utility to recursively create a NeuraSequential network, while writing it in a declarative and linear fashion. /// Note that this can quickly create big and unwieldly types. #[macro_export] diff --git a/src/network/traits.rs b/src/network/traits.rs index 7dd2847..a16af93 100644 --- a/src/network/traits.rs +++ b/src/network/traits.rs @@ -1,5 +1,7 @@ use std::borrow::Cow; +use crate::prelude::NeuraTrainableLayerBase; + use super::*; /// This trait has to be non-generic, to ensure that no downstream crate can implement it for foreign types, @@ -46,10 +48,17 @@ where ) -> Cow<'a, NodeInput>; } -pub trait NeuraNetworkRec: NeuraNetworkBase { +pub trait NeuraNetworkRec: NeuraNetworkBase + NeuraTrainableLayerBase { /// The type of the children network, it does not need to implement `NeuraNetworkBase`, /// although many functions will expect it to be either `()` or an implementation of `NeuraNetworkRec`. - type NextNode; + type NextNode: NeuraTrainableLayerBase; fn get_next(&self) -> &Self::NextNode; + + fn merge_gradient( + &self, + rec_gradient: ::Gradient, + layer_gradient: ::Gradient + ) -> Self::Gradient + where Self::Layer: NeuraTrainableLayerBase; }