Implement From<NeuraSequential> and NeuraLayer for NeuraGraph

main
Shad Amethyst 2 years ago
parent efaed91f83
commit 251e4d02d2

@ -71,6 +71,7 @@ pub enum NeuraIsolateLayerErr {
pub enum NeuraAxisErr { pub enum NeuraAxisErr {
NoInput, NoInput,
ConflictingShape(NeuraShape, NeuraShape), ConflictingShape(NeuraShape, NeuraShape),
InvalidAmount(usize, usize, Option<usize>),
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]

@ -0,0 +1,61 @@
use crate::network::residual::{NeuraAxisDefault, NeuraSplitInputs};
use super::*;
trait FromSequential<Seq, Data> {
fn from_sequential(
seq: &Seq,
nodes: Vec<NeuraGraphNodeConstructed<Data>>,
output_shape: NeuraShape,
) -> Self;
}
impl<Data> FromSequential<(), Data> for NeuraGraph<Data> {
fn from_sequential(
_seq: &(),
nodes: Vec<NeuraGraphNodeConstructed<Data>>,
output_shape: NeuraShape,
) -> Self {
Self {
output_index: nodes.len(),
buffer_size: nodes.len() + 1,
nodes: nodes,
output_shape,
}
}
}
impl<
Data: Clone,
Layer: NeuraLayer<Data, Output = Data> + Clone + std::fmt::Debug + 'static,
ChildNetwork,
> FromSequential<NeuraSequential<Layer, ChildNetwork>, Data> for NeuraGraph<Data>
where
NeuraGraph<Data>: FromSequential<ChildNetwork, Data>,
NeuraAxisDefault: NeuraSplitInputs<Data, Combined = Data>,
{
fn from_sequential(
seq: &NeuraSequential<Layer, ChildNetwork>,
mut nodes: Vec<NeuraGraphNodeConstructed<Data>>,
output_shape: NeuraShape,
) -> Self {
nodes.push(NeuraGraphNodeConstructed {
node: Box::new(NeuraGraphNode::from(seq.layer.clone())),
inputs: vec![nodes.len()],
output: nodes.len() + 1,
});
Self::from_sequential(&seq.child_network, nodes, output_shape)
}
}
impl<Data, Layer, ChildNetwork> From<NeuraSequential<Layer, ChildNetwork>> for NeuraGraph<Data>
where
NeuraGraph<Data>: FromSequential<NeuraSequential<Layer, ChildNetwork>, Data>,
NeuraSequential<Layer, ChildNetwork>: NeuraShapedLayer,
{
fn from(network: NeuraSequential<Layer, ChildNetwork>) -> Self {
let output_shape = network.output_shape();
Self::from_sequential(&network, vec![], output_shape)
}
}

@ -1,131 +1,12 @@
#![allow(dead_code)] // TODO: remove this use crate::{layer::NeuraShapedLayer, prelude::*};
use std::collections::{HashMap, HashSet, VecDeque};
use crate::prelude::*;
use crate::{err::NeuraGraphErr, layer::NeuraShapedLayer};
mod node; mod node;
pub use node::*; pub use node::*;
pub struct NeuraGraphPartial<Data> { mod partial;
pub nodes: Vec<Box<dyn NeuraGraphNodePartial<Data>>>, pub use partial::NeuraGraphPartial;
pub output: String,
pub input: String,
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
enum GraphIndex {
Input,
Node(usize),
}
impl<Data> NeuraGraphPartial<Data> {
fn get_index_map(&self) -> Result<HashMap<String, GraphIndex>, NeuraGraphErr> {
let mut result = HashMap::with_capacity(self.nodes.len());
result.insert(self.input.clone(), GraphIndex::Input);
for (index, node) in self.nodes.iter().enumerate() {
if result.contains_key(node.name()) {
return Err(NeuraGraphErr::InvalidName(node.name().to_string()));
}
result.insert(node.name().to_string(), GraphIndex::Node(index));
}
Ok(result)
}
fn get_reverse_graph(
&self,
index_map: &HashMap<String, GraphIndex>,
) -> Result<HashMap<GraphIndex, HashSet<GraphIndex>>, NeuraGraphErr> {
let mut result = HashMap::new();
result.insert(GraphIndex::Input, HashSet::new());
for i in 0..self.nodes.len() {
result.insert(GraphIndex::Node(i), HashSet::new());
}
for (index, node) in self.nodes.iter().enumerate() {
for input in node.inputs() {
let input_index = index_map
.get(input)
.copied()
.ok_or_else(|| NeuraGraphErr::MissingNode(input.clone()))?;
result
.get_mut(&input_index)
.expect("index_map returned invalid values")
.insert(GraphIndex::Node(index));
}
}
Ok(result)
}
fn get_node_order(
&self,
index_map: &HashMap<String, GraphIndex>,
reverse_graph: &HashMap<GraphIndex, HashSet<GraphIndex>>,
) -> Result<Vec<usize>, NeuraGraphErr> {
let mut result: Vec<usize> = Vec::new();
let mut closed: HashSet<GraphIndex> = HashSet::with_capacity(self.nodes.len());
let mut open = VecDeque::with_capacity(self.nodes.len());
open.push_front(GraphIndex::Input);
/*
index_map.get(&self.output)
.copied()
.ok_or_else(|| NeuraGraphErr::MissingNode(self.output.clone()))?
*/
while let Some(current) = open.pop_back() {
if closed.contains(&current) {
continue;
}
closed.insert(current);
// Do not put 0 (the input) in result
if let GraphIndex::Node(index) = current {
result.push(index);
}
println!("{:?}", current);
for next_node in reverse_graph[&current].iter().copied() {
// Ignore nodes that are already in the closed set
if closed.contains(&next_node) {
continue;
}
let GraphIndex::Node(node_index) = next_node else {
panic!("Unreachable: cannot have GraphIndex::Input as the output of a node");
};
let inputs = self.nodes[node_index].inputs();
// Only consider nodes whose inputs are in the closed set (meaning they would be ready to be evaluated)
if !inputs
.iter()
.all(|input| closed.contains(&index_map[input]))
{
continue;
}
open.push_front(next_node);
}
}
if result.len() != self.nodes.len() { mod from;
// TODO: verify that if result.len() != self.nodes.len(), then there is a cyclic subgraph
return Err(NeuraGraphErr::Cyclic);
}
Ok(result)
}
}
#[derive(Debug)] #[derive(Debug)]
struct NeuraGraphNodeConstructed<Data> { struct NeuraGraphNodeConstructed<Data> {
@ -143,7 +24,7 @@ pub struct NeuraGraph<Data> {
/// - `nodes[0].inputs = [0]` /// - `nodes[0].inputs = [0]`
nodes: Vec<NeuraGraphNodeConstructed<Data>>, nodes: Vec<NeuraGraphNodeConstructed<Data>>,
input_shape: NeuraShape, // input_shape: NeuraShape,
output_shape: NeuraShape, output_shape: NeuraShape,
output_index: usize, output_index: usize,
@ -156,82 +37,57 @@ impl<Data> NeuraShapedLayer for NeuraGraph<Data> {
} }
} }
impl<Data> NeuraPartialLayer for NeuraGraphPartial<Data> { impl<Data> NeuraGraph<Data> {
type Constructed = NeuraGraph<Data>; fn create_buffer(&self) -> Vec<Option<Data>> {
let mut res = Vec::with_capacity(self.buffer_size);
for _ in 0..self.buffer_size {
res.push(None);
}
type Err = NeuraGraphErr; res
}
fn construct(self, input_shape: NeuraShape) -> Result<Self::Constructed, Self::Err> { fn eval_in(&self, input: &Data, buffer: &mut Vec<Option<Data>>)
let index_map = self.get_index_map()?; where
let reverse_graph = self.get_reverse_graph(&index_map)?; Data: Clone,
{
buffer[0] = Some(input.clone());
// List out the nodes in their execution order for node in self.nodes.iter() {
let node_order = self.get_node_order(&index_map, &reverse_graph)?; // PERF: re-use the allocation for `inputs`, and `.take()` the elements only needed once?
let mut new_index_map: HashMap<String, usize> = HashMap::from_iter( let inputs: Vec<_> = node
node_order .inputs
.iter() .iter()
.map(|&i| (self.nodes[i].name().to_string(), i)), .map(|&i| {
); buffer[i]
new_index_map.insert(self.input.clone(), 0); .clone()
.expect("Unreachable: output of previous layer was not set")
// TODO: filter out the nodes that are not necessary for computing the result (BFS from the output node back to the inputs) })
// A temporary solution can be to trim the graph .collect();
let output_index = new_index_map let result = node.node.eval(&inputs);
.get(&self.output) buffer[node.output] = Some(result);
.copied()
.ok_or_else(|| NeuraGraphErr::MissingNode(self.output.clone()))?;
let mut nodes = Vec::with_capacity(self.nodes.len());
let mut shapes: Vec<Option<NeuraShape>> = vec![None; self.nodes.len() + 1];
shapes[0] = Some(input_shape);
for index in node_order.into_iter() {
let node = &*self.nodes[index];
let node_inputs = node.inputs();
let mut inputs = Vec::with_capacity(node_inputs.len());
let mut input_shapes = Vec::with_capacity(node_inputs.len());
for input in node_inputs {
let input_index = new_index_map.get(input).copied().expect(
"Unreachable: new_index_map should contain all nodes defined and all nodes should have existing nodes as input"
);
inputs.push(input_index);
input_shapes.push(shapes[input_index].expect(
"Unreachable: the order of execution should guarantee that all inputs have appeared before")
);
} }
}
}
let (constructed, output_shape) = node impl<Data: Clone> NeuraLayer<Data> for NeuraGraph<Data> {
.construct(input_shapes) type Output = Data;
.map_err(|e| NeuraGraphErr::LayerErr(e))?;
shapes[index] = Some(output_shape);
nodes.push(NeuraGraphNodeConstructed { fn eval(&self, input: &Data) -> Self::Output {
node: constructed, let mut buffer = self.create_buffer();
inputs,
output: new_index_map
.get(node.name())
.copied()
.unwrap_or_else(|| unreachable!()),
});
}
let output_shape = shapes[output_index].unwrap_or_else(|| unreachable!()); self.eval_in(input, &mut buffer);
Ok(NeuraGraph { buffer[self.output_index]
nodes, .take()
input_shape, .expect("Unreachable: output was not set")
output_shape,
output_index,
buffer_size: self.nodes.len() + 1,
})
} }
} }
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use crate::network::residual::NeuraAxisAppend; use crate::{err::NeuraGraphErr, network::residual::NeuraAxisAppend, utils::uniform_vector};
use super::*; use super::*;
@ -346,4 +202,27 @@ mod test {
NeuraGraphErr::MissingNode(String::from("missing")) NeuraGraphErr::MissingNode(String::from("missing"))
); );
} }
#[test]
fn test_eval_equal_sequential() {
let network = neura_sequential![
neura_layer!("dense", 4, f64),
neura_layer!("dense", 2, f64),
neura_layer!("softmax")
]
.construct(NeuraShape::Vector(3))
.unwrap();
let graph = NeuraGraph::from(network.clone());
for _ in 0..10 {
let input = uniform_vector(3);
let seq_result = network.eval(&input);
let graph_result = graph.eval(&input);
assert_eq!(seq_result.shape(), graph_result.shape());
approx::assert_relative_eq!(seq_result[0], graph_result[0]);
approx::assert_relative_eq!(seq_result[1], graph_result[1]);
}
}
} }

