parent
38bd61fed5
commit
2f9c334e62
@ -0,0 +1,131 @@
|
||||
use std::collections::{HashMap, HashSet, VecDeque};
|
||||
|
||||
use crate::prelude::*;
|
||||
use crate::{err::NeuraGraphErr, layer::NeuraShapedLayer};
|
||||
|
||||
mod node;
|
||||
pub use node::*;
|
||||
|
||||
pub struct NeuraGraphPartial<Data> {
|
||||
pub nodes: Vec<Box<dyn NeuraGraphNodeTrait<Data>>>,
|
||||
pub output: String,
|
||||
pub input: String,
|
||||
}
|
||||
|
||||
impl<Data> NeuraGraphPartial<Data> {
|
||||
fn get_index_map(&self) -> Result<HashMap<String, usize>, NeuraGraphErr> {
|
||||
let mut result = HashMap::with_capacity(self.nodes.len());
|
||||
|
||||
result.insert(self.input.clone(), 0);
|
||||
|
||||
for (index, node) in self.nodes.iter().enumerate() {
|
||||
if result.contains_key(node.name()) {
|
||||
return Err(NeuraGraphErr::InvalidName(node.name().to_string()));
|
||||
}
|
||||
result.insert(node.name().to_string(), index + 1);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn get_reverse_graph(
|
||||
&self,
|
||||
index_map: &HashMap<String, usize>,
|
||||
) -> Result<Vec<HashSet<usize>>, NeuraGraphErr> {
|
||||
let mut result = vec![HashSet::new(); self.nodes.len()];
|
||||
|
||||
for (index, node) in self.nodes.iter().enumerate() {
|
||||
for input in node.inputs() {
|
||||
let input_index = index_map
|
||||
.get(input)
|
||||
.copied()
|
||||
.ok_or_else(|| NeuraGraphErr::MissingNode(input.clone()))?;
|
||||
result[input_index].insert(index + 1);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn get_node_order(
|
||||
&self,
|
||||
index_map: &HashMap<String, usize>,
|
||||
reverse_graph: &Vec<HashSet<usize>>,
|
||||
) -> Result<Vec<usize>, NeuraGraphErr> {
|
||||
let mut result: Vec<usize> = Vec::new();
|
||||
let mut closed: HashSet<usize> = HashSet::with_capacity(self.nodes.len());
|
||||
let mut open = VecDeque::with_capacity(self.nodes.len());
|
||||
open.push_front(0usize);
|
||||
|
||||
/*
|
||||
index_map.get(&self.output)
|
||||
.copied()
|
||||
.ok_or_else(|| NeuraGraphErr::MissingNode(self.output.clone()))?
|
||||
*/
|
||||
|
||||
while let Some(current) = open.pop_back() {
|
||||
if closed.contains(¤t) {
|
||||
continue;
|
||||
}
|
||||
|
||||
closed.insert(current);
|
||||
result.push(current);
|
||||
|
||||
for output_index in reverse_graph[current].iter().copied() {
|
||||
assert!(output_index > 0);
|
||||
|
||||
// Ignore nodes that are already in the closed set
|
||||
if closed.contains(&output_index) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let inputs = self.nodes[output_index - 1].inputs();
|
||||
|
||||
// Only consider nodes whose inputs are in the closed set
|
||||
if !inputs
|
||||
.iter()
|
||||
.all(|input| closed.contains(&index_map[input]))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
struct NeuraGraphNodeConstructed<Data> {
|
||||
node: Box<dyn NeuraGraphNodeTrait<Data>>,
|
||||
inputs: Vec<usize>,
|
||||
output: usize,
|
||||
}
|
||||
|
||||
pub struct NeuraGraph<Data> {
|
||||
/// ## Class invariants
|
||||
///
|
||||
/// - The order of nodes should match with the order of execution, ie.
|
||||
/// `forall (x, y), nodes = [..., x, ..., y, ...] => !(y in x.node.inputs)`
|
||||
/// - `nodes[0].inputs = [0]`
|
||||
/// - `nodes[nodes.len() - 1].output = buffer.len() - 1`
|
||||
nodes: Vec<NeuraGraphNodeConstructed<Data>>,
|
||||
|
||||
input_shape: NeuraShape,
|
||||
output_shape: NeuraShape,
|
||||
}
|
||||
|
||||
impl<Data> NeuraShapedLayer for NeuraGraph<Data> {
|
||||
fn output_shape(&self) -> NeuraShape {
|
||||
self.output_shape
|
||||
}
|
||||
}
|
||||
|
||||
impl<Data> NeuraPartialLayer for NeuraGraphPartial<Data> {
|
||||
type Constructed = NeuraGraph<Data>;
|
||||
|
||||
type Err = NeuraGraphErr;
|
||||
|
||||
fn construct(self, input_shape: NeuraShape) -> Result<Self::Constructed, Self::Err> {
|
||||
todo!()
|
||||
}
|
||||
}
|
@ -0,0 +1,56 @@
|
||||
use crate::{layer::NeuraLayer, network::residual::NeuraSplitInputs};
|
||||
|
||||
pub trait NeuraGraphNodeTrait<Data> {
|
||||
fn eval<'a>(&'a self, inputs: &[Data]) -> Data;
|
||||
|
||||
fn inputs<'a>(&'a self) -> &'a [String];
|
||||
fn name<'a>(&'a self) -> &'a str;
|
||||
}
|
||||
|
||||
pub struct NeuraGraphNode<Axis, Layer> {
|
||||
inputs: Vec<String>,
|
||||
axis: Axis,
|
||||
layer: Layer,
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl<Axis, Layer> NeuraGraphNode<Axis, Layer> {
|
||||
pub fn new(inputs: Vec<String>, axis: Axis, layer: Layer, name: String) -> Self {
|
||||
// Check that `name not in inputs` ?
|
||||
Self {
|
||||
inputs,
|
||||
axis,
|
||||
layer,
|
||||
name,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_boxed<Data: Clone>(self) -> Box<dyn NeuraGraphNodeTrait<Data>>
|
||||
where
|
||||
Axis: NeuraSplitInputs<Data> + 'static,
|
||||
Layer: NeuraLayer<Axis::Combined, Output = Data> + 'static,
|
||||
{
|
||||
Box::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
Data: Clone,
|
||||
Axis: NeuraSplitInputs<Data>,
|
||||
Layer: NeuraLayer<Axis::Combined, Output = Data>,
|
||||
> NeuraGraphNodeTrait<Data> for NeuraGraphNode<Axis, Layer>
|
||||
{
|
||||
fn eval<'a>(&'a self, inputs: &[Data]) -> Data {
|
||||
// TODO: use to_vec_in?
|
||||
let combined = self.axis.combine(inputs.to_vec());
|
||||
self.layer.eval(&combined)
|
||||
}
|
||||
|
||||
fn inputs<'a>(&'a self) -> &'a [String] {
|
||||
&self.inputs
|
||||
}
|
||||
|
||||
fn name<'a>(&'a self) -> &'a str {
|
||||
&self.name
|
||||
}
|
||||
}
|
Loading…
Reference in new issue