You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
64 lines
1.9 KiB
64 lines
1.9 KiB
use std::fs::File;
|
|
|
|
use approx::assert_relative_eq;
|
|
use nalgebra::{dvector, DMatrix, DVector};
|
|
use neuramethyst::{
|
|
derivable::{
|
|
activation::{Relu, Tanh},
|
|
loss::Euclidean,
|
|
regularize::NeuraL0,
|
|
},
|
|
layer::dense::NeuraDenseLayer,
|
|
prelude::*,
|
|
};
|
|
|
|
fn load_test_data() -> Vec<(DMatrix<f64>, DVector<f64>, DMatrix<f64>, DVector<f64>)> {
|
|
let file = File::open("tests/xor.json").unwrap();
|
|
let data: Vec<(DMatrix<f64>, DVector<f64>, DMatrix<f64>, DVector<f64>)> =
|
|
serde_json::from_reader(&file).unwrap();
|
|
|
|
data
|
|
}
|
|
|
|
#[test]
|
|
fn test_xor_training() {
|
|
let data = load_test_data();
|
|
|
|
let mut network = neura_sequential![
|
|
NeuraDenseLayer::new(data[0].0.clone(), data[0].1.clone(), Relu, NeuraL0),
|
|
NeuraDenseLayer::new(data[0].2.clone(), data[0].3.clone(), Tanh, NeuraL0),
|
|
];
|
|
|
|
let inputs = [
|
|
(dvector![0.0, 0.0], dvector![0.0]),
|
|
(dvector![0.0, 1.0], dvector![1.0]),
|
|
(dvector![1.0, 0.0], dvector![1.0]),
|
|
(dvector![1.0, 1.0], dvector![0.0]),
|
|
];
|
|
|
|
let mut trainer = NeuraBatchedTrainer::new().learning_rate(0.05).iterations(1);
|
|
trainer.batch_size = 1;
|
|
|
|
for iteration in 0..4 {
|
|
trainer.train(
|
|
&NeuraBackprop::new(Euclidean),
|
|
&mut network,
|
|
inputs.iter().cloned().skip(iteration).take(1),
|
|
&inputs,
|
|
);
|
|
|
|
let expected = data[iteration + 1].clone();
|
|
let actual = (
|
|
network.layer.weights.clone(),
|
|
network.layer.bias.clone(),
|
|
network.child_network.layer.weights.clone(),
|
|
network.child_network.layer.bias.clone(),
|
|
);
|
|
|
|
assert_relative_eq!(expected.0.as_slice(), actual.0.as_slice());
|
|
assert_relative_eq!(expected.1.as_slice(), actual.1.as_slice());
|
|
assert_relative_eq!(expected.2.as_slice(), actual.2.as_slice());
|
|
assert_relative_eq!(expected.3.as_slice(), actual.3.as_slice());
|
|
}
|
|
}
|