@ -1,14 +1,15 @@
use dyn_clone::DynClone; use dyn_clone::DynClone;
use std::fmt::Debug;
use crate::{ use crate::{
err::NeuraAxisErr, err::NeuraAxisErr,
layer::{NeuraLayer, NeuraShapedLayer}, layer::{NeuraLayer, NeuraShapedLayer},
network::residual::{NeuraCombineInputs, NeuraSplitInputs}, network::residual::{NeuraAxisDefault, NeuraCombineInputs, NeuraSplitInputs},
prelude::{NeuraPartialLayer, NeuraShape}, prelude::{NeuraPartialLayer, NeuraShape},
}; };
// TODO: split into two traits // TODO: split into two traits
pub trait NeuraGraphNodePartial<Data>: DynClone + std::fmt::Debug { pub trait NeuraGraphNodePartial<Data>: DynClone + Debug {
fn inputs<'a>(&'a self) -> &'a [String]; fn inputs<'a>(&'a self) -> &'a [String];
fn name<'a>(&'a self) -> &'a str; fn name<'a>(&'a self) -> &'a str;
@ -18,7 +19,7 @@ pub trait NeuraGraphNodePartial<Data>: DynClone + std::fmt::Debug {
) -> Result<(Box<dyn NeuraGraphNodeEval<Data>>, NeuraShape), String>; ) -> Result<(Box<dyn NeuraGraphNodeEval<Data>>, NeuraShape), String>;
} }
pub trait NeuraGraphNodeEval<Data>: DynClone + std::fmt::Debug { pub trait NeuraGraphNodeEval<Data>: DynClone + Debug {
fn eval<'a>(&'a self, inputs: &[Data]) -> Data; fn eval<'a>(&'a self, inputs: &[Data]) -> Data;
} }
@ -46,15 +47,15 @@ impl<Axis, Layer> NeuraGraphNode<Axis, Layer> {
Axis: NeuraSplitInputs<Data> Axis: NeuraSplitInputs<Data>
+ NeuraCombineInputs<NeuraShape, Combined = Result<NeuraShape, NeuraAxisErr>> + NeuraCombineInputs<NeuraShape, Combined = Result<NeuraShape, NeuraAxisErr>>
+ Clone + Clone
+ std::fmt::Debug + Debug
+ 'static, + 'static,
Layer: NeuraPartialLayer + Clone + std::fmt::Debug + 'static, Layer: NeuraPartialLayer + Clone + Debug + 'static,
Layer::Constructed: NeuraShapedLayer Layer::Constructed: NeuraShapedLayer
+ NeuraLayer<<Axis as NeuraCombineInputs<Data>>::Combined, Output = Data> + NeuraLayer<<Axis as NeuraCombineInputs<Data>>::Combined, Output = Data>
+ Clone + Clone
+ std::fmt::Debug + Debug
+ 'static, + 'static,
Layer::Err: std::fmt::Debug, Layer::Err: Debug,
{ {
Box::new(self) Box::new(self)
} }
@ -62,10 +63,8 @@ impl<Axis, Layer> NeuraGraphNode<Axis, Layer> {
impl< impl<
Data: Clone, Data: Clone,
Axis: NeuraSplitInputs<Data> + Clone + std::fmt::Debug, Axis: NeuraSplitInputs<Data> + Clone + Debug,
Layer: NeuraLayer<<Axis as NeuraCombineInputs<Data>>::Combined, Output = Data> Layer: NeuraLayer<<Axis as NeuraCombineInputs<Data>>::Combined, Output = Data> + Clone + Debug,
+ Clone
+ std::fmt::Debug,
> NeuraGraphNodeEval<Data> for NeuraGraphNode<Axis, Layer> > NeuraGraphNodeEval<Data> for NeuraGraphNode<Axis, Layer>
{ {
fn eval<'a>(&'a self, inputs: &[Data]) -> Data { fn eval<'a>(&'a self, inputs: &[Data]) -> Data {
@ -75,22 +74,33 @@ impl<
} }
} }
impl<Layer: Clone + Debug> From<Layer> for NeuraGraphNode<NeuraAxisDefault, Layer> {
fn from(layer: Layer) -> Self {
Self {
inputs: vec![],
axis: NeuraAxisDefault,
layer,
name: random_name(),
}
}
}
impl< impl<
Data: Clone, Data: Clone,
Axis: NeuraSplitInputs<Data> Axis: NeuraSplitInputs<Data>
+ NeuraCombineInputs<NeuraShape, Combined = Result<NeuraShape, NeuraAxisErr>> + NeuraCombineInputs<NeuraShape, Combined = Result<NeuraShape, NeuraAxisErr>>
+ Clone + Clone
+ std::fmt::Debug + Debug
+ 'static, + 'static,
Layer: NeuraPartialLayer + Clone + std::fmt::Debug, Layer: NeuraPartialLayer + Clone + Debug,
> NeuraGraphNodePartial<Data> for NeuraGraphNode<Axis, Layer> > NeuraGraphNodePartial<Data> for NeuraGraphNode<Axis, Layer>
where where
Layer::Constructed: NeuraShapedLayer Layer::Constructed: NeuraShapedLayer
+ NeuraLayer<<Axis as NeuraCombineInputs<Data>>::Combined, Output = Data> + NeuraLayer<<Axis as NeuraCombineInputs<Data>>::Combined, Output = Data>
+ Clone + Clone
+ std::fmt::Debug + Debug
+ 'static, + 'static,
Layer::Err: std::fmt::Debug, Layer::Err: Debug,
{ {
fn inputs<'a>(&'a self) -> &'a [String] { fn inputs<'a>(&'a self) -> &'a [String] {
&self.inputs &self.inputs
@ -127,3 +137,20 @@ where
)) ))
} }
} }
pub fn random_name() -> String {
use rand::Rng;
use std::fmt::Write;
let mut res = String::with_capacity(10);
write!(&mut res, "value_").unwrap();
let mut rng = rand::thread_rng();
for _ in 0..4 {
let ch = char::from_u32(rng.gen_range((b'a' as u32)..(b'z' as u32))).unwrap();
write!(&mut res, "{}", ch).unwrap();
}
res
}

