parent
efaed91f83
commit
251e4d02d2
@ -0,0 +1,61 @@
|
|||||||
|
use crate::network::residual::{NeuraAxisDefault, NeuraSplitInputs};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
trait FromSequential<Seq, Data> {
|
||||||
|
fn from_sequential(
|
||||||
|
seq: &Seq,
|
||||||
|
nodes: Vec<NeuraGraphNodeConstructed<Data>>,
|
||||||
|
output_shape: NeuraShape,
|
||||||
|
) -> Self;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Data> FromSequential<(), Data> for NeuraGraph<Data> {
|
||||||
|
fn from_sequential(
|
||||||
|
_seq: &(),
|
||||||
|
nodes: Vec<NeuraGraphNodeConstructed<Data>>,
|
||||||
|
output_shape: NeuraShape,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
output_index: nodes.len(),
|
||||||
|
buffer_size: nodes.len() + 1,
|
||||||
|
nodes: nodes,
|
||||||
|
output_shape,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<
|
||||||
|
Data: Clone,
|
||||||
|
Layer: NeuraLayer<Data, Output = Data> + Clone + std::fmt::Debug + 'static,
|
||||||
|
ChildNetwork,
|
||||||
|
> FromSequential<NeuraSequential<Layer, ChildNetwork>, Data> for NeuraGraph<Data>
|
||||||
|
where
|
||||||
|
NeuraGraph<Data>: FromSequential<ChildNetwork, Data>,
|
||||||
|
NeuraAxisDefault: NeuraSplitInputs<Data, Combined = Data>,
|
||||||
|
{
|
||||||
|
fn from_sequential(
|
||||||
|
seq: &NeuraSequential<Layer, ChildNetwork>,
|
||||||
|
mut nodes: Vec<NeuraGraphNodeConstructed<Data>>,
|
||||||
|
output_shape: NeuraShape,
|
||||||
|
) -> Self {
|
||||||
|
nodes.push(NeuraGraphNodeConstructed {
|
||||||
|
node: Box::new(NeuraGraphNode::from(seq.layer.clone())),
|
||||||
|
inputs: vec![nodes.len()],
|
||||||
|
output: nodes.len() + 1,
|
||||||
|
});
|
||||||
|
|
||||||
|
Self::from_sequential(&seq.child_network, nodes, output_shape)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Data, Layer, ChildNetwork> From<NeuraSequential<Layer, ChildNetwork>> for NeuraGraph<Data>
|
||||||
|
where
|
||||||
|
NeuraGraph<Data>: FromSequential<NeuraSequential<Layer, ChildNetwork>, Data>,
|
||||||
|
NeuraSequential<Layer, ChildNetwork>: NeuraShapedLayer,
|
||||||
|
{
|
||||||
|
fn from(network: NeuraSequential<Layer, ChildNetwork>) -> Self {
|
||||||
|
let output_shape = network.output_shape();
|
||||||
|
Self::from_sequential(&network, vec![], output_shape)
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,196 @@
|
|||||||
|
use crate::err::NeuraGraphErr;
|
||||||
|
use std::collections::{HashMap, HashSet, VecDeque};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
pub struct NeuraGraphPartial<Data> {
|
||||||
|
pub nodes: Vec<Box<dyn NeuraGraphNodePartial<Data>>>,
|
||||||
|
pub output: String,
|
||||||
|
pub input: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, Hash)]
|
||||||
|
pub(crate) enum GraphIndex {
|
||||||
|
Input,
|
||||||
|
Node(usize),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Data> NeuraGraphPartial<Data> {
|
||||||
|
pub(crate) fn get_index_map(&self) -> Result<HashMap<String, GraphIndex>, NeuraGraphErr> {
|
||||||
|
let mut result = HashMap::with_capacity(self.nodes.len());
|
||||||
|
|
||||||
|
result.insert(self.input.clone(), GraphIndex::Input);
|
||||||
|
|
||||||
|
for (index, node) in self.nodes.iter().enumerate() {
|
||||||
|
if result.contains_key(node.name()) {
|
||||||
|
return Err(NeuraGraphErr::InvalidName(node.name().to_string()));
|
||||||
|
}
|
||||||
|
result.insert(node.name().to_string(), GraphIndex::Node(index));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn get_reverse_graph(
|
||||||
|
&self,
|
||||||
|
index_map: &HashMap<String, GraphIndex>,
|
||||||
|
) -> Result<HashMap<GraphIndex, HashSet<GraphIndex>>, NeuraGraphErr> {
|
||||||
|
let mut result = HashMap::new();
|
||||||
|
|
||||||
|
result.insert(GraphIndex::Input, HashSet::new());
|
||||||
|
|
||||||
|
for i in 0..self.nodes.len() {
|
||||||
|
result.insert(GraphIndex::Node(i), HashSet::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
for (index, node) in self.nodes.iter().enumerate() {
|
||||||
|
for input in node.inputs() {
|
||||||
|
let input_index = index_map
|
||||||
|
.get(input)
|
||||||
|
.copied()
|
||||||
|
.ok_or_else(|| NeuraGraphErr::MissingNode(input.clone()))?;
|
||||||
|
result
|
||||||
|
.get_mut(&input_index)
|
||||||
|
.expect("index_map returned invalid values")
|
||||||
|
.insert(GraphIndex::Node(index));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn get_node_order(
|
||||||
|
&self,
|
||||||
|
index_map: &HashMap<String, GraphIndex>,
|
||||||
|
reverse_graph: &HashMap<GraphIndex, HashSet<GraphIndex>>,
|
||||||
|
) -> Result<Vec<usize>, NeuraGraphErr> {
|
||||||
|
let mut result: Vec<usize> = Vec::new();
|
||||||
|
let mut closed: HashSet<GraphIndex> = HashSet::with_capacity(self.nodes.len());
|
||||||
|
let mut open = VecDeque::with_capacity(self.nodes.len());
|
||||||
|
open.push_front(GraphIndex::Input);
|
||||||
|
|
||||||
|
/*
|
||||||
|
index_map.get(&self.output)
|
||||||
|
.copied()
|
||||||
|
.ok_or_else(|| NeuraGraphErr::MissingNode(self.output.clone()))?
|
||||||
|
*/
|
||||||
|
|
||||||
|
while let Some(current) = open.pop_back() {
|
||||||
|
if closed.contains(¤t) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
closed.insert(current);
|
||||||
|
// Do not put 0 (the input) in result
|
||||||
|
if let GraphIndex::Node(index) = current {
|
||||||
|
result.push(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("{:?}", current);
|
||||||
|
|
||||||
|
for next_node in reverse_graph[¤t].iter().copied() {
|
||||||
|
// Ignore nodes that are already in the closed set
|
||||||
|
if closed.contains(&next_node) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let GraphIndex::Node(node_index) = next_node else {
|
||||||
|
panic!("Unreachable: cannot have GraphIndex::Input as the output of a node");
|
||||||
|
};
|
||||||
|
|
||||||
|
let inputs = self.nodes[node_index].inputs();
|
||||||
|
|
||||||
|
// Only consider nodes whose inputs are in the closed set (meaning they would be ready to be evaluated)
|
||||||
|
if !inputs
|
||||||
|
.iter()
|
||||||
|
.all(|input| closed.contains(&index_map[input]))
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
open.push_front(next_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.len() != self.nodes.len() {
|
||||||
|
// TODO: verify that if result.len() != self.nodes.len(), then there is a cyclic subgraph
|
||||||
|
|
||||||
|
return Err(NeuraGraphErr::Cyclic);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Data> NeuraPartialLayer for NeuraGraphPartial<Data> {
|
||||||
|
type Constructed = NeuraGraph<Data>;
|
||||||
|
|
||||||
|
type Err = NeuraGraphErr;
|
||||||
|
|
||||||
|
fn construct(self, input_shape: NeuraShape) -> Result<Self::Constructed, Self::Err> {
|
||||||
|
let index_map = self.get_index_map()?;
|
||||||
|
let reverse_graph = self.get_reverse_graph(&index_map)?;
|
||||||
|
|
||||||
|
// List out the nodes in their execution order
|
||||||
|
let node_order = self.get_node_order(&index_map, &reverse_graph)?;
|
||||||
|
let mut new_index_map: HashMap<String, usize> = HashMap::from_iter(
|
||||||
|
node_order
|
||||||
|
.iter()
|
||||||
|
.map(|&i| (self.nodes[i].name().to_string(), i)),
|
||||||
|
);
|
||||||
|
new_index_map.insert(self.input.clone(), 0);
|
||||||
|
|
||||||
|
// TODO: filter out the nodes that are not necessary for computing the result (BFS from the output node back to the inputs)
|
||||||
|
// A temporary solution can be to trim the graph
|
||||||
|
let output_index = new_index_map
|
||||||
|
.get(&self.output)
|
||||||
|
.copied()
|
||||||
|
.ok_or_else(|| NeuraGraphErr::MissingNode(self.output.clone()))?;
|
||||||
|
|
||||||
|
let mut nodes = Vec::with_capacity(self.nodes.len());
|
||||||
|
let mut shapes: Vec<Option<NeuraShape>> = vec![None; self.nodes.len() + 1];
|
||||||
|
shapes[0] = Some(input_shape);
|
||||||
|
|
||||||
|
for index in node_order.into_iter() {
|
||||||
|
let node = &*self.nodes[index];
|
||||||
|
let node_inputs = node.inputs();
|
||||||
|
let mut inputs = Vec::with_capacity(node_inputs.len());
|
||||||
|
let mut input_shapes = Vec::with_capacity(node_inputs.len());
|
||||||
|
|
||||||
|
for input in node_inputs {
|
||||||
|
let input_index = new_index_map.get(input).copied().expect(
|
||||||
|
"Unreachable: new_index_map should contain all nodes defined and all nodes should have existing nodes as input"
|
||||||
|
);
|
||||||
|
inputs.push(input_index);
|
||||||
|
input_shapes.push(shapes[input_index].expect(
|
||||||
|
"Unreachable: the order of execution should guarantee that all inputs have appeared before")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let (constructed, output_shape) = node
|
||||||
|
.construct(input_shapes)
|
||||||
|
.map_err(|e| NeuraGraphErr::LayerErr(e))?;
|
||||||
|
|
||||||
|
shapes[index] = Some(output_shape);
|
||||||
|
|
||||||
|
nodes.push(NeuraGraphNodeConstructed {
|
||||||
|
node: constructed,
|
||||||
|
inputs,
|
||||||
|
output: new_index_map
|
||||||
|
.get(node.name())
|
||||||
|
.copied()
|
||||||
|
.unwrap_or_else(|| unreachable!()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let output_shape = shapes[output_index].unwrap_or_else(|| unreachable!());
|
||||||
|
|
||||||
|
Ok(NeuraGraph {
|
||||||
|
nodes,
|
||||||
|
// input_shape,
|
||||||
|
output_shape,
|
||||||
|
output_index,
|
||||||
|
buffer_size: self.nodes.len() + 1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in new issue