🚚 rename optimize to gradient_solver

main
Shad Amethyst 2 years ago
parent 81de6ddbcd
commit 6d45eafbe7

@ -1,44 +1,9 @@
use num::ToPrimitive;
use crate::{
derivable::NeuraLoss,
layer::NeuraTrainableLayer,
network::{NeuraTrainableNetwork, NeuraTrainableNetworkBase},
};
use crate::{network::NeuraTrainableNetworkBase, derivable::NeuraLoss, layer::NeuraTrainableLayer};
pub trait NeuraOptimizerBase {
type Output<NetworkInput, NetworkGradient>;
}
pub trait NeuraOptimizerFinal<LayerOutput>: NeuraOptimizerBase {
fn eval_final(&self, output: LayerOutput) -> Self::Output<LayerOutput, ()>;
}
use super::*;
pub trait NeuraOptimizerTransient<LayerOutput>: NeuraOptimizerBase {
fn eval_layer<
Input,
NetworkGradient,
RecGradient,
Layer: NeuraTrainableLayer<Input, Output = LayerOutput>,
>(
&self,
layer: &Layer,
input: &Input,
rec_opt_output: Self::Output<LayerOutput, RecGradient>,
combine_gradients: impl Fn(Layer::Gradient, RecGradient) -> NetworkGradient,
) -> Self::Output<Input, NetworkGradient>;
}
pub trait NeuraOptimizer<Input, Target, Trainable: NeuraTrainableNetworkBase<Input>> {
fn get_gradient(
&self,
trainable: &Trainable,
input: &Input,
target: &Target,
) -> Trainable::Gradient;
fn score(&self, trainable: &Trainable, input: &Input, target: &Target) -> f64;
}
pub struct NeuraBackprop<Loss> {
loss: Loss,
@ -55,7 +20,7 @@ impl<
Target,
Trainable: NeuraTrainableNetworkBase<Input>,
Loss: NeuraLoss<Trainable::Output, Target = Target> + Clone,
> NeuraOptimizer<Input, Target, Trainable> for NeuraBackprop<Loss>
> NeuraGradientSolver<Input, Target, Trainable> for NeuraBackprop<Loss>
where
<Loss as NeuraLoss<Trainable::Output>>::Output: ToPrimitive,
Trainable: for<'a> NeuraTrainableNetwork<Input, (&'a NeuraBackprop<Loss>, &'a Target)>,
@ -77,19 +42,19 @@ where
}
}
impl<Loss, Target> NeuraOptimizerBase for (&NeuraBackprop<Loss>, &Target) {
impl<Loss, Target> NeuraGradientSolverBase for (&NeuraBackprop<Loss>, &Target) {
type Output<NetworkInput, NetworkGradient> = (NetworkInput, NetworkGradient); // epsilon, gradient
}
impl<LayerOutput, Target, Loss: NeuraLoss<LayerOutput, Target = Target>>
NeuraOptimizerFinal<LayerOutput> for (&NeuraBackprop<Loss>, &Target)
NeuraGradientSolverFinal<LayerOutput> for (&NeuraBackprop<Loss>, &Target)
{
fn eval_final(&self, output: LayerOutput) -> Self::Output<LayerOutput, ()> {
(self.0.loss.nabla(self.1, &output), ())
}
}
impl<LayerOutput, Target, Loss> NeuraOptimizerTransient<LayerOutput>
impl<LayerOutput, Target, Loss> NeuraGradientSolverTransient<LayerOutput>
for (&NeuraBackprop<Loss>, &Target)
{
fn eval_layer<

@ -0,0 +1,41 @@
mod backprop;
pub use backprop::NeuraBackprop;
use crate::{
layer::NeuraTrainableLayer,
network::{NeuraTrainableNetwork, NeuraTrainableNetworkBase},
};
pub trait NeuraGradientSolverBase {
type Output<NetworkInput, NetworkGradient>;
}
pub trait NeuraGradientSolverFinal<LayerOutput>: NeuraGradientSolverBase {
fn eval_final(&self, output: LayerOutput) -> Self::Output<LayerOutput, ()>;
}
pub trait NeuraGradientSolverTransient<LayerOutput>: NeuraGradientSolverBase {
fn eval_layer<
Input,
NetworkGradient,
RecGradient,
Layer: NeuraTrainableLayer<Input, Output = LayerOutput>,
>(
&self,
layer: &Layer,
input: &Input,
rec_opt_output: Self::Output<LayerOutput, RecGradient>,
combine_gradients: impl Fn(Layer::Gradient, RecGradient) -> NetworkGradient,
) -> Self::Output<Input, NetworkGradient>;
}
pub trait NeuraGradientSolver<Input, Target, Trainable: NeuraTrainableNetworkBase<Input>> {
fn get_gradient(
&self,
trainable: &Trainable,
input: &Input,
target: &Target,
) -> Trainable::Gradient;
fn score(&self, trainable: &Trainable, input: &Input, target: &Target) -> f64;
}

@ -5,7 +5,7 @@ pub mod algebra;
pub mod derivable;
pub mod layer;
pub mod network;
pub mod optimize;
pub mod gradient_solver;
pub mod train;
mod utils;
@ -22,6 +22,6 @@ pub mod prelude {
pub use crate::network::sequential::{
NeuraSequential, NeuraSequentialConstruct, NeuraSequentialTail,
};
pub use crate::optimize::NeuraBackprop;
pub use crate::gradient_solver::NeuraBackprop;
pub use crate::train::NeuraBatchedTrainer;
}

@ -1,4 +1,4 @@
use crate::{algebra::NeuraVectorSpace, layer::NeuraLayer, optimize::NeuraOptimizerBase};
use crate::{algebra::NeuraVectorSpace, layer::NeuraLayer, gradient_solver::NeuraGradientSolverBase};
pub mod sequential;
@ -19,7 +19,7 @@ pub trait NeuraTrainableNetworkBase<Input>: NeuraLayer<Input> {
pub trait NeuraTrainableNetwork<Input, Optimizer>: NeuraTrainableNetworkBase<Input>
where
Optimizer: NeuraOptimizerBase,
Optimizer: NeuraGradientSolverBase,
{
fn traverse(
&self,

@ -1,7 +1,7 @@
use super::{NeuraTrainableNetwork, NeuraTrainableNetworkBase};
use crate::{
layer::{NeuraLayer, NeuraPartialLayer, NeuraShape, NeuraTrainableLayer},
optimize::{NeuraOptimizerFinal, NeuraOptimizerTransient},
gradient_solver::{NeuraGradientSolverFinal, NeuraGradientSolverTransient},
};
mod construct;
@ -189,7 +189,7 @@ impl<Input: Clone> NeuraTrainableNetworkBase<Input> for () {
impl<
Input,
Layer: NeuraTrainableLayer<Input>,
Optimizer: NeuraOptimizerTransient<Layer::Output>,
Optimizer: NeuraGradientSolverTransient<Layer::Output>,
ChildNetwork: NeuraTrainableNetworkBase<Layer::Output>,
> NeuraTrainableNetwork<Input, Optimizer> for NeuraSequential<Layer, ChildNetwork>
where
@ -212,7 +212,7 @@ where
}
}
impl<Input: Clone, Optimizer: NeuraOptimizerFinal<Input>> NeuraTrainableNetwork<Input, Optimizer>
impl<Input: Clone, Optimizer: NeuraGradientSolverFinal<Input>> NeuraTrainableNetwork<Input, Optimizer>
for ()
{
fn traverse(

@ -1,5 +1,5 @@
use crate::{
algebra::NeuraVectorSpace, network::NeuraTrainableNetworkBase, optimize::NeuraOptimizer,
algebra::NeuraVectorSpace, network::NeuraTrainableNetworkBase, gradient_solver::NeuraGradientSolver,
};
#[non_exhaustive]
@ -73,7 +73,7 @@ impl NeuraBatchedTrainer {
Input: Clone,
Target: Clone,
Network: NeuraTrainableNetworkBase<Input>,
GradientSolver: NeuraOptimizer<Input, Target, Network>,
GradientSolver: NeuraGradientSolver<Input, Target, Network>,
Inputs: IntoIterator<Item = (Input, Target)>,
>(
&self,
@ -143,7 +143,7 @@ mod test {
layer::{dense::NeuraDenseLayer, NeuraLayer},
network::sequential::{NeuraSequential, NeuraSequentialTail},
neura_sequential,
optimize::NeuraBackprop,
gradient_solver::NeuraBackprop,
};
#[test]

Loading…
Cancel
Save