@ -0,0 +1,196 @@
use crate::err::NeuraGraphErr;
use std::collections::{HashMap, HashSet, VecDeque};
use super::*;
pub struct NeuraGraphPartial<Data> {
pub nodes: Vec<Box<dyn NeuraGraphNodePartial<Data>>>,
pub output: String,
pub input: String,
}
#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
pub(crate) enum GraphIndex {
Input,
Node(usize),
}
impl<Data> NeuraGraphPartial<Data> {
pub(crate) fn get_index_map(&self) -> Result<HashMap<String, GraphIndex>, NeuraGraphErr> {
let mut result = HashMap::with_capacity(self.nodes.len());
result.insert(self.input.clone(), GraphIndex::Input);
for (index, node) in self.nodes.iter().enumerate() {
if result.contains_key(node.name()) {
return Err(NeuraGraphErr::InvalidName(node.name().to_string()));
}
result.insert(node.name().to_string(), GraphIndex::Node(index));
}
Ok(result)
}
pub(crate) fn get_reverse_graph(
&self,
index_map: &HashMap<String, GraphIndex>,
) -> Result<HashMap<GraphIndex, HashSet<GraphIndex>>, NeuraGraphErr> {
let mut result = HashMap::new();
result.insert(GraphIndex::Input, HashSet::new());
for i in 0..self.nodes.len() {
result.insert(GraphIndex::Node(i), HashSet::new());
}
for (index, node) in self.nodes.iter().enumerate() {
for input in node.inputs() {
let input_index = index_map
.get(input)
.copied()
.ok_or_else(|| NeuraGraphErr::MissingNode(input.clone()))?;
result
.get_mut(&input_index)
.expect("index_map returned invalid values")
.insert(GraphIndex::Node(index));
}
}
Ok(result)
}
pub(crate) fn get_node_order(
&self,
index_map: &HashMap<String, GraphIndex>,
reverse_graph: &HashMap<GraphIndex, HashSet<GraphIndex>>,
) -> Result<Vec<usize>, NeuraGraphErr> {
let mut result: Vec<usize> = Vec::new();
let mut closed: HashSet<GraphIndex> = HashSet::with_capacity(self.nodes.len());
let mut open = VecDeque::with_capacity(self.nodes.len());
open.push_front(GraphIndex::Input);
/*
index_map.get(&self.output)
.copied()
.ok_or_else(|| NeuraGraphErr::MissingNode(self.output.clone()))?
*/
while let Some(current) = open.pop_back() {
if closed.contains(&current) {
continue;
}
closed.insert(current);
// Do not put 0 (the input) in result
if let GraphIndex::Node(index) = current {
result.push(index);
}
println!("{:?}", current);
for next_node in reverse_graph[&current].iter().copied() {
// Ignore nodes that are already in the closed set
if closed.contains(&next_node) {
continue;
}
let GraphIndex::Node(node_index) = next_node else {
panic!("Unreachable: cannot have GraphIndex::Input as the output of a node");
};
let inputs = self.nodes[node_index].inputs();
// Only consider nodes whose inputs are in the closed set (meaning they would be ready to be evaluated)
if !inputs
.iter()
.all(|input| closed.contains(&index_map[input]))
{
continue;
}
open.push_front(next_node);
}
}
if result.len() != self.nodes.len() {
// TODO: verify that if result.len() != self.nodes.len(), then there is a cyclic subgraph
return Err(NeuraGraphErr::Cyclic);
}
Ok(result)
}
}
impl<Data> NeuraPartialLayer for NeuraGraphPartial<Data> {
type Constructed = NeuraGraph<Data>;
type Err = NeuraGraphErr;
fn construct(self, input_shape: NeuraShape) -> Result<Self::Constructed, Self::Err> {
let index_map = self.get_index_map()?;
let reverse_graph = self.get_reverse_graph(&index_map)?;
// List out the nodes in their execution order
let node_order = self.get_node_order(&index_map, &reverse_graph)?;
let mut new_index_map: HashMap<String, usize> = HashMap::from_iter(
node_order
.iter()
.map(|&i| (self.nodes[i].name().to_string(), i)),
);
new_index_map.insert(self.input.clone(), 0);
// TODO: filter out the nodes that are not necessary for computing the result (BFS from the output node back to the inputs)
// A temporary solution can be to trim the graph
let output_index = new_index_map
.get(&self.output)
.copied()
.ok_or_else(|| NeuraGraphErr::MissingNode(self.output.clone()))?;
let mut nodes = Vec::with_capacity(self.nodes.len());
let mut shapes: Vec<Option<NeuraShape>> = vec![None; self.nodes.len() + 1];
shapes[0] = Some(input_shape);
for index in node_order.into_iter() {
let node = &*self.nodes[index];
let node_inputs = node.inputs();
let mut inputs = Vec::with_capacity(node_inputs.len());
let mut input_shapes = Vec::with_capacity(node_inputs.len());
for input in node_inputs {
let input_index = new_index_map.get(input).copied().expect(
"Unreachable: new_index_map should contain all nodes defined and all nodes should have existing nodes as input"
);
inputs.push(input_index);
input_shapes.push(shapes[input_index].expect(
"Unreachable: the order of execution should guarantee that all inputs have appeared before")
);
}
let (constructed, output_shape) = node
.construct(input_shapes)
.map_err(|e| NeuraGraphErr::LayerErr(e))?;
shapes[index] = Some(output_shape);
nodes.push(NeuraGraphNodeConstructed {
node: constructed,
inputs,
output: new_index_map
.get(node.name())
.copied()
.unwrap_or_else(|| unreachable!()),
});
}
let output_shape = shapes[output_index].unwrap_or_else(|| unreachable!());
Ok(NeuraGraph {
nodes,
// input_shape,
output_shape,
output_index,
buffer_size: self.nodes.len() + 1,
})
}
}

