diff --git a/examples/ml_kmeans_plot/main.v b/examples/ml_kmeans_plot/main.v index ce3cc0f73..638937147 100644 --- a/examples/ml_kmeans_plot/main.v +++ b/examples/ml_kmeans_plot/main.v @@ -30,5 +30,5 @@ fn main() { model.train(epochs: 6) // Plot the results using the new plot method - model.plot()! + model.get_plotter().show()! } diff --git a/examples/ml_knn_plot/main.v b/examples/ml_knn_plot/main.v index edb9bd670..8cbd96720 100644 --- a/examples/ml_knn_plot/main.v +++ b/examples/ml_knn_plot/main.v @@ -42,5 +42,5 @@ fn main() { println('Prediction: ${prediction}') // Plot the KNN model - knn.plot()! + knn.get_plotter().show()! } diff --git a/examples/ml_linreg_plot/main.v b/examples/ml_linreg_plot/main.v index b6847d9c9..cdb386e9b 100644 --- a/examples/ml_linreg_plot/main.v +++ b/examples/ml_linreg_plot/main.v @@ -30,5 +30,5 @@ fn main() { reg.train() - reg.plot()! + reg.get_plotter().show()! } diff --git a/ml/kmeans.v b/ml/kmeans.v index cba8d2c2e..53f78447a 100644 --- a/ml/kmeans.v +++ b/ml/kmeans.v @@ -149,8 +149,8 @@ pub fn (o &Kmeans) str() string { return res.join('\n') } -// plot method for visualizing the clustering -pub fn (o &Kmeans) plot() ! { +// get_plotter returns a plot.Plot struct for plotting +pub fn (o &Kmeans) get_plotter() &plot.Plot { mut plt := plot.Plot.new() plt.layout( title: 'K-means Clustering' @@ -194,5 +194,5 @@ pub fn (o &Kmeans) plot() ! { } ) - plt.show()! + return plt } diff --git a/ml/knn.v b/ml/knn.v index 3ab026899..f00d4efe6 100644 --- a/ml/knn.v +++ b/ml/knn.v @@ -204,8 +204,9 @@ pub fn (o &KNN) str() string { return res.join('\n') } -// plot method for visualizing the KNN model -pub fn (o &KNN) plot() ! { +// get_plotter returns a plot.Plot struct with the data needed to plot +// the KNN model. +pub fn (o &KNN) get_plotter() &plot.Plot { mut plt := plot.Plot.new() plt.layout( title: 'K-Nearest Neighbors' @@ -237,5 +238,5 @@ pub fn (o &KNN) plot() ! { ) } - plt.show()! + return plt } diff --git a/ml/linreg.v b/ml/linreg.v index cace7c402..9e742cd45 100644 --- a/ml/linreg.v +++ b/ml/linreg.v @@ -150,8 +150,8 @@ pub fn (o &LinReg) str() string { return res.join('\n') } -// plot plots the data and the linear regression model -pub fn (o &LinReg) plot() ! { +// get_plotter returns a plot.Plot struct for plotting the data and the linear regression model +pub fn (o &LinReg) get_plotter() &plot.Plot { // Get the minimum and maximum values of the features min_x := o.stat.min_x[0] max_x := o.stat.max_x[0] @@ -167,28 +167,18 @@ pub fn (o &LinReg) plot() ! { plt.layout( title: 'Linear Regression Example' ) - plt.scatter( name: 'dataset' x: o.data.x.get_col(0) y: o.data.y mode: 'markers' - colorscale: 'smoker' - marker: plot.Marker{ - size: []f64{len: o.data.y.len, init: 10.0} - } ) - plt.scatter( - name: 'linear regression' + name: 'prediction' x: x_values y: y_values mode: 'lines' - colorscale: 'smoker' - line: plot.Line{ - color: 'red' - } ) - plt.show()! + return plt } diff --git a/plot/plot.v b/plot/plot.v index de630c3e9..6f263b636 100644 --- a/plot/plot.v +++ b/plot/plot.v @@ -10,8 +10,8 @@ pub mut: layout Layout } -pub fn Plot.new() Plot { - return Plot{} +pub fn Plot.new() &Plot { + return &Plot{} } // add_trace adds a trace to the plot