From 99d0cb44081615a1e17e160d0197f12436685653 Mon Sep 17 00:00:00 2001 From: Adrien Burgun Date: Tue, 9 May 2023 12:17:40 +0200 Subject: [PATCH] :art: Clean up axis operators, move them to crate root --- src/{network/residual => }/axis.rs | 117 ++++++++++++++--------------- src/err.rs | 4 +- src/lib.rs | 1 + src/network/graph/from.rs | 6 +- src/network/graph/mod.rs | 7 +- src/network/graph/node.rs | 55 ++++---------- src/network/residual/construct.rs | 13 ++-- src/network/residual/last.rs | 2 +- src/network/residual/mod.rs | 3 - src/network/residual/node.rs | 14 ++-- 10 files changed, 91 insertions(+), 131 deletions(-) rename src/{network/residual => }/axis.rs (64%) diff --git a/src/network/residual/axis.rs b/src/axis.rs similarity index 64% rename from src/network/residual/axis.rs rename to src/axis.rs index 9e34c8d..a264a40 100644 --- a/src/network/residual/axis.rs +++ b/src/axis.rs @@ -1,46 +1,64 @@ use std::borrow::Borrow; +use std::fmt::Debug; use nalgebra::{Const, DVector, Dyn, Scalar, VecStorage}; -use crate::{err::NeuraAxisErr, prelude::NeuraShape}; +use crate::err::NeuraAxisErr; +use crate::prelude::NeuraShape; -// TODO: create a NeuraAxis trait +pub trait NeuraAxisBase: Clone + Debug + 'static { + type Err: Debug; -#[derive(Clone, Copy, Debug)] -pub struct NeuraAxisAppend; + fn shape(&self, input_shapes: &[NeuraShape]) -> Result; +} -pub trait NeuraCombineInputs { - type Combined; +/// Axis operators take in a set of inputs, and combine them together into one output, +/// which is then usually fed to a layer. +pub trait NeuraAxis: NeuraAxisBase { + type Combined: 'static; - fn combine(&self, inputs: Vec>) -> Self::Combined; -} + fn combine(&self, inputs: &[impl Borrow]) -> Self::Combined; -pub trait NeuraSplitInputs: NeuraCombineInputs { - fn split(&self, combined: &Self::Combined, input_shapes: &[NeuraShape]) -> Vec; + fn split(&self, combined: &Self::Combined, input_shapes: &[NeuraShape]) -> Vec; } -impl NeuraCombineInputs> for NeuraAxisAppend { - type Combined = DVector; +/// An axis operator that +#[derive(Clone, Debug)] +pub struct NeuraAxisDefault; - fn combine(&self, inputs: Vec>>) -> Self::Combined { - assert!(inputs.len() > 0); - let mut res = Vec::with_capacity(inputs.iter().map(|vec| vec.borrow().len()).sum()); +impl NeuraAxisBase for NeuraAxisDefault { + type Err = NeuraAxisErr; - for input in inputs { - for x in input.borrow().iter() { - res.push(x.clone()); - } + fn shape(&self, inputs: &[NeuraShape]) -> Result { + if inputs.len() != 1 { + Err(NeuraAxisErr::InvalidAmount(inputs.len(), 1, Some(1))) + } else { + Ok(*inputs[0].borrow()) } + } +} - DVector::from_data(VecStorage::new(Dyn(res.len()), Const as Const<1>, res)) +impl NeuraAxis for NeuraAxisDefault { + type Combined = Data; + + fn combine(&self, inputs: &[impl Borrow]) -> Self::Combined { + assert!(inputs.len() == 1); + + inputs[0].borrow().clone() + } + + fn split(&self, combined: &Self::Combined, _input_shapes: &[NeuraShape]) -> Vec { + vec![combined.clone()] } } -// TODO: use another trait for combining NeuraShape, or make it another member of the trait -impl NeuraCombineInputs for NeuraAxisAppend { - type Combined = Result; +#[derive(Clone, Copy, Debug)] +pub struct NeuraAxisAppend; + +impl NeuraAxisBase for NeuraAxisAppend { + type Err = NeuraAxisErr; - fn combine(&self, inputs: Vec>) -> Self::Combined { + fn shape(&self, inputs: &[NeuraShape]) -> Result { let mut inputs = inputs.into_iter().map(|x| *x.borrow()); if let Some(mut res) = inputs.next() { for operand in inputs { @@ -60,7 +78,22 @@ impl NeuraCombineInputs for NeuraAxisAppend { } } -impl NeuraSplitInputs> for NeuraAxisAppend { +impl NeuraAxis> for NeuraAxisAppend { + type Combined = DVector; + + fn combine(&self, inputs: &[impl Borrow>]) -> Self::Combined { + assert!(inputs.len() > 0); + let mut res = Vec::with_capacity(inputs.iter().map(|vec| vec.borrow().len()).sum()); + + for input in inputs { + for x in input.borrow().iter() { + res.push(x.clone()); + } + } + + DVector::from_data(VecStorage::new(Dyn(res.len()), Const as Const<1>, res)) + } + fn split(&self, combined: &Self::Combined, input_shapes: &[NeuraShape]) -> Vec> { let mut result = Vec::with_capacity(input_shapes.len()); let mut offset = 0; @@ -83,37 +116,3 @@ impl NeuraSplitInputs> for NeuraAxisAppe result } } - -#[derive(Clone, Debug)] -pub struct NeuraAxisDefault; - -impl NeuraCombineInputs> for NeuraAxisDefault { - type Combined = DVector; - - fn combine(&self, inputs: Vec>>) -> Self::Combined { - assert!(inputs.len() == 1); - - inputs[0].borrow().clone() - } -} - -impl NeuraCombineInputs for NeuraAxisDefault { - type Combined = Result; - - fn combine(&self, inputs: Vec>) -> Self::Combined { - if inputs.len() != 1 { - Err(NeuraAxisErr::InvalidAmount(inputs.len(), 1, Some(1))) - } else { - Ok(*inputs[0].borrow()) - } - } -} - -impl NeuraSplitInputs for NeuraAxisDefault -where - NeuraAxisDefault: NeuraCombineInputs, -{ - fn split(&self, combined: &Self::Combined, _input_shapes: &[NeuraShape]) -> Vec { - vec![combined.clone()] - } -} diff --git a/src/err.rs b/src/err.rs index 91a204c..611b630 100644 --- a/src/err.rs +++ b/src/err.rs @@ -75,10 +75,10 @@ pub enum NeuraAxisErr { } #[derive(Clone, Debug)] -pub enum NeuraResidualConstructErr { +pub enum NeuraResidualConstructErr { Layer(LayerErr), WrongConnection(isize), - AxisErr(NeuraAxisErr), + AxisErr(AxisErr), NoOutput, } diff --git a/src/lib.rs b/src/lib.rs index 8b5885d..bf94f7d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ pub mod algebra; +pub mod axis; pub mod derivable; pub mod err; pub mod gradient_solver; diff --git a/src/network/graph/from.rs b/src/network/graph/from.rs index 4802590..2e2f05a 100644 --- a/src/network/graph/from.rs +++ b/src/network/graph/from.rs @@ -1,7 +1,4 @@ -use crate::network::{ - residual::{NeuraAxisDefault, NeuraSplitInputs}, - sequential::NeuraSequentialLast, -}; +use crate::network::sequential::NeuraSequentialLast; use super::*; @@ -32,7 +29,6 @@ impl, ChildNetwork FromSequential, Data> for NeuraGraph where NeuraGraph: FromSequential, - NeuraAxisDefault: NeuraSplitInputs, Layer::IntermediaryRepr: 'static, { fn from_sequential_rec( diff --git a/src/network/graph/mod.rs b/src/network/graph/mod.rs index 3713fd9..f307285 100644 --- a/src/network/graph/mod.rs +++ b/src/network/graph/mod.rs @@ -13,11 +13,6 @@ pub use partial::NeuraGraphPartial; mod from; pub use from::FromSequential; -#[deprecated] -pub trait NeuraTrainableLayerFull: NeuraLayer {} - -impl NeuraTrainableLayerFull for T where T: NeuraLayer {} - #[derive(Debug)] pub struct NeuraGraphNodeConstructed { node: Box>, @@ -202,7 +197,7 @@ impl NeuraLayer for NeuraGraph: DynClone + Debug { fn inputs<'a>(&'a self) -> &'a [String]; fn name<'a>(&'a self) -> &'a str; @@ -73,19 +71,11 @@ impl NeuraGraphNode { pub fn as_boxed(self) -> Box> where - Axis: NeuraSplitInputs - + NeuraCombineInputs> - + Clone - + Debug - + 'static, + Axis: NeuraAxis, Layer: NeuraPartialLayer + Clone + Debug + 'static, - Layer::Constructed: - NeuraLayer<>::Combined, Output = Data>, + Layer::Constructed: NeuraLayer, Layer::Err: Debug, - >::Combined, - >>::IntermediaryRepr: 'static, - >::Combined: 'static, + >::IntermediaryRepr: 'static, { Box::new(self) } @@ -95,9 +85,8 @@ impl NeuraGraphNode { intermediary: &'a dyn Any, ) -> &'a Intermediary where - Axis: NeuraCombineInputs, + Axis: NeuraAxis, Layer: NeuraLayer, - Axis::Combined: 'static, { intermediary .downcast_ref::>() @@ -113,23 +102,16 @@ where layer_intermediary: Layer::IntermediaryRepr, } -impl< - Data: Clone, - Axis: NeuraSplitInputs + Clone + Debug, - Layer: NeuraLayer<>::Combined, Output = Data>, - > NeuraGraphNodeEval for NeuraGraphNode -where - Layer::IntermediaryRepr: 'static, - Axis::Combined: 'static, +impl, Layer: NeuraLayer> + NeuraGraphNodeEval for NeuraGraphNode { fn eval<'a>(&'a self, inputs: &[Data]) -> Data { - // TODO: use to_vec_in? - let combined = self.axis.combine(inputs.to_vec()); + let combined = self.axis.combine(inputs); self.layer.eval(&combined) } fn eval_training<'a>(&self, inputs: &[Data]) -> (Data, Box) { - let combined = self.axis.combine(inputs.to_vec()); + let combined = self.axis.combine(inputs); let (result, layer_intermediary) = self.layer.eval_training(&combined); let intermediary: Intermediary = Intermediary { @@ -177,20 +159,11 @@ where } } -impl< - Data: Clone, - Axis: NeuraSplitInputs - + NeuraCombineInputs> - + Clone - + Debug - + 'static, - Layer: NeuraPartialLayer + Clone + Debug, - > NeuraGraphNodePartial for NeuraGraphNode +impl, Layer: NeuraPartialLayer + Clone + Debug> + NeuraGraphNodePartial for NeuraGraphNode where - Layer::Constructed: NeuraLayer<>::Combined, Output = Data>, + Layer::Constructed: NeuraLayer, Layer::Err: Debug, - >::Combined>>::IntermediaryRepr: 'static, - >::Combined: 'static, { fn inputs<'a>(&'a self) -> &'a [String] { &self.inputs @@ -206,7 +179,7 @@ where ) -> Result<(Box>, NeuraShape), String> { let combined = self .axis - .combine(input_shapes.clone()) + .shape(&input_shapes) .map_err(|err| format!("{:?}", err))?; let constructed_layer = self @@ -222,7 +195,7 @@ where axis: self.axis.clone(), layer: constructed_layer, name: self.name.clone(), - input_shapes: Some(input_shapes) + input_shapes: Some(input_shapes), }), output_shape, )) diff --git a/src/network/residual/construct.rs b/src/network/residual/construct.rs index 5c1159c..6a7f305 100644 --- a/src/network/residual/construct.rs +++ b/src/network/residual/construct.rs @@ -1,4 +1,4 @@ -use crate::err::*; +use crate::{axis::NeuraAxisBase, err::*}; use super::*; use NeuraResidualConstructErr::*; @@ -15,13 +15,12 @@ pub trait NeuraResidualConstruct { ) -> Result; } -impl NeuraResidualConstruct - for NeuraResidualNode -where - Axis: NeuraCombineInputs>, +impl + NeuraResidualConstruct for NeuraResidualNode { type Constructed = NeuraResidualNode; - type Err = NeuraRecursiveErr, ChildNetwork::Err>; + type Err = + NeuraRecursiveErr, ChildNetwork::Err>; fn construct_residual( self, @@ -36,7 +35,7 @@ where let layer_input_shape = self .axis - .combine(input_shapes) + .shape(&input_shapes.iter().map(|shape| **shape).collect::>()) .map_err(|e| NeuraRecursiveErr::Current(AxisErr(e)))?; let layer = self diff --git a/src/network/residual/last.rs b/src/network/residual/last.rs index 85619a3..6347138 100644 --- a/src/network/residual/last.rs +++ b/src/network/residual/last.rs @@ -31,7 +31,7 @@ impl Default for NeuraResidualLast { impl NeuraResidualConstruct for NeuraResidualLast { type Constructed = NeuraResidualLast; - type Err = NeuraRecursiveErr, ()>; + type Err = NeuraRecursiveErr, ()>; fn construct_residual( self, diff --git a/src/network/residual/mod.rs b/src/network/residual/mod.rs index f88ed24..0e26a4f 100644 --- a/src/network/residual/mod.rs +++ b/src/network/residual/mod.rs @@ -8,9 +8,6 @@ pub use wrapper::*; mod input; pub use input::*; -mod axis; -pub use axis::*; - mod construct; pub use construct::NeuraResidualConstruct; diff --git a/src/network/residual/node.rs b/src/network/residual/node.rs index cfc99aa..4c7edd7 100644 --- a/src/network/residual/node.rs +++ b/src/network/residual/node.rs @@ -1,6 +1,6 @@ use std::borrow::Cow; -use crate::network::*; +use crate::{axis::*, network::*}; use super::*; @@ -68,12 +68,12 @@ impl NeuraResidualNode { input: &NeuraResidualInput, ) -> (Axis::Combined, NeuraResidualInput) where - Axis: NeuraCombineInputs, + Axis: NeuraAxis, Layer: NeuraLayer, { let (inputs, rest) = input.shift(); - let layer_input = self.axis.combine(inputs); + let layer_input = self.axis.combine(&inputs); (layer_input, rest) } @@ -94,9 +94,9 @@ impl NeuraResidualNode { pub(crate) fn map_input_owned(&self, input: &NeuraResidualInput) -> Axis::Combined where - Axis: NeuraCombineInputs, + Axis: NeuraAxis, { - self.axis.combine(input.shift().0) + self.axis.combine(&input.shift().0) } } @@ -148,7 +148,7 @@ impl< impl NeuraLayer> for NeuraResidualNode where - Axis: NeuraCombineInputs, + Axis: NeuraAxis, Layer: NeuraLayer, ChildNetwork: NeuraLayer>, { @@ -239,7 +239,7 @@ impl< impl< Data: Clone + std::fmt::Debug, - Axis: NeuraCombineInputs + NeuraSplitInputs, + Axis: NeuraAxis, Layer: NeuraLayer + std::fmt::Debug, ChildNetwork, > NeuraNetwork> for NeuraResidualNode