@ -4,6 +4,8 @@ use nalgebra::{Const, DVector, Dyn, Scalar, VecStorage};
use crate::{err::NeuraAxisErr, prelude::NeuraShape}; use crate::{err::NeuraAxisErr, prelude::NeuraShape};
// TODO: create a NeuraAxis trait
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]
pub struct NeuraAxisAppend; pub struct NeuraAxisAppend;
@ -34,6 +36,7 @@ impl<F: Clone> NeuraCombineInputs<DVector<F>> for NeuraAxisAppend {
} }
} }
// TODO: use another trait for combining NeuraShape, or make it another member of the trait
impl NeuraCombineInputs<NeuraShape> for NeuraAxisAppend { impl NeuraCombineInputs<NeuraShape> for NeuraAxisAppend {
type Combined = Result<NeuraShape, NeuraAxisErr>; type Combined = Result<NeuraShape, NeuraAxisErr>;
@ -80,3 +83,37 @@ impl<F: Clone + Scalar + Default> NeuraSplitInputs<DVector<F>> for NeuraAxisAppe
result result
} }
} }
#[derive(Clone, Debug)]
pub struct NeuraAxisDefault;
impl<F: Clone> NeuraCombineInputs<DVector<F>> for NeuraAxisDefault {
type Combined = DVector<F>;
fn combine(&self, inputs: Vec<impl Borrow<DVector<F>>>) -> Self::Combined {
assert!(inputs.len() == 1);
inputs[0].borrow().clone()
}
}
impl NeuraCombineInputs<NeuraShape> for NeuraAxisDefault {
type Combined = Result<NeuraShape, NeuraAxisErr>;
fn combine(&self, inputs: Vec<impl Borrow<NeuraShape>>) -> Self::Combined {
if inputs.len() != 1 {
Err(NeuraAxisErr::InvalidAmount(inputs.len(), 1, Some(1)))
} else {
Ok(*inputs[0].borrow())
}
}
}
impl<Data: Clone> NeuraSplitInputs<Data> for NeuraAxisDefault
where
NeuraAxisDefault: NeuraCombineInputs<Data, Combined = Data>,
{
fn split(&self, combined: &Self::Combined, _input_shapes: &[NeuraShape]) -> Vec<Data> {
vec![combined.clone()]
}
}

Loading…
Cancel
Save