diff --git a/src/optimize.rs b/src/gradient_solver/backprop.rs similarity index 57% rename from src/optimize.rs rename to src/gradient_solver/backprop.rs index 9a58b71..f924f1f 100644 --- a/src/optimize.rs +++ b/src/gradient_solver/backprop.rs @@ -1,44 +1,9 @@ use num::ToPrimitive; -use crate::{ - derivable::NeuraLoss, - layer::NeuraTrainableLayer, - network::{NeuraTrainableNetwork, NeuraTrainableNetworkBase}, -}; +use crate::{network::NeuraTrainableNetworkBase, derivable::NeuraLoss, layer::NeuraTrainableLayer}; -pub trait NeuraOptimizerBase { - type Output; -} - -pub trait NeuraOptimizerFinal: NeuraOptimizerBase { - fn eval_final(&self, output: LayerOutput) -> Self::Output; -} +use super::*; -pub trait NeuraOptimizerTransient: NeuraOptimizerBase { - fn eval_layer< - Input, - NetworkGradient, - RecGradient, - Layer: NeuraTrainableLayer, - >( - &self, - layer: &Layer, - input: &Input, - rec_opt_output: Self::Output, - combine_gradients: impl Fn(Layer::Gradient, RecGradient) -> NetworkGradient, - ) -> Self::Output; -} - -pub trait NeuraOptimizer> { - fn get_gradient( - &self, - trainable: &Trainable, - input: &Input, - target: &Target, - ) -> Trainable::Gradient; - - fn score(&self, trainable: &Trainable, input: &Input, target: &Target) -> f64; -} pub struct NeuraBackprop { loss: Loss, @@ -55,7 +20,7 @@ impl< Target, Trainable: NeuraTrainableNetworkBase, Loss: NeuraLoss + Clone, - > NeuraOptimizer for NeuraBackprop + > NeuraGradientSolver for NeuraBackprop where >::Output: ToPrimitive, Trainable: for<'a> NeuraTrainableNetwork, &'a Target)>, @@ -77,19 +42,19 @@ where } } -impl NeuraOptimizerBase for (&NeuraBackprop, &Target) { +impl NeuraGradientSolverBase for (&NeuraBackprop, &Target) { type Output = (NetworkInput, NetworkGradient); // epsilon, gradient } impl> - NeuraOptimizerFinal for (&NeuraBackprop, &Target) + NeuraGradientSolverFinal for (&NeuraBackprop, &Target) { fn eval_final(&self, output: LayerOutput) -> Self::Output { (self.0.loss.nabla(self.1, &output), ()) } } -impl NeuraOptimizerTransient +impl NeuraGradientSolverTransient for (&NeuraBackprop, &Target) { fn eval_layer< diff --git a/src/gradient_solver/mod.rs b/src/gradient_solver/mod.rs new file mode 100644 index 0000000..83bc595 --- /dev/null +++ b/src/gradient_solver/mod.rs @@ -0,0 +1,41 @@ +mod backprop; +pub use backprop::NeuraBackprop; + +use crate::{ + layer::NeuraTrainableLayer, + network::{NeuraTrainableNetwork, NeuraTrainableNetworkBase}, +}; + +pub trait NeuraGradientSolverBase { + type Output; +} + +pub trait NeuraGradientSolverFinal: NeuraGradientSolverBase { + fn eval_final(&self, output: LayerOutput) -> Self::Output; +} + +pub trait NeuraGradientSolverTransient: NeuraGradientSolverBase { + fn eval_layer< + Input, + NetworkGradient, + RecGradient, + Layer: NeuraTrainableLayer, + >( + &self, + layer: &Layer, + input: &Input, + rec_opt_output: Self::Output, + combine_gradients: impl Fn(Layer::Gradient, RecGradient) -> NetworkGradient, + ) -> Self::Output; +} + +pub trait NeuraGradientSolver> { + fn get_gradient( + &self, + trainable: &Trainable, + input: &Input, + target: &Target, + ) -> Trainable::Gradient; + + fn score(&self, trainable: &Trainable, input: &Input, target: &Target) -> f64; +} diff --git a/src/lib.rs b/src/lib.rs index 391cd9f..d8f1896 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,7 @@ pub mod algebra; pub mod derivable; pub mod layer; pub mod network; -pub mod optimize; +pub mod gradient_solver; pub mod train; mod utils; @@ -22,6 +22,6 @@ pub mod prelude { pub use crate::network::sequential::{ NeuraSequential, NeuraSequentialConstruct, NeuraSequentialTail, }; - pub use crate::optimize::NeuraBackprop; + pub use crate::gradient_solver::NeuraBackprop; pub use crate::train::NeuraBatchedTrainer; } diff --git a/src/network/mod.rs b/src/network/mod.rs index ac37723..0e2556b 100644 --- a/src/network/mod.rs +++ b/src/network/mod.rs @@ -1,4 +1,4 @@ -use crate::{algebra::NeuraVectorSpace, layer::NeuraLayer, optimize::NeuraOptimizerBase}; +use crate::{algebra::NeuraVectorSpace, layer::NeuraLayer, gradient_solver::NeuraGradientSolverBase}; pub mod sequential; @@ -19,7 +19,7 @@ pub trait NeuraTrainableNetworkBase: NeuraLayer { pub trait NeuraTrainableNetwork: NeuraTrainableNetworkBase where - Optimizer: NeuraOptimizerBase, + Optimizer: NeuraGradientSolverBase, { fn traverse( &self, diff --git a/src/network/sequential/mod.rs b/src/network/sequential/mod.rs index 7a0ee72..637a42f 100644 --- a/src/network/sequential/mod.rs +++ b/src/network/sequential/mod.rs @@ -1,7 +1,7 @@ use super::{NeuraTrainableNetwork, NeuraTrainableNetworkBase}; use crate::{ layer::{NeuraLayer, NeuraPartialLayer, NeuraShape, NeuraTrainableLayer}, - optimize::{NeuraOptimizerFinal, NeuraOptimizerTransient}, + gradient_solver::{NeuraGradientSolverFinal, NeuraGradientSolverTransient}, }; mod construct; @@ -189,7 +189,7 @@ impl NeuraTrainableNetworkBase for () { impl< Input, Layer: NeuraTrainableLayer, - Optimizer: NeuraOptimizerTransient, + Optimizer: NeuraGradientSolverTransient, ChildNetwork: NeuraTrainableNetworkBase, > NeuraTrainableNetwork for NeuraSequential where @@ -212,7 +212,7 @@ where } } -impl> NeuraTrainableNetwork +impl> NeuraTrainableNetwork for () { fn traverse( diff --git a/src/train.rs b/src/train.rs index 1e45d5e..1288578 100644 --- a/src/train.rs +++ b/src/train.rs @@ -1,5 +1,5 @@ use crate::{ - algebra::NeuraVectorSpace, network::NeuraTrainableNetworkBase, optimize::NeuraOptimizer, + algebra::NeuraVectorSpace, network::NeuraTrainableNetworkBase, gradient_solver::NeuraGradientSolver, }; #[non_exhaustive] @@ -73,7 +73,7 @@ impl NeuraBatchedTrainer { Input: Clone, Target: Clone, Network: NeuraTrainableNetworkBase, - GradientSolver: NeuraOptimizer, + GradientSolver: NeuraGradientSolver, Inputs: IntoIterator, >( &self, @@ -143,7 +143,7 @@ mod test { layer::{dense::NeuraDenseLayer, NeuraLayer}, network::sequential::{NeuraSequential, NeuraSequentialTail}, neura_sequential, - optimize::NeuraBackprop, + gradient_solver::NeuraBackprop, }; #[test]