Skip to content

Commit

Permalink
Added split method for data
Browse files Browse the repository at this point in the history
  • Loading branch information
ulises-jeremias committed Oct 17, 2023
1 parent 42ac97c commit 78f987b
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 6 deletions.
34 changes: 34 additions & 0 deletions examples/ml_multilinreg_plot/main.v
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
module main

import vsl.ml

fn main() {
xy := [
[0.99, 90.01],
[1.02, 89.05],
[1.15, 91.43],
[1.29, 93.74],
[1.46, 96.73],
[1.36, 94.45],
[0.87, 87.59],
[1.23, 91.77],
[1.55, 99.42],
[1.40, 93.65],
[1.19, 93.54],
[1.15, 92.52],
[0.98, 90.56],
[1.01, 89.54],
[1.11, 89.85],
[1.20, 90.39],
[1.26, 93.25],
[1.32, 93.41],
[1.43, 94.98],
[0.95, 87.33],
]
mut data := ml.data_from_raw_xy(xy)!
mut reg := ml.new_multi_lin_reg(mut data, 3, 'linear regression')

reg.train()

reg.plot()!
}
46 changes: 42 additions & 4 deletions la/matrix.v
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,9 @@ pub fn (o &Matrix[T]) get_col(j int) []T {
// extract_cols returns columns from j=start to j=endp1-1
// start -- first column
// endp1 -- "end-plus-one", the number of the last requested column + 1
pub fn (o &Matrix[T]) extract_cols(start int, endp1 int) &Matrix[T] {
pub fn (o &Matrix[T]) extract_cols(start int, endp1 int) !&Matrix[T] {
if endp1 <= start {
errors.vsl_panic("endp1 'end-plus-one' must be greater than start. start=${start}, endp1=${endp1} invalid",
return errors.error("endp1 'end-plus-one' must be greater than start. start=${start}, endp1=${endp1} invalid",
.efailed)
}
ncol := endp1 - start
Expand All @@ -235,9 +235,9 @@ pub fn (o &Matrix[T]) extract_cols(start int, endp1 int) &Matrix[T] {
// extract_rows returns rows from i=start to i=endp1-1
// start -- first column
// endp1 -- "end-plus-one", the number of the last requested column + 1
pub fn (o &Matrix[T]) extract_rows(start int, endp1 int) &Matrix[T] {
pub fn (o &Matrix[T]) extract_rows(start int, endp1 int) !&Matrix[T] {
if endp1 <= start {
errors.vsl_panic("endp1 'end-plus-one' must be greater than start. start=${start}, endp1=${endp1} invalid",
return errors.error("endp1 'end-plus-one' must be greater than start. start=${start}, endp1=${endp1} invalid",
.efailed)
}
nrow := endp1 - start
Expand All @@ -260,6 +260,44 @@ pub fn (mut o Matrix[T]) set_col(j int, value T) {
}
}

// split_by_col splits this matrix into two matrices at column j
// j -- column index
pub fn (o &Matrix[T]) split_by_col(j int) !(&Matrix[T], &Matrix[T]) {
if j < 0 || j >= o.n {
return errors.error('j=${j} must be in range [0, ${o.n})', .efailed)
}
mut left := new_matrix[T](o.m, j)
mut right := new_matrix[T](o.m, o.n - j)
for i in 0 .. o.m {
for k := 0; k < j; k++ {
left.set(i, k, o.get(i, k))
}
for k := j; k < o.n; k++ {
right.set(i, k - j, o.get(i, k))
}
}
return left, right
}

// split_by_row splits this matrix into two matrices at row i
// i -- row index
pub fn (o &Matrix[T]) split_by_row(i int) !(&Matrix[T], &Matrix[T]) {
if i < 0 || i >= o.m {
return errors.error('i=${i} must be in range [0, ${o.m})', .efailed)
}
mut top := new_matrix[T](i, o.n)
mut bottom := new_matrix[T](o.m - i, o.n)
for j in 0 .. o.n {
for k := 0; k < i; k++ {
top.set(k, j, o.get(k, j))
}
for k := i; k < o.m; k++ {
bottom.set(k - i, j, o.get(k, j))
}
}
return top, bottom
}

// norm_frob returns the Frobenius norm of this matrix
// nrm := ‖a‖_F = sqrt(Σ_i Σ_j a[ij]⋅a[ij]) = ‖a‖_2
pub fn (o &Matrix[T]) norm_frob() T {
Expand Down
2 changes: 1 addition & 1 deletion la/matrix_test.v
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ fn test_extract_cols_and_set_col() {
[31.0, 32.0, 33.0],
])
println(mat.data)
mat_e := mat.extract_cols(1, 3)
mat_e := mat.extract_cols(1, 3)!
assert mat_e.m == 3
assert mat_e.n == 2
assert mat_e.data == [12.0, 13, 22, 23, 32, 33]
Expand Down
33 changes: 32 additions & 1 deletion ml/data.v
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module ml

import math
import vsl.util
import vsl.la
import vsl.errors
Expand Down Expand Up @@ -38,7 +39,11 @@ pub mut:
// Output:
// new object
pub fn new_data[T](nb_samples int, nb_features int, use_y bool, allocate bool) !&Data[T] {
x := if allocate { la.new_matrix[T](nb_samples, nb_features) } else { la.new_matrix[T](0, 0) }
x := if allocate {
la.new_matrix[T](nb_samples, nb_features)
} else {
&la.Matrix[T](unsafe { nil })
}
mut y := []T{}
if allocate && use_y {
y = []T{len: nb_samples}
Expand Down Expand Up @@ -186,3 +191,29 @@ pub fn (mut o Data[T]) add_observer(obs util.Observer) {
pub fn (mut o Data[T]) notify_update() {
o.observable.notify_update()
}

// split returns a new object with data split into two parts
// Input:
// ratio -- ratio of samples to be put in the first part
// Output:
// new object
pub fn (o &Data[T]) split(ratio f64) !(&Data[T], &Data[T]) {
if ratio <= 0.0 || ratio >= 1.0 {
return errors.error('ratio must be between 0 and 1', .efailed)
}
nb_features := o.nb_features
nb_samples := o.nb_samples

nb_samples1 := int(math.floor((ratio * nb_samples)))
nb_samples2 := nb_samples - nb_samples1

m1, m2 := o.x.split_by_row(nb_samples1)!

mut o1 := new_data[T](nb_samples1, nb_features, false, false)!
mut o2 := new_data[T](nb_samples2, nb_features, false, false)!

o1.set(m1, o.y[..nb_samples1])!
o2.set(m2, o.y[nb_samples1..])!

return o1, o2
}
16 changes: 16 additions & 0 deletions ml/data_test.v
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,19 @@ fn test_data_01() {
assert data_backup.nb_features == 3
assert data_backup.nb_samples == 5
}

fn test_split() {
data := data_from_raw_xy([
[-1.0, 0, -3, 0],
[-2.0, 3, 3, 1],
[3.0, 1, 4, 1],
[-4.0, 5, 0, 0],
[1.0, -8, 5, 1],
[-1.0, 0, -3, 1],
])!
data1, data2 := data.split(0.2)!
assert data1.nb_samples == 1
assert data2.nb_samples == 5
assert data1.nb_features == 3
assert data2.nb_features == 3
}

0 comments on commit 78f987b

Please sign in to comment.