diff --git a/src/lib.rs b/src/lib.rs
index 80b548d..391cd9f 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -5,6 +5,7 @@ pub mod algebra;
pub mod derivable;
pub mod layer;
pub mod network;
+pub mod optimize;
pub mod train;
mod utils;
@@ -21,5 +22,6 @@ pub mod prelude {
pub use crate::network::sequential::{
NeuraSequential, NeuraSequentialConstruct, NeuraSequentialTail,
};
- pub use crate::train::{NeuraBackprop, NeuraBatchedTrainer};
+ pub use crate::optimize::NeuraBackprop;
+ pub use crate::train::NeuraBatchedTrainer;
}
diff --git a/src/network/mod.rs b/src/network/mod.rs
index 5527e21..ac37723 100644
--- a/src/network/mod.rs
+++ b/src/network/mod.rs
@@ -1,25 +1,29 @@
-use crate::{algebra::NeuraVectorSpace, derivable::NeuraLoss, layer::NeuraLayer};
+use crate::{algebra::NeuraVectorSpace, layer::NeuraLayer, optimize::NeuraOptimizerBase};
pub mod sequential;
-pub trait NeuraTrainableNetwork: NeuraLayer {
+pub trait NeuraTrainableNetworkBase: NeuraLayer {
type Gradient: NeuraVectorSpace;
+ type LayerOutput;
fn default_gradient(&self) -> Self::Gradient;
fn apply_gradient(&mut self, gradient: &Self::Gradient);
- /// Should implement the backpropagation algorithm, see `NeuraTrainableLayer::backpropagate` for more information.
- fn backpropagate>(
- &self,
- input: &Input,
- target: &Loss::Target,
- loss: Loss,
- ) -> (Input, Self::Gradient);
-
/// Should return the regularization gradient
fn regularize(&self) -> Self::Gradient;
/// Called before an iteration begins, to allow the network to set itself up for training or not.
fn prepare(&mut self, train_iteration: bool);
}
+
+pub trait NeuraTrainableNetwork: NeuraTrainableNetworkBase
+where
+ Optimizer: NeuraOptimizerBase,
+{
+ fn traverse(
+ &self,
+ input: &Input,
+ optimizer: &Optimizer,
+ ) -> Optimizer::Output;
+}
diff --git a/src/network/sequential/mod.rs b/src/network/sequential/mod.rs
index ed95a87..7a0ee72 100644
--- a/src/network/sequential/mod.rs
+++ b/src/network/sequential/mod.rs
@@ -1,7 +1,7 @@
-use super::NeuraTrainableNetwork;
+use super::{NeuraTrainableNetwork, NeuraTrainableNetworkBase};
use crate::{
- derivable::NeuraLoss,
layer::{NeuraLayer, NeuraPartialLayer, NeuraShape, NeuraTrainableLayer},
+ optimize::{NeuraOptimizerFinal, NeuraOptimizerTransient},
};
mod construct;
@@ -129,10 +129,11 @@ impl<
impl<
Input,
Layer: NeuraTrainableLayer,
- ChildNetwork: NeuraTrainableNetwork,
- > NeuraTrainableNetwork for NeuraSequential
+ ChildNetwork: NeuraTrainableNetworkBase,
+ > NeuraTrainableNetworkBase for NeuraSequential
{
type Gradient = (Layer::Gradient, Box);
+ type LayerOutput = Layer::Output;
fn default_gradient(&self) -> Self::Gradient {
(
@@ -146,25 +147,6 @@ impl<
self.child_network.apply_gradient(&gradient.1);
}
- fn backpropagate>(
- &self,
- input: &Input,
- target: &Loss::Target,
- loss: Loss,
- ) -> (Input, Self::Gradient) {
- let next_activation = self.layer.eval(input);
- let (backprop_gradient, weights_gradient) =
- self.child_network
- .backpropagate(&next_activation, target, loss);
- let (backprop_gradient, layer_gradient) =
- self.layer.backprop_layer(input, backprop_gradient);
-
- (
- backprop_gradient,
- (layer_gradient, Box::new(weights_gradient)),
- )
- }
-
fn regularize(&self) -> Self::Gradient {
(
self.layer.regularize_layer(),
@@ -179,8 +161,9 @@ impl<
}
/// A dummy implementation of `NeuraTrainableNetwork`, which simply calls `loss.eval` in `backpropagate`.
-impl NeuraTrainableNetwork for () {
+impl NeuraTrainableNetworkBase for () {
type Gradient = ();
+ type LayerOutput = Input;
#[inline(always)]
fn default_gradient(&self) -> () {
@@ -192,18 +175,6 @@ impl NeuraTrainableNetwork for () {
// Noop
}
- #[inline(always)]
- fn backpropagate>(
- &self,
- final_activation: &Input,
- target: &Loss::Target,
- loss: Loss,
- ) -> (Input, Self::Gradient) {
- let backprop_epsilon = loss.nabla(target, &final_activation);
-
- (backprop_epsilon, ())
- }
-
#[inline(always)]
fn regularize(&self) -> () {
()
@@ -215,6 +186,44 @@ impl NeuraTrainableNetwork for () {
}
}
+impl<
+ Input,
+ Layer: NeuraTrainableLayer,
+ Optimizer: NeuraOptimizerTransient,
+ ChildNetwork: NeuraTrainableNetworkBase,
+ > NeuraTrainableNetwork for NeuraSequential
+where
+ ChildNetwork: NeuraTrainableNetwork,
+{
+ fn traverse(
+ &self,
+ input: &Input,
+ optimizer: &Optimizer,
+ ) -> Optimizer::Output {
+ let next_activation = self.layer.eval(input);
+ let child_result = self.child_network.traverse(&next_activation, optimizer);
+
+ optimizer.eval_layer(
+ &self.layer,
+ input,
+ child_result,
+ |layer_gradient, child_gradient| (layer_gradient, Box::new(child_gradient)),
+ )
+ }
+}
+
+impl> NeuraTrainableNetwork
+ for ()
+{
+ fn traverse(
+ &self,
+ input: &Input,
+ optimizer: &Optimizer,
+ ) -> Optimizer::Output {
+ optimizer.eval_final(input.clone())
+ }
+}
+
impl From for NeuraSequential {
fn from(layer: Layer) -> Self {
Self {
diff --git a/src/optimize.rs b/src/optimize.rs
new file mode 100644
index 0000000..9a58b71
--- /dev/null
+++ b/src/optimize.rs
@@ -0,0 +1,112 @@
+use num::ToPrimitive;
+
+use crate::{
+ derivable::NeuraLoss,
+ layer::NeuraTrainableLayer,
+ network::{NeuraTrainableNetwork, NeuraTrainableNetworkBase},
+};
+
+pub trait NeuraOptimizerBase {
+ type Output;
+}
+
+pub trait NeuraOptimizerFinal: NeuraOptimizerBase {
+ fn eval_final(&self, output: LayerOutput) -> Self::Output;
+}
+
+pub trait NeuraOptimizerTransient: NeuraOptimizerBase {
+ fn eval_layer<
+ Input,
+ NetworkGradient,
+ RecGradient,
+ Layer: NeuraTrainableLayer,
+ >(
+ &self,
+ layer: &Layer,
+ input: &Input,
+ rec_opt_output: Self::Output,
+ combine_gradients: impl Fn(Layer::Gradient, RecGradient) -> NetworkGradient,
+ ) -> Self::Output;
+}
+
+pub trait NeuraOptimizer> {
+ fn get_gradient(
+ &self,
+ trainable: &Trainable,
+ input: &Input,
+ target: &Target,
+ ) -> Trainable::Gradient;
+
+ fn score(&self, trainable: &Trainable, input: &Input, target: &Target) -> f64;
+}
+
+pub struct NeuraBackprop {
+ loss: Loss,
+}
+
+impl NeuraBackprop {
+ pub fn new(loss: Loss) -> Self {
+ Self { loss }
+ }
+}
+
+impl<
+ Input,
+ Target,
+ Trainable: NeuraTrainableNetworkBase,
+ Loss: NeuraLoss + Clone,
+ > NeuraOptimizer for NeuraBackprop
+where
+ >::Output: ToPrimitive,
+ Trainable: for<'a> NeuraTrainableNetwork, &'a Target)>,
+{
+ fn get_gradient(
+ &self,
+ trainable: &Trainable,
+ input: &Input,
+ target: &Target,
+ ) -> Trainable::Gradient {
+ let (_, gradient) = trainable.traverse(input, &(self, target));
+
+ gradient
+ }
+
+ fn score(&self, trainable: &Trainable, input: &Input, target: &Target) -> f64 {
+ let output = trainable.eval(&input);
+ self.loss.eval(target, &output).to_f64().unwrap()
+ }
+}
+
+impl NeuraOptimizerBase for (&NeuraBackprop, &Target) {
+ type Output = (NetworkInput, NetworkGradient); // epsilon, gradient
+}
+
+impl>
+ NeuraOptimizerFinal for (&NeuraBackprop, &Target)
+{
+ fn eval_final(&self, output: LayerOutput) -> Self::Output {
+ (self.0.loss.nabla(self.1, &output), ())
+ }
+}
+
+impl NeuraOptimizerTransient
+ for (&NeuraBackprop, &Target)
+{
+ fn eval_layer<
+ Input,
+ NetworkGradient,
+ RecGradient,
+ Layer: NeuraTrainableLayer,
+ >(
+ &self,
+ layer: &Layer,
+ input: &Input,
+ rec_opt_output: Self::Output,
+ combine_gradients: impl Fn(Layer::Gradient, RecGradient) -> NetworkGradient,
+ ) -> Self::Output {
+ let (epsilon_in, rec_gradient) = rec_opt_output;
+ let (epsilon_out, layer_gradient) = layer.backprop_layer(input, epsilon_in);
+
+ (epsilon_out, combine_gradients(layer_gradient, rec_gradient))
+ }
+}
diff --git a/src/train.rs b/src/train.rs
index a331955..1e45d5e 100644
--- a/src/train.rs
+++ b/src/train.rs
@@ -1,52 +1,6 @@
-use num::ToPrimitive;
-
-use crate::{algebra::NeuraVectorSpace, derivable::NeuraLoss, network::NeuraTrainableNetwork};
-
-pub trait NeuraGradientSolver> {
- fn get_gradient(
- &self,
- trainable: &Trainable,
- input: &Input,
- target: &Target,
- ) -> Trainable::Gradient;
-
- fn score(&self, trainable: &Trainable, input: &Input, target: &Target) -> f64;
-}
-
-#[non_exhaustive]
-pub struct NeuraBackprop {
- loss: Loss,
-}
-
-impl NeuraBackprop {
- pub fn new(loss: Loss) -> Self {
- Self { loss }
- }
-}
-
-impl<
- Input,
- Target,
- Trainable: NeuraTrainableNetwork,
- Loss: NeuraLoss + Clone,
- > NeuraGradientSolver for NeuraBackprop
-where
- >::Output: ToPrimitive,
-{
- fn get_gradient(
- &self,
- trainable: &Trainable,
- input: &Input,
- target: &Target,
- ) -> Trainable::Gradient {
- trainable.backpropagate(input, target, self.loss.clone()).1
- }
-
- fn score(&self, trainable: &Trainable, input: &Input, target: &Target) -> f64 {
- let output = trainable.eval(&input);
- self.loss.eval(target, &output).to_f64().unwrap()
- }
-}
+use crate::{
+ algebra::NeuraVectorSpace, network::NeuraTrainableNetworkBase, optimize::NeuraOptimizer,
+};
#[non_exhaustive]
pub struct NeuraBatchedTrainer {
@@ -118,8 +72,8 @@ impl NeuraBatchedTrainer {
pub fn train<
Input: Clone,
Target: Clone,
- Network: NeuraTrainableNetwork,
- GradientSolver: NeuraGradientSolver,
+ Network: NeuraTrainableNetworkBase,
+ GradientSolver: NeuraOptimizer,
Inputs: IntoIterator- ,
>(
&self,
@@ -185,10 +139,11 @@ mod test {
use super::*;
use crate::{
assert_approx,
- derivable::{activation::Linear, loss::Euclidean, regularize::NeuraL0},
+ derivable::{activation::Linear, loss::Euclidean, regularize::NeuraL0, NeuraLoss},
layer::{dense::NeuraDenseLayer, NeuraLayer},
network::sequential::{NeuraSequential, NeuraSequentialTail},
neura_sequential,
+ optimize::NeuraBackprop,
};
#[test]