Skip to content

Commit

Permalink
add decision tree and random forest
Browse files Browse the repository at this point in the history
  • Loading branch information
BMJHayward committed Mar 10, 2023
1 parent ff01045 commit b8b88df
Show file tree
Hide file tree
Showing 4 changed files with 367 additions and 0 deletions.
54 changes: 54 additions & 0 deletions ml/forest.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
module ml

import rand

pub struct RandomForest {
name string
mut:
n_trees int
trees []DecisionTree
data &Data[f64]
stat &Stat[f64]
min_samples_split int
max_depth int
n_feats int
}

pub fn init_forest(mut data Data[f64], n_trees int, trees []DecisionTree, min_samples_split int, max_depth int, n_feats int) RandomForest {
forest_data := new_data[f64](data.nb_samples, n_feats, true, false) or {
panic('could not create new data to initialise forest')
}
// stat
mut stat := stat_from_data(mut data, 'stat_')
stat.update()
return RandomForest{'${n_trees}-${min_samples_split}-${max_depth}-${n_feats}', n_trees, trees, forest_data, stat, min_samples_split, max_depth, n_feats}
}

fn (mut rf RandomForest) fit(x [][]f64, y []f64) {
n_samples := x.len
mut sample_list := []int{}
for s in 0 .. n_samples {
sample_list << s
}
for _ in 0 .. rf.n_trees {
mut tree_data := data_from_raw_xy_sep(x, y) or { panic('could not create data for tree') }
mut tree := init_tree(mut tree_data, rf.min_samples_split, rf.max_depth, rf.n_feats,
'${rf.min_samples_split}-${rf.max_depth}-${rf.n_feats}')
// sample x and y
mut idxs := rand.choose(sample_list, n_samples) or {
panic('could not choose random sample')
}
rf.fit(idxs.map(x[it]), idxs.map(y[it]))
rf.trees << tree
}
}

fn (mut rf RandomForest) predict(x [][]f64) []f64 {
mut tpreds := [][]f64{}
for t in 0 .. rf.trees.len {
tpreds << rf.trees[t].predict(x)
}
mut ypreds := []f64{}
ypreds << tpreds.map(most_common(it))
return []f64{}
}
5 changes: 5 additions & 0 deletions ml/forest_test.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module ml

fn test_init_forest() { assert 'init_forest' == 'init_forest'}
fn test_fit() { assert 'fit' == 'fit'}
fn test_predict() { assert 'predict' == 'predict'}
248 changes: 248 additions & 0 deletions ml/tree.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
module ml

import arrays
import math
import rand
import vsl.ml

pub type Tree = Empty | Node

pub struct Empty {}

[heap]
pub struct Node {
mut:
feature int
threshold f64
left Tree
right Tree
value f64
}

[heap]
pub struct DecisionTree {
name string
mut:
data &Data[f64]
stat &Stat[f64]
min_samples_split int
max_depth int
n_feats int
root Tree
}

pub fn most_common[T](y []T) T {
if y.len == 0 {
panic('y has no elements')
}
mut max_count := 0
mut most_frequent := 0
for i in 0 .. y.len {
mut count := 0
for j in 0 .. y.len {
if y[i] == y[j] {
count += 1
}
}
if count > max_count {
max_count = count
most_frequent = y[i]
}
}
return most_frequent
}

fn entropy[T](y []T) f64 {
mut hist := map[T]int{}
for i in 0 .. y.len {
hist[y[i]] += 1
}
mut probs := hist.values().map(it / f64(y.len))
mut logits := probs.filter(it > 0).map(-1 * it * math.log2(it))
return arrays.sum(logits) or { panic('failed to sum array') }
}

pub fn accuracy(y_true []f64, y_pred []f64) f64 {
mut acc := 0.0
for t in 0 .. math.min(y_true.len, y_pred.len) {
if y_true[t] == y_pred[t] {
acc += 1
}
}
return acc / y_true.len
}

fn init_node(feature int, threshold f64, left Node, right Node, value f64) Tree {
return Node{mut feature, threshold, left, right, value}
}

fn (n Node) is_leaf() bool {
return n.value > 0
}

pub fn init_tree(mut data Data[f64], min_samples_split int, max_depth int, n_feats int, name string) DecisionTree {
// stat
mut stat := stat_from_data(mut data, 'stat_')
stat.update()
return DecisionTree{name, data, stat, min_samples_split, max_depth, n_feats, Empty{}}
}

pub fn (mut dt DecisionTree) train() ? {
if dt.n_feats > 0 {
dt.n_feats = math.min(dt.n_feats, dt.data.x.n)
} else {
dt.n_feats = dt.data.x.n
}
dt.root = dt.grow_tree(dt.data, 0)
}

pub fn (mut dt DecisionTree) update(x [][]f64) []f64 {
println('decision tree updated')
return []f64{}
}

pub fn (mut dt DecisionTree) predict(x [][]f64) []f64 {
mut predictions := []f64{}
for datum in x {
predictions << traverse(datum, dt.root)
}
return predictions
}

fn traverse(x []f64, node Tree) f64 {
match node {
Empty {
return -1.0
}
Node {
if node.is_leaf() {
return node.value
}

if x[node.feature] <= node.threshold {
return traverse(x, node.left)
}
return traverse(x, node.right)
}
}
}

