Working NeuraGraph construction

main
Shad Amethyst 2 years ago
parent 2f9c334e62
commit efaed91f83

@ -19,6 +19,7 @@ rand_distr = "0.4.3"
textplots = "0.8.0" textplots = "0.8.0"
image = { version = "0.24.6", optional = true } image = { version = "0.24.6", optional = true }
viuer = { version = "0.6.2", optional = true } viuer = { version = "0.6.2", optional = true }
dyn-clone = "1.0.11"
[dev-dependencies] [dev-dependencies]
image = "0.24.6" image = "0.24.6"

@ -92,8 +92,10 @@ pub struct NeuraDimensionsMismatch {
pub new: NeuraShape, pub new: NeuraShape,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug, PartialEq, Eq)]
pub enum NeuraGraphErr { pub enum NeuraGraphErr {
MissingNode(String), MissingNode(String),
InvalidName(String), InvalidName(String),
LayerErr(String),
Cyclic,
} }

@ -1,3 +1,5 @@
#![allow(dead_code)] // TODO: remove this
use std::collections::{HashMap, HashSet, VecDeque}; use std::collections::{HashMap, HashSet, VecDeque};
use crate::prelude::*; use crate::prelude::*;
@ -7,22 +9,28 @@ mod node;
pub use node::*; pub use node::*;
pub struct NeuraGraphPartial<Data> { pub struct NeuraGraphPartial<Data> {
pub nodes: Vec<Box<dyn NeuraGraphNodeTrait<Data>>>, pub nodes: Vec<Box<dyn NeuraGraphNodePartial<Data>>>,
pub output: String, pub output: String,
pub input: String, pub input: String,
} }
#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
enum GraphIndex {
Input,
Node(usize),
}
impl<Data> NeuraGraphPartial<Data> { impl<Data> NeuraGraphPartial<Data> {
fn get_index_map(&self) -> Result<HashMap<String, usize>, NeuraGraphErr> { fn get_index_map(&self) -> Result<HashMap<String, GraphIndex>, NeuraGraphErr> {
let mut result = HashMap::with_capacity(self.nodes.len()); let mut result = HashMap::with_capacity(self.nodes.len());
result.insert(self.input.clone(), 0); result.insert(self.input.clone(), GraphIndex::Input);
for (index, node) in self.nodes.iter().enumerate() { for (index, node) in self.nodes.iter().enumerate() {
if result.contains_key(node.name()) { if result.contains_key(node.name()) {
return Err(NeuraGraphErr::InvalidName(node.name().to_string())); return Err(NeuraGraphErr::InvalidName(node.name().to_string()));
} }
result.insert(node.name().to_string(), index + 1); result.insert(node.name().to_string(), GraphIndex::Node(index));
} }
Ok(result) Ok(result)
@ -30,9 +38,15 @@ impl<Data> NeuraGraphPartial<Data> {
fn get_reverse_graph( fn get_reverse_graph(
&self, &self,
index_map: &HashMap<String, usize>, index_map: &HashMap<String, GraphIndex>,
) -> Result<Vec<HashSet<usize>>, NeuraGraphErr> { ) -> Result<HashMap<GraphIndex, HashSet<GraphIndex>>, NeuraGraphErr> {
let mut result = vec![HashSet::new(); self.nodes.len()]; let mut result = HashMap::new();
result.insert(GraphIndex::Input, HashSet::new());
for i in 0..self.nodes.len() {
result.insert(GraphIndex::Node(i), HashSet::new());
}
for (index, node) in self.nodes.iter().enumerate() { for (index, node) in self.nodes.iter().enumerate() {
for input in node.inputs() { for input in node.inputs() {
@ -40,7 +54,10 @@ impl<Data> NeuraGraphPartial<Data> {
.get(input) .get(input)
.copied() .copied()
.ok_or_else(|| NeuraGraphErr::MissingNode(input.clone()))?; .ok_or_else(|| NeuraGraphErr::MissingNode(input.clone()))?;
result[input_index].insert(index + 1); result
.get_mut(&input_index)
.expect("index_map returned invalid values")
.insert(GraphIndex::Node(index));
} }
} }
@ -49,13 +66,13 @@ impl<Data> NeuraGraphPartial<Data> {
fn get_node_order( fn get_node_order(
&self, &self,
index_map: &HashMap<String, usize>, index_map: &HashMap<String, GraphIndex>,
reverse_graph: &Vec<HashSet<usize>>, reverse_graph: &HashMap<GraphIndex, HashSet<GraphIndex>>,
) -> Result<Vec<usize>, NeuraGraphErr> { ) -> Result<Vec<usize>, NeuraGraphErr> {
let mut result: Vec<usize> = Vec::new(); let mut result: Vec<usize> = Vec::new();
let mut closed: HashSet<usize> = HashSet::with_capacity(self.nodes.len()); let mut closed: HashSet<GraphIndex> = HashSet::with_capacity(self.nodes.len());
let mut open = VecDeque::with_capacity(self.nodes.len()); let mut open = VecDeque::with_capacity(self.nodes.len());
open.push_front(0usize); open.push_front(GraphIndex::Input);
/* /*
index_map.get(&self.output) index_map.get(&self.output)
@ -69,49 +86,68 @@ impl<Data> NeuraGraphPartial<Data> {
} }
closed.insert(current); closed.insert(current);
result.push(current); // Do not put 0 (the input) in result
if let GraphIndex::Node(index) = current {
result.push(index);
}
for output_index in reverse_graph[current].iter().copied() { println!("{:?}", current);
assert!(output_index > 0);
for next_node in reverse_graph[&current].iter().copied() {
// Ignore nodes that are already in the closed set // Ignore nodes that are already in the closed set
if closed.contains(&output_index) { if closed.contains(&next_node) {
continue; continue;
} }
let inputs = self.nodes[output_index - 1].inputs(); let GraphIndex::Node(node_index) = next_node else {
panic!("Unreachable: cannot have GraphIndex::Input as the output of a node");
};
let inputs = self.nodes[node_index].inputs();
// Only consider nodes whose inputs are in the closed set // Only consider nodes whose inputs are in the closed set (meaning they would be ready to be evaluated)
if !inputs if !inputs
.iter() .iter()
.all(|input| closed.contains(&index_map[input])) .all(|input| closed.contains(&index_map[input]))
{ {
continue; continue;
} }
open.push_front(next_node);
} }
} }
if result.len() != self.nodes.len() {
// TODO: verify that if result.len() != self.nodes.len(), then there is a cyclic subgraph
return Err(NeuraGraphErr::Cyclic);
}
Ok(result) Ok(result)
} }
} }
#[derive(Debug)]
struct NeuraGraphNodeConstructed<Data> { struct NeuraGraphNodeConstructed<Data> {
node: Box<dyn NeuraGraphNodeTrait<Data>>, node: Box<dyn NeuraGraphNodeEval<Data>>,
inputs: Vec<usize>, inputs: Vec<usize>,
output: usize, output: usize,
} }
#[derive(Debug)]
pub struct NeuraGraph<Data> { pub struct NeuraGraph<Data> {
/// ## Class invariants /// ## Class invariants
/// ///
/// - The order of nodes should match with the order of execution, ie. /// - The order of nodes should match with the order of execution, ie.
/// `forall (x, y), nodes = [..., x, ..., y, ...] => !(y in x.node.inputs)` /// `forall (x, y), nodes = [..., x, ..., y, ...] => !(y in x.node.inputs)`
/// - `nodes[0].inputs = [0]` /// - `nodes[0].inputs = [0]`
/// - `nodes[nodes.len() - 1].output = buffer.len() - 1`
nodes: Vec<NeuraGraphNodeConstructed<Data>>, nodes: Vec<NeuraGraphNodeConstructed<Data>>,
input_shape: NeuraShape, input_shape: NeuraShape,
output_shape: NeuraShape, output_shape: NeuraShape,
output_index: usize,
buffer_size: usize,
} }
impl<Data> NeuraShapedLayer for NeuraGraph<Data> { impl<Data> NeuraShapedLayer for NeuraGraph<Data> {
@ -126,6 +162,188 @@ impl<Data> NeuraPartialLayer for NeuraGraphPartial<Data> {
type Err = NeuraGraphErr; type Err = NeuraGraphErr;
fn construct(self, input_shape: NeuraShape) -> Result<Self::Constructed, Self::Err> { fn construct(self, input_shape: NeuraShape) -> Result<Self::Constructed, Self::Err> {
todo!() let index_map = self.get_index_map()?;
let reverse_graph = self.get_reverse_graph(&index_map)?;
// List out the nodes in their execution order
let node_order = self.get_node_order(&index_map, &reverse_graph)?;
let mut new_index_map: HashMap<String, usize> = HashMap::from_iter(
node_order
.iter()
.map(|&i| (self.nodes[i].name().to_string(), i)),
);
new_index_map.insert(self.input.clone(), 0);
// TODO: filter out the nodes that are not necessary for computing the result (BFS from the output node back to the inputs)
// A temporary solution can be to trim the graph
let output_index = new_index_map
.get(&self.output)
.copied()
.ok_or_else(|| NeuraGraphErr::MissingNode(self.output.clone()))?;
let mut nodes = Vec::with_capacity(self.nodes.len());
let mut shapes: Vec<Option<NeuraShape>> = vec![None; self.nodes.len() + 1];
shapes[0] = Some(input_shape);
for index in node_order.into_iter() {
let node = &*self.nodes[index];
let node_inputs = node.inputs();
let mut inputs = Vec::with_capacity(node_inputs.len());
let mut input_shapes = Vec::with_capacity(node_inputs.len());
for input in node_inputs {
let input_index = new_index_map.get(input).copied().expect(
"Unreachable: new_index_map should contain all nodes defined and all nodes should have existing nodes as input"
);
inputs.push(input_index);
input_shapes.push(shapes[input_index].expect(
"Unreachable: the order of execution should guarantee that all inputs have appeared before")
);
}
let (constructed, output_shape) = node
.construct(input_shapes)
.map_err(|e| NeuraGraphErr::LayerErr(e))?;
shapes[index] = Some(output_shape);
nodes.push(NeuraGraphNodeConstructed {
node: constructed,
inputs,
output: new_index_map
.get(node.name())
.copied()
.unwrap_or_else(|| unreachable!()),
});
}
let output_shape = shapes[output_index].unwrap_or_else(|| unreachable!());
Ok(NeuraGraph {
nodes,
input_shape,
output_shape,
output_index,
buffer_size: self.nodes.len() + 1,
})
}
}
#[cfg(test)]
mod test {
use crate::network::residual::NeuraAxisAppend;
use super::*;
#[test]
fn test_construct_simple_graph() {
let graph = NeuraGraphPartial {
nodes: vec![NeuraGraphNode::new(
vec!["input".to_string()],
NeuraAxisAppend,
neura_layer!("dense", 10),
"output".to_string(),
)
.as_boxed()],
output: "output".to_string(),
input: "input".to_string(),
};
let constructed = graph.construct(NeuraShape::Vector(5));
assert!(constructed.is_ok());
}
#[test]
fn test_construct_deep_graph() {
let graph = NeuraGraphPartial {
nodes: vec![
// Node intentionally out of order
NeuraGraphNode::new(
vec!["inter".to_string(), "inter2".to_string()],
NeuraAxisAppend,
neura_layer!("dense", 2),
"output".to_string(),
)
.as_boxed(),
NeuraGraphNode::new(
vec!["input".to_string()],
NeuraAxisAppend,
neura_layer!("dense", 10),
"inter".to_string(),
)
.as_boxed(),
NeuraGraphNode::new(
vec!["inter".to_string()],
NeuraAxisAppend,
neura_layer!("dense", 20),
"inter2".to_string(),
)
.as_boxed(),
],
output: "output".to_string(),
input: "input".to_string(),
};
let index_map = graph.get_index_map().unwrap();
let reverse_graph = graph.get_reverse_graph(&index_map).unwrap();
assert_eq!(
graph.get_node_order(&index_map, &reverse_graph),
Ok(vec![1, 2, 0])
);
let constructed = graph.construct(NeuraShape::Vector(5));
assert!(constructed.is_ok());
}
#[test]
fn test_construct_cyclic_graph() {
let graph = NeuraGraphPartial {
nodes: vec![NeuraGraphNode::new(
vec!["input".to_string(), "output".to_string()],
NeuraAxisAppend,
neura_layer!("dense", 10),
"output".to_string(),
)
.as_boxed()],
output: "output".to_string(),
input: "input".to_string(),
};
let constructed = graph.construct(NeuraShape::Vector(5));
assert_eq!(constructed.unwrap_err(), NeuraGraphErr::Cyclic);
}
#[test]
fn test_construct_disjoint_graph() {
let graph = NeuraGraphPartial {
nodes: vec![
NeuraGraphNode::new(
vec!["input".to_string()],
NeuraAxisAppend,
neura_layer!("dense", 10),
"inter".to_string(),
)
.as_boxed(),
NeuraGraphNode::new(
vec!["missing".to_string()],
NeuraAxisAppend,
neura_layer!("dense", 10),
"output".to_string(),
)
.as_boxed(),
],
output: "output".to_string(),
input: "input".to_string(),
};
let constructed = graph.construct(NeuraShape::Vector(5));
assert_eq!(
constructed.unwrap_err(),
NeuraGraphErr::MissingNode(String::from("missing"))
);
} }
} }

@ -1,12 +1,28 @@
use crate::{layer::NeuraLayer, network::residual::NeuraSplitInputs}; use dyn_clone::DynClone;
pub trait NeuraGraphNodeTrait<Data> { use crate::{
fn eval<'a>(&'a self, inputs: &[Data]) -> Data; err::NeuraAxisErr,
layer::{NeuraLayer, NeuraShapedLayer},
network::residual::{NeuraCombineInputs, NeuraSplitInputs},
prelude::{NeuraPartialLayer, NeuraShape},
};
// TODO: split into two traits
pub trait NeuraGraphNodePartial<Data>: DynClone + std::fmt::Debug {
fn inputs<'a>(&'a self) -> &'a [String]; fn inputs<'a>(&'a self) -> &'a [String];
fn name<'a>(&'a self) -> &'a str; fn name<'a>(&'a self) -> &'a str;
fn construct(
&self,
input_shapes: Vec<NeuraShape>,
) -> Result<(Box<dyn NeuraGraphNodeEval<Data>>, NeuraShape), String>;
} }
pub trait NeuraGraphNodeEval<Data>: DynClone + std::fmt::Debug {
fn eval<'a>(&'a self, inputs: &[Data]) -> Data;
}
#[derive(Clone, Debug)]
pub struct NeuraGraphNode<Axis, Layer> { pub struct NeuraGraphNode<Axis, Layer> {
inputs: Vec<String>, inputs: Vec<String>,
axis: Axis, axis: Axis,
@ -25,10 +41,20 @@ impl<Axis, Layer> NeuraGraphNode<Axis, Layer> {
} }
} }
pub fn as_boxed<Data: Clone>(self) -> Box<dyn NeuraGraphNodeTrait<Data>> pub fn as_boxed<Data: Clone>(self) -> Box<dyn NeuraGraphNodePartial<Data>>
where where
Axis: NeuraSplitInputs<Data> + 'static, Axis: NeuraSplitInputs<Data>
Layer: NeuraLayer<Axis::Combined, Output = Data> + 'static, + NeuraCombineInputs<NeuraShape, Combined = Result<NeuraShape, NeuraAxisErr>>
+ Clone
+ std::fmt::Debug
+ 'static,
Layer: NeuraPartialLayer + Clone + std::fmt::Debug + 'static,
Layer::Constructed: NeuraShapedLayer
+ NeuraLayer<<Axis as NeuraCombineInputs<Data>>::Combined, Output = Data>
+ Clone
+ std::fmt::Debug
+ 'static,
Layer::Err: std::fmt::Debug,
{ {
Box::new(self) Box::new(self)
} }
@ -36,16 +62,36 @@ impl<Axis, Layer> NeuraGraphNode<Axis, Layer> {
impl< impl<
Data: Clone, Data: Clone,
Axis: NeuraSplitInputs<Data>, Axis: NeuraSplitInputs<Data> + Clone + std::fmt::Debug,
Layer: NeuraLayer<Axis::Combined, Output = Data>, Layer: NeuraLayer<<Axis as NeuraCombineInputs<Data>>::Combined, Output = Data>
> NeuraGraphNodeTrait<Data> for NeuraGraphNode<Axis, Layer> + Clone
+ std::fmt::Debug,
> NeuraGraphNodeEval<Data> for NeuraGraphNode<Axis, Layer>
{ {
fn eval<'a>(&'a self, inputs: &[Data]) -> Data { fn eval<'a>(&'a self, inputs: &[Data]) -> Data {
// TODO: use to_vec_in? // TODO: use to_vec_in?
let combined = self.axis.combine(inputs.to_vec()); let combined = self.axis.combine(inputs.to_vec());
self.layer.eval(&combined) self.layer.eval(&combined)
} }
}
impl<
Data: Clone,
Axis: NeuraSplitInputs<Data>
+ NeuraCombineInputs<NeuraShape, Combined = Result<NeuraShape, NeuraAxisErr>>
+ Clone
+ std::fmt::Debug
+ 'static,
Layer: NeuraPartialLayer + Clone + std::fmt::Debug,
> NeuraGraphNodePartial<Data> for NeuraGraphNode<Axis, Layer>
where
Layer::Constructed: NeuraShapedLayer
+ NeuraLayer<<Axis as NeuraCombineInputs<Data>>::Combined, Output = Data>
+ Clone
+ std::fmt::Debug
+ 'static,
Layer::Err: std::fmt::Debug,
{
fn inputs<'a>(&'a self) -> &'a [String] { fn inputs<'a>(&'a self) -> &'a [String] {
&self.inputs &self.inputs
} }
@ -53,4 +99,31 @@ impl<
fn name<'a>(&'a self) -> &'a str { fn name<'a>(&'a self) -> &'a str {
&self.name &self.name
} }
fn construct(
&self,
input_shapes: Vec<NeuraShape>,
) -> Result<(Box<dyn NeuraGraphNodeEval<Data>>, NeuraShape), String> {
let combined = self
.axis
.combine(input_shapes)
.map_err(|err| format!("{:?}", err))?;
let constructed_layer = self
.layer
.clone()
.construct(combined)
.map_err(|err| format!("{:?}", err))?;
let output_shape = constructed_layer.output_shape();
Ok((
Box::new(NeuraGraphNode {
inputs: self.inputs.clone(),
axis: self.axis.clone(),
layer: constructed_layer,
name: self.name.clone(),
}),
output_shape,
))
}
} }

Loading…
Cancel
Save