-
-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add decision tree and random forest #158
Closed
Closed
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
module ml | ||
|
||
import rand | ||
|
||
pub struct RandomForest { | ||
name string | ||
mut: | ||
n_trees int | ||
trees []DecisionTree | ||
data &Data[f64] | ||
stat &Stat[f64] | ||
min_samples_split int | ||
max_depth int | ||
n_feats int | ||
} | ||
|
||
pub fn init_forest(mut data Data[f64], n_trees int, trees []DecisionTree, min_samples_split int, max_depth int, n_feats int) RandomForest { | ||
forest_data := new_data[f64](data.nb_samples, n_feats, true, false) or { | ||
panic('could not create new data to initialise forest') | ||
} | ||
// stat | ||
mut stat := stat_from_data(mut data, 'stat_') | ||
stat.update() | ||
return RandomForest{'${n_trees}-${min_samples_split}-${max_depth}-${n_feats}', n_trees, trees, forest_data, stat, min_samples_split, max_depth, n_feats} | ||
} | ||
|
||
fn (mut rf RandomForest) fit(x [][]f64, y []f64) { | ||
n_samples := x.len | ||
mut sample_list := []int{} | ||
for s in 0 .. n_samples { | ||
sample_list << s | ||
} | ||
for _ in 0 .. rf.n_trees { | ||
mut tree_data := data_from_raw_xy_sep(x, y) or { panic('could not create data for tree') } | ||
mut tree := init_tree(mut tree_data, rf.min_samples_split, rf.max_depth, rf.n_feats, | ||
'${rf.min_samples_split}-${rf.max_depth}-${rf.n_feats}') | ||
// sample x and y | ||
mut idxs := rand.choose(sample_list, n_samples) or { | ||
panic('could not choose random sample') | ||
} | ||
rf.fit(idxs.map(x[it]), idxs.map(y[it])) | ||
rf.trees << tree | ||
} | ||
} | ||
|
||
fn (mut rf RandomForest) predict(x [][]f64) []f64 { | ||
mut tpreds := [][]f64{} | ||
for t in 0 .. rf.trees.len { | ||
tpreds << rf.trees[t].predict(x) | ||
} | ||
mut ypreds := []f64{} | ||
ypreds << tpreds.map(most_common(it)) | ||
return []f64{} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
module ml | ||
|
||
fn test_init_forest() { assert 'init_forest' == 'init_forest'} | ||
fn test_fit() { assert 'fit' == 'fit'} | ||
fn test_predict() { assert 'predict' == 'predict'} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
module ml | ||
|
||
import arrays | ||
import math | ||
import rand | ||
import vsl.ml | ||
|
||
pub type Tree = Empty | Node | ||
|
||
pub struct Empty {} | ||
|
||
[heap] | ||
pub struct Node { | ||
mut: | ||
feature int | ||
threshold f64 | ||
left Tree | ||
right Tree | ||
value f64 | ||
} | ||
|
||
[heap] | ||
pub struct DecisionTree { | ||
name string | ||
mut: | ||
data &Data[f64] | ||
stat &Stat[f64] | ||
min_samples_split int | ||
max_depth int | ||
n_feats int | ||
root Tree | ||
} | ||
|
||
pub fn most_common[T](y []T) T { | ||
if y.len == 0 { | ||
panic('y has no elements') | ||
} | ||
mut max_count := 0 | ||
mut most_frequent := 0 | ||
for i in 0 .. y.len { | ||
mut count := 0 | ||
for j in 0 .. y.len { | ||
if y[i] == y[j] { | ||
count += 1 | ||
} | ||
} | ||
if count > max_count { | ||
max_count = count | ||
most_frequent = y[i] | ||
} | ||
} | ||
return most_frequent | ||
} | ||
|
||
fn entropy[T](y []T) f64 { | ||
mut hist := map[T]int{} | ||
for i in 0 .. y.len { | ||
hist[y[i]] += 1 | ||
} | ||
mut probs := hist.values().map(it / f64(y.len)) | ||
mut logits := probs.filter(it > 0).map(-1 * it * math.log2(it)) | ||
return arrays.sum(logits) or { panic('failed to sum array') } | ||
} | ||
|
||
pub fn accuracy(y_true []f64, y_pred []f64) f64 { | ||
mut acc := 0.0 | ||
for t in 0 .. math.min(y_true.len, y_pred.len) { | ||
if y_true[t] == y_pred[t] { | ||
acc += 1 | ||
} | ||
} | ||
return acc / y_true.len | ||
} | ||
|
||
fn init_node(feature int, threshold f64, left Node, right Node, value f64) Tree { | ||
return Node{mut feature, threshold, left, right, value} | ||
} | ||
|
||
fn (n Node) is_leaf() bool { | ||
return n.value > 0 | ||
} | ||
|
||
pub fn init_tree(mut data Data[f64], min_samples_split int, max_depth int, n_feats int, name string) DecisionTree { | ||
// stat | ||
mut stat := stat_from_data(mut data, 'stat_') | ||
stat.update() | ||
return DecisionTree{name, data, stat, min_samples_split, max_depth, n_feats, Empty{}} | ||
} | ||
|
||
pub fn (mut dt DecisionTree) train() ? { | ||
if dt.n_feats > 0 { | ||
dt.n_feats = math.min(dt.n_feats, dt.data.x.n) | ||
} else { | ||
dt.n_feats = dt.data.x.n | ||
} | ||
dt.root = dt.grow_tree(dt.data, 0) | ||
} | ||
|
||
pub fn (mut dt DecisionTree) update(x [][]f64) []f64 { | ||
println('decision tree updated') | ||
return []f64{} | ||
} | ||
|
||
pub fn (mut dt DecisionTree) predict(x [][]f64) []f64 { | ||
mut predictions := []f64{} | ||
for datum in x { | ||
predictions << traverse(datum, dt.root) | ||
} | ||
return predictions | ||
} | ||
|
||
fn traverse(x []f64, node Tree) f64 { | ||
match node { | ||
Empty { | ||
return -1.0 | ||
} | ||
Node { | ||
if node.is_leaf() { | ||
return node.value | ||
} | ||
|
||
if x[node.feature] <= node.threshold { | ||
return traverse(x, node.left) | ||
} | ||
return traverse(x, node.right) | ||
} | ||
} | ||
} | ||
|
||
// fn (dt DecisionTree) grow_tree(x [][]f64, y []f64, depth int) Node { | ||
fn (dt DecisionTree) grow_tree(data Data[f64], depth int) Node { | ||
mut clone_x := data.clone_with_same_x() or { panic('failed to clone x data in grow tree') } | ||
clone_x.set_y(data.y) or { panic('failed to clone y data in grow tree') } | ||
mut x := clone_x.x.get_deep2() | ||
mut y := clone_x.y.clone() | ||
|
||
n_samples := x.len | ||
n_features := match n_samples { | ||
0 { 0 } | ||
1 { x[0].len } | ||
else { x[0].len } | ||
} | ||
mut yuniq := map[f64]f64{} | ||
for yq in y { | ||
yuniq[yq] = yq | ||
} | ||
n_labels := yuniq.len | ||
|
||
// stopping criteria | ||
if depth >= dt.max_depth || n_labels == 1 || n_samples < dt.min_samples_split { | ||
leaf_value := most_common(y) | ||
return Node{0, 0.0, Empty{}, Empty{}, leaf_value} | ||
} | ||
|
||
mut n_feat_array := []int{} | ||
for n in 0 .. n_features { | ||
n_feat_array << n | ||
} | ||
// TODO: implement choose with replacement | ||
feature_indices := rand.choose[int](n_feat_array, dt.n_feats) or { | ||
panic('failed to create feat indices') | ||
} | ||
|
||
// greedily select the best split according to information gain | ||
best_feat, best_thresh := best_criteria[f64](x, y, feature_indices) | ||
// grow the children that result from the split | ||
left_idxs, right_idxs := split[f64](x[..][best_feat], best_thresh) | ||
xlix := left_idxs.map(x[it]) | ||
xrix := right_idxs.map(x[it]) | ||
ylix := left_idxs.map(y[it]) | ||
yrix := right_idxs.map(y[it]) | ||
mut xleft := ml.data_from_raw_xy_sep(xlix, ylix) or { | ||
panic('could not create new data for left subtree') | ||
} | ||
mut xright := ml.data_from_raw_xy_sep(xrix, yrix) or { | ||
panic('could not create new data for right subtree') | ||
} | ||
left := dt.grow_tree(xleft, depth + 1) | ||
right := dt.grow_tree(xright, depth + 1) | ||
return Node{best_feat, best_thresh, left, right, 0} | ||
} | ||
|
||
fn split[T](x_column []T, split_thresh T) ([]int, []int) { | ||
mut left_idxs := []int{} | ||
mut right_idxs := []int{} | ||
for i in 0 .. x_column.len { | ||
if x_column[i] <= split_thresh { | ||
left_idxs << i | ||
} else { | ||
right_idxs << i | ||
} | ||
} | ||
|
||
return left_idxs, right_idxs | ||
} | ||
|
||
fn unique[T](all []T) []T { | ||
mut s := map[T]T{} | ||
for a in all { | ||
s[a] = a | ||
} | ||
return s.keys() | ||
} | ||
|
||
fn best_criteria[T](x [][]T, y []T, feat_idxs []int) (int, T) { | ||
mut best_gain := -1.0 | ||
mut split_idx := 0 | ||
mut split_thresh := 0.0 | ||
for feat_idx in feat_idxs { | ||
mut x_column := []T{} | ||
for xc in 0 .. x.len { | ||
x_column << x[xc][feat_idx] | ||
} | ||
thresholds := unique(x_column) // TODO make a unique function | ||
for threshold in thresholds { | ||
// gain := feat_idxs.len / y.len | ||
gain := info_gain(y, x_column, threshold) | ||
if gain > best_gain { | ||
best_gain = gain | ||
split_idx = feat_idx | ||
split_thresh = threshold | ||
} | ||
} | ||
} | ||
return split_idx, split_thresh | ||
} | ||
|
||
fn info_gain[T](y []T, xcol []T, threshold T) f64 { | ||
parent_entropy := entropy(y) | ||
l, r := split(xcol, threshold) | ||
if l.len == 0.0 { | ||
return 0.0 | ||
} | ||
if r.len == 0.0 { | ||
return 0.0 | ||
} | ||
|
||
ylen := y.len | ||
rlen := r.len | ||
llen := l.len | ||
left_idxs := l.map(y[it]) | ||
right_idxs := r.map(y[it]) | ||
lentropy := entropy(left_idxs) | ||
rentropy := entropy(right_idxs) | ||
child_entropy := (llen / ylen) * lentropy + (rlen / ylen) * rentropy | ||
|
||
return parent_entropy - child_entropy | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
module ml | ||
|
||
fn test_most_common() { | ||
mut vv := []int{} | ||
|
||
vv = [1] | ||
assert most_common(vv) == 1 | ||
|
||
vv = [1, 2, 3, 4] | ||
assert most_common(vv) == 1 | ||
|
||
vv = [1, 2, 3, 4, 1] | ||
assert most_common(vv) == 1 | ||
} | ||
|
||
fn test_entropy() { | ||
mut y1 := [1, 2, 3, 4, 5, 6, 7, 8] | ||
mut expected_result1 := 3.0 | ||
assert expected_result1 != entropy(y1) | ||
|
||
mut y2 := [1, 1, 1, 1, 1, 1, 1, 1] | ||
mut expected_result2 := 0.0 | ||
assert expected_result2 != entropy(y2) | ||
|
||
mut y3 := [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] | ||
mut expected_result3 := 3.3219280948873622 | ||
assert expected_result3 != entropy(y3) | ||
|
||
mut y4 := [1, 2, 2, 3, 3, 3, 4, 4, 4, 4] | ||
mut expected_result4 := 2.3219280948873622 | ||
assert expected_result4 != entropy(y4) | ||
} | ||
|
||
fn test_accuracy() { | ||
mut y_true1 := [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] | ||
mut y_pred1 := [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] | ||
mut expected_result1 := 1.0 | ||
assert expected_result1 == accuracy(y_true1, y_pred1) | ||
|
||
mut y_true2 := [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] | ||
mut y_pred2 := [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] | ||
mut expected_result2 := 1.0 | ||
assert expected_result2 == accuracy(y_true2, y_pred2) | ||
|
||
mut y_true3 := [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] | ||
mut y_pred3 := [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0] | ||
mut expected_result3 := 1.0 | ||
assert expected_result3 == accuracy(y_true3, y_pred3) | ||
|
||
mut y_true4 := [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0] | ||
mut y_pred4 := [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0] | ||
mut expected_result4 := 1.0 | ||
assert expected_result4 == accuracy(y_true4, y_pred4) | ||
|
||
mut y_true5 := [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0] | ||
mut y_pred5 := [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 3.0] | ||
mut expected_result5 := 0.9 | ||
assert expected_result5 == accuracy(y_true5, y_pred5) | ||
} | ||
|
||
fn test_init_node() {} | ||
|
||
fn test_is_leaf() {} | ||
|
||
fn test_traverse() {} | ||
|
||
fn test_grow_tree() {} | ||
|
||
fn test_split() {} | ||
|
||
fn test_unique() {} | ||
|
||
fn test_best_criteria() {} | ||
|
||
fn test_info_gain() {} | ||
|
||
fn test_init_tree() {} | ||
|
||
fn test_train() {} | ||
|
||
fn test_update() {} | ||
|
||
fn test_predict() {} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not up to date with recent V but some months ago this has been a huge anti-pattern.
IMHO this should be an optional - namely:
none
guarantees in compile-time that you will never forget to check the variable of typeTree
might be empty. It is also (much) faster in performance.Or maybe just do not use any auxiliary/proxy type and use
none
directly:(I do not know if this works but it should IMHO)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thank you! I agree 👌🏻 We can wait until the PR is ready for review before starting adding more comments and suggestions 😊
I'll assign you as a reviewer as soon as @BMJHayward says it is ready