diff --git a/src/err.rs b/src/err.rs index 7b742ef..878d22d 100644 --- a/src/err.rs +++ b/src/err.rs @@ -91,3 +91,9 @@ pub struct NeuraDimensionsMismatch { pub existing: usize, pub new: NeuraShape, } + +#[derive(Clone, Debug)] +pub enum NeuraGraphErr { + MissingNode(String), + InvalidName(String), +} diff --git a/src/network/graph/mod.rs b/src/network/graph/mod.rs new file mode 100644 index 0000000..a46bd65 --- /dev/null +++ b/src/network/graph/mod.rs @@ -0,0 +1,131 @@ +use std::collections::{HashMap, HashSet, VecDeque}; + +use crate::prelude::*; +use crate::{err::NeuraGraphErr, layer::NeuraShapedLayer}; + +mod node; +pub use node::*; + +pub struct NeuraGraphPartial { + pub nodes: Vec>>, + pub output: String, + pub input: String, +} + +impl NeuraGraphPartial { + fn get_index_map(&self) -> Result, NeuraGraphErr> { + let mut result = HashMap::with_capacity(self.nodes.len()); + + result.insert(self.input.clone(), 0); + + for (index, node) in self.nodes.iter().enumerate() { + if result.contains_key(node.name()) { + return Err(NeuraGraphErr::InvalidName(node.name().to_string())); + } + result.insert(node.name().to_string(), index + 1); + } + + Ok(result) + } + + fn get_reverse_graph( + &self, + index_map: &HashMap, + ) -> Result>, NeuraGraphErr> { + let mut result = vec![HashSet::new(); self.nodes.len()]; + + for (index, node) in self.nodes.iter().enumerate() { + for input in node.inputs() { + let input_index = index_map + .get(input) + .copied() + .ok_or_else(|| NeuraGraphErr::MissingNode(input.clone()))?; + result[input_index].insert(index + 1); + } + } + + Ok(result) + } + + fn get_node_order( + &self, + index_map: &HashMap, + reverse_graph: &Vec>, + ) -> Result, NeuraGraphErr> { + let mut result: Vec = Vec::new(); + let mut closed: HashSet = HashSet::with_capacity(self.nodes.len()); + let mut open = VecDeque::with_capacity(self.nodes.len()); + open.push_front(0usize); + + /* + index_map.get(&self.output) + .copied() + .ok_or_else(|| NeuraGraphErr::MissingNode(self.output.clone()))? + */ + + while let Some(current) = open.pop_back() { + if closed.contains(¤t) { + continue; + } + + closed.insert(current); + result.push(current); + + for output_index in reverse_graph[current].iter().copied() { + assert!(output_index > 0); + + // Ignore nodes that are already in the closed set + if closed.contains(&output_index) { + continue; + } + + let inputs = self.nodes[output_index - 1].inputs(); + + // Only consider nodes whose inputs are in the closed set + if !inputs + .iter() + .all(|input| closed.contains(&index_map[input])) + { + continue; + } + } + } + + Ok(result) + } +} + +struct NeuraGraphNodeConstructed { + node: Box>, + inputs: Vec, + output: usize, +} + +pub struct NeuraGraph { + /// ## Class invariants + /// + /// - The order of nodes should match with the order of execution, ie. + /// `forall (x, y), nodes = [..., x, ..., y, ...] => !(y in x.node.inputs)` + /// - `nodes[0].inputs = [0]` + /// - `nodes[nodes.len() - 1].output = buffer.len() - 1` + nodes: Vec>, + + input_shape: NeuraShape, + output_shape: NeuraShape, +} + +impl NeuraShapedLayer for NeuraGraph { + fn output_shape(&self) -> NeuraShape { + self.output_shape + } +} + +impl NeuraPartialLayer for NeuraGraphPartial { + type Constructed = NeuraGraph; + + type Err = NeuraGraphErr; + + fn construct(self, input_shape: NeuraShape) -> Result { + todo!() + } +} diff --git a/src/network/graph/node.rs b/src/network/graph/node.rs new file mode 100644 index 0000000..5777b51 --- /dev/null +++ b/src/network/graph/node.rs @@ -0,0 +1,56 @@ +use crate::{layer::NeuraLayer, network::residual::NeuraSplitInputs}; + +pub trait NeuraGraphNodeTrait { + fn eval<'a>(&'a self, inputs: &[Data]) -> Data; + + fn inputs<'a>(&'a self) -> &'a [String]; + fn name<'a>(&'a self) -> &'a str; +} + +pub struct NeuraGraphNode { + inputs: Vec, + axis: Axis, + layer: Layer, + name: String, +} + +impl NeuraGraphNode { + pub fn new(inputs: Vec, axis: Axis, layer: Layer, name: String) -> Self { + // Check that `name not in inputs` ? + Self { + inputs, + axis, + layer, + name, + } + } + + pub fn as_boxed(self) -> Box> + where + Axis: NeuraSplitInputs + 'static, + Layer: NeuraLayer + 'static, + { + Box::new(self) + } +} + +impl< + Data: Clone, + Axis: NeuraSplitInputs, + Layer: NeuraLayer, + > NeuraGraphNodeTrait for NeuraGraphNode +{ + fn eval<'a>(&'a self, inputs: &[Data]) -> Data { + // TODO: use to_vec_in? + let combined = self.axis.combine(inputs.to_vec()); + self.layer.eval(&combined) + } + + fn inputs<'a>(&'a self) -> &'a [String] { + &self.inputs + } + + fn name<'a>(&'a self) -> &'a str { + &self.name + } +} diff --git a/src/network/mod.rs b/src/network/mod.rs index fcc8836..cd61cf7 100644 --- a/src/network/mod.rs +++ b/src/network/mod.rs @@ -1,3 +1,4 @@ +pub mod graph; pub mod residual; pub mod sequential;