From 8e0c6b19c5328f953919c33a4330e57cd4e1452e Mon Sep 17 00:00:00 2001 From: Pasha Stetsenko Date: Mon, 20 Nov 2023 16:49:30 -0800 Subject: [PATCH] Fix crash in dt_linearmodel --- src/core/models/dt_linearmodel.cc | 32 +++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/src/core/models/dt_linearmodel.cc b/src/core/models/dt_linearmodel.cc index 8671171332..5c3a08f27e 100644 --- a/src/core/models/dt_linearmodel.cc +++ b/src/core/models/dt_linearmodel.cc @@ -164,15 +164,25 @@ LinearModelFitOutput LinearModel::fit_impl() { [&]() { // Each thread gets a private storage for observations and feature importances. tptr x = tptr(new T[nfeatures_]); - dtptr dt_model; + std::vector> data_container; + std::vector betas; + size_t ncols = dt_model_->ncols(); + size_t nrows = dt_model_->nrows(); + data_container.resize(ncols); + betas.resize(ncols); + for (size_t i = 0; i < ncols; i++) { + data_container[i].resize(nrows); + betas[i] = data_container[i].data(); + } for (size_t iter = 0; iter < niterations; ++iter) { // Each thread gets its own copy of the model - std::vector betas; - { - PythonLock pylock; - dt_model = dtptr(new DataTable(*dt_model_)); - betas = get_model_data(dt_model); + for (size_t i = 0; i < ncols; i++) { + const auto* data = static_cast( + dt_model_->get_column(i).get_data_readonly()); + for (size_t j = 0; j < nrows; j++) { + betas[i][j] = data[j]; + } } size_t iteration_start = iter * iteration_nrows; @@ -233,8 +243,8 @@ LinearModelFitOutput LinearModel::fit_impl() { { std::lock_guard lock(m); auto nth = static_cast(dt::num_threads_in_team()); - for (size_t i = 0; i < dt_model->ncols(); ++i) { - for (size_t j = 0; j < dt_model->nrows(); ++j) { + for (size_t i = 0; i < ncols; ++i) { + for (size_t j = 0; j < nrows; ++j) { betas_[i][j] += betas[i][j] / nth; } } @@ -300,12 +310,6 @@ LinearModelFitOutput LinearModel::fit_impl() { } // End validation } // End iteration - - { - PythonLock pylock; - dt_model = nullptr; - } - } ); job.done();