parent
8ac82e20e2
commit
220c61ff6b
@ -0,0 +1,87 @@
|
||||
use rand::Rng;
|
||||
|
||||
use crate::train::NeuraTrainableLayer;
|
||||
|
||||
use super::NeuraLayer;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NeuraDropoutLayer<const LENGTH: usize, R: Rng> {
|
||||
pub dropout_probability: f64,
|
||||
multiplier: f64,
|
||||
mask: [bool; LENGTH],
|
||||
rng: R,
|
||||
}
|
||||
|
||||
impl<const LENGTH: usize, R: Rng> NeuraDropoutLayer<LENGTH, R> {
|
||||
pub fn new(dropout_probability: f64, rng: R) -> Self {
|
||||
Self {
|
||||
dropout_probability,
|
||||
multiplier: 1.0,
|
||||
mask: [false; LENGTH],
|
||||
rng,
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_dropout(&self, vector: &mut [f64; LENGTH]) {
|
||||
for (index, &dropout) in self.mask.iter().enumerate() {
|
||||
if dropout {
|
||||
vector[index] = 0.0;
|
||||
} else {
|
||||
vector[index] *= self.multiplier;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LENGTH: usize, R: Rng> NeuraLayer for NeuraDropoutLayer<LENGTH, R> {
|
||||
type Input = [f64; LENGTH];
|
||||
type Output = [f64; LENGTH];
|
||||
|
||||
fn eval(&self, input: &Self::Input) -> Self::Output {
|
||||
let mut result = input.clone();
|
||||
|
||||
self.apply_dropout(&mut result);
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl<const LENGTH: usize, R: Rng> NeuraTrainableLayer for NeuraDropoutLayer<LENGTH, R> {
|
||||
type Delta = ();
|
||||
|
||||
fn backpropagate(
|
||||
&self,
|
||||
_input: &Self::Input,
|
||||
mut epsilon: Self::Output,
|
||||
) -> (Self::Input, Self::Delta) {
|
||||
self.apply_dropout(&mut epsilon);
|
||||
|
||||
(epsilon, ())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn apply_gradient(&mut self, _gradient: &Self::Delta) {
|
||||
// Noop
|
||||
}
|
||||
|
||||
fn prepare_epoch(&mut self) {
|
||||
// Rejection sampling to prevent all the inputs from being dropped out
|
||||
loop {
|
||||
let mut sum = 0;
|
||||
for i in 0..LENGTH {
|
||||
self.mask[i] = self.rng.gen_bool(self.dropout_probability);
|
||||
sum += (!self.mask[i]) as usize;
|
||||
}
|
||||
|
||||
if sum < LENGTH {
|
||||
self.multiplier = LENGTH as f64 / sum as f64;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn cleanup(&mut self) {
|
||||
self.mask = [false; LENGTH];
|
||||
self.multiplier = 1.0;
|
||||
}
|
||||
}
|
Loading…
Reference in new issue