From d82cab788b65f8015966a222052b2922269937fa Mon Sep 17 00:00:00 2001 From: Adrien Burgun Date: Thu, 27 Apr 2023 17:18:12 +0200 Subject: [PATCH] :sparkles: Create NeuraNetwork traits, WIP --- src/gradient_solver/backprop.rs | 49 +++++++++++++++++++++++++++++++-- src/layer/mod.rs | 1 + src/network/mod.rs | 3 ++ src/network/traits.rs | 44 +++++++++++++++++++++++++++++ 4 files changed, 95 insertions(+), 2 deletions(-) create mode 100644 src/network/traits.rs diff --git a/src/gradient_solver/backprop.rs b/src/gradient_solver/backprop.rs index d3e2c34..1214a60 100644 --- a/src/gradient_solver/backprop.rs +++ b/src/gradient_solver/backprop.rs @@ -1,8 +1,8 @@ use num::ToPrimitive; use crate::{ - derivable::NeuraLoss, layer::NeuraTrainableLayerBackprop, layer::NeuraTrainableLayerSelf, - network::NeuraOldTrainableNetworkBase, + derivable::NeuraLoss, layer::*, + network::*, }; use super::*; @@ -91,9 +91,46 @@ impl< } } +trait BackpropRecurse { + fn recurse(&self, network: &Network, input: &Input) -> (Input, Gradient); +} + +impl> BackpropRecurse for (&NeuraBackprop, &Loss::Target) { + fn recurse(&self, _network: &(), input: &Input) -> (Input, ()) { + (self.0.loss.nabla(self.1, input), ()) + } +} + +impl< + Input: Clone, + Network: NeuraNetworkRec + NeuraNetwork + NeuraTrainableLayerBase, + Loss, + Target +> BackpropRecurse for (&NeuraBackprop, &Target) +where + // Verify that we can traverse recursively + for<'a> (&'a NeuraBackprop, &'a Target): BackpropRecurse>::Gradient>, + // Verify that the current layer implements the right traits + Network::Layer: NeuraTrainableLayerSelf + NeuraTrainableLayerBackprop, + // Verify that the layer output can be cloned + >::Output: Clone, + Network::NextNode: NeuraTrainableLayerBase, +{ + fn recurse(&self, network: &Network, input: &Input) -> (Input, Network::Gradient) { + let layer_input = network.map_input(input); + let (layer_output, layer_intermediary) = network.get_layer().eval_training(layer_input.as_ref()); + let output = network.map_output(input, &layer_output); + + let (epsilon_in, gradient_rec) = self.recurse(network.get_next(), output.as_ref()); + + todo!() + } +} + #[cfg(test)] mod test { use approx::assert_relative_eq; + use nalgebra::dvector; use super::*; use crate::{ @@ -161,4 +198,12 @@ mod test { assert_relative_eq!(gradient1_actual.as_slice(), gradient1_expected.as_slice()); } } + + #[test] + fn test_recursive() { + let backprop = NeuraBackprop::new(Euclidean); + let target = dvector![0.0]; + + (&backprop, &target).recurse(&(), &dvector![0.0]); + } } diff --git a/src/layer/mod.rs b/src/layer/mod.rs index 914bd7b..87fabf7 100644 --- a/src/layer/mod.rs +++ b/src/layer/mod.rs @@ -71,6 +71,7 @@ pub trait NeuraTrainableLayerBase: NeuraLayer { /// Applies `δW_l` to the weights of the layer fn apply_gradient(&mut self, gradient: &Self::Gradient); + // TODO: move this into another trait fn eval_training(&self, input: &Input) -> (Self::Output, Self::IntermediaryRepr); /// Arbitrary computation that can be executed at the start of an epoch diff --git a/src/network/mod.rs b/src/network/mod.rs index abc148a..b40d08d 100644 --- a/src/network/mod.rs +++ b/src/network/mod.rs @@ -5,6 +5,9 @@ use crate::{ // pub mod residual; pub mod sequential; +mod traits; +pub use traits::*; + // TODO: extract regularize from this, so that we can drop the trait constraints on NeuraSequential's impl pub trait NeuraOldTrainableNetworkBase: NeuraLayer { type Gradient: NeuraVectorSpace; diff --git a/src/network/traits.rs b/src/network/traits.rs new file mode 100644 index 0000000..93d8e4f --- /dev/null +++ b/src/network/traits.rs @@ -0,0 +1,44 @@ +use std::borrow::Cow; + +use super::*; + +/// This trait has to be non-generic, to ensure that no downstream crate can implement it for foreign types, +/// as that would otherwise cause infinite recursion when dealing with `NeuraNetworkRec`. +pub trait NeuraNetworkBase { + /// The type of the enclosed layer + type Layer; + + fn get_layer(&self) -> &Self::Layer; +} + +pub trait NeuraNetwork: NeuraNetworkBase +where + Self::Layer: NeuraLayer, + >::Output: Clone +{ + /// The type of the input to `Self::Layer` + type LayerInput: Clone; + + /// The type of the output of this node + type NodeOutput: Clone; + + /// Maps the input of network node to the enclosed layer + fn map_input<'a>(&'_ self, input: &'a NodeInput) -> Cow<'a, Self::LayerInput>; + /// Maps the output of the enclosed layer to the output of the network node + fn map_output<'a>(&'_ self, input: &'_ NodeInput, layer_output: &'a >::Output) -> Cow<'a, Self::NodeOutput>; + + /// Maps a gradient in the format of the node's output into the format of the enclosed layer's output + fn map_gradient_in<'a>(&'_ self, input: &'_ NodeInput, gradient_in: &'a Self::NodeOutput) -> Cow<'a, >::Output>; + /// Maps a gradient in the format of the enclosed layer's input into the format of the node's input + fn map_gradient_out<'a>(&'_ self, input: &'_ NodeInput, gradient_in: &'_ Self::NodeOutput, gradient_out: &'a Self::LayerInput) -> Cow<'a, NodeInput>; +} + +pub trait NeuraNetworkRec: NeuraNetworkBase { + /// The type of the children network, it does not need to implement `NeuraNetworkBase`, + /// although many functions will expect it to be either `()` or an implementation of `NeuraNetworkRec`. + type NextNode; + + fn get_next(&self) -> &Self::NextNode; + + +}