Return training and validation losses in train(), plot them out

main
Shad Amethyst 2 years ago
parent a5237a8ef1
commit b3b97f76bd

@ -13,6 +13,7 @@ num = "^0.4"
# num-traits = "0.2.15"
rand = "^0.8"
rand_distr = "0.4.3"
textplots = "0.8.0"
[dev-dependencies]
image = "0.24.6"

@ -5,7 +5,7 @@ use nalgebra::{dvector, DVector};
use neuramethyst::derivable::activation::{LeakyRelu, Linear, Relu, Tanh};
use neuramethyst::derivable::regularize::NeuraL1;
use neuramethyst::gradient_solver::NeuraForwardForward;
use neuramethyst::prelude::*;
use neuramethyst::{plot_losses, prelude::*};
use rand::Rng;
@ -88,11 +88,15 @@ fn main() {
trainer.batch_size = 10;
trainer.log_iterations = 20;
trainer.train(
&NeuraForwardForward::new(Tanh, threshold as f64),
&mut network,
inputs.clone(),
&test_inputs,
plot_losses(
trainer.train(
&NeuraForwardForward::new(Tanh, threshold as f64),
&mut network,
inputs.clone(),
&test_inputs,
),
128,
48,
);
// println!("{}", String::from("\n").repeat(64));

@ -7,7 +7,7 @@ use nalgebra::{dvector, DVector};
use neuramethyst::derivable::activation::{LeakyRelu, Linear, Relu, Tanh};
use neuramethyst::derivable::loss::CrossEntropy;
use neuramethyst::derivable::regularize::NeuraL1;
use neuramethyst::prelude::*;
use neuramethyst::{plot_losses, prelude::*};
use rand::Rng;
@ -77,11 +77,15 @@ fn main() {
trainer.batch_size = 10;
trainer.log_iterations = 20;
trainer.train(
&NeuraBackprop::new(CrossEntropy),
&mut network,
inputs.clone(),
&test_inputs,
plot_losses(
trainer.train(
&NeuraBackprop::new(CrossEntropy),
&mut network,
inputs.clone(),
&test_inputs,
),
128,
48,
);
// println!("{}", String::from("\n").repeat(64));

@ -11,7 +11,7 @@ pub mod train;
mod utils;
// TODO: move to a different file
pub use utils::{argmax, cycle_shuffling, one_hot};
pub use utils::{argmax, cycle_shuffling, one_hot, plot_losses};
pub mod prelude {
// Macros

@ -82,7 +82,8 @@ impl NeuraBatchedTrainer {
network: &mut Network,
inputs: Inputs,
test_inputs: &[(Input, Target)],
) {
) -> Vec<(f64, f64)> {
let mut losses = Vec::new();
let mut iter = inputs.into_iter();
let factor = -self.learning_rate / (self.batch_size as f64);
let momentum_factor = self.learning_momentum / self.learning_rate;
@ -90,6 +91,7 @@ impl NeuraBatchedTrainer {
// Contains `momentum_factor * factor * gradient_sum_previous_iter`
let mut previous_gradient_sum = network.default_gradient();
let mut train_loss = 0.0;
'd: for iteration in 0..self.iterations {
let mut gradient_sum = network.default_gradient();
network.prepare(true);
@ -98,6 +100,8 @@ impl NeuraBatchedTrainer {
if let Some((input, target)) = iter.next() {
let gradient = gradient_solver.get_gradient(&network, &input, &target);
gradient_sum.add_assign(&gradient);
train_loss += gradient_solver.score(&network, &input, &target);
} else {
break 'd;
}
@ -120,16 +124,27 @@ impl NeuraBatchedTrainer {
if self.log_iterations > 0 && (iteration + 1) % self.log_iterations == 0 {
network.prepare(false);
let mut loss_sum = 0.0;
let mut val_loss = 0.0;
for (input, target) in test_inputs {
loss_sum += gradient_solver.score(&network, input, target);
val_loss += gradient_solver.score(&network, input, target);
}
loss_sum /= test_inputs.len() as f64;
println!("Iteration {}, Loss: {:.3}", iteration + 1, loss_sum);
val_loss /= test_inputs.len() as f64;
train_loss /= (self.batch_size * self.log_iterations) as f64;
println!(
"Iteration {}, Training loss: {:.3}, Validation loss: {:.3}",
iteration + 1,
train_loss,
val_loss
);
losses.push((train_loss, val_loss));
train_loss = 0.0;
}
}
network.prepare(false);
losses
}
}

@ -128,3 +128,24 @@ macro_rules! assert_approx {
}
};
}
// TODO: put this behind a feature
pub fn plot_losses(losses: Vec<(f64, f64)>, width: u32, height: u32) {
use textplots::{Chart, ColorPlot, Plot, Shape};
let train_losses: Vec<_> = losses
.iter()
.enumerate()
.map(|(x, y)| (x as f32, y.0 as f32))
.collect();
let val_losses: Vec<_> = losses
.iter()
.enumerate()
.map(|(x, y)| (x as f32, y.1 as f32))
.collect();
Chart::new(width, height, 0.0, losses.len() as f32)
.lineplot(&Shape::Lines(&train_losses))
.linecolorplot(&Shape::Lines(&val_losses), (255, 0, 255).into())
.nice();
}

Loading…
Cancel
Save