-
-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implementation of Decision Trees Algorithm
- Loading branch information
1 parent
0bc1351
commit b485243
Showing
2 changed files
with
299 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
module ml | ||
|
||
import math | ||
|
||
pub struct Sample { | ||
pub mut: | ||
features []f64 | ||
pub: | ||
label int | ||
} | ||
|
||
pub struct Dataset { | ||
pub mut: | ||
samples []Sample | ||
pub: | ||
n_features int | ||
n_classes int | ||
} | ||
|
||
struct Node { | ||
mut: | ||
feature int | ||
threshold f64 | ||
label int | ||
left &Node | ||
right &Node | ||
} | ||
|
||
pub struct DecisionTree { | ||
mut: | ||
root &Node | ||
max_depth int | ||
min_samples_split int | ||
} | ||
|
||
pub fn DecisionTree.new(max_depth int, min_samples_split int) &DecisionTree { | ||
return &DecisionTree{ | ||
root: &Node(unsafe { nil }) | ||
max_depth: max_depth | ||
min_samples_split: min_samples_split | ||
} | ||
} | ||
|
||
pub fn index_of_max(arr []int) int { | ||
mut max_index := 0 | ||
for i := 1; i < arr.len; i++ { | ||
if arr[i] > arr[max_index] { | ||
max_index = i | ||
} | ||
} | ||
return max_index | ||
} | ||
|
||
pub fn create_dataset(n_features int, n_classes int) &Dataset { | ||
return &Dataset{ | ||
samples: []Sample{} | ||
n_features: n_features | ||
n_classes: n_classes | ||
} | ||
} | ||
|
||
pub fn (mut dataset Dataset) add_sample(features []f64, label int) bool { | ||
if label < 0 || label >= dataset.n_classes { | ||
return false | ||
} | ||
dataset.samples << Sample{ | ||
features: features.clone() | ||
label: label | ||
} | ||
return true | ||
} | ||
|
||
pub fn (dataset &Dataset) calculate_entropy() f64 { | ||
mut class_counts := []int{len: dataset.n_classes, init: 0} | ||
for sample in dataset.samples { | ||
class_counts[sample.label]++ | ||
} | ||
|
||
mut entropy := 0.0 | ||
for count in class_counts { | ||
if count > 0 { | ||
p := f64(count) / f64(dataset.samples.len) | ||
entropy -= p * math.log2(p) | ||
} | ||
} | ||
return entropy | ||
} | ||
|
||
fn find_best_split(dataset &Dataset) (int, f64, f64) { | ||
mut best_gain := -1.0 | ||
mut best_feature := 0 | ||
mut best_threshold := 0.0 | ||
|
||
for feature in 0 .. dataset.n_features { | ||
for sample in dataset.samples { | ||
threshold := sample.features[feature] | ||
mut left := create_dataset(dataset.n_features, dataset.n_classes) | ||
mut right := create_dataset(dataset.n_features, dataset.n_classes) | ||
|
||
for s in dataset.samples { | ||
if s.features[feature] <= threshold { | ||
left.add_sample(s.features, s.label) | ||
} else { | ||
right.add_sample(s.features, s.label) | ||
} | ||
} | ||
|
||
if left.samples.len > 0 && right.samples.len > 0 { | ||
p_left := f64(left.samples.len) / f64(dataset.samples.len) | ||
p_right := f64(right.samples.len) / f64(dataset.samples.len) | ||
gain := dataset.calculate_entropy() - (p_left * left.calculate_entropy() + p_right * right.calculate_entropy()) | ||
|
||
if gain > best_gain { | ||
best_gain = gain | ||
best_feature = feature | ||
best_threshold = threshold | ||
} | ||
} | ||
} | ||
} | ||
|
||
return best_feature, best_threshold, best_gain | ||
} | ||
|
||
fn build_tree(dataset &Dataset, max_depth int, min_samples_split int) &Node { | ||
if dataset.samples.len < min_samples_split || max_depth == 0 { | ||
mut class_counts := []int{len: dataset.n_classes, init: 0} | ||
for sample in dataset.samples { | ||
class_counts[sample.label]++ | ||
} | ||
label := index_of_max(class_counts) | ||
return &Node{ | ||
feature: -1 | ||
threshold: 0 | ||
label: label | ||
left: &Node(unsafe { nil }) | ||
right: &Node(unsafe { nil }) | ||
} | ||
} | ||
|
||
best_feature, best_threshold, best_gain := find_best_split(dataset) | ||
|
||
if best_gain <= 0 { | ||
mut class_counts := []int{len: dataset.n_classes, init: 0} | ||
for sample in dataset.samples { | ||
class_counts[sample.label]++ | ||
} | ||
label := index_of_max(class_counts) | ||
return &Node{ | ||
feature: -1 | ||
threshold: 0 | ||
label: label | ||
left: &Node(unsafe { nil }) | ||
right: &Node(unsafe { nil }) | ||
} | ||
} | ||
|
||
mut left := create_dataset(dataset.n_features, dataset.n_classes) | ||
mut right := create_dataset(dataset.n_features, dataset.n_classes) | ||
|
||
for sample in dataset.samples { | ||
if sample.features[best_feature] <= best_threshold { | ||
left.add_sample(sample.features, sample.label) | ||
} else { | ||
right.add_sample(sample.features, sample.label) | ||
} | ||
} | ||
|
||
left_subtree := build_tree(left, max_depth - 1, min_samples_split) | ||
right_subtree := build_tree(right, max_depth - 1, min_samples_split) | ||
|
||
return &Node{ | ||
feature: best_feature | ||
threshold: best_threshold | ||
label: -1 | ||
left: left_subtree | ||
right: right_subtree | ||
} | ||
} | ||
|
||
pub fn (mut dt DecisionTree) train(dataset &Dataset) { | ||
dt.root = build_tree(dataset, dt.max_depth, dt.min_samples_split) | ||
} | ||
|
||
pub fn (dt &DecisionTree) predict(features []f64) int { | ||
return predict_recursive(dt.root, features) | ||
} | ||
|
||
fn predict_recursive(node &Node, features []f64) int { | ||
if node.left == unsafe { nil } && node.right == unsafe { nil } { | ||
return node.label | ||
} | ||
|
||
if features[node.feature] <= node.threshold { | ||
return predict_recursive(node.left, features) | ||
} else { | ||
return predict_recursive(node.right, features) | ||
} | ||
} | ||
|
||
pub fn calculate_information_gain(parent &Dataset, left &Dataset, right &Dataset) f64 { | ||
p_left := f64(left.samples.len) / f64(parent.samples.len) | ||
p_right := f64(right.samples.len) / f64(parent.samples.len) | ||
return parent.calculate_entropy() - (p_left * left.calculate_entropy() + p_right * right.calculate_entropy()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
module ml | ||
|
||
import math | ||
|
||
fn test_decision_tree_creation() { | ||
max_depth := 3 | ||
min_samples_split := 2 | ||
dt := DecisionTree.new(max_depth, min_samples_split) | ||
assert dt.max_depth == max_depth | ||
assert dt.min_samples_split == min_samples_split | ||
} | ||
|
||
fn test_dataset_creation() { | ||
n_features := 3 | ||
n_classes := 4 | ||
dataset := create_dataset(n_features, n_classes) | ||
assert dataset.n_features == n_features | ||
assert dataset.n_classes == n_classes | ||
assert dataset.samples.len == 0 | ||
} | ||
|
||
fn test_add_sample() { | ||
mut dataset := create_dataset(3, 4) | ||
features := [1.0, 2.0, 3.0] | ||
label := 2 | ||
assert dataset.add_sample(features, label) == true | ||
assert dataset.samples.len == 1 | ||
assert dataset.samples[0].features == features | ||
assert dataset.samples[0].label == label | ||
|
||
// Test invalid label | ||
assert dataset.add_sample(features, 5) == false | ||
assert dataset.samples.len == 1 | ||
} | ||
|
||
fn test_entropy_calculation() { | ||
mut dataset := create_dataset(3, 4) | ||
dataset.add_sample([1.0, 2.0, 0.5], 0) | ||
dataset.add_sample([2.0, 3.0, 1.0], 1) | ||
dataset.add_sample([3.0, 4.0, 1.5], 2) | ||
dataset.add_sample([4.0, 5.0, 2.0], 3) | ||
dataset.add_sample([2.5, 3.5, 1.2], 1) | ||
|
||
entropy := dataset.calculate_entropy() | ||
expected_entropy := 1.9219280948873623 // Manually calculated | ||
assert math.abs(entropy - expected_entropy) < 1e-6 | ||
} | ||
|
||
fn test_decision_tree_training_and_prediction() { | ||
mut dataset := create_dataset(3, 4) | ||
dataset.add_sample([1.0, 2.0, 0.5], 0) | ||
dataset.add_sample([2.0, 3.0, 1.0], 1) | ||
dataset.add_sample([3.0, 4.0, 1.5], 2) | ||
dataset.add_sample([4.0, 5.0, 2.0], 3) | ||
dataset.add_sample([2.5, 3.5, 1.2], 1) | ||
|
||
mut dt := DecisionTree.new(3, 2) | ||
dt.train(dataset) | ||
|
||
// Test predictions | ||
assert dt.predict([2.5, 3.5, 1.3]) == 1 // Manually calculated | ||
} | ||
|
||
fn test_information_gain() { | ||
mut parent := create_dataset(3, 3) | ||
parent.add_sample([2.0, 3.5, 1.1], 0) | ||
parent.add_sample([3.0, 4.0, 1.5], 1) | ||
parent.add_sample([1.5, 2.0, 0.5], 0) | ||
parent.add_sample([2.5, 3.0, 1.0], 1) | ||
parent.add_sample([4.0, 5.0, 2.0], 2) | ||
|
||
mut left := create_dataset(3, 3) | ||
left.add_sample([2.0, 3.5, 1.1], 0) | ||
left.add_sample([1.5, 2.0, 0.5], 0) | ||
|
||
mut right := create_dataset(3, 3) | ||
right.add_sample([3.0, 4.0, 1.5], 1) | ||
right.add_sample([2.5, 3.0, 1.0], 1) | ||
right.add_sample([4.0, 5.0, 2.0], 2) | ||
|
||
info_gain := calculate_information_gain(parent, left, right) | ||
expected_gain := 0.9709505944546686 // Manually calculated | ||
assert math.abs(info_gain - expected_gain) < 1e-6 | ||
} | ||
|
||
fn main() { | ||
test_decision_tree_creation() | ||
test_dataset_creation() | ||
test_add_sample() | ||
test_entropy_calculation() | ||
test_decision_tree_training_and_prediction() | ||
test_information_gain() | ||
} | ||
|