1D convolution layer

Nailed it on the first try :3c

(or not, and I'll regret writing this in a few years)
main
Shad Amethyst 2 years ago
parent 6c1d6874d7
commit d7eb6de34e

@ -1,5 +1,5 @@
#![feature(generic_arg_infer)]
// #![feature(generic_const_exprs)]
#![feature(generic_const_exprs)]
use neuramethyst::algebra::NeuraVector;
use rust_mnist::Mnist;
@ -9,7 +9,7 @@ use neuramethyst::derivable::loss::CrossEntropy;
use neuramethyst::{cycle_shuffling, one_hot, prelude::*};
fn main() {
const TRAIN_SIZE: usize = 100;
const TRAIN_SIZE: usize = 1000;
let Mnist {
train_data: train_images,
@ -53,19 +53,18 @@ fn main() {
let test_inputs: Vec<_> = test_images.zip(test_labels.into_iter()).collect();
let mut network = neura_sequential![
neura_layer!("dense", { 28 * 28 }, 200; Relu),
neura_layer!("dropout", 0.5),
neura_layer!("dense", 100; Relu),
neura_layer!("dropout", 0.5),
neura_layer!("unstable_reshape", 28, 28),
neura_layer!("conv1d_pad", 3; neura_layer!("dense", {28 * 3}, 10; Relu)),
neura_layer!("unstable_flatten"),
// neura_layer!("dense", 100; Relu),
// neura_layer!("dropout", 0.5),
neura_layer!("dense", 30; Relu),
neura_layer!("dropout", 0.5),
neura_layer!("dense", 10; Linear),
neura_layer!("softmax")
];
let mut trainer = NeuraBatchedTrainer::new(0.03, TRAIN_SIZE * 10);
trainer.log_iterations = (TRAIN_SIZE / 128).max(1);
trainer.batch_size = 128;
let mut trainer = NeuraBatchedTrainer::with_epochs(0.03, 100, 128, TRAIN_SIZE);
trainer.learning_momentum = 0.001;
trainer.train(

@ -31,6 +31,24 @@ impl<const WIDTH: usize, const HEIGHT: usize, F> NeuraMatrix<WIDTH, HEIGHT, F> {
Some(&self.data[y][x])
}
#[inline]
pub fn set_row(&mut self, y: usize, row: impl Borrow<[F; WIDTH]>)
where
F: Clone,
{
if y >= HEIGHT {
panic!(
"Cannot set row {} of NeuraMatrix<{}, {}, _>: row index out of bound",
y, WIDTH, HEIGHT
);
}
let row = row.borrow();
for j in 0..WIDTH {
self.data[y][j] = row[j].clone();
}
}
}
impl<const WIDTH: usize, const HEIGHT: usize, F: Float> NeuraMatrix<WIDTH, HEIGHT, F> {

@ -0,0 +1,138 @@
use crate::algebra::{NeuraMatrix, NeuraVector};
use super::*;
/// A 1-dimensional convolutional
#[derive(Clone, Debug)]
pub struct NeuraConv1DPad<
const LENGTH: usize,
const IN_FEATS: usize,
const WINDOW: usize,
Layer: NeuraLayer<Input = NeuraVector<{ IN_FEATS * WINDOW }, f64>>,
> {
inner_layer: Layer,
pad_with: NeuraVector<IN_FEATS, f64>,
}
impl<
const LENGTH: usize,
const IN_FEATS: usize,
const WINDOW: usize,
Layer: NeuraLayer<Input = NeuraVector<{ IN_FEATS * WINDOW }, f64>>,
> NeuraConv1DPad<LENGTH, IN_FEATS, WINDOW, Layer>
where
[u8; IN_FEATS * WINDOW]: Sized,
{
pub fn new(inner_layer: Layer, pad_with: NeuraVector<IN_FEATS, f64>) -> Self {
Self {
inner_layer,
pad_with,
}
}
fn iterate_windows<'a>(
&'a self,
input: &'a NeuraMatrix<IN_FEATS, LENGTH, f64>,
) -> impl Iterator<Item = (usize, Layer::Input)> + 'a {
(0..LENGTH).map(move |window_center| {
let mut virtual_input: NeuraVector<{ IN_FEATS * WINDOW }, f64> = NeuraVector::default();
for i in 0..WINDOW {
let input_index = i as isize + window_center as isize - (WINDOW - 1) as isize / 2;
if input_index < 0 || input_index >= LENGTH as isize {
for j in 0..IN_FEATS {
virtual_input[i * IN_FEATS + j] = self.pad_with[j];
}
} else {
for j in 0..IN_FEATS {
virtual_input[i * IN_FEATS + j] = input[input_index as usize][j];
}
}
}
(window_center, virtual_input)
})
}
}
impl<
const LENGTH: usize,
const IN_FEATS: usize,
const OUT_FEATS: usize,
const WINDOW: usize,
Layer: NeuraLayer<
Input = NeuraVector<{ IN_FEATS * WINDOW }, f64>,
Output = NeuraVector<OUT_FEATS, f64>,
>,
> NeuraLayer for NeuraConv1DPad<LENGTH, IN_FEATS, WINDOW, Layer>
{
type Input = NeuraMatrix<IN_FEATS, LENGTH, f64>;
type Output = NeuraMatrix<OUT_FEATS, LENGTH, f64>;
fn eval(&self, input: &Self::Input) -> Self::Output {
let mut res = NeuraMatrix::default();
for (window_center, virtual_input) in self.iterate_windows(input) {
res.set_row(window_center, self.inner_layer.eval(&virtual_input));
}
res
}
}
impl<
const LENGTH: usize,
const IN_FEATS: usize,
const OUT_FEATS: usize,
const WINDOW: usize,
Layer: NeuraLayer<
Input = NeuraVector<{ IN_FEATS * WINDOW }, f64>,
Output = NeuraVector<OUT_FEATS, f64>,
>,
> NeuraTrainableLayer for NeuraConv1DPad<LENGTH, IN_FEATS, WINDOW, Layer>
where
Layer: NeuraTrainableLayer,
{
type Delta = <Layer as NeuraTrainableLayer>::Delta;
fn backpropagate(
&self,
input: &Self::Input,
epsilon: Self::Output,
) -> (Self::Input, Self::Delta) {
let mut next_epsilon = Self::Input::default();
let mut weights_gradient_sum = Self::Delta::zero();
// TODO: consume epsilon efficiently
for (window_center, virtual_input) in self.iterate_windows(input) {
let epsilon = NeuraVector::from(&epsilon[window_center]);
let (layer_next_epsilon, weights_gradient) =
self.inner_layer.backpropagate(&virtual_input, epsilon);
weights_gradient_sum.add_assign(&weights_gradient);
for i in 0..WINDOW {
// Re-compute the positions in `input` matching the positions in `layer_next_epsilon` and `virtual_input`
let input_index = window_center as isize + i as isize - (WINDOW - 1) as isize / 2;
if input_index < 0 || input_index >= LENGTH as isize {
continue;
}
let input_index = input_index as usize;
for j in 0..IN_FEATS {
next_epsilon[input_index][j] += layer_next_epsilon[i * WINDOW + j];
}
}
}
(next_epsilon, weights_gradient_sum)
}
fn regularize(&self) -> Self::Delta {
self.inner_layer.regularize()
}
fn apply_gradient(&mut self, gradient: &Self::Delta) {
self.inner_layer.apply_gradient(gradient);
}
}

@ -1,6 +1,9 @@
mod dense;
pub use dense::NeuraDenseLayer;
mod convolution;
pub use convolution::NeuraConv1DPad;
mod dropout;
pub use dropout::NeuraDropoutLayer;
@ -10,15 +13,12 @@ pub use softmax::NeuraSoftmaxLayer;
mod one_hot;
pub use one_hot::NeuraOneHotLayer;
// mod reshape;
// pub use reshape::{
// NeuraFlattenLayer,
// NeuraReshapeLayer
// };
mod lock;
pub use lock::NeuraLockLayer;
mod reshape;
pub use reshape::{NeuraFlattenLayer, NeuraReshapeLayer};
use crate::algebra::NeuraVectorSpace;
pub trait NeuraLayer {
@ -105,19 +105,27 @@ macro_rules! neura_layer {
$crate::layer::NeuraLockLayer($layer)
};
// ( "flatten" ) => {
// $crate::layer::NeuraFlattenLayer::new() as $crate::layer::NeuraFlattenLayer<_, _, f64>
// };
( "conv1d_pad", $length:expr, $feats:expr, $window:expr; $layer:expr ) => {
$crate::layer::NeuraConv1DPad::new($layer, Default::default()) as $crate::layer::NeuraConv1DPad<$length, $feats, $window, _>
};
// ( "flatten", $width:expr, $height:expr ) => {
// $crate::layer::NeuraFlattenLayer::new() as $crate::layer::NeuraFlattenLayer<$width, $height, f64>
// };
( "conv1d_pad", $window:expr; $layer:expr ) => {
$crate::layer::NeuraConv1DPad::new($layer, Default::default()) as $crate::layer::NeuraConv1DPad<_, _, $window, _>
};
// ( "reshape", $height:expr ) => {
// $crate::layer::NeuraReshapeLayer::new() as $crate::layer::NeuraReshapeLayer<_, $height, f64>
// };
( "unstable_flatten" ) => {
$crate::layer::NeuraFlattenLayer::new() as $crate::layer::NeuraFlattenLayer<_, _, f64>
};
// ( "reshape", $width:expr, $height:expr ) => {
// $crate::layer::NeuraReshapeLayer::new() as $crate::layer::NeuraReshapeLayer<$width, $height, f64>
// };
( "unstable_flatten", $width:expr, $height:expr ) => {
$crate::layer::NeuraFlattenLayer::new() as $crate::layer::NeuraFlattenLayer<$width, $height, f64>
};
( "unstable_reshape", $height:expr ) => {
$crate::layer::NeuraReshapeLayer::new() as $crate::layer::NeuraReshapeLayer<_, $height, f64>
};
( "unstable_reshape", $width:expr, $height:expr ) => {
$crate::layer::NeuraReshapeLayer::new() as $crate::layer::NeuraReshapeLayer<$width, $height, f64>
};
}

@ -1,32 +1,35 @@
//! This module is currently disabled, as it relies on `generic_const_exprs`, which is too unstable to use as of now
//! This module requires the `generic_const_exprs` feature to be enabled,
//! which is still quite unstable as of writing this.
use std::borrow::Borrow;
use crate::algebra::{NeuraMatrix, NeuraVector};
use super::{NeuraLayer, NeuraTrainableLayer};
/// Converts a `[[T; WIDTH]; HEIGHT]` into a `[T; WIDTH * HEIGHT]`.
/// Converts a `[[T; WIDTH]; HEIGHT]` into a `NeuraVector<{WIDTH * HEIGHT}, T>`.
/// Requires the `#![feature(generic_const_exprs)]` feature to be enabled.
pub struct NeuraFlattenLayer<const WIDTH: usize, const HEIGHT: usize, T> {
phantom: std::marker::PhantomData<T>,
}
/// Converts a `[T; WIDTH * HEIGHT]` into a `[[T; WIDTH]; HEIGHT]`.
/// Converts a `NeuraVector<{WIDTH * HEIGHT}, T>` into a `[[T; WIDTH]; HEIGHT]`.
/// Requires the `#![feature(generic_const_exprs)]` feature to be enabled.
pub struct NeuraReshapeLayer<const WIDTH: usize, const HEIGHT: usize, T> {
phantom: std::marker::PhantomData<T>,
}
#[inline(always)]
fn flatten<const WIDTH: usize, const HEIGHT: usize, T: Copy + Default>(
input: &[[T; WIDTH]; HEIGHT],
) -> [T; WIDTH * HEIGHT]
where
[T; WIDTH * HEIGHT]: Sized,
{
let mut res = [T::default(); WIDTH * HEIGHT];
fn flatten<const WIDTH: usize, const HEIGHT: usize, T: Clone + Default>(
input: impl Borrow<[[T; WIDTH]; HEIGHT]>,
) -> NeuraVector<{ WIDTH * HEIGHT }, T> {
let mut res = NeuraVector::default();
let input = input.borrow();
// Hopefully the optimizer realizes this can be all optimized away
for i in 0..HEIGHT {
for j in 0..WIDTH {
res[i * WIDTH + j] = input[i][j];
res[i * WIDTH + j] = input[i][j].clone();
}
}
@ -34,18 +37,16 @@ where
}
#[inline(always)]
fn reshape<const WIDTH: usize, const HEIGHT: usize, T: Copy + Default>(
input: &[T; WIDTH * HEIGHT],
) -> [[T; WIDTH]; HEIGHT]
where
[T; WIDTH * HEIGHT]: Sized,
{
let mut res = [[T::default(); WIDTH]; HEIGHT];
fn reshape<const WIDTH: usize, const HEIGHT: usize, T: Clone + Default>(
input: impl Borrow<[T; WIDTH * HEIGHT]>,
) -> NeuraMatrix<WIDTH, HEIGHT, T> {
let input = input.borrow();
let mut res = NeuraMatrix::default();
// Hopefully the optimizer realizes this can be all optimized away
for i in 0..HEIGHT {
for j in 0..WIDTH {
res[i][j] = input[i * WIDTH + j];
res[i][j] = input[i * WIDTH + j].clone();
}
}
@ -71,26 +72,26 @@ impl<const WIDTH: usize, const HEIGHT: usize, T> NeuraReshapeLayer<WIDTH, HEIGHT
impl<const WIDTH: usize, const HEIGHT: usize, T: Copy + Default> NeuraLayer
for NeuraFlattenLayer<WIDTH, HEIGHT, T>
where
[T; WIDTH * HEIGHT]: Sized,
NeuraVector<{ WIDTH * HEIGHT }, T>: Sized,
{
type Input = [[T; WIDTH]; HEIGHT];
type Input = NeuraMatrix<WIDTH, HEIGHT, T>;
type Output = [T; WIDTH * HEIGHT];
type Output = NeuraVector<{ WIDTH * HEIGHT }, T>;
#[inline(always)]
fn eval(&self, input: &Self::Input) -> Self::Output {
flatten(input)
flatten(input.as_ref())
}
}
impl<const WIDTH: usize, const HEIGHT: usize, T: Copy + Default> NeuraLayer
for NeuraReshapeLayer<WIDTH, HEIGHT, T>
where
[T; WIDTH * HEIGHT]: Sized,
NeuraVector<{ WIDTH * HEIGHT }, T>: Sized,
{
type Input = [T; WIDTH * HEIGHT];
type Input = NeuraVector<{ WIDTH * HEIGHT }, T>;
type Output = [[T; WIDTH]; HEIGHT];
type Output = NeuraMatrix<WIDTH, HEIGHT, T>;
#[inline(always)]
fn eval(&self, input: &Self::Input) -> Self::Output {
@ -101,7 +102,7 @@ where
impl<const WIDTH: usize, const HEIGHT: usize, T: Copy + Default> NeuraTrainableLayer
for NeuraFlattenLayer<WIDTH, HEIGHT, T>
where
[T; WIDTH * HEIGHT]: Sized,
NeuraVector<{ WIDTH * HEIGHT }, T>: Sized,
{
type Delta = ();
@ -114,7 +115,7 @@ where
}
fn regularize(&self) -> Self::Delta {
todo!()
()
}
fn apply_gradient(&mut self, _gradient: &Self::Delta) {
@ -125,7 +126,7 @@ where
impl<const WIDTH: usize, const HEIGHT: usize, T: Copy + Default> NeuraTrainableLayer
for NeuraReshapeLayer<WIDTH, HEIGHT, T>
where
[T; WIDTH * HEIGHT]: Sized,
NeuraVector<{ WIDTH * HEIGHT }, T>: Sized,
{
type Delta = ();
@ -134,11 +135,11 @@ where
_input: &Self::Input,
epsilon: Self::Output,
) -> (Self::Input, Self::Delta) {
(flatten(&epsilon), ())
(flatten(epsilon), ())
}
fn regularize(&self) -> Self::Delta {
todo!()
()
}
fn apply_gradient(&mut self, _gradient: &Self::Delta) {

@ -1,6 +1,5 @@
#![feature(generic_arg_infer)]
#![feature(generic_associated_types)]
// #![feature(generic_const_exprs)]
#![feature(generic_const_exprs)]
pub mod algebra;
pub mod derivable;

@ -121,6 +121,21 @@ impl NeuraBatchedTrainer {
}
}
pub fn with_epochs(
learning_rate: f64,
epochs: usize,
batch_size: usize,
training_size: usize,
) -> Self {
Self {
learning_rate,
iterations: (training_size * epochs / batch_size).max(1),
log_iterations: (training_size / batch_size).max(1),
batch_size,
..Default::default()
}
}
pub fn train<
Output,
Target: Clone,

Loading…
Cancel
Save