|
|
|
@ -1,46 +1,64 @@
|
|
|
|
|
use std::borrow::Borrow;
|
|
|
|
|
use std::fmt::Debug;
|
|
|
|
|
|
|
|
|
|
use nalgebra::{Const, DVector, Dyn, Scalar, VecStorage};
|
|
|
|
|
|
|
|
|
|
use crate::{err::NeuraAxisErr, prelude::NeuraShape};
|
|
|
|
|
use crate::err::NeuraAxisErr;
|
|
|
|
|
use crate::prelude::NeuraShape;
|
|
|
|
|
|
|
|
|
|
// TODO: create a NeuraAxis trait
|
|
|
|
|
pub trait NeuraAxisBase: Clone + Debug + 'static {
|
|
|
|
|
type Err: Debug;
|
|
|
|
|
|
|
|
|
|
#[derive(Clone, Copy, Debug)]
|
|
|
|
|
pub struct NeuraAxisAppend;
|
|
|
|
|
fn shape(&self, input_shapes: &[NeuraShape]) -> Result<NeuraShape, Self::Err>;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub trait NeuraCombineInputs<T> {
|
|
|
|
|
type Combined;
|
|
|
|
|
/// Axis operators take in a set of inputs, and combine them together into one output,
|
|
|
|
|
/// which is then usually fed to a layer.
|
|
|
|
|
pub trait NeuraAxis<Input>: NeuraAxisBase {
|
|
|
|
|
type Combined: 'static;
|
|
|
|
|
|
|
|
|
|
fn combine(&self, inputs: Vec<impl Borrow<T>>) -> Self::Combined;
|
|
|
|
|
}
|
|
|
|
|
fn combine(&self, inputs: &[impl Borrow<Input>]) -> Self::Combined;
|
|
|
|
|
|
|
|
|
|
pub trait NeuraSplitInputs<T>: NeuraCombineInputs<T> {
|
|
|
|
|
fn split(&self, combined: &Self::Combined, input_shapes: &[NeuraShape]) -> Vec<T>;
|
|
|
|
|
fn split(&self, combined: &Self::Combined, input_shapes: &[NeuraShape]) -> Vec<Input>;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl<F: Clone> NeuraCombineInputs<DVector<F>> for NeuraAxisAppend {
|
|
|
|
|
type Combined = DVector<F>;
|
|
|
|
|
/// An axis operator that
|
|
|
|
|
#[derive(Clone, Debug)]
|
|
|
|
|
pub struct NeuraAxisDefault;
|
|
|
|
|
|
|
|
|
|
fn combine(&self, inputs: Vec<impl Borrow<DVector<F>>>) -> Self::Combined {
|
|
|
|
|
assert!(inputs.len() > 0);
|
|
|
|
|
let mut res = Vec::with_capacity(inputs.iter().map(|vec| vec.borrow().len()).sum());
|
|
|
|
|
impl NeuraAxisBase for NeuraAxisDefault {
|
|
|
|
|
type Err = NeuraAxisErr;
|
|
|
|
|
|
|
|
|
|
for input in inputs {
|
|
|
|
|
for x in input.borrow().iter() {
|
|
|
|
|
res.push(x.clone());
|
|
|
|
|
}
|
|
|
|
|
fn shape(&self, inputs: &[NeuraShape]) -> Result<NeuraShape, NeuraAxisErr> {
|
|
|
|
|
if inputs.len() != 1 {
|
|
|
|
|
Err(NeuraAxisErr::InvalidAmount(inputs.len(), 1, Some(1)))
|
|
|
|
|
} else {
|
|
|
|
|
Ok(*inputs[0].borrow())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DVector::from_data(VecStorage::new(Dyn(res.len()), Const as Const<1>, res))
|
|
|
|
|
impl<Data: Clone + 'static> NeuraAxis<Data> for NeuraAxisDefault {
|
|
|
|
|
type Combined = Data;
|
|
|
|
|
|
|
|
|
|
fn combine(&self, inputs: &[impl Borrow<Data>]) -> Self::Combined {
|
|
|
|
|
assert!(inputs.len() == 1);
|
|
|
|
|
|
|
|
|
|
inputs[0].borrow().clone()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn split(&self, combined: &Self::Combined, _input_shapes: &[NeuraShape]) -> Vec<Data> {
|
|
|
|
|
vec![combined.clone()]
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: use another trait for combining NeuraShape, or make it another member of the trait
|
|
|
|
|
impl NeuraCombineInputs<NeuraShape> for NeuraAxisAppend {
|
|
|
|
|
type Combined = Result<NeuraShape, NeuraAxisErr>;
|
|
|
|
|
#[derive(Clone, Copy, Debug)]
|
|
|
|
|
pub struct NeuraAxisAppend;
|
|
|
|
|
|
|
|
|
|
impl NeuraAxisBase for NeuraAxisAppend {
|
|
|
|
|
type Err = NeuraAxisErr;
|
|
|
|
|
|
|
|
|
|
fn combine(&self, inputs: Vec<impl Borrow<NeuraShape>>) -> Self::Combined {
|
|
|
|
|
fn shape(&self, inputs: &[NeuraShape]) -> Result<NeuraShape, NeuraAxisErr> {
|
|
|
|
|
let mut inputs = inputs.into_iter().map(|x| *x.borrow());
|
|
|
|
|
if let Some(mut res) = inputs.next() {
|
|
|
|
|
for operand in inputs {
|
|
|
|
@ -60,7 +78,22 @@ impl NeuraCombineInputs<NeuraShape> for NeuraAxisAppend {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl<F: Clone + Scalar + Default> NeuraSplitInputs<DVector<F>> for NeuraAxisAppend {
|
|
|
|
|
impl<F: Clone + Default + Scalar> NeuraAxis<DVector<F>> for NeuraAxisAppend {
|
|
|
|
|
type Combined = DVector<F>;
|
|
|
|
|
|
|
|
|
|
fn combine(&self, inputs: &[impl Borrow<DVector<F>>]) -> Self::Combined {
|
|
|
|
|
assert!(inputs.len() > 0);
|
|
|
|
|
let mut res = Vec::with_capacity(inputs.iter().map(|vec| vec.borrow().len()).sum());
|
|
|
|
|
|
|
|
|
|
for input in inputs {
|
|
|
|
|
for x in input.borrow().iter() {
|
|
|
|
|
res.push(x.clone());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DVector::from_data(VecStorage::new(Dyn(res.len()), Const as Const<1>, res))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn split(&self, combined: &Self::Combined, input_shapes: &[NeuraShape]) -> Vec<DVector<F>> {
|
|
|
|
|
let mut result = Vec::with_capacity(input_shapes.len());
|
|
|
|
|
let mut offset = 0;
|
|
|
|
@ -83,37 +116,3 @@ impl<F: Clone + Scalar + Default> NeuraSplitInputs<DVector<F>> for NeuraAxisAppe
|
|
|
|
|
result
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Clone, Debug)]
|
|
|
|
|
pub struct NeuraAxisDefault;
|
|
|
|
|
|
|
|
|
|
impl<F: Clone> NeuraCombineInputs<DVector<F>> for NeuraAxisDefault {
|
|
|
|
|
type Combined = DVector<F>;
|
|
|
|
|
|
|
|
|
|
fn combine(&self, inputs: Vec<impl Borrow<DVector<F>>>) -> Self::Combined {
|
|
|
|
|
assert!(inputs.len() == 1);
|
|
|
|
|
|
|
|
|
|
inputs[0].borrow().clone()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl NeuraCombineInputs<NeuraShape> for NeuraAxisDefault {
|
|
|
|
|
type Combined = Result<NeuraShape, NeuraAxisErr>;
|
|
|
|
|
|
|
|
|
|
fn combine(&self, inputs: Vec<impl Borrow<NeuraShape>>) -> Self::Combined {
|
|
|
|
|
if inputs.len() != 1 {
|
|
|
|
|
Err(NeuraAxisErr::InvalidAmount(inputs.len(), 1, Some(1)))
|
|
|
|
|
} else {
|
|
|
|
|
Ok(*inputs[0].borrow())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl<Data: Clone> NeuraSplitInputs<Data> for NeuraAxisDefault
|
|
|
|
|
where
|
|
|
|
|
NeuraAxisDefault: NeuraCombineInputs<Data, Combined = Data>,
|
|
|
|
|
{
|
|
|
|
|
fn split(&self, combined: &Self::Combined, _input_shapes: &[NeuraShape]) -> Vec<Data> {
|
|
|
|
|
vec![combined.clone()]
|
|
|
|
|
}
|
|
|
|
|
}
|