🎨 Clean up axis operators, move them to crate root

main
Shad Amethyst 2 years ago
parent fdc906c220
commit 99d0cb4408

@ -1,46 +1,64 @@
use std::borrow::Borrow; use std::borrow::Borrow;
use std::fmt::Debug;
use nalgebra::{Const, DVector, Dyn, Scalar, VecStorage}; use nalgebra::{Const, DVector, Dyn, Scalar, VecStorage};
use crate::{err::NeuraAxisErr, prelude::NeuraShape}; use crate::err::NeuraAxisErr;
use crate::prelude::NeuraShape;
// TODO: create a NeuraAxis trait pub trait NeuraAxisBase: Clone + Debug + 'static {
type Err: Debug;
#[derive(Clone, Copy, Debug)] fn shape(&self, input_shapes: &[NeuraShape]) -> Result<NeuraShape, Self::Err>;
pub struct NeuraAxisAppend; }
pub trait NeuraCombineInputs<T> { /// Axis operators take in a set of inputs, and combine them together into one output,
type Combined; /// which is then usually fed to a layer.
pub trait NeuraAxis<Input>: NeuraAxisBase {
type Combined: 'static;
fn combine(&self, inputs: Vec<impl Borrow<T>>) -> Self::Combined; fn combine(&self, inputs: &[impl Borrow<Input>]) -> Self::Combined;
}
pub trait NeuraSplitInputs<T>: NeuraCombineInputs<T> { fn split(&self, combined: &Self::Combined, input_shapes: &[NeuraShape]) -> Vec<Input>;
fn split(&self, combined: &Self::Combined, input_shapes: &[NeuraShape]) -> Vec<T>;
} }
impl<F: Clone> NeuraCombineInputs<DVector<F>> for NeuraAxisAppend { /// An axis operator that
type Combined = DVector<F>; #[derive(Clone, Debug)]
pub struct NeuraAxisDefault;
fn combine(&self, inputs: Vec<impl Borrow<DVector<F>>>) -> Self::Combined { impl NeuraAxisBase for NeuraAxisDefault {
assert!(inputs.len() > 0); type Err = NeuraAxisErr;
let mut res = Vec::with_capacity(inputs.iter().map(|vec| vec.borrow().len()).sum());
for input in inputs { fn shape(&self, inputs: &[NeuraShape]) -> Result<NeuraShape, NeuraAxisErr> {
for x in input.borrow().iter() { if inputs.len() != 1 {
res.push(x.clone()); Err(NeuraAxisErr::InvalidAmount(inputs.len(), 1, Some(1)))
} } else {
Ok(*inputs[0].borrow())
} }
}
}
DVector::from_data(VecStorage::new(Dyn(res.len()), Const as Const<1>, res)) impl<Data: Clone + 'static> NeuraAxis<Data> for NeuraAxisDefault {
type Combined = Data;
fn combine(&self, inputs: &[impl Borrow<Data>]) -> Self::Combined {
assert!(inputs.len() == 1);
inputs[0].borrow().clone()
}
fn split(&self, combined: &Self::Combined, _input_shapes: &[NeuraShape]) -> Vec<Data> {
vec![combined.clone()]
} }
} }
// TODO: use another trait for combining NeuraShape, or make it another member of the trait #[derive(Clone, Copy, Debug)]
impl NeuraCombineInputs<NeuraShape> for NeuraAxisAppend { pub struct NeuraAxisAppend;
type Combined = Result<NeuraShape, NeuraAxisErr>;
impl NeuraAxisBase for NeuraAxisAppend {
type Err = NeuraAxisErr;
fn combine(&self, inputs: Vec<impl Borrow<NeuraShape>>) -> Self::Combined { fn shape(&self, inputs: &[NeuraShape]) -> Result<NeuraShape, NeuraAxisErr> {
let mut inputs = inputs.into_iter().map(|x| *x.borrow()); let mut inputs = inputs.into_iter().map(|x| *x.borrow());
if let Some(mut res) = inputs.next() { if let Some(mut res) = inputs.next() {
for operand in inputs { for operand in inputs {
@ -60,7 +78,22 @@ impl NeuraCombineInputs<NeuraShape> for NeuraAxisAppend {
} }
} }
impl<F: Clone + Scalar + Default> NeuraSplitInputs<DVector<F>> for NeuraAxisAppend { impl<F: Clone + Default + Scalar> NeuraAxis<DVector<F>> for NeuraAxisAppend {
type Combined = DVector<F>;
fn combine(&self, inputs: &[impl Borrow<DVector<F>>]) -> Self::Combined {
assert!(inputs.len() > 0);
let mut res = Vec::with_capacity(inputs.iter().map(|vec| vec.borrow().len()).sum());
for input in inputs {
for x in input.borrow().iter() {
res.push(x.clone());
}
}
DVector::from_data(VecStorage::new(Dyn(res.len()), Const as Const<1>, res))
}
fn split(&self, combined: &Self::Combined, input_shapes: &[NeuraShape]) -> Vec<DVector<F>> { fn split(&self, combined: &Self::Combined, input_shapes: &[NeuraShape]) -> Vec<DVector<F>> {
let mut result = Vec::with_capacity(input_shapes.len()); let mut result = Vec::with_capacity(input_shapes.len());
let mut offset = 0; let mut offset = 0;
@ -83,37 +116,3 @@ impl<F: Clone + Scalar + Default> NeuraSplitInputs<DVector<F>> for NeuraAxisAppe
result result
} }
} }
#[derive(Clone, Debug)]
pub struct NeuraAxisDefault;
impl<F: Clone> NeuraCombineInputs<DVector<F>> for NeuraAxisDefault {
type Combined = DVector<F>;
fn combine(&self, inputs: Vec<impl Borrow<DVector<F>>>) -> Self::Combined {
assert!(inputs.len() == 1);
inputs[0].borrow().clone()
}
}
impl NeuraCombineInputs<NeuraShape> for NeuraAxisDefault {
type Combined = Result<NeuraShape, NeuraAxisErr>;
fn combine(&self, inputs: Vec<impl Borrow<NeuraShape>>) -> Self::Combined {
if inputs.len() != 1 {
Err(NeuraAxisErr::InvalidAmount(inputs.len(), 1, Some(1)))
} else {
Ok(*inputs[0].borrow())
}
}
}
impl<Data: Clone> NeuraSplitInputs<Data> for NeuraAxisDefault
where
NeuraAxisDefault: NeuraCombineInputs<Data, Combined = Data>,
{
fn split(&self, combined: &Self::Combined, _input_shapes: &[NeuraShape]) -> Vec<Data> {
vec![combined.clone()]
}
}

@ -75,10 +75,10 @@ pub enum NeuraAxisErr {
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub enum NeuraResidualConstructErr<LayerErr> { pub enum NeuraResidualConstructErr<LayerErr, AxisErr> {
Layer(LayerErr), Layer(LayerErr),
WrongConnection(isize), WrongConnection(isize),
AxisErr(NeuraAxisErr), AxisErr(AxisErr),
NoOutput, NoOutput,
} }

@ -1,4 +1,5 @@
pub mod algebra; pub mod algebra;
pub mod axis;
pub mod derivable; pub mod derivable;
pub mod err; pub mod err;
pub mod gradient_solver; pub mod gradient_solver;

@ -1,7 +1,4 @@
use crate::network::{ use crate::network::sequential::NeuraSequentialLast;
residual::{NeuraAxisDefault, NeuraSplitInputs},
sequential::NeuraSequentialLast,
};
use super::*; use super::*;
@ -32,7 +29,6 @@ impl<Data: Clone + 'static, Layer: NeuraLayer<Data, Output = Data>, ChildNetwork
FromSequential<NeuraSequential<Layer, ChildNetwork>, Data> for NeuraGraph<Data> FromSequential<NeuraSequential<Layer, ChildNetwork>, Data> for NeuraGraph<Data>
where where
NeuraGraph<Data>: FromSequential<ChildNetwork, Data>, NeuraGraph<Data>: FromSequential<ChildNetwork, Data>,
NeuraAxisDefault: NeuraSplitInputs<Data, Combined = Data>,
Layer::IntermediaryRepr: 'static, Layer::IntermediaryRepr: 'static,
{ {
fn from_sequential_rec( fn from_sequential_rec(

@ -13,11 +13,6 @@ pub use partial::NeuraGraphPartial;
mod from; mod from;
pub use from::FromSequential; pub use from::FromSequential;
#[deprecated]
pub trait NeuraTrainableLayerFull<Input>: NeuraLayer<Input> {}
impl<Input, T> NeuraTrainableLayerFull<Input> for T where T: NeuraLayer<Input> {}
#[derive(Debug)] #[derive(Debug)]
pub struct NeuraGraphNodeConstructed<Data> { pub struct NeuraGraphNodeConstructed<Data> {
node: Box<dyn NeuraGraphNodeEval<Data>>, node: Box<dyn NeuraGraphNodeEval<Data>>,
@ -202,7 +197,7 @@ impl<Data: Clone + std::fmt::Debug + 'static> NeuraLayer<Data> for NeuraGraph<Da
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::{err::NeuraGraphErr, network::residual::NeuraAxisAppend, utils::uniform_vector}; use crate::{axis::NeuraAxisAppend, err::NeuraGraphErr, utils::uniform_vector};
use super::*; use super::*;

@ -3,14 +3,12 @@ use std::{any::Any, fmt::Debug};
use crate::{ use crate::{
algebra::NeuraDynVectorSpace, algebra::NeuraDynVectorSpace,
err::NeuraAxisErr, axis::{NeuraAxis, NeuraAxisDefault},
network::residual::{NeuraAxisDefault, NeuraCombineInputs, NeuraSplitInputs},
prelude::{NeuraPartialLayer, NeuraShape}, prelude::{NeuraPartialLayer, NeuraShape},
}; };
use super::*; use super::*;
// TODO: split into two traits
pub trait NeuraGraphNodePartial<Data>: DynClone + Debug { pub trait NeuraGraphNodePartial<Data>: DynClone + Debug {
fn inputs<'a>(&'a self) -> &'a [String]; fn inputs<'a>(&'a self) -> &'a [String];
fn name<'a>(&'a self) -> &'a str; fn name<'a>(&'a self) -> &'a str;
@ -73,19 +71,11 @@ impl<Axis, Layer> NeuraGraphNode<Axis, Layer> {
pub fn as_boxed<Data: Clone>(self) -> Box<dyn NeuraGraphNodePartial<Data>> pub fn as_boxed<Data: Clone>(self) -> Box<dyn NeuraGraphNodePartial<Data>>
where where
Axis: NeuraSplitInputs<Data> Axis: NeuraAxis<Data>,
+ NeuraCombineInputs<NeuraShape, Combined = Result<NeuraShape, NeuraAxisErr>>
+ Clone
+ Debug
+ 'static,
Layer: NeuraPartialLayer + Clone + Debug + 'static, Layer: NeuraPartialLayer + Clone + Debug + 'static,
Layer::Constructed: Layer::Constructed: NeuraLayer<Axis::Combined, Output = Data>,
NeuraLayer<<Axis as NeuraCombineInputs<Data>>::Combined, Output = Data>,
Layer::Err: Debug, Layer::Err: Debug,
<Layer::Constructed as NeuraLayer< <Layer::Constructed as NeuraLayer<Axis::Combined>>::IntermediaryRepr: 'static,
<Axis as NeuraCombineInputs<Data>>::Combined,
>>::IntermediaryRepr: 'static,
<Axis as NeuraCombineInputs<Data>>::Combined: 'static,
{ {
Box::new(self) Box::new(self)
} }
@ -95,9 +85,8 @@ impl<Axis, Layer> NeuraGraphNode<Axis, Layer> {
intermediary: &'a dyn Any, intermediary: &'a dyn Any,
) -> &'a Intermediary<Axis::Combined, Layer> ) -> &'a Intermediary<Axis::Combined, Layer>
where where
Axis: NeuraCombineInputs<Data>, Axis: NeuraAxis<Data>,
Layer: NeuraLayer<Axis::Combined>, Layer: NeuraLayer<Axis::Combined>,
Axis::Combined: 'static,
{ {
intermediary intermediary
.downcast_ref::<Intermediary<Axis::Combined, Layer>>() .downcast_ref::<Intermediary<Axis::Combined, Layer>>()
@ -113,23 +102,16 @@ where
layer_intermediary: Layer::IntermediaryRepr, layer_intermediary: Layer::IntermediaryRepr,
} }
impl< impl<Data: Clone, Axis: NeuraAxis<Data>, Layer: NeuraLayer<Axis::Combined, Output = Data>>
Data: Clone, NeuraGraphNodeEval<Data> for NeuraGraphNode<Axis, Layer>
Axis: NeuraSplitInputs<Data> + Clone + Debug,
Layer: NeuraLayer<<Axis as NeuraCombineInputs<Data>>::Combined, Output = Data>,
> NeuraGraphNodeEval<Data> for NeuraGraphNode<Axis, Layer>
where
Layer::IntermediaryRepr: 'static,
Axis::Combined: 'static,
{ {
fn eval<'a>(&'a self, inputs: &[Data]) -> Data { fn eval<'a>(&'a self, inputs: &[Data]) -> Data {
// TODO: use to_vec_in? let combined = self.axis.combine(inputs);
let combined = self.axis.combine(inputs.to_vec());
self.layer.eval(&combined) self.layer.eval(&combined)
} }
fn eval_training<'a>(&self, inputs: &[Data]) -> (Data, Box<dyn Any>) { fn eval_training<'a>(&self, inputs: &[Data]) -> (Data, Box<dyn Any>) {
let combined = self.axis.combine(inputs.to_vec()); let combined = self.axis.combine(inputs);
let (result, layer_intermediary) = self.layer.eval_training(&combined); let (result, layer_intermediary) = self.layer.eval_training(&combined);
let intermediary: Intermediary<Axis::Combined, Layer> = Intermediary { let intermediary: Intermediary<Axis::Combined, Layer> = Intermediary {
@ -177,20 +159,11 @@ where
} }
} }
impl< impl<Data: Clone, Axis: NeuraAxis<Data>, Layer: NeuraPartialLayer + Clone + Debug>
Data: Clone, NeuraGraphNodePartial<Data> for NeuraGraphNode<Axis, Layer>
Axis: NeuraSplitInputs<Data>
+ NeuraCombineInputs<NeuraShape, Combined = Result<NeuraShape, NeuraAxisErr>>
+ Clone
+ Debug
+ 'static,
Layer: NeuraPartialLayer + Clone + Debug,
> NeuraGraphNodePartial<Data> for NeuraGraphNode<Axis, Layer>
where where
Layer::Constructed: NeuraLayer<<Axis as NeuraCombineInputs<Data>>::Combined, Output = Data>, Layer::Constructed: NeuraLayer<Axis::Combined, Output = Data>,
Layer::Err: Debug, Layer::Err: Debug,
<Layer::Constructed as NeuraLayer<<Axis as NeuraCombineInputs<Data>>::Combined>>::IntermediaryRepr: 'static,
<Axis as NeuraCombineInputs<Data>>::Combined: 'static,
{ {
fn inputs<'a>(&'a self) -> &'a [String] { fn inputs<'a>(&'a self) -> &'a [String] {
&self.inputs &self.inputs
@ -206,7 +179,7 @@ where
) -> Result<(Box<dyn NeuraGraphNodeEval<Data>>, NeuraShape), String> { ) -> Result<(Box<dyn NeuraGraphNodeEval<Data>>, NeuraShape), String> {
let combined = self let combined = self
.axis .axis
.combine(input_shapes.clone()) .shape(&input_shapes)
.map_err(|err| format!("{:?}", err))?; .map_err(|err| format!("{:?}", err))?;
let constructed_layer = self let constructed_layer = self
@ -222,7 +195,7 @@ where
axis: self.axis.clone(), axis: self.axis.clone(),
layer: constructed_layer, layer: constructed_layer,
name: self.name.clone(), name: self.name.clone(),
input_shapes: Some(input_shapes) input_shapes: Some(input_shapes),
}), }),
output_shape, output_shape,
)) ))

@ -1,4 +1,4 @@
use crate::err::*; use crate::{axis::NeuraAxisBase, err::*};
use super::*; use super::*;
use NeuraResidualConstructErr::*; use NeuraResidualConstructErr::*;
@ -15,13 +15,12 @@ pub trait NeuraResidualConstruct {
) -> Result<Self::Constructed, Self::Err>; ) -> Result<Self::Constructed, Self::Err>;
} }
impl<Layer: NeuraPartialLayer, ChildNetwork: NeuraResidualConstruct, Axis> NeuraResidualConstruct impl<Layer: NeuraPartialLayer, ChildNetwork: NeuraResidualConstruct, Axis: NeuraAxisBase>
for NeuraResidualNode<Layer, ChildNetwork, Axis> NeuraResidualConstruct for NeuraResidualNode<Layer, ChildNetwork, Axis>
where
Axis: NeuraCombineInputs<NeuraShape, Combined = Result<NeuraShape, NeuraAxisErr>>,
{ {
type Constructed = NeuraResidualNode<Layer::Constructed, ChildNetwork::Constructed, Axis>; type Constructed = NeuraResidualNode<Layer::Constructed, ChildNetwork::Constructed, Axis>;
type Err = NeuraRecursiveErr<NeuraResidualConstructErr<Layer::Err>, ChildNetwork::Err>; type Err =
NeuraRecursiveErr<NeuraResidualConstructErr<Layer::Err, Axis::Err>, ChildNetwork::Err>;
fn construct_residual( fn construct_residual(
self, self,
@ -36,7 +35,7 @@ where
let layer_input_shape = self let layer_input_shape = self
.axis .axis
.combine(input_shapes) .shape(&input_shapes.iter().map(|shape| **shape).collect::<Vec<_>>())
.map_err(|e| NeuraRecursiveErr::Current(AxisErr(e)))?; .map_err(|e| NeuraRecursiveErr::Current(AxisErr(e)))?;
let layer = self let layer = self

@ -31,7 +31,7 @@ impl Default for NeuraResidualLast {
impl NeuraResidualConstruct for NeuraResidualLast { impl NeuraResidualConstruct for NeuraResidualLast {
type Constructed = NeuraResidualLast; type Constructed = NeuraResidualLast;
type Err = NeuraRecursiveErr<NeuraResidualConstructErr<()>, ()>; type Err = NeuraRecursiveErr<NeuraResidualConstructErr<(), NeuraAxisErr>, ()>;
fn construct_residual( fn construct_residual(
self, self,

@ -8,9 +8,6 @@ pub use wrapper::*;
mod input; mod input;
pub use input::*; pub use input::*;
mod axis;
pub use axis::*;
mod construct; mod construct;
pub use construct::NeuraResidualConstruct; pub use construct::NeuraResidualConstruct;

@ -1,6 +1,6 @@
use std::borrow::Cow; use std::borrow::Cow;
use crate::network::*; use crate::{axis::*, network::*};
use super::*; use super::*;
@ -68,12 +68,12 @@ impl<Layer, ChildNetwork, Axis> NeuraResidualNode<Layer, ChildNetwork, Axis> {
input: &NeuraResidualInput<Data>, input: &NeuraResidualInput<Data>,
) -> (Axis::Combined, NeuraResidualInput<Data>) ) -> (Axis::Combined, NeuraResidualInput<Data>)
where where
Axis: NeuraCombineInputs<Data>, Axis: NeuraAxis<Data>,
Layer: NeuraLayer<Axis::Combined>, Layer: NeuraLayer<Axis::Combined>,
{ {
let (inputs, rest) = input.shift(); let (inputs, rest) = input.shift();
let layer_input = self.axis.combine(inputs); let layer_input = self.axis.combine(&inputs);
(layer_input, rest) (layer_input, rest)
} }
@ -94,9 +94,9 @@ impl<Layer, ChildNetwork, Axis> NeuraResidualNode<Layer, ChildNetwork, Axis> {
pub(crate) fn map_input_owned<Data>(&self, input: &NeuraResidualInput<Data>) -> Axis::Combined pub(crate) fn map_input_owned<Data>(&self, input: &NeuraResidualInput<Data>) -> Axis::Combined
where where
Axis: NeuraCombineInputs<Data>, Axis: NeuraAxis<Data>,
{ {
self.axis.combine(input.shift().0) self.axis.combine(&input.shift().0)
} }
} }
@ -148,7 +148,7 @@ impl<
impl<Data: Clone + 'static, Layer, ChildNetwork, Axis: Clone + std::fmt::Debug + 'static> impl<Data: Clone + 'static, Layer, ChildNetwork, Axis: Clone + std::fmt::Debug + 'static>
NeuraLayer<NeuraResidualInput<Data>> for NeuraResidualNode<Layer, ChildNetwork, Axis> NeuraLayer<NeuraResidualInput<Data>> for NeuraResidualNode<Layer, ChildNetwork, Axis>
where where
Axis: NeuraCombineInputs<Data>, Axis: NeuraAxis<Data>,
Layer: NeuraLayer<Axis::Combined, Output = Data>, Layer: NeuraLayer<Axis::Combined, Output = Data>,
ChildNetwork: NeuraLayer<NeuraResidualInput<Data>>, ChildNetwork: NeuraLayer<NeuraResidualInput<Data>>,
{ {
@ -239,7 +239,7 @@ impl<
impl< impl<
Data: Clone + std::fmt::Debug, Data: Clone + std::fmt::Debug,
Axis: NeuraCombineInputs<Data> + NeuraSplitInputs<Data>, Axis: NeuraAxis<Data>,
Layer: NeuraLayer<Axis::Combined, Output = Data> + std::fmt::Debug, Layer: NeuraLayer<Axis::Combined, Output = Data> + std::fmt::Debug,
ChildNetwork, ChildNetwork,
> NeuraNetwork<NeuraResidualInput<Data>> for NeuraResidualNode<Layer, ChildNetwork, Axis> > NeuraNetwork<NeuraResidualInput<Data>> for NeuraResidualNode<Layer, ChildNetwork, Axis>

Loading…
Cancel
Save