You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

161 lines
4.4 KiB

use nalgebra::DVector;
use rand::Rng;
use rust_mnist::Mnist;
use neuramethyst::{
argmax, cycle_shuffling,
derivable::{
activation::{Linear, Logistic, Relu, Swish, Tanh},
loss::{CrossEntropy, Euclidean},
},
one_hot, plot_losses,
prelude::*,
};
const TRAIN_SIZE: usize = 50000;
const TEST_SIZE: usize = 1000;
const WIDTH: usize = 28;
const HEIGHT: usize = 28;
const LATENT_SIZE: usize = 25;
pub fn main() {
let Mnist {
train_data: train_images,
train_labels,
test_data: test_images,
test_labels,
..
} = Mnist::new("data/");
let train_images = train_images
.into_iter()
.map(|raw| {
DVector::from_iterator(WIDTH * HEIGHT, raw.into_iter().map(|x| x as f32 / 255.0))
})
.take(TRAIN_SIZE);
let train_labels = train_labels
.into_iter()
.map(|x| one_hot(x as usize, 10))
.take(TRAIN_SIZE);
let test_images = test_images
.into_iter()
.map(|raw| {
DVector::from_iterator(WIDTH * HEIGHT, raw.into_iter().map(|x| x as f32 / 255.0))
})
.take(TEST_SIZE);
let test_labels = test_labels
.into_iter()
.map(|x| one_hot(x as usize, 10))
.take(TEST_SIZE);
let test_data = test_images
.clone()
.zip(test_images.clone())
.collect::<Vec<_>>();
// First, train an encoder-decoder network (unsupervised)
let mut network = neura_sequential![
neura_layer!("dense", 100).activation(Swish(Logistic)),
neura_layer!("dense", 50).activation(Swish(Logistic)),
neura_layer!("dense", LATENT_SIZE).activation(Tanh),
neura_layer!("dense", 100).activation(Swish(Logistic)),
neura_layer!("dense", WIDTH * HEIGHT).activation(Relu),
]
.construct(NeuraShape::Vector(WIDTH * HEIGHT))
.unwrap();
let mut trainer = NeuraBatchedTrainer::with_epochs(0.03, 200, 512, TRAIN_SIZE);
trainer.learning_momentum = 0.002;
// trainer.log_iterations = 1;
let mut rng = rand::thread_rng();
let losses = trainer.train(
&NeuraBackprop::new(Euclidean),
&mut network,
cycle_shuffling(train_images.clone(), rand::thread_rng()).map(move |input| {
let dx = rng.gen_range(-4..4);
let dy = rng.gen_range(-4..4);
let shifted = shift(&input, dx, dy);
(shifted.clone(), shifted)
}),
&test_data,
);
plot_losses(losses, 128, 48);
// Then, train a small network to decode the encoded data into the categories
let trimmed_network = network.clone().trim_tail().trim_tail();
let mut network = neura_sequential![
..trimmed_network.lock(),
neura_layer!("dense", LATENT_SIZE)
.activation(Tanh)
.construct(NeuraShape::Vector(LATENT_SIZE))
.unwrap(),
neura_layer!("dense", 10)
.activation(Linear)
.construct(NeuraShape::Vector(LATENT_SIZE))
.unwrap(),
neura_layer!("softmax")
];
let test_data = test_images
.clone()
.zip(test_labels.clone())
.collect::<Vec<_>>();
let trainer = NeuraBatchedTrainer::with_epochs(0.03, 10, 128, TRAIN_SIZE);
plot_losses(
trainer.train(
&NeuraBackprop::new(CrossEntropy),
&mut network,
cycle_shuffling(train_images.clone().zip(train_labels), rand::thread_rng()),
&test_data,
),
128,
48,
);
let mut correct = 0;
for (test_image, test_label) in test_images.zip(test_labels) {
let guess = network.eval(&test_image);
let guess = argmax(guess.as_slice());
let actual = argmax(test_label.as_slice());
if guess == actual {
correct += 1;
}
}
println!("");
println!(
"{} correct out of {}: {:.2}%",
correct,
TEST_SIZE,
(correct as f32 / TEST_SIZE as f32) * 100.0
);
}
fn shift(image: &DVector<f32>, dx: i32, dy: i32) -> DVector<f32> {
let mut res = DVector::from_element(image.len(), 0.0);
let width = WIDTH as i32;
let height = HEIGHT as i32;
for y in 0..height {
for x in 0..width {
let x2 = x + dx;
let y2 = y + dy;
if y2 < 0 || y2 >= height || x2 < 0 || x2 >= width {
continue;
}
res[(y2 * width + x2) as usize] = image[(y * width + x) as usize];
}
}
res
}