|
|
|
@ -82,7 +82,8 @@ impl NeuraBatchedTrainer {
|
|
|
|
|
network: &mut Network,
|
|
|
|
|
inputs: Inputs,
|
|
|
|
|
test_inputs: &[(Input, Target)],
|
|
|
|
|
) {
|
|
|
|
|
) -> Vec<(f64, f64)> {
|
|
|
|
|
let mut losses = Vec::new();
|
|
|
|
|
let mut iter = inputs.into_iter();
|
|
|
|
|
let factor = -self.learning_rate / (self.batch_size as f64);
|
|
|
|
|
let momentum_factor = self.learning_momentum / self.learning_rate;
|
|
|
|
@ -90,6 +91,7 @@ impl NeuraBatchedTrainer {
|
|
|
|
|
|
|
|
|
|
// Contains `momentum_factor * factor * gradient_sum_previous_iter`
|
|
|
|
|
let mut previous_gradient_sum = network.default_gradient();
|
|
|
|
|
let mut train_loss = 0.0;
|
|
|
|
|
'd: for iteration in 0..self.iterations {
|
|
|
|
|
let mut gradient_sum = network.default_gradient();
|
|
|
|
|
network.prepare(true);
|
|
|
|
@ -98,6 +100,8 @@ impl NeuraBatchedTrainer {
|
|
|
|
|
if let Some((input, target)) = iter.next() {
|
|
|
|
|
let gradient = gradient_solver.get_gradient(&network, &input, &target);
|
|
|
|
|
gradient_sum.add_assign(&gradient);
|
|
|
|
|
|
|
|
|
|
train_loss += gradient_solver.score(&network, &input, &target);
|
|
|
|
|
} else {
|
|
|
|
|
break 'd;
|
|
|
|
|
}
|
|
|
|
@ -120,16 +124,27 @@ impl NeuraBatchedTrainer {
|
|
|
|
|
|
|
|
|
|
if self.log_iterations > 0 && (iteration + 1) % self.log_iterations == 0 {
|
|
|
|
|
network.prepare(false);
|
|
|
|
|
let mut loss_sum = 0.0;
|
|
|
|
|
let mut val_loss = 0.0;
|
|
|
|
|
for (input, target) in test_inputs {
|
|
|
|
|
loss_sum += gradient_solver.score(&network, input, target);
|
|
|
|
|
val_loss += gradient_solver.score(&network, input, target);
|
|
|
|
|
}
|
|
|
|
|
loss_sum /= test_inputs.len() as f64;
|
|
|
|
|
println!("Iteration {}, Loss: {:.3}", iteration + 1, loss_sum);
|
|
|
|
|
val_loss /= test_inputs.len() as f64;
|
|
|
|
|
train_loss /= (self.batch_size * self.log_iterations) as f64;
|
|
|
|
|
println!(
|
|
|
|
|
"Iteration {}, Training loss: {:.3}, Validation loss: {:.3}",
|
|
|
|
|
iteration + 1,
|
|
|
|
|
train_loss,
|
|
|
|
|
val_loss
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
losses.push((train_loss, val_loss));
|
|
|
|
|
train_loss = 0.0;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
network.prepare(false);
|
|
|
|
|
|
|
|
|
|
losses
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|