🎨 Clean things up, add unit tests, add one hot layer

main
Shad Amethyst 2 years ago
parent bca56a5557
commit b4a97694a6

@ -10,3 +10,7 @@ ndarray = "^0.15"
# num-traits = "0.2.15"
rand = "^0.8"
rand_distr = "0.4.3"
[dev-dependencies]
image = "0.24.6"
viuer = "0.6.2"

@ -2,85 +2,112 @@
use std::io::Write;
use neuramethyst::derivable::activation::Linear;
#[allow(unused_imports)]
use neuramethyst::derivable::activation::{LeakyRelu, Relu, Tanh};
use neuramethyst::derivable::loss::Euclidean;
use neuramethyst::derivable::regularize::NeuraElastic;
use neuramethyst::derivable::activation::{LeakyRelu, Linear, Relu, Tanh};
use neuramethyst::derivable::loss::CrossEntropy;
use neuramethyst::derivable::regularize::NeuraL1;
use neuramethyst::prelude::*;
use rand::Rng;
fn main() {
let mut network = neura_network![
neura_layer!("dense", 2, 8; LeakyRelu(0.01)),
neura_layer!("dropout", 0.1),
neura_layer!("dense", 8; LeakyRelu(0.01), NeuraElastic::new(0.0001, 0.002)),
neura_layer!("dropout", 0.3),
neura_layer!("dense", 8; LeakyRelu(0.01), NeuraElastic::new(0.0001, 0.002)),
neura_layer!("dropout", 0.1),
neura_layer!("dense", 4; LeakyRelu(0.1), NeuraElastic::new(0.0001, 0.002)),
neura_layer!("dense", 2; Linear),
neura_layer!("dense", 2, 8; Relu, NeuraL1(0.001)),
neura_layer!("dropout", 0.25),
neura_layer!("dense", 2; Linear, NeuraL1(0.001)),
neura_layer!("softmax"),
];
// println!("{:#?}", network);
let mut rng = rand::thread_rng();
let inputs = (0..=1).cycle().map(move |category| {
let inputs = (0..1).cycle().map(move |_| {
let mut rng = rand::thread_rng(); // TODO: move out
let category = rng.gen_bool(0.5) as usize;
let (x, y) = if category == 0 {
let radius: f64 = rng.gen_range(0.0..1.0);
let radius = radius.sqrt();
let radius: f64 = rng.gen_range(0.0..2.0);
let angle = rng.gen_range(0.0..std::f64::consts::TAU);
(angle.cos() * radius, angle.sin() * radius)
} else {
let radius: f64 = rng.gen_range(1.0..2.0);
let radius: f64 = rng.gen_range(3.0..5.0);
let angle = rng.gen_range(0.0..std::f64::consts::TAU);
(angle.cos() * radius, angle.sin() * radius)
};
([x, y], one_hot::<2>(category))
([x, y], neuramethyst::one_hot::<2>(category))
});
let test_inputs: Vec<_> = inputs.clone().take(100).collect();
let test_inputs: Vec<_> = inputs.clone().take(10).collect();
let mut trainer = NeuraBatchedTrainer::new(0.25, 1000);
trainer.log_epochs = 50;
trainer.learning_momentum = 0.05;
trainer.batch_size = 2000;
if std::env::args().any(|arg| arg == "draw") {
for epoch in 0..200 {
let mut trainer = NeuraBatchedTrainer::new(0.03, 10);
trainer.batch_size = 10;
trainer.train(
NeuraBackprop::new(Euclidean),
&mut network,
inputs,
&test_inputs,
);
trainer.train(
NeuraBackprop::new(CrossEntropy),
&mut network,
inputs.clone(),
&test_inputs,
);
let network = network.clone();
draw_neuron_activation(|input| network.eval(&input).into_iter().collect(), 6.0);
println!("{}", epoch);
std::thread::sleep(std::time::Duration::new(0, 50_000_000));
}
} else {
let mut trainer = NeuraBatchedTrainer::new(0.03, 20 * 50);
trainer.batch_size = 10;
trainer.log_iterations = 20;
trainer.train(
NeuraBackprop::new(CrossEntropy),
&mut network,
inputs.clone(),
&test_inputs,
);
// println!("{}", String::from("\n").repeat(64));
// draw_neuron_activation(|input| network.eval(&input).into_iter().collect(), 6.0);
}
let mut file = std::fs::File::create("target/bivariate.csv").unwrap();
for (input, _target) in test_inputs {
let guess = argmax(&network.eval(&input));
let guess = neuramethyst::argmax(&network.eval(&input));
writeln!(&mut file, "{},{},{}", input[0], input[1], guess).unwrap();
// println!("{:?}", network.eval(&input));
}
// println!("{:#?}", network);
}
fn one_hot<const N: usize>(value: usize) -> [f64; N] {
let mut res = [0.0; N];
if value < N {
res[value] = 1.0;
// TODO: move this to the library?
fn draw_neuron_activation<F: Fn([f64; 2]) -> Vec<f64>>(callback: F, scale: f64) {
use viuer::Config;
const WIDTH: u32 = 64;
const HEIGHT: u32 = 64;
let mut image = image::RgbImage::new(WIDTH, HEIGHT);
fn sigmoid(x: f64) -> f64 {
1.0 / (1.0 + (-x * 3.0).exp())
}
res
}
fn argmax(array: &[f64]) -> usize {
let mut res = 0;
for y in 0..HEIGHT {
let y2 = 2.0 * y as f64 / HEIGHT as f64 - 1.0;
for x in 0..WIDTH {
let x2 = 2.0 * x as f64 / WIDTH as f64 - 1.0;
let activation = callback([x2 * scale, y2 * scale]);
let r = (sigmoid(activation.get(0).copied().unwrap_or(-1.0)) * 255.0).floor() as u8;
let g = (sigmoid(activation.get(1).copied().unwrap_or(-1.0)) * 255.0).floor() as u8;
let b = (sigmoid(activation.get(2).copied().unwrap_or(-1.0)) * 255.0).floor() as u8;
for n in 1..array.len() {
if array[n] > array[res] {
res = n;
*image.get_pixel_mut(x, y) = image::Rgb([r, g, b]);
}
}
res
let config = Config {
use_kitty: false,
// absolute_offset: false,
..Default::default()
};
viuer::print(&image::DynamicImage::ImageRgb8(image), &config).unwrap();
}

@ -29,7 +29,7 @@ fn main() {
let mut trainer = NeuraBatchedTrainer::new(0.05, 1000);
trainer.batch_size = 6;
trainer.log_epochs = 250;
trainer.log_iterations = 250;
trainer.learning_momentum = 0.01;
trainer.train(

@ -5,6 +5,8 @@ pub trait NeuraVectorSpace {
fn mul_assign(&mut self, by: f64);
fn zero() -> Self;
fn norm_squared(&self) -> f64;
}
impl NeuraVectorSpace for () {
@ -22,6 +24,10 @@ impl NeuraVectorSpace for () {
fn zero() -> Self {
()
}
fn norm_squared(&self) -> f64 {
0.0
}
}
impl<Left: NeuraVectorSpace, Right: NeuraVectorSpace> NeuraVectorSpace for (Left, Right) {
@ -38,6 +44,10 @@ impl<Left: NeuraVectorSpace, Right: NeuraVectorSpace> NeuraVectorSpace for (Left
fn zero() -> Self {
(Left::zero(), Right::zero())
}
fn norm_squared(&self) -> f64 {
self.0.norm_squared() + self.1.norm_squared()
}
}
impl<const N: usize, T: NeuraVectorSpace + Clone> NeuraVectorSpace for [T; N] {
@ -65,6 +75,10 @@ impl<const N: usize, T: NeuraVectorSpace + Clone> NeuraVectorSpace for [T; N] {
unreachable!()
})
}
fn norm_squared(&self) -> f64 {
self.iter().map(T::norm_squared).sum()
}
}
macro_rules! base {
@ -81,6 +95,10 @@ macro_rules! base {
fn zero() -> Self {
<Self as Default>::default()
}
fn norm_squared(&self) -> f64 {
(self * self) as f64
}
}
};
}

@ -1,135 +1,108 @@
use super::NeuraDerivable;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Relu;
#![allow(unused_variables)]
impl NeuraDerivable<f64> for Relu {
#[inline(always)]
fn eval(&self, input: f64) -> f64 {
input.max(0.0)
}
use super::NeuraDerivable;
#[inline(always)]
fn derivate(&self, input: f64) -> f64 {
if input > 0.0 {
1.0
} else {
0.0
macro_rules! impl_derivable {
( $type_f32:ty, $type_f64:ty, $self:ident, $variable:ident, $eval:expr, $derivate:expr $(; $variance_hint:expr, $bias_hint:expr )? ) => {
impl NeuraDerivable<f32> for $type_f32 {
#[inline(always)]
fn eval($self: &Self, $variable: f32) -> f32 {
$eval
}
#[inline(always)]
fn derivate($self: &Self, $variable: f32) -> f32 {
$derivate
}
$(
#[inline(always)]
fn variance_hint($self: &Self) -> f64 {
$variance_hint
}
#[inline(always)]
fn bias_hint($self: &Self) -> f64 {
$bias_hint
}
)?
}
}
}
impl NeuraDerivable<f32> for Relu {
#[inline(always)]
fn eval(&self, input: f32) -> f32 {
input.max(0.0)
}
#[inline(always)]
fn derivate(&self, input: f32) -> f32 {
if input > 0.0 {
1.0
} else {
0.0
impl NeuraDerivable<f64> for $type_f64 {
#[inline(always)]
fn eval($self: &Self, $variable: f64) -> f64 {
$eval
}
#[inline(always)]
fn derivate($self: &Self, $variable: f64) -> f64 {
$derivate
}
$(
#[inline(always)]
fn variance_hint($self: &Self) -> f64 {
$variance_hint
}
#[inline(always)]
fn bias_hint($self: &Self) -> f64 {
$bias_hint
}
)?
}
}
};
( $type:ty, $variable:ident, $eval:expr, $derivate:expr $(; $variance_hint:expr, $bias_hint:expr )? ) => {
impl_derivable!($type, $type, self, $variable, $eval, $derivate $(; $variance_hint, $bias_hint)?);
};
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct LeakyRelu<F>(pub F);
pub struct Relu;
impl NeuraDerivable<f64> for LeakyRelu<f64> {
#[inline(always)]
fn eval(&self, input: f64) -> f64 {
if input > 0.0 {
input
} else {
self.0 * input
}
impl_derivable!(Relu, x, x.max(0.0), {
if x > 0.0 {
1.0
} else {
0.0
}
}; 2.0, 0.1);
#[inline(always)]
fn derivate(&self, input: f64) -> f64 {
if input > 0.0 {
1.0
} else {
self.0
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct LeakyRelu<F>(pub F);
impl NeuraDerivable<f32> for LeakyRelu<f32> {
#[inline(always)]
fn eval(&self, input: f32) -> f32 {
if input > 0.0 {
input
impl_derivable!(
LeakyRelu<f32>,
LeakyRelu<f64>,
self,
x,
{
if x > 0.0 {
x
} else {
self.0 * input
self.0 * x
}
}
#[inline(always)]
fn derivate(&self, input: f32) -> f32 {
if input > 0.0 {
},
{
if x > 0.0 {
1.0
} else {
self.0
}
}
}
};
2.0, 0.1
);
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Tanh;
impl NeuraDerivable<f64> for Tanh {
#[inline(always)]
fn eval(&self, input: f64) -> f64 {
0.5 * input.tanh() + 0.5
}
#[inline(always)]
fn derivate(&self, at: f64) -> f64 {
let tanh = at.tanh();
0.5 * (1.0 - tanh * tanh)
}
}
impl NeuraDerivable<f32> for Tanh {
#[inline(always)]
fn eval(&self, input: f32) -> f32 {
0.5 * input.tanh() + 0.5
}
#[inline(always)]
fn derivate(&self, at: f32) -> f32 {
let tanh = at.tanh();
0.5 * (1.0 - tanh * tanh)
}
}
impl_derivable!(Tanh, x, x.tanh(), {
let y = x.tanh();
1.0 - y * y
});
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Linear;
impl NeuraDerivable<f64> for Linear {
#[inline(always)]
fn eval(&self, input: f64) -> f64 {
input
}
#[inline(always)]
fn derivate(&self, _at: f64) -> f64 {
1.0
}
}
impl NeuraDerivable<f32> for Linear {
#[inline(always)]
fn eval(&self, input: f32) -> f32 {
input
}
#[inline(always)]
fn derivate(&self, _at: f32) -> f32 {
1.0
}
}
impl_derivable!(Linear, x, x, 1.0);

@ -30,3 +30,56 @@ impl<const N: usize> NeuraLoss for Euclidean<N> {
res
}
}
/// The cross-entropy loss function, defined as `L(y, ŷ) = -Σᵢ(yᵢ*ln(ŷᵢ))`.
///
/// This version of the cross-entropy function does not make assumptions about the target vector being one-hot encoded.
///
/// This function requires that `ŷ` (the output of the neural network) is in `[0; 1]^n`.
/// This guarantee is notably not given by the `Relu`, `LeakyRelu` and `Swish` activation functions,
/// so you should pick another activation on the last layer, or pass it into a `NeuraSoftmax` layer.
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct CrossEntropy<const N: usize>;
const DERIVATIVE_CAP: f64 = 100.0;
const LOG_MIN: f64 = 0.00001;
impl<const N: usize> CrossEntropy<N> {
#[inline(always)]
pub fn eval_single(&self, target: f64, actual: f64) -> f64 {
-target * actual.max(LOG_MIN).log(std::f64::consts::E)
}
#[inline(always)]
pub fn derivate_single(&self, target: f64, actual: f64) -> f64 {
-(target / actual).min(DERIVATIVE_CAP)
}
}
impl<const N: usize> NeuraLoss for CrossEntropy<N> {
type Input = [f64; N];
type Target = [f64; N];
fn eval(&self, target: &Self::Target, actual: &Self::Input) -> f64 {
let mut result = 0.0;
for i in 0..N {
result += self.eval_single(target[i], actual[i]);
}
result
}
fn nabla(&self, target: &Self::Target, actual: &Self::Input) -> Self::Input {
let mut result = [0.0; N];
for i in 0..N {
result[i] = self.derivate_single(target[i], actual[i]);
}
result
}
}
// TODO: a one-hot encoded, CrossEntropy + Softmax loss function?
// It would be a lot more efficient than the current method

@ -7,6 +7,18 @@ pub trait NeuraDerivable<F> {
/// Should return the derivative of `self.eval(input)`
fn derivate(&self, at: F) -> F;
/// Should return a hint for how much the variance for a random initialization should be
#[inline(always)]
fn variance_hint(&self) -> f64 {
1.0
}
/// Should return a hint for what the default bias value should be
#[inline(always)]
fn bias_hint(&self) -> f64 {
0.0
}
}
pub trait NeuraLoss {

@ -73,7 +73,7 @@ pub struct NeuraL2<F>(pub F);
impl NeuraDerivable<f64> for NeuraL2<f64> {
#[inline(always)]
fn eval(&self, input: f64) -> f64 {
self.0 * (input * input)
(0.5 * self.0) * (input * input)
}
#[inline(always)]
@ -85,7 +85,7 @@ impl NeuraDerivable<f64> for NeuraL2<f64> {
impl NeuraDerivable<f32> for NeuraL2<f32> {
#[inline(always)]
fn eval(&self, input: f32) -> f32 {
self.0 * (input * input)
(0.5 * self.0) * (input * input)
}
#[inline(always)]

@ -46,8 +46,14 @@ impl<
pub fn from_rng(rng: &mut impl Rng, activation: Act, regularization: Reg) -> Self {
let mut weights = [[0.0; INPUT_LEN]; OUTPUT_LEN];
let distribution =
rand_distr::Normal::new(0.0, 2.0 / (INPUT_LEN as f64 + OUTPUT_LEN as f64)).unwrap();
// Use Xavier (or He) initialisation, using the harmonic mean
// Ref: https://www.deeplearning.ai/ai-notes/initialization/index.html
let distribution = rand_distr::Normal::new(
0.0,
activation.variance_hint() * 2.0 / (INPUT_LEN as f64 + OUTPUT_LEN as f64),
)
.unwrap();
// let distribution = rand_distr::Uniform::new(-0.5, 0.5);
for i in 0..OUTPUT_LEN {
for j in 0..INPUT_LEN {
@ -57,8 +63,8 @@ impl<
Self {
weights,
// Biases are zero-initialized, as this shouldn't cause any issues during training
bias: [0.0; OUTPUT_LEN],
// Biases are initialized based on the activation's hint
bias: [activation.bias_hint(); OUTPUT_LEN],
activation,
regularization,
}
@ -96,20 +102,22 @@ impl<
{
type Delta = ([[f64; INPUT_LEN]; OUTPUT_LEN], [f64; OUTPUT_LEN]);
// TODO: double-check the math in this
fn backpropagate(
&self,
input: &Self::Input,
epsilon: Self::Output,
) -> (Self::Input, Self::Delta) {
let evaluated = multiply_matrix_vector(&self.weights, input);
// Compute delta from epsilon, with `self.activation'(input) ° epsilon = delta`
// Compute delta (the input gradient of the neuron) from epsilon (the output gradient of the neuron),
// with `self.activation'(input) ° epsilon = delta`
let mut delta = epsilon.clone();
for i in 0..OUTPUT_LEN {
delta[i] *= self.activation.derivate(evaluated[i]);
}
// Compute the weight gradient
let weights_gradient = reverse_dot_product(&delta, input);
// According to https://datascience.stackexchange.com/questions/20139/gradients-for-bias-terms-in-backpropagation
// The gradient of the bias is equal to the delta term of the backpropagation algorithm
let bias_gradient = delta;

@ -7,6 +7,9 @@ pub use dropout::NeuraDropoutLayer;
mod softmax;
pub use softmax::NeuraSoftmaxLayer;
mod one_hot;
pub use one_hot::NeuraOneHotLayer;
pub trait NeuraLayer {
type Input;
type Output;
@ -46,4 +49,8 @@ macro_rules! neura_layer {
( "softmax", $length:expr ) => {
$crate::layer::NeuraSoftmaxLayer::new() as $crate::layer::NeuraSoftmaxLayer<$length>
};
( "one_hot" ) => {
$crate::layer::NeuraOneHotLayer as $crate::layer::NeuraOneHotLayer<2, _>
};
}

@ -0,0 +1,61 @@
use crate::train::NeuraTrainableLayer;
use super::NeuraLayer;
/// A special layer that allows you to split a vector into one-hot vectors
#[derive(Debug, Clone, PartialEq)]
pub struct NeuraOneHotLayer<const CATS: usize, const LENGTH: usize>;
impl<const CATS: usize, const LENGTH: usize> NeuraLayer for NeuraOneHotLayer<CATS, LENGTH>
where
[(); LENGTH * CATS]: Sized,
{
type Input = [f64; LENGTH];
type Output = [f64; LENGTH * CATS];
fn eval(&self, input: &Self::Input) -> Self::Output {
let mut res = [0.0; LENGTH * CATS];
for i in 0..LENGTH {
let cat_low = input[i].floor().max(0.0).min(CATS as f64 - 2.0);
let amount = (input[i] - cat_low).max(0.0).min(1.0);
let cat_low = cat_low as usize;
res[i * LENGTH + cat_low] = 1.0 - amount;
res[i * LENGTH + cat_low + 1] = amount;
}
res
}
}
impl<const CATS: usize, const LENGTH: usize> NeuraTrainableLayer for NeuraOneHotLayer<CATS, LENGTH>
where
[(); LENGTH * CATS]: Sized,
{
type Delta = ();
fn backpropagate(
&self,
input: &Self::Input,
epsilon: Self::Output,
) -> (Self::Input, Self::Delta) {
let mut res = [0.0; LENGTH];
for i in 0..LENGTH {
let cat_low = input[i].floor().max(0.0).min(CATS as f64 - 2.0) as usize;
let epsilon = -epsilon[i * LENGTH + cat_low] + epsilon[i * LENGTH + cat_low + 1];
// Scale epsilon by how many entries were ignored
res[i] = epsilon * CATS as f64 / 2.0;
}
(res, ())
}
fn regularize(&self) -> Self::Delta {
()
}
fn apply_gradient(&mut self, _gradient: &Self::Delta) {
// Noop
}
}

@ -1,4 +1,6 @@
#![feature(generic_arg_infer)]
#![feature(generic_associated_types)]
#![feature(generic_const_exprs)]
pub mod algebra;
pub mod derivable;
@ -8,13 +10,16 @@ pub mod train;
mod utils;
// TODO: move to a different file
pub use utils::{argmax, one_hot};
pub mod prelude {
// Macros
pub use crate::{neura_layer, neura_network};
// Structs and traits
pub use crate::layer::{NeuraDenseLayer, NeuraDropoutLayer, NeuraLayer};
pub use crate::network::NeuraNetwork;
pub use crate::network::{NeuraNetwork, NeuraNetworkTail};
pub use crate::train::{NeuraBackprop, NeuraBatchedTrainer};
pub use crate::utils::cycle_shuffling;
}

@ -6,8 +6,8 @@ use crate::{
#[derive(Clone, Debug)]
pub struct NeuraNetwork<Layer: NeuraLayer, ChildNetwork> {
layer: Layer,
child_network: ChildNetwork,
pub layer: Layer,
pub child_network: ChildNetwork,
}
impl<Layer: NeuraLayer, ChildNetwork> NeuraNetwork<Layer, ChildNetwork> {
@ -25,20 +25,66 @@ impl<Layer: NeuraLayer, ChildNetwork> NeuraNetwork<Layer, ChildNetwork> {
Self::new(layer, child_network)
}
pub fn child_network(&self) -> &ChildNetwork {
&self.child_network
pub fn trim_front(self) -> ChildNetwork {
self.child_network
}
pub fn layer(&self) -> &Layer {
&self.layer
pub fn push_front<T: NeuraLayer>(self, layer: T) -> NeuraNetwork<T, Self> {
NeuraNetwork {
layer: layer,
child_network: self,
}
}
}
impl<Layer: NeuraLayer> From<Layer> for NeuraNetwork<Layer, ()> {
fn from(layer: Layer) -> Self {
Self {
layer,
child_network: (),
/// Operations on the tail end of the network
pub trait NeuraNetworkTail {
type TailTrimmed;
type TailPushed<T: NeuraLayer>;
fn trim_tail(self) -> Self::TailTrimmed;
fn push_tail<T: NeuraLayer>(self, layer: T) -> Self::TailPushed<T>;
}
// Trimming the last layer returns an empty network
impl<Layer: NeuraLayer> NeuraNetworkTail for NeuraNetwork<Layer, ()> {
type TailTrimmed = ();
type TailPushed<T: NeuraLayer> = NeuraNetwork<Layer, NeuraNetwork<T, ()>>;
fn trim_tail(self) -> Self::TailTrimmed {
()
}
fn push_tail<T: NeuraLayer>(self, layer: T) -> Self::TailPushed<T> {
NeuraNetwork {
layer: self.layer,
child_network: NeuraNetwork {
layer,
child_network: (),
},
}
}
}
// Trimming another layer returns a network which calls trim recursively
impl<Layer: NeuraLayer, ChildNetwork: NeuraNetworkTail> NeuraNetworkTail
for NeuraNetwork<Layer, ChildNetwork>
{
type TailTrimmed = NeuraNetwork<Layer, <ChildNetwork as NeuraNetworkTail>::TailTrimmed>;
type TailPushed<T: NeuraLayer> =
NeuraNetwork<Layer, <ChildNetwork as NeuraNetworkTail>::TailPushed<T>>;
fn trim_tail(self) -> Self::TailTrimmed {
NeuraNetwork {
layer: self.layer,
child_network: self.child_network.trim_tail(),
}
}
fn push_tail<T: NeuraLayer>(self, layer: T) -> Self::TailPushed<T> {
NeuraNetwork {
layer: self.layer,
child_network: self.child_network.push_tail(layer),
}
}
}
@ -136,6 +182,15 @@ impl<Layer: NeuraTrainableLayer, ChildNetwork: NeuraTrainable<Input = Layer::Out
}
}
impl<Layer: NeuraLayer> From<Layer> for NeuraNetwork<Layer, ()> {
fn from(layer: Layer) -> Self {
Self {
layer,
child_network: (),
}
}
}
#[macro_export]
macro_rules! neura_network {
[] => {
@ -143,11 +198,11 @@ macro_rules! neura_network {
};
[ $layer:expr $(,)? ] => {
NeuraNetwork::from($layer)
$crate::network::NeuraNetwork::from($layer)
};
[ $first:expr, $($rest:expr),+ $(,)? ] => {
NeuraNetwork::new_match_output($first, neura_network![$($rest),+])
$crate::network::NeuraNetwork::new_match_output($first, neura_network![$($rest),+])
};
}
@ -159,8 +214,6 @@ mod test {
neura_layer,
};
use super::*;
#[test]
fn test_neura_network_macro() {
let mut rng = rand::thread_rng();

@ -2,7 +2,7 @@ use crate::{
algebra::NeuraVectorSpace, derivable::NeuraLoss, layer::NeuraLayer, network::NeuraNetwork,
};
// TODO: move this to layer/mod.rs
// TODO: move this trait to layer/mod.rs
pub trait NeuraTrainableLayer: NeuraLayer {
type Delta: NeuraVectorSpace;
@ -29,7 +29,7 @@ pub trait NeuraTrainableLayer: NeuraLayer {
/// Applies `δW_l` to the weights of the layer
fn apply_gradient(&mut self, gradient: &Self::Delta);
/// Called before an epoch begins, to allow the layer to set itself up for training.
/// Called before an iteration begins, to allow the layer to set itself up for training.
#[inline(always)]
fn prepare_epoch(&mut self) {}
@ -54,7 +54,7 @@ pub trait NeuraTrainable: NeuraLayer {
/// Should return the regularization gradient
fn regularize(&self) -> Self::Delta;
/// Called before an epoch begins, to allow the network to set itself up for training.
/// Called before an iteration begins, to allow the network to set itself up for training.
fn prepare_epoch(&mut self);
/// Called at the end of training, to allow the network to clean itself up
@ -116,7 +116,8 @@ impl<const N: usize, Loss: NeuraLoss<Input = [f64; N]> + Clone>
where
NeuraNetwork<Layer, ChildNetwork>: NeuraTrainable<Input = Layer::Input, Output = [f64; N]>,
{
self.loss.eval(target, &trainable.eval(&input))
let output = trainable.eval(&input);
self.loss.eval(target, &output)
}
}
@ -138,15 +139,17 @@ pub struct NeuraBatchedTrainer {
/// How many gradient computations to average before updating the weights
pub batch_size: usize,
/// How many batches to run for; if `epochs * batch_size` exceeds the input length, then training will stop.
/// How many batches to run for; if `iterations * batch_size` exceeds the input length, then training will stop.
/// You should use `cycle_shuffling` from the `prelude` module to avoid this.
pub epochs: usize,
///
/// Note that this is different from epochs, which count how many times the dataset has been fully iterated over.
pub iterations: usize,
/// The trainer will log progress at every multiple of `log_epochs` steps.
/// If `log_epochs` is zero (default), then no progress will be logged.
/// The trainer will log progress at every multiple of `log_iterations` iterations.
/// If `log_iterations` is zero (default), then no progress will be logged.
///
/// The test inputs is used to measure the score of the network.
pub log_epochs: usize,
pub log_iterations: usize,
}
impl Default for NeuraBatchedTrainer {
@ -155,17 +158,17 @@ impl Default for NeuraBatchedTrainer {
learning_rate: 0.1,
learning_momentum: 0.0,
batch_size: 100,
epochs: 100,
log_epochs: 0,
iterations: 100,
log_iterations: 0,
}
}
}
impl NeuraBatchedTrainer {
pub fn new(learning_rate: f64, epochs: usize) -> Self {
pub fn new(learning_rate: f64, iterations: usize) -> Self {
Self {
learning_rate,
epochs,
iterations,
..Default::default()
}
}
@ -195,7 +198,7 @@ impl NeuraBatchedTrainer {
// Contains `momentum_factor * factor * gradient_sum_previous_iter`
let mut previous_gradient_sum =
<NeuraNetwork<Layer, ChildNetwork> as NeuraTrainable>::Delta::zero();
'd: for epoch in 0..self.epochs {
'd: for iteration in 0..self.iterations {
let mut gradient_sum =
<NeuraNetwork<Layer, ChildNetwork> as NeuraTrainable>::Delta::zero();
network.prepare_epoch();
@ -211,7 +214,7 @@ impl NeuraBatchedTrainer {
gradient_sum.mul_assign(factor);
// Add regularization gradient (TODO: check if it can be factored out of momentum)
// Add regularization gradient
let mut reg_gradient = network.regularize();
reg_gradient.mul_assign(reg_factor);
gradient_sum.add_assign(&reg_gradient);
@ -224,14 +227,14 @@ impl NeuraBatchedTrainer {
previous_gradient_sum.mul_assign(momentum_factor);
}
if self.log_epochs > 0 && (epoch + 1) % self.log_epochs == 0 {
if self.log_iterations > 0 && (iteration + 1) % self.log_iterations == 0 {
network.cleanup();
let mut loss_sum = 0.0;
for (input, target) in test_inputs {
loss_sum += gradient_solver.score(&network, input, target);
}
loss_sum /= test_inputs.len() as f64;
println!("Epoch {}, Loss: {:.3}", epoch + 1, loss_sum);
println!("Iteration {}, Loss: {:.3}", iteration + 1, loss_sum);
}
}
@ -243,8 +246,11 @@ impl NeuraBatchedTrainer {
mod test {
use super::*;
use crate::{
assert_approx,
derivable::{activation::Linear, loss::Euclidean, regularize::NeuraL0},
layer::NeuraDenseLayer,
network::NeuraNetworkTail,
neura_network,
};
#[test]
@ -263,4 +269,40 @@ mod test {
}
}
}
#[test]
fn test_backpropagation_complex() {
const EPSILON: f64 = 0.00001;
// Test that we get the same values as https://hmkcode.com/ai/backpropagation-step-by-step/
let network = neura_network![
NeuraDenseLayer::new([[0.11, 0.21], [0.12, 0.08]], [0.0; 2], Linear, NeuraL0),
NeuraDenseLayer::new([[0.14, 0.15]], [0.0], Linear, NeuraL0)
];
let input = [2.0, 3.0];
let target = [1.0];
let intermediary = network.clone().trim_tail().eval(&input);
assert_approx!(0.85, intermediary[0], EPSILON);
assert_approx!(0.48, intermediary[1], EPSILON);
assert_approx!(0.191, network.eval(&input)[0], EPSILON);
assert_approx!(0.327, Euclidean.eval(&target, &network.eval(&input)), 0.001);
let delta = network.eval(&input)[0] - target[0];
let (gradient_first, gradient_second) =
NeuraBackprop::new(Euclidean).get_gradient(&network, &input, &target);
let gradient_first = gradient_first.0;
let gradient_second = gradient_second.0[0];
assert_approx!(gradient_second[0], intermediary[0] * delta, EPSILON);
assert_approx!(gradient_second[1], intermediary[1] * delta, EPSILON);
assert_approx!(gradient_first[0][0], input[0] * delta * 0.14, EPSILON);
assert_approx!(gradient_first[0][1], input[1] * delta * 0.14, EPSILON);
assert_approx!(gradient_first[1][0], input[0] * delta * 0.15, EPSILON);
assert_approx!(gradient_first[1][1], input[1] * delta * 0.15, EPSILON);
}
}

@ -175,3 +175,54 @@ pub(crate) fn uniform_vector<const LENGTH: usize>() -> [f64; LENGTH] {
res
}
pub fn one_hot<const N: usize>(value: usize) -> [f64; N] {
let mut res = [0.0; N];
if value < N {
res[value] = 1.0;
}
res
}
pub fn argmax(array: &[f64]) -> usize {
let mut res = 0;
for n in 1..array.len() {
if array[n] > array[res] {
res = n;
}
}
res
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_reverse_dot_product() {
let left = [2.0, 3.0, 5.0];
let right = [7.0, 11.0, 13.0, 17.0];
let expected = [
[14.0, 22.0, 26.0, 34.0],
[21.0, 33.0, 39.0, 51.0],
[35.0, 55.0, 65.0, 85.0],
];
assert_eq!(expected, reverse_dot_product(&left, &right));
}
}
#[cfg(test)]
#[macro_export]
macro_rules! assert_approx {
( $left:expr, $right:expr, $epsilon:expr ) => {
let left = $left;
let right = $right;
if (left - right).abs() >= $epsilon {
panic!("Expected {} to be approximately equal to {}", left, right);
}
};
}

Loading…
Cancel
Save