|
|
@ -3,14 +3,12 @@ use std::{any::Any, fmt::Debug};
|
|
|
|
|
|
|
|
|
|
|
|
use crate::{
|
|
|
|
use crate::{
|
|
|
|
algebra::NeuraDynVectorSpace,
|
|
|
|
algebra::NeuraDynVectorSpace,
|
|
|
|
err::NeuraAxisErr,
|
|
|
|
axis::{NeuraAxis, NeuraAxisDefault},
|
|
|
|
network::residual::{NeuraAxisDefault, NeuraCombineInputs, NeuraSplitInputs},
|
|
|
|
|
|
|
|
prelude::{NeuraPartialLayer, NeuraShape},
|
|
|
|
prelude::{NeuraPartialLayer, NeuraShape},
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
use super::*;
|
|
|
|
use super::*;
|
|
|
|
|
|
|
|
|
|
|
|
// TODO: split into two traits
|
|
|
|
|
|
|
|
pub trait NeuraGraphNodePartial<Data>: DynClone + Debug {
|
|
|
|
pub trait NeuraGraphNodePartial<Data>: DynClone + Debug {
|
|
|
|
fn inputs<'a>(&'a self) -> &'a [String];
|
|
|
|
fn inputs<'a>(&'a self) -> &'a [String];
|
|
|
|
fn name<'a>(&'a self) -> &'a str;
|
|
|
|
fn name<'a>(&'a self) -> &'a str;
|
|
|
@ -73,19 +71,11 @@ impl<Axis, Layer> NeuraGraphNode<Axis, Layer> {
|
|
|
|
|
|
|
|
|
|
|
|
pub fn as_boxed<Data: Clone>(self) -> Box<dyn NeuraGraphNodePartial<Data>>
|
|
|
|
pub fn as_boxed<Data: Clone>(self) -> Box<dyn NeuraGraphNodePartial<Data>>
|
|
|
|
where
|
|
|
|
where
|
|
|
|
Axis: NeuraSplitInputs<Data>
|
|
|
|
Axis: NeuraAxis<Data>,
|
|
|
|
+ NeuraCombineInputs<NeuraShape, Combined = Result<NeuraShape, NeuraAxisErr>>
|
|
|
|
|
|
|
|
+ Clone
|
|
|
|
|
|
|
|
+ Debug
|
|
|
|
|
|
|
|
+ 'static,
|
|
|
|
|
|
|
|
Layer: NeuraPartialLayer + Clone + Debug + 'static,
|
|
|
|
Layer: NeuraPartialLayer + Clone + Debug + 'static,
|
|
|
|
Layer::Constructed:
|
|
|
|
Layer::Constructed: NeuraLayer<Axis::Combined, Output = Data>,
|
|
|
|
NeuraLayer<<Axis as NeuraCombineInputs<Data>>::Combined, Output = Data>,
|
|
|
|
|
|
|
|
Layer::Err: Debug,
|
|
|
|
Layer::Err: Debug,
|
|
|
|
<Layer::Constructed as NeuraLayer<
|
|
|
|
<Layer::Constructed as NeuraLayer<Axis::Combined>>::IntermediaryRepr: 'static,
|
|
|
|
<Axis as NeuraCombineInputs<Data>>::Combined,
|
|
|
|
|
|
|
|
>>::IntermediaryRepr: 'static,
|
|
|
|
|
|
|
|
<Axis as NeuraCombineInputs<Data>>::Combined: 'static,
|
|
|
|
|
|
|
|
{
|
|
|
|
{
|
|
|
|
Box::new(self)
|
|
|
|
Box::new(self)
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -95,9 +85,8 @@ impl<Axis, Layer> NeuraGraphNode<Axis, Layer> {
|
|
|
|
intermediary: &'a dyn Any,
|
|
|
|
intermediary: &'a dyn Any,
|
|
|
|
) -> &'a Intermediary<Axis::Combined, Layer>
|
|
|
|
) -> &'a Intermediary<Axis::Combined, Layer>
|
|
|
|
where
|
|
|
|
where
|
|
|
|
Axis: NeuraCombineInputs<Data>,
|
|
|
|
Axis: NeuraAxis<Data>,
|
|
|
|
Layer: NeuraLayer<Axis::Combined>,
|
|
|
|
Layer: NeuraLayer<Axis::Combined>,
|
|
|
|
Axis::Combined: 'static,
|
|
|
|
|
|
|
|
{
|
|
|
|
{
|
|
|
|
intermediary
|
|
|
|
intermediary
|
|
|
|
.downcast_ref::<Intermediary<Axis::Combined, Layer>>()
|
|
|
|
.downcast_ref::<Intermediary<Axis::Combined, Layer>>()
|
|
|
@ -113,23 +102,16 @@ where
|
|
|
|
layer_intermediary: Layer::IntermediaryRepr,
|
|
|
|
layer_intermediary: Layer::IntermediaryRepr,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
impl<
|
|
|
|
impl<Data: Clone, Axis: NeuraAxis<Data>, Layer: NeuraLayer<Axis::Combined, Output = Data>>
|
|
|
|
Data: Clone,
|
|
|
|
NeuraGraphNodeEval<Data> for NeuraGraphNode<Axis, Layer>
|
|
|
|
Axis: NeuraSplitInputs<Data> + Clone + Debug,
|
|
|
|
|
|
|
|
Layer: NeuraLayer<<Axis as NeuraCombineInputs<Data>>::Combined, Output = Data>,
|
|
|
|
|
|
|
|
> NeuraGraphNodeEval<Data> for NeuraGraphNode<Axis, Layer>
|
|
|
|
|
|
|
|
where
|
|
|
|
|
|
|
|
Layer::IntermediaryRepr: 'static,
|
|
|
|
|
|
|
|
Axis::Combined: 'static,
|
|
|
|
|
|
|
|
{
|
|
|
|
{
|
|
|
|
fn eval<'a>(&'a self, inputs: &[Data]) -> Data {
|
|
|
|
fn eval<'a>(&'a self, inputs: &[Data]) -> Data {
|
|
|
|
// TODO: use to_vec_in?
|
|
|
|
let combined = self.axis.combine(inputs);
|
|
|
|
let combined = self.axis.combine(inputs.to_vec());
|
|
|
|
|
|
|
|
self.layer.eval(&combined)
|
|
|
|
self.layer.eval(&combined)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
fn eval_training<'a>(&self, inputs: &[Data]) -> (Data, Box<dyn Any>) {
|
|
|
|
fn eval_training<'a>(&self, inputs: &[Data]) -> (Data, Box<dyn Any>) {
|
|
|
|
let combined = self.axis.combine(inputs.to_vec());
|
|
|
|
let combined = self.axis.combine(inputs);
|
|
|
|
let (result, layer_intermediary) = self.layer.eval_training(&combined);
|
|
|
|
let (result, layer_intermediary) = self.layer.eval_training(&combined);
|
|
|
|
|
|
|
|
|
|
|
|
let intermediary: Intermediary<Axis::Combined, Layer> = Intermediary {
|
|
|
|
let intermediary: Intermediary<Axis::Combined, Layer> = Intermediary {
|
|
|
@ -177,20 +159,11 @@ where
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
impl<
|
|
|
|
impl<Data: Clone, Axis: NeuraAxis<Data>, Layer: NeuraPartialLayer + Clone + Debug>
|
|
|
|
Data: Clone,
|
|
|
|
NeuraGraphNodePartial<Data> for NeuraGraphNode<Axis, Layer>
|
|
|
|
Axis: NeuraSplitInputs<Data>
|
|
|
|
|
|
|
|
+ NeuraCombineInputs<NeuraShape, Combined = Result<NeuraShape, NeuraAxisErr>>
|
|
|
|
|
|
|
|
+ Clone
|
|
|
|
|
|
|
|
+ Debug
|
|
|
|
|
|
|
|
+ 'static,
|
|
|
|
|
|
|
|
Layer: NeuraPartialLayer + Clone + Debug,
|
|
|
|
|
|
|
|
> NeuraGraphNodePartial<Data> for NeuraGraphNode<Axis, Layer>
|
|
|
|
|
|
|
|
where
|
|
|
|
where
|
|
|
|
Layer::Constructed: NeuraLayer<<Axis as NeuraCombineInputs<Data>>::Combined, Output = Data>,
|
|
|
|
Layer::Constructed: NeuraLayer<Axis::Combined, Output = Data>,
|
|
|
|
Layer::Err: Debug,
|
|
|
|
Layer::Err: Debug,
|
|
|
|
<Layer::Constructed as NeuraLayer<<Axis as NeuraCombineInputs<Data>>::Combined>>::IntermediaryRepr: 'static,
|
|
|
|
|
|
|
|
<Axis as NeuraCombineInputs<Data>>::Combined: 'static,
|
|
|
|
|
|
|
|
{
|
|
|
|
{
|
|
|
|
fn inputs<'a>(&'a self) -> &'a [String] {
|
|
|
|
fn inputs<'a>(&'a self) -> &'a [String] {
|
|
|
|
&self.inputs
|
|
|
|
&self.inputs
|
|
|
@ -206,7 +179,7 @@ where
|
|
|
|
) -> Result<(Box<dyn NeuraGraphNodeEval<Data>>, NeuraShape), String> {
|
|
|
|
) -> Result<(Box<dyn NeuraGraphNodeEval<Data>>, NeuraShape), String> {
|
|
|
|
let combined = self
|
|
|
|
let combined = self
|
|
|
|
.axis
|
|
|
|
.axis
|
|
|
|
.combine(input_shapes.clone())
|
|
|
|
.shape(&input_shapes)
|
|
|
|
.map_err(|err| format!("{:?}", err))?;
|
|
|
|
.map_err(|err| format!("{:?}", err))?;
|
|
|
|
|
|
|
|
|
|
|
|
let constructed_layer = self
|
|
|
|
let constructed_layer = self
|
|
|
@ -222,7 +195,7 @@ where
|
|
|
|
axis: self.axis.clone(),
|
|
|
|
axis: self.axis.clone(),
|
|
|
|
layer: constructed_layer,
|
|
|
|
layer: constructed_layer,
|
|
|
|
name: self.name.clone(),
|
|
|
|
name: self.name.clone(),
|
|
|
|
input_shapes: Some(input_shapes)
|
|
|
|
input_shapes: Some(input_shapes),
|
|
|
|
}),
|
|
|
|
}),
|
|
|
|
output_shape,
|
|
|
|
output_shape,
|
|
|
|
))
|
|
|
|
))
|
|
|
|