diff --git a/Cargo.toml b/Cargo.toml index d9b2515..6db7a22 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ rand_distr = "0.4.3" textplots = "0.8.0" image = { version = "0.24.6", optional = true } viuer = { version = "0.6.2", optional = true } +dyn-clone = "1.0.11" [dev-dependencies] image = "0.24.6" diff --git a/src/err.rs b/src/err.rs index 878d22d..07d9e5c 100644 --- a/src/err.rs +++ b/src/err.rs @@ -92,8 +92,10 @@ pub struct NeuraDimensionsMismatch { pub new: NeuraShape, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum NeuraGraphErr { MissingNode(String), InvalidName(String), + LayerErr(String), + Cyclic, } diff --git a/src/network/graph/mod.rs b/src/network/graph/mod.rs index a46bd65..1c172e1 100644 --- a/src/network/graph/mod.rs +++ b/src/network/graph/mod.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] // TODO: remove this + use std::collections::{HashMap, HashSet, VecDeque}; use crate::prelude::*; @@ -7,22 +9,28 @@ mod node; pub use node::*; pub struct NeuraGraphPartial { - pub nodes: Vec>>, + pub nodes: Vec>>, pub output: String, pub input: String, } +#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)] +enum GraphIndex { + Input, + Node(usize), +} + impl NeuraGraphPartial { - fn get_index_map(&self) -> Result, NeuraGraphErr> { + fn get_index_map(&self) -> Result, NeuraGraphErr> { let mut result = HashMap::with_capacity(self.nodes.len()); - result.insert(self.input.clone(), 0); + result.insert(self.input.clone(), GraphIndex::Input); for (index, node) in self.nodes.iter().enumerate() { if result.contains_key(node.name()) { return Err(NeuraGraphErr::InvalidName(node.name().to_string())); } - result.insert(node.name().to_string(), index + 1); + result.insert(node.name().to_string(), GraphIndex::Node(index)); } Ok(result) @@ -30,9 +38,15 @@ impl NeuraGraphPartial { fn get_reverse_graph( &self, - index_map: &HashMap, - ) -> Result>, NeuraGraphErr> { - let mut result = vec![HashSet::new(); self.nodes.len()]; + index_map: &HashMap, + ) -> Result>, NeuraGraphErr> { + let mut result = HashMap::new(); + + result.insert(GraphIndex::Input, HashSet::new()); + + for i in 0..self.nodes.len() { + result.insert(GraphIndex::Node(i), HashSet::new()); + } for (index, node) in self.nodes.iter().enumerate() { for input in node.inputs() { @@ -40,7 +54,10 @@ impl NeuraGraphPartial { .get(input) .copied() .ok_or_else(|| NeuraGraphErr::MissingNode(input.clone()))?; - result[input_index].insert(index + 1); + result + .get_mut(&input_index) + .expect("index_map returned invalid values") + .insert(GraphIndex::Node(index)); } } @@ -49,13 +66,13 @@ impl NeuraGraphPartial { fn get_node_order( &self, - index_map: &HashMap, - reverse_graph: &Vec>, + index_map: &HashMap, + reverse_graph: &HashMap>, ) -> Result, NeuraGraphErr> { let mut result: Vec = Vec::new(); - let mut closed: HashSet = HashSet::with_capacity(self.nodes.len()); + let mut closed: HashSet = HashSet::with_capacity(self.nodes.len()); let mut open = VecDeque::with_capacity(self.nodes.len()); - open.push_front(0usize); + open.push_front(GraphIndex::Input); /* index_map.get(&self.output) @@ -69,49 +86,68 @@ impl NeuraGraphPartial { } closed.insert(current); - result.push(current); + // Do not put 0 (the input) in result + if let GraphIndex::Node(index) = current { + result.push(index); + } - for output_index in reverse_graph[current].iter().copied() { - assert!(output_index > 0); + println!("{:?}", current); + for next_node in reverse_graph[¤t].iter().copied() { // Ignore nodes that are already in the closed set - if closed.contains(&output_index) { + if closed.contains(&next_node) { continue; } - let inputs = self.nodes[output_index - 1].inputs(); + let GraphIndex::Node(node_index) = next_node else { + panic!("Unreachable: cannot have GraphIndex::Input as the output of a node"); + }; + + let inputs = self.nodes[node_index].inputs(); - // Only consider nodes whose inputs are in the closed set + // Only consider nodes whose inputs are in the closed set (meaning they would be ready to be evaluated) if !inputs .iter() .all(|input| closed.contains(&index_map[input])) { continue; } + + open.push_front(next_node); } } + if result.len() != self.nodes.len() { + // TODO: verify that if result.len() != self.nodes.len(), then there is a cyclic subgraph + + return Err(NeuraGraphErr::Cyclic); + } + Ok(result) } } +#[derive(Debug)] struct NeuraGraphNodeConstructed { - node: Box>, + node: Box>, inputs: Vec, output: usize, } +#[derive(Debug)] pub struct NeuraGraph { /// ## Class invariants /// /// - The order of nodes should match with the order of execution, ie. /// `forall (x, y), nodes = [..., x, ..., y, ...] => !(y in x.node.inputs)` /// - `nodes[0].inputs = [0]` - /// - `nodes[nodes.len() - 1].output = buffer.len() - 1` nodes: Vec>, input_shape: NeuraShape, output_shape: NeuraShape, + + output_index: usize, + buffer_size: usize, } impl NeuraShapedLayer for NeuraGraph { @@ -126,6 +162,188 @@ impl NeuraPartialLayer for NeuraGraphPartial { type Err = NeuraGraphErr; fn construct(self, input_shape: NeuraShape) -> Result { - todo!() + let index_map = self.get_index_map()?; + let reverse_graph = self.get_reverse_graph(&index_map)?; + + // List out the nodes in their execution order + let node_order = self.get_node_order(&index_map, &reverse_graph)?; + let mut new_index_map: HashMap = HashMap::from_iter( + node_order + .iter() + .map(|&i| (self.nodes[i].name().to_string(), i)), + ); + new_index_map.insert(self.input.clone(), 0); + + // TODO: filter out the nodes that are not necessary for computing the result (BFS from the output node back to the inputs) + // A temporary solution can be to trim the graph + let output_index = new_index_map + .get(&self.output) + .copied() + .ok_or_else(|| NeuraGraphErr::MissingNode(self.output.clone()))?; + + let mut nodes = Vec::with_capacity(self.nodes.len()); + let mut shapes: Vec> = vec![None; self.nodes.len() + 1]; + shapes[0] = Some(input_shape); + + for index in node_order.into_iter() { + let node = &*self.nodes[index]; + let node_inputs = node.inputs(); + let mut inputs = Vec::with_capacity(node_inputs.len()); + let mut input_shapes = Vec::with_capacity(node_inputs.len()); + + for input in node_inputs { + let input_index = new_index_map.get(input).copied().expect( + "Unreachable: new_index_map should contain all nodes defined and all nodes should have existing nodes as input" + ); + inputs.push(input_index); + input_shapes.push(shapes[input_index].expect( + "Unreachable: the order of execution should guarantee that all inputs have appeared before") + ); + } + + let (constructed, output_shape) = node + .construct(input_shapes) + .map_err(|e| NeuraGraphErr::LayerErr(e))?; + + shapes[index] = Some(output_shape); + + nodes.push(NeuraGraphNodeConstructed { + node: constructed, + inputs, + output: new_index_map + .get(node.name()) + .copied() + .unwrap_or_else(|| unreachable!()), + }); + } + + let output_shape = shapes[output_index].unwrap_or_else(|| unreachable!()); + + Ok(NeuraGraph { + nodes, + input_shape, + output_shape, + output_index, + buffer_size: self.nodes.len() + 1, + }) + } +} + +#[cfg(test)] +mod test { + use crate::network::residual::NeuraAxisAppend; + + use super::*; + + #[test] + fn test_construct_simple_graph() { + let graph = NeuraGraphPartial { + nodes: vec![NeuraGraphNode::new( + vec!["input".to_string()], + NeuraAxisAppend, + neura_layer!("dense", 10), + "output".to_string(), + ) + .as_boxed()], + output: "output".to_string(), + input: "input".to_string(), + }; + + let constructed = graph.construct(NeuraShape::Vector(5)); + + assert!(constructed.is_ok()); + } + + #[test] + fn test_construct_deep_graph() { + let graph = NeuraGraphPartial { + nodes: vec![ + // Node intentionally out of order + NeuraGraphNode::new( + vec!["inter".to_string(), "inter2".to_string()], + NeuraAxisAppend, + neura_layer!("dense", 2), + "output".to_string(), + ) + .as_boxed(), + NeuraGraphNode::new( + vec!["input".to_string()], + NeuraAxisAppend, + neura_layer!("dense", 10), + "inter".to_string(), + ) + .as_boxed(), + NeuraGraphNode::new( + vec!["inter".to_string()], + NeuraAxisAppend, + neura_layer!("dense", 20), + "inter2".to_string(), + ) + .as_boxed(), + ], + output: "output".to_string(), + input: "input".to_string(), + }; + + let index_map = graph.get_index_map().unwrap(); + let reverse_graph = graph.get_reverse_graph(&index_map).unwrap(); + assert_eq!( + graph.get_node_order(&index_map, &reverse_graph), + Ok(vec![1, 2, 0]) + ); + + let constructed = graph.construct(NeuraShape::Vector(5)); + + assert!(constructed.is_ok()); + } + + #[test] + fn test_construct_cyclic_graph() { + let graph = NeuraGraphPartial { + nodes: vec![NeuraGraphNode::new( + vec!["input".to_string(), "output".to_string()], + NeuraAxisAppend, + neura_layer!("dense", 10), + "output".to_string(), + ) + .as_boxed()], + output: "output".to_string(), + input: "input".to_string(), + }; + + let constructed = graph.construct(NeuraShape::Vector(5)); + + assert_eq!(constructed.unwrap_err(), NeuraGraphErr::Cyclic); + } + + #[test] + fn test_construct_disjoint_graph() { + let graph = NeuraGraphPartial { + nodes: vec![ + NeuraGraphNode::new( + vec!["input".to_string()], + NeuraAxisAppend, + neura_layer!("dense", 10), + "inter".to_string(), + ) + .as_boxed(), + NeuraGraphNode::new( + vec!["missing".to_string()], + NeuraAxisAppend, + neura_layer!("dense", 10), + "output".to_string(), + ) + .as_boxed(), + ], + output: "output".to_string(), + input: "input".to_string(), + }; + + let constructed = graph.construct(NeuraShape::Vector(5)); + + assert_eq!( + constructed.unwrap_err(), + NeuraGraphErr::MissingNode(String::from("missing")) + ); } } diff --git a/src/network/graph/node.rs b/src/network/graph/node.rs index 5777b51..afd8d4b 100644 --- a/src/network/graph/node.rs +++ b/src/network/graph/node.rs @@ -1,12 +1,28 @@ -use crate::{layer::NeuraLayer, network::residual::NeuraSplitInputs}; +use dyn_clone::DynClone; -pub trait NeuraGraphNodeTrait { - fn eval<'a>(&'a self, inputs: &[Data]) -> Data; +use crate::{ + err::NeuraAxisErr, + layer::{NeuraLayer, NeuraShapedLayer}, + network::residual::{NeuraCombineInputs, NeuraSplitInputs}, + prelude::{NeuraPartialLayer, NeuraShape}, +}; +// TODO: split into two traits +pub trait NeuraGraphNodePartial: DynClone + std::fmt::Debug { fn inputs<'a>(&'a self) -> &'a [String]; fn name<'a>(&'a self) -> &'a str; + + fn construct( + &self, + input_shapes: Vec, + ) -> Result<(Box>, NeuraShape), String>; } +pub trait NeuraGraphNodeEval: DynClone + std::fmt::Debug { + fn eval<'a>(&'a self, inputs: &[Data]) -> Data; +} + +#[derive(Clone, Debug)] pub struct NeuraGraphNode { inputs: Vec, axis: Axis, @@ -25,10 +41,20 @@ impl NeuraGraphNode { } } - pub fn as_boxed(self) -> Box> + pub fn as_boxed(self) -> Box> where - Axis: NeuraSplitInputs + 'static, - Layer: NeuraLayer + 'static, + Axis: NeuraSplitInputs + + NeuraCombineInputs> + + Clone + + std::fmt::Debug + + 'static, + Layer: NeuraPartialLayer + Clone + std::fmt::Debug + 'static, + Layer::Constructed: NeuraShapedLayer + + NeuraLayer<>::Combined, Output = Data> + + Clone + + std::fmt::Debug + + 'static, + Layer::Err: std::fmt::Debug, { Box::new(self) } @@ -36,16 +62,36 @@ impl NeuraGraphNode { impl< Data: Clone, - Axis: NeuraSplitInputs, - Layer: NeuraLayer, - > NeuraGraphNodeTrait for NeuraGraphNode + Axis: NeuraSplitInputs + Clone + std::fmt::Debug, + Layer: NeuraLayer<>::Combined, Output = Data> + + Clone + + std::fmt::Debug, + > NeuraGraphNodeEval for NeuraGraphNode { fn eval<'a>(&'a self, inputs: &[Data]) -> Data { // TODO: use to_vec_in? let combined = self.axis.combine(inputs.to_vec()); self.layer.eval(&combined) } +} +impl< + Data: Clone, + Axis: NeuraSplitInputs + + NeuraCombineInputs> + + Clone + + std::fmt::Debug + + 'static, + Layer: NeuraPartialLayer + Clone + std::fmt::Debug, + > NeuraGraphNodePartial for NeuraGraphNode +where + Layer::Constructed: NeuraShapedLayer + + NeuraLayer<>::Combined, Output = Data> + + Clone + + std::fmt::Debug + + 'static, + Layer::Err: std::fmt::Debug, +{ fn inputs<'a>(&'a self) -> &'a [String] { &self.inputs } @@ -53,4 +99,31 @@ impl< fn name<'a>(&'a self) -> &'a str { &self.name } + + fn construct( + &self, + input_shapes: Vec, + ) -> Result<(Box>, NeuraShape), String> { + let combined = self + .axis + .combine(input_shapes) + .map_err(|err| format!("{:?}", err))?; + + let constructed_layer = self + .layer + .clone() + .construct(combined) + .map_err(|err| format!("{:?}", err))?; + let output_shape = constructed_layer.output_shape(); + + Ok(( + Box::new(NeuraGraphNode { + inputs: self.inputs.clone(), + axis: self.axis.clone(), + layer: constructed_layer, + name: self.name.clone(), + }), + output_shape, + )) + } }