diff --git a/Cargo.toml b/Cargo.toml index fe4ca1f..4e78ae6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ num = "^0.4" # num-traits = "0.2.15" rand = "^0.8" rand_distr = "0.4.3" +textplots = "0.8.0" [dev-dependencies] image = "0.24.6" diff --git a/examples/bivariate-forward.rs b/examples/bivariate-forward.rs index 9fde389..2bf9cdb 100644 --- a/examples/bivariate-forward.rs +++ b/examples/bivariate-forward.rs @@ -5,7 +5,7 @@ use nalgebra::{dvector, DVector}; use neuramethyst::derivable::activation::{LeakyRelu, Linear, Relu, Tanh}; use neuramethyst::derivable::regularize::NeuraL1; use neuramethyst::gradient_solver::NeuraForwardForward; -use neuramethyst::prelude::*; +use neuramethyst::{plot_losses, prelude::*}; use rand::Rng; @@ -88,11 +88,15 @@ fn main() { trainer.batch_size = 10; trainer.log_iterations = 20; - trainer.train( - &NeuraForwardForward::new(Tanh, threshold as f64), - &mut network, - inputs.clone(), - &test_inputs, + plot_losses( + trainer.train( + &NeuraForwardForward::new(Tanh, threshold as f64), + &mut network, + inputs.clone(), + &test_inputs, + ), + 128, + 48, ); // println!("{}", String::from("\n").repeat(64)); diff --git a/examples/bivariate.rs b/examples/bivariate.rs index 82fd87b..4e8c90b 100644 --- a/examples/bivariate.rs +++ b/examples/bivariate.rs @@ -7,7 +7,7 @@ use nalgebra::{dvector, DVector}; use neuramethyst::derivable::activation::{LeakyRelu, Linear, Relu, Tanh}; use neuramethyst::derivable::loss::CrossEntropy; use neuramethyst::derivable::regularize::NeuraL1; -use neuramethyst::prelude::*; +use neuramethyst::{plot_losses, prelude::*}; use rand::Rng; @@ -77,11 +77,15 @@ fn main() { trainer.batch_size = 10; trainer.log_iterations = 20; - trainer.train( - &NeuraBackprop::new(CrossEntropy), - &mut network, - inputs.clone(), - &test_inputs, + plot_losses( + trainer.train( + &NeuraBackprop::new(CrossEntropy), + &mut network, + inputs.clone(), + &test_inputs, + ), + 128, + 48, ); // println!("{}", String::from("\n").repeat(64)); diff --git a/src/lib.rs b/src/lib.rs index e5abe7d..0bdf1cc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,7 +11,7 @@ pub mod train; mod utils; // TODO: move to a different file -pub use utils::{argmax, cycle_shuffling, one_hot}; +pub use utils::{argmax, cycle_shuffling, one_hot, plot_losses}; pub mod prelude { // Macros diff --git a/src/train.rs b/src/train.rs index b782b5c..576acca 100644 --- a/src/train.rs +++ b/src/train.rs @@ -82,7 +82,8 @@ impl NeuraBatchedTrainer { network: &mut Network, inputs: Inputs, test_inputs: &[(Input, Target)], - ) { + ) -> Vec<(f64, f64)> { + let mut losses = Vec::new(); let mut iter = inputs.into_iter(); let factor = -self.learning_rate / (self.batch_size as f64); let momentum_factor = self.learning_momentum / self.learning_rate; @@ -90,6 +91,7 @@ impl NeuraBatchedTrainer { // Contains `momentum_factor * factor * gradient_sum_previous_iter` let mut previous_gradient_sum = network.default_gradient(); + let mut train_loss = 0.0; 'd: for iteration in 0..self.iterations { let mut gradient_sum = network.default_gradient(); network.prepare(true); @@ -98,6 +100,8 @@ impl NeuraBatchedTrainer { if let Some((input, target)) = iter.next() { let gradient = gradient_solver.get_gradient(&network, &input, &target); gradient_sum.add_assign(&gradient); + + train_loss += gradient_solver.score(&network, &input, &target); } else { break 'd; } @@ -120,16 +124,27 @@ impl NeuraBatchedTrainer { if self.log_iterations > 0 && (iteration + 1) % self.log_iterations == 0 { network.prepare(false); - let mut loss_sum = 0.0; + let mut val_loss = 0.0; for (input, target) in test_inputs { - loss_sum += gradient_solver.score(&network, input, target); + val_loss += gradient_solver.score(&network, input, target); } - loss_sum /= test_inputs.len() as f64; - println!("Iteration {}, Loss: {:.3}", iteration + 1, loss_sum); + val_loss /= test_inputs.len() as f64; + train_loss /= (self.batch_size * self.log_iterations) as f64; + println!( + "Iteration {}, Training loss: {:.3}, Validation loss: {:.3}", + iteration + 1, + train_loss, + val_loss + ); + + losses.push((train_loss, val_loss)); + train_loss = 0.0; } } network.prepare(false); + + losses } } diff --git a/src/utils.rs b/src/utils.rs index 6a1d976..1b378fd 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -128,3 +128,24 @@ macro_rules! assert_approx { } }; } + +// TODO: put this behind a feature +pub fn plot_losses(losses: Vec<(f64, f64)>, width: u32, height: u32) { + use textplots::{Chart, ColorPlot, Plot, Shape}; + + let train_losses: Vec<_> = losses + .iter() + .enumerate() + .map(|(x, y)| (x as f32, y.0 as f32)) + .collect(); + let val_losses: Vec<_> = losses + .iter() + .enumerate() + .map(|(x, y)| (x as f32, y.1 as f32)) + .collect(); + + Chart::new(width, height, 0.0, losses.len() as f32) + .lineplot(&Shape::Lines(&train_losses)) + .linecolorplot(&Shape::Lines(&val_losses), (255, 0, 255).into()) + .nice(); +}