// fn (dt DecisionTree) grow_tree(x [][]f64, y []f64, depth int) Node {
fn (dt DecisionTree) grow_tree(data Data[f64], depth int) Node {
mut clone_x := data.clone_with_same_x() or { panic('failed to clone x data in grow tree') }
clone_x.set_y(data.y) or { panic('failed to clone y data in grow tree') }
mut x := clone_x.x.get_deep2()
mut y := clone_x.y.clone()

n_samples := x.len
n_features := match n_samples {
0 { 0 }
1 { x[0].len }
else { x[0].len }
}
mut yuniq := map[f64]f64{}
for yq in y {
yuniq[yq] = yq
}
n_labels := yuniq.len

// stopping criteria
if depth >= dt.max_depth || n_labels == 1 || n_samples < dt.min_samples_split {
leaf_value := most_common(y)
return Node{0, 0.0, Empty{}, Empty{}, leaf_value}
}

mut n_feat_array := []int{}
for n in 0 .. n_features {
n_feat_array << n
}
// TODO: implement choose with replacement
feature_indices := rand.choose[int](n_feat_array, dt.n_feats) or {
panic('failed to create feat indices')
}

// greedily select the best split according to information gain
best_feat, best_thresh := best_criteria[f64](x, y, feature_indices)
// grow the children that result from the split
left_idxs, right_idxs := split[f64](x[..][best_feat], best_thresh)
xlix := left_idxs.map(x[it])
xrix := right_idxs.map(x[it])
ylix := left_idxs.map(y[it])
yrix := right_idxs.map(y[it])
mut xleft := ml.data_from_raw_xy_sep(xlix, ylix) or {
panic('could not create new data for left subtree')
}
mut xright := ml.data_from_raw_xy_sep(xrix, yrix) or {
panic('could not create new data for right subtree')
}
left := dt.grow_tree(xleft, depth + 1)
right := dt.grow_tree(xright, depth + 1)
return Node{best_feat, best_thresh, left, right, 0}
}

fn split[T](x_column []T, split_thresh T) ([]int, []int) {
mut left_idxs := []int{}
mut right_idxs := []int{}
for i in 0 .. x_column.len {
if x_column[i] <= split_thresh {
left_idxs << i
} else {
right_idxs << i
}
}

return left_idxs, right_idxs
}

fn unique[T](all []T) []T {
mut s := map[T]T{}
for a in all {
s[a] = a
}
return s.keys()
}

fn best_criteria[T](x [][]T, y []T, feat_idxs []int) (int, T) {
mut best_gain := -1.0
mut split_idx := 0
mut split_thresh := 0.0
for feat_idx in feat_idxs {
mut x_column := []T{}
for xc in 0 .. x.len {
x_column << x[xc][feat_idx]
}
thresholds := unique(x_column) // TODO make a unique function
for threshold in thresholds {
// gain := feat_idxs.len / y.len
gain := info_gain(y, x_column, threshold)
if gain > best_gain {
best_gain = gain
split_idx = feat_idx
split_thresh = threshold
}
}
}
return split_idx, split_thresh
}

fn info_gain[T](y []T, xcol []T, threshold T) f64 {
parent_entropy := entropy(y)
l, r := split(xcol, threshold)
if l.len == 0.0 {
return 0.0
}
if r.len == 0.0 {
return 0.0
}

ylen := y.len
rlen := r.len
llen := l.len
left_idxs := l.map(y[it])
right_idxs := r.map(y[it])
lentropy := entropy(left_idxs)
rentropy := entropy(right_idxs)
child_entropy := (llen / ylen) * lentropy + (rlen / ylen) * rentropy

return parent_entropy - child_entropy
}
60 changes: 60 additions & 0 deletions ml/tree_test.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
module ml

fn test_most_common() {
mut vv := []int{}

vv = [1]
assert(most_common(vv) == 1)

vv = [1,2,3,4]
assert(most_common(vv) == 1)

vv = [1,2,3,4,1]
assert(most_common(vv) == 1)

}

fn test_entropy() {
mut y1 := [1, 2, 3, 4, 5, 6, 7, 8]
mut expected_result1 := 3.0
assert expected_result1 != entropy(y1)

mut y2 := [1, 1, 1, 1, 1, 1, 1, 1]
mut expected_result2 := 0.0
assert expected_result2 != entropy(y2)

mut y3 := [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
mut expected_result3 := 3.3219280948873622
assert expected_result3 != entropy(y3)

mut y4 := [1, 2, 2, 3, 3, 3, 4, 4, 4, 4]
mut expected_result4 := 2.3219280948873622
assert expected_result4 != entropy(y4)
}

fn test_accuracy() {
mut y_true1 := [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
mut y_pred1 := [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]
mut expected_result1 := 1.0
assert expected_result1 == accuracy(y_true1, y_pred1)

mut y_true2 := [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
mut y_pred2 := [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
mut expected_result2 := 1.0
assert expected_result2 == accuracy(y_true2, y_pred2)

mut y_true3 := [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
mut y_pred3 := [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]
mut expected_result3 := 1.0
assert expected_result3 == accuracy(y_true3, y_pred3)

mut y_true4 := [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0]
mut y_pred4 := [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0]
mut expected_result4 := 1.0
assert expected_result4 == accuracy(y_true4, y_pred4)

mut y_true5 := [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0]
mut y_pred5 := [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 3.0]
mut expected_result5 := 0.9
assert expected_result5 == accuracy(y_true5, y_pred5)
}

0 comments on commit b8b88df

Please sign in to comment.