🔥 WIP implementation of arbitrary neural network ADG

main
Shad Amethyst 2 years ago
parent 38bd61fed5
commit 2f9c334e62

@ -91,3 +91,9 @@ pub struct NeuraDimensionsMismatch {
pub existing: usize,
pub new: NeuraShape,
}
#[derive(Clone, Debug)]
pub enum NeuraGraphErr {
MissingNode(String),
InvalidName(String),
}

@ -0,0 +1,131 @@
use std::collections::{HashMap, HashSet, VecDeque};
use crate::prelude::*;
use crate::{err::NeuraGraphErr, layer::NeuraShapedLayer};
mod node;
pub use node::*;
pub struct NeuraGraphPartial<Data> {
pub nodes: Vec<Box<dyn NeuraGraphNodeTrait<Data>>>,
pub output: String,
pub input: String,
}
impl<Data> NeuraGraphPartial<Data> {
fn get_index_map(&self) -> Result<HashMap<String, usize>, NeuraGraphErr> {
let mut result = HashMap::with_capacity(self.nodes.len());
result.insert(self.input.clone(), 0);
for (index, node) in self.nodes.iter().enumerate() {
if result.contains_key(node.name()) {
return Err(NeuraGraphErr::InvalidName(node.name().to_string()));
}
result.insert(node.name().to_string(), index + 1);
}
Ok(result)
}
fn get_reverse_graph(
&self,
index_map: &HashMap<String, usize>,
) -> Result<Vec<HashSet<usize>>, NeuraGraphErr> {
let mut result = vec![HashSet::new(); self.nodes.len()];
for (index, node) in self.nodes.iter().enumerate() {
for input in node.inputs() {
let input_index = index_map
.get(input)
.copied()
.ok_or_else(|| NeuraGraphErr::MissingNode(input.clone()))?;
result[input_index].insert(index + 1);
}
}
Ok(result)
}
fn get_node_order(
&self,
index_map: &HashMap<String, usize>,
reverse_graph: &Vec<HashSet<usize>>,
) -> Result<Vec<usize>, NeuraGraphErr> {
let mut result: Vec<usize> = Vec::new();
let mut closed: HashSet<usize> = HashSet::with_capacity(self.nodes.len());
let mut open = VecDeque::with_capacity(self.nodes.len());
open.push_front(0usize);
/*
index_map.get(&self.output)
.copied()
.ok_or_else(|| NeuraGraphErr::MissingNode(self.output.clone()))?
*/
while let Some(current) = open.pop_back() {
if closed.contains(&current) {
continue;
}
closed.insert(current);
result.push(current);
for output_index in reverse_graph[current].iter().copied() {
assert!(output_index > 0);
// Ignore nodes that are already in the closed set
if closed.contains(&output_index) {
continue;
}
let inputs = self.nodes[output_index - 1].inputs();
// Only consider nodes whose inputs are in the closed set
if !inputs
.iter()
.all(|input| closed.contains(&index_map[input]))
{
continue;
}
}
}
Ok(result)
}
}
struct NeuraGraphNodeConstructed<Data> {
node: Box<dyn NeuraGraphNodeTrait<Data>>,
inputs: Vec<usize>,
output: usize,
}
pub struct NeuraGraph<Data> {
/// ## Class invariants
///
/// - The order of nodes should match with the order of execution, ie.
/// `forall (x, y), nodes = [..., x, ..., y, ...] => !(y in x.node.inputs)`
/// - `nodes[0].inputs = [0]`
/// - `nodes[nodes.len() - 1].output = buffer.len() - 1`
nodes: Vec<NeuraGraphNodeConstructed<Data>>,
input_shape: NeuraShape,
output_shape: NeuraShape,
}
impl<Data> NeuraShapedLayer for NeuraGraph<Data> {
fn output_shape(&self) -> NeuraShape {
self.output_shape
}
}
impl<Data> NeuraPartialLayer for NeuraGraphPartial<Data> {
type Constructed = NeuraGraph<Data>;
type Err = NeuraGraphErr;
fn construct(self, input_shape: NeuraShape) -> Result<Self::Constructed, Self::Err> {
todo!()
}
}

@ -0,0 +1,56 @@
use crate::{layer::NeuraLayer, network::residual::NeuraSplitInputs};
pub trait NeuraGraphNodeTrait<Data> {
fn eval<'a>(&'a self, inputs: &[Data]) -> Data;
fn inputs<'a>(&'a self) -> &'a [String];
fn name<'a>(&'a self) -> &'a str;
}
pub struct NeuraGraphNode<Axis, Layer> {
inputs: Vec<String>,
axis: Axis,
layer: Layer,
name: String,
}
impl<Axis, Layer> NeuraGraphNode<Axis, Layer> {
pub fn new(inputs: Vec<String>, axis: Axis, layer: Layer, name: String) -> Self {
// Check that `name not in inputs` ?
Self {
inputs,
axis,
layer,
name,
}
}
pub fn as_boxed<Data: Clone>(self) -> Box<dyn NeuraGraphNodeTrait<Data>>
where
Axis: NeuraSplitInputs<Data> + 'static,
Layer: NeuraLayer<Axis::Combined, Output = Data> + 'static,
{
Box::new(self)
}
}
impl<
Data: Clone,
Axis: NeuraSplitInputs<Data>,
Layer: NeuraLayer<Axis::Combined, Output = Data>,
> NeuraGraphNodeTrait<Data> for NeuraGraphNode<Axis, Layer>
{
fn eval<'a>(&'a self, inputs: &[Data]) -> Data {
// TODO: use to_vec_in?
let combined = self.axis.combine(inputs.to_vec());
self.layer.eval(&combined)
}
fn inputs<'a>(&'a self) -> &'a [String] {
&self.inputs
}
fn name<'a>(&'a self) -> &'a str {
&self.name
}
}

@ -1,3 +1,4 @@
pub mod graph;
pub mod residual;
pub mod sequential;

Loading…
Cancel
Save