forked from AtheMathmo/rusty-machine
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnnet-and_gate.rs
83 lines (69 loc) · 2.34 KB
/
nnet-and_gate.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
extern crate rusty_machine;
extern crate rand;
use rand::{random, Closed01};
use std::vec::Vec;
use rusty_machine::learning::nnet::{NeuralNet, BCECriterion};
use rusty_machine::learning::toolkit::regularization::Regularization;
use rusty_machine::learning::optim::grad_desc::StochasticGD;
use rusty_machine::linalg::Matrix;
use rusty_machine::learning::SupModel;
// AND gate
fn main() {
println!("AND gate learner sample:");
const THRESHOLD: f64 = 0.7;
const SAMPLES: usize = 10000;
println!("Generating {} training data and labels...", SAMPLES as u32);
let mut input_data = Vec::with_capacity(SAMPLES * 2);
let mut label_data = Vec::with_capacity(SAMPLES);
for _ in 0..SAMPLES {
// The two inputs are "signals" between 0 and 1
let Closed01(left) = random::<Closed01<f64>>();
let Closed01(right) = random::<Closed01<f64>>();
input_data.push(left);
input_data.push(right);
if left > THRESHOLD && right > THRESHOLD {
label_data.push(1.0);
} else {
label_data.push(0.0)
}
}
let inputs = Matrix::new(SAMPLES, 2, input_data);
let targets = Matrix::new(SAMPLES, 1, label_data);
let layers = &[2, 1];
let criterion = BCECriterion::new(Regularization::L2(0.));
let mut model = NeuralNet::new(layers, criterion, StochasticGD::default());
println!("Training...");
// Our train function returns a Result<(), E>
model.train(&inputs, &targets).unwrap();
let test_cases = vec![
0.0, 0.0,
0.0, 1.0,
1.0, 1.0,
1.0, 0.0,
];
let expected = vec![
0.0,
0.0,
1.0,
0.0,
];
let test_inputs = Matrix::new(test_cases.len() / 2, 2, test_cases);
let res = model.predict(&test_inputs).unwrap();
println!("Evaluation...");
let mut hits = 0;
let mut misses = 0;
// Evaluation
println!("Got\tExpected");
for (idx, prediction) in res.into_vec().iter().enumerate() {
println!("{:.2}\t{}", prediction, expected[idx]);
if (prediction - 0.5) * (expected[idx] - 0.5) > 0. {
hits += 1;
} else {
misses += 1;
}
}
println!("Hits: {}, Misses: {}", hits, misses);
let hits_f = hits as f64;
let total = (hits + misses) as f64;
println!("Accuracy: {}%", (hits_f / total) * 100.);
}