diff --git a/src/gradient_solver/backprop.rs b/src/gradient_solver/backprop.rs
index d3e2c34..1214a60 100644
--- a/src/gradient_solver/backprop.rs
+++ b/src/gradient_solver/backprop.rs
@@ -1,8 +1,8 @@
use num::ToPrimitive;
use crate::{
- derivable::NeuraLoss, layer::NeuraTrainableLayerBackprop, layer::NeuraTrainableLayerSelf,
- network::NeuraOldTrainableNetworkBase,
+ derivable::NeuraLoss, layer::*,
+ network::*,
};
use super::*;
@@ -91,9 +91,46 @@ impl<
}
}
+trait BackpropRecurse {
+ fn recurse(&self, network: &Network, input: &Input) -> (Input, Gradient);
+}
+
+impl> BackpropRecurse for (&NeuraBackprop, &Loss::Target) {
+ fn recurse(&self, _network: &(), input: &Input) -> (Input, ()) {
+ (self.0.loss.nabla(self.1, input), ())
+ }
+}
+
+impl<
+ Input: Clone,
+ Network: NeuraNetworkRec + NeuraNetwork + NeuraTrainableLayerBase,
+ Loss,
+ Target
+> BackpropRecurse for (&NeuraBackprop, &Target)
+where
+ // Verify that we can traverse recursively
+ for<'a> (&'a NeuraBackprop, &'a Target): BackpropRecurse>::Gradient>,
+ // Verify that the current layer implements the right traits
+ Network::Layer: NeuraTrainableLayerSelf + NeuraTrainableLayerBackprop,
+ // Verify that the layer output can be cloned
+ >::Output: Clone,
+ Network::NextNode: NeuraTrainableLayerBase,
+{
+ fn recurse(&self, network: &Network, input: &Input) -> (Input, Network::Gradient) {
+ let layer_input = network.map_input(input);
+ let (layer_output, layer_intermediary) = network.get_layer().eval_training(layer_input.as_ref());
+ let output = network.map_output(input, &layer_output);
+
+ let (epsilon_in, gradient_rec) = self.recurse(network.get_next(), output.as_ref());
+
+ todo!()
+ }
+}
+
#[cfg(test)]
mod test {
use approx::assert_relative_eq;
+ use nalgebra::dvector;
use super::*;
use crate::{
@@ -161,4 +198,12 @@ mod test {
assert_relative_eq!(gradient1_actual.as_slice(), gradient1_expected.as_slice());
}
}
+
+ #[test]
+ fn test_recursive() {
+ let backprop = NeuraBackprop::new(Euclidean);
+ let target = dvector![0.0];
+
+ (&backprop, &target).recurse(&(), &dvector![0.0]);
+ }
}
diff --git a/src/layer/mod.rs b/src/layer/mod.rs
index 914bd7b..87fabf7 100644
--- a/src/layer/mod.rs
+++ b/src/layer/mod.rs
@@ -71,6 +71,7 @@ pub trait NeuraTrainableLayerBase: NeuraLayer {
/// Applies `δW_l` to the weights of the layer
fn apply_gradient(&mut self, gradient: &Self::Gradient);
+ // TODO: move this into another trait
fn eval_training(&self, input: &Input) -> (Self::Output, Self::IntermediaryRepr);
/// Arbitrary computation that can be executed at the start of an epoch
diff --git a/src/network/mod.rs b/src/network/mod.rs
index abc148a..b40d08d 100644
--- a/src/network/mod.rs
+++ b/src/network/mod.rs
@@ -5,6 +5,9 @@ use crate::{
// pub mod residual;
pub mod sequential;
+mod traits;
+pub use traits::*;
+
// TODO: extract regularize from this, so that we can drop the trait constraints on NeuraSequential's impl
pub trait NeuraOldTrainableNetworkBase: NeuraLayer {
type Gradient: NeuraVectorSpace;
diff --git a/src/network/traits.rs b/src/network/traits.rs
new file mode 100644
index 0000000..93d8e4f
--- /dev/null
+++ b/src/network/traits.rs
@@ -0,0 +1,44 @@
+use std::borrow::Cow;
+
+use super::*;
+
+/// This trait has to be non-generic, to ensure that no downstream crate can implement it for foreign types,
+/// as that would otherwise cause infinite recursion when dealing with `NeuraNetworkRec`.
+pub trait NeuraNetworkBase {
+ /// The type of the enclosed layer
+ type Layer;
+
+ fn get_layer(&self) -> &Self::Layer;
+}
+
+pub trait NeuraNetwork: NeuraNetworkBase
+where
+ Self::Layer: NeuraLayer,
+ >::Output: Clone
+{
+ /// The type of the input to `Self::Layer`
+ type LayerInput: Clone;
+
+ /// The type of the output of this node
+ type NodeOutput: Clone;
+
+ /// Maps the input of network node to the enclosed layer
+ fn map_input<'a>(&'_ self, input: &'a NodeInput) -> Cow<'a, Self::LayerInput>;
+ /// Maps the output of the enclosed layer to the output of the network node
+ fn map_output<'a>(&'_ self, input: &'_ NodeInput, layer_output: &'a >::Output) -> Cow<'a, Self::NodeOutput>;
+
+ /// Maps a gradient in the format of the node's output into the format of the enclosed layer's output
+ fn map_gradient_in<'a>(&'_ self, input: &'_ NodeInput, gradient_in: &'a Self::NodeOutput) -> Cow<'a, >::Output>;
+ /// Maps a gradient in the format of the enclosed layer's input into the format of the node's input
+ fn map_gradient_out<'a>(&'_ self, input: &'_ NodeInput, gradient_in: &'_ Self::NodeOutput, gradient_out: &'a Self::LayerInput) -> Cow<'a, NodeInput>;
+}
+
+pub trait NeuraNetworkRec: NeuraNetworkBase {
+ /// The type of the children network, it does not need to implement `NeuraNetworkBase`,
+ /// although many functions will expect it to be either `()` or an implementation of `NeuraNetworkRec`.
+ type NextNode;
+
+ fn get_next(&self) -> &Self::NextNode;
+
+
+}