From 3fe21b915e2c3ffa9c9a3b141e17157393e1deef Mon Sep 17 00:00:00 2001 From: Markus Date: Sun, 15 Dec 2024 15:42:52 +0100 Subject: [PATCH 01/30] Parameters is a new struct to hold and add parameters --- Cargo.toml | 2 - benches/compare.rs | 79 +++------- examples/bimodal_ke/config.toml | 4 - examples/bimodal_ke/main.rs | 12 +- examples/toml.rs | 11 -- src/algorithms/npag.rs | 13 +- src/algorithms/npod.rs | 15 +- src/algorithms/postprob.rs | 2 +- src/algorithms/routines/initialization/mod.rs | 6 +- src/algorithms/routines/output.rs | 4 +- src/algorithms/routines/settings.rs | 147 ++++++------------ 11 files changed, 101 insertions(+), 194 deletions(-) delete mode 100644 examples/toml.rs diff --git a/Cargo.toml b/Cargo.toml index 121cb5de3..4be1b0439 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,8 +38,6 @@ faer = "0.19.3" faer-ext = { version = "0.2.0", features = ["nalgebra", "ndarray"] } # pharmsol = "0.7.1" pharmsol = {path="../pharmsol"} -# pharmsol = { git = "https://github.com/LAPKB/pharmsol.git", branch = "dev"} -toml = "0.8.14" #REMOVE rand = "0.8.5" anyhow = "1.0.86" nalgebra = "0.33.0" diff --git a/benches/compare.rs b/benches/compare.rs index 767296798..a3b993e9f 100644 --- a/benches/compare.rs +++ b/benches/compare.rs @@ -2,7 +2,6 @@ use pmcore::prelude::*; use diol::prelude::*; use settings::{Log, *}; -use toml::Table; fn main() -> std::io::Result<()> { let mut bench = Bench::new(BenchConfig::from_args()?); @@ -150,44 +149,23 @@ fn tel_settings() -> Settings { }, convergence: Default::default(), advanced: Default::default(), - random: Random { - parameters: Table::from( - [ - ( - "Ka".to_string(), - toml::Value::Array(vec![toml::Value::Float(0.1), toml::Value::Float(0.9)]), - ), - ( - "Ke".to_string(), - toml::Value::Array(vec![ - toml::Value::Float(0.001), - toml::Value::Float(0.1), - ]), - ), - ( - "Tlag1".to_string(), - toml::Value::Array(vec![toml::Value::Float(0.0), toml::Value::Float(4.0)]), - ), - ( - "V".to_string(), - toml::Value::Array(vec![ - toml::Value::Float(30.0), - toml::Value::Float(120.0), - ]), - ), - ] - .iter() - .cloned() - .collect(), - ), - }, - fixed: None, - constant: None, error: Error { value: 5.0, class: "proportional".to_string(), poly: (0.02, 0.05, -2e-04, 0.0), }, + parameters: { + Parameters::new() + .add("Ka".to_string(), 0.1, 0.3, false) + .unwrap() + .add("Ke".to_string(), 0.001, 0.1, false) + .unwrap() + .add("Tlag1".to_string(), 0.0, 4.00, false) + .unwrap() + .add("V".to_string(), 30.0, 120.0, false) + .unwrap() + .to_owned() + }, }; settings.validate().unwrap(); settings @@ -220,36 +198,19 @@ fn bke_settings() -> Settings { }, convergence: Convergence::default(), advanced: Advanced::default(), - random: Random { - parameters: Table::from( - [ - ( - "Ke".to_string(), - toml::Value::Array(vec![ - toml::Value::Float(0.001), - toml::Value::Float(3.0), - ]), - ), - ( - "V".to_string(), - toml::Value::Array(vec![ - toml::Value::Float(25.0), - toml::Value::Float(250.0), - ]), - ), - ] - .iter() - .cloned() - .collect(), - ), - }, - fixed: None, - constant: None, error: Error { value: 0.0, class: "additive".to_string(), poly: (0.0, 0.05, 0.0, 0.0), }, + parameters: { + Parameters::new() + .add("Ke".to_string(), 0.001, 0.1, false) + .unwrap() + .add("V".to_string(), 25.0, 250.0, false) + .unwrap() + .to_owned() + }, }; settings.validate().unwrap(); settings diff --git a/examples/bimodal_ke/config.toml b/examples/bimodal_ke/config.toml index 132512664..faffb49b2 100644 --- a/examples/bimodal_ke/config.toml +++ b/examples/bimodal_ke/config.toml @@ -3,10 +3,6 @@ cycles = 1024 algorithm = "NPAG" cache = true -[random] -ke = [0.001, 3.0] -v = [25.0, 250.0] - [error] value = 0.0 class = "additive" diff --git a/examples/bimodal_ke/main.rs b/examples/bimodal_ke/main.rs index 42c1b186e..45d52e27d 100644 --- a/examples/bimodal_ke/main.rs +++ b/examples/bimodal_ke/main.rs @@ -1,6 +1,8 @@ +use anyhow::Result; use logger::setup_log; use pmcore::prelude::*; -fn main() { +use settings::Parameters; +fn main() -> Result<()> { let eq = equation::ODE::new( |x, p, _t, dx, rateiv, _cov| { // fetch_cov!(cov, t, wt); @@ -39,7 +41,12 @@ fn main() { // (1, 1), // ); - let settings = settings::read("examples/bimodal_ke/config.toml").unwrap(); + let mut settings = settings::read("examples/bimodal_ke/config.toml").unwrap(); + let parameters = Parameters::new() + .add("ke", 0.001, 3.0, false)? + .add("v", 25.0, 250.0, false)? + .to_owned(); + settings.parameters = parameters; setup_log(&settings).unwrap(); let data = data::read_pmetrics("examples/bimodal_ke/bimodal_ke.csv").unwrap(); let mut algorithm = dispatch_algorithm(settings, eq, data).unwrap(); @@ -50,4 +57,5 @@ fn main() { result.write_outputs().unwrap(); // println!("{:?}", result); // let _result = fit(eq, data, settings); + Ok(()) } diff --git a/examples/toml.rs b/examples/toml.rs deleted file mode 100644 index f51b0347e..000000000 --- a/examples/toml.rs +++ /dev/null @@ -1,11 +0,0 @@ -use pmcore::prelude::*; - -fn main() { - let path = "examples/bimodal_ke/config.toml".to_string(); - for i in 0..10 { - let s = settings::read(path.clone()).unwrap(); - let keys: Vec<&String> = s.random.parameters.keys().collect(); - // let values = s.random.parameters.values(); - println!("{}: {:?}", i, keys); - } -} diff --git a/src/algorithms/npag.rs b/src/algorithms/npag.rs index d4c47c7f6..4e83c88c9 100644 --- a/src/algorithms/npag.rs +++ b/src/algorithms/npag.rs @@ -28,7 +28,6 @@ const THETA_D: f64 = 1e-4; #[derive(Debug)] pub struct NPAG { equation: E, - ranges: Vec<(f64, f64)>, psi: Array2, theta: Array2, lambda: Array1, @@ -53,7 +52,6 @@ impl Algorithm for NPAG { fn new(settings: Settings, equation: E, data: Data) -> Result, anyhow::Error> { Ok(Box::new(Self { equation, - ranges: settings.random.ranges(), psi: Array2::default((0, 0)), theta: Array2::zeros((0, 0)), lambda: Array1::default(0), @@ -175,7 +173,7 @@ impl Algorithm for NPAG { } fn evaluation(&mut self) -> Result<()> { - let theta = Theta::new(self.theta.clone(), self.settings.random.names()); + let theta = Theta::new(self.theta.clone(), self.settings.parameters.names()); self.psi = psi( &self.equation, @@ -262,7 +260,7 @@ impl Algorithm for NPAG { let gamma_up = self.gamma * (1.0 + self.gamma_delta); let gamma_down = self.gamma / (1.0 + self.gamma_delta); - let theta = Theta::new(self.theta.clone(), self.settings.random.names()); + let theta = Theta::new(self.theta.clone(), self.settings.parameters.names()); let psi_up = psi( &self.equation, @@ -338,7 +336,12 @@ impl Algorithm for NPAG { } fn expansion(&mut self) -> Result<()> { - adaptative_grid(&mut self.theta, self.eps, &self.ranges, THETA_D); + adaptative_grid( + &mut self.theta, + self.eps, + &self.settings.parameters.ranges(), + THETA_D, + ); Ok(()) } } diff --git a/src/algorithms/npod.rs b/src/algorithms/npod.rs index 097c3e501..b53671670 100644 --- a/src/algorithms/npod.rs +++ b/src/algorithms/npod.rs @@ -28,7 +28,6 @@ const THETA_D: f64 = 1e-4; pub struct NPOD { equation: E, - ranges: Vec<(f64, f64)>, psi: Array2, theta: Array2, lambda: Array1, @@ -50,7 +49,6 @@ impl Algorithm for NPOD { fn new(settings: Settings, equation: E, data: Data) -> Result, anyhow::Error> { Ok(Box::new(Self { equation, - ranges: settings.random.ranges(), psi: Array2::default((0, 0)), theta: Array2::zeros((0, 0)), lambda: Array1::default(0), @@ -159,7 +157,7 @@ impl Algorithm for NPOD { } fn evaluation(&mut self) -> Result<()> { - let theta = Theta::new(self.theta.clone(), self.settings.random.names()); + let theta = Theta::new(self.theta.clone(), self.settings.parameters.names()); self.psi = psi( &self.equation, @@ -245,7 +243,7 @@ impl Algorithm for NPOD { // TODO: Move this to e.g. /evaluation/error.rs let gamma_up = self.gamma * (1.0 + self.gamma_delta); let gamma_down = self.gamma / (1.0 + self.gamma_delta); - let theta = Theta::new(self.theta.clone(), self.settings.random.names()); + let theta = Theta::new(self.theta.clone(), self.settings.parameters.names()); let psi_up = psi( &self.equation, @@ -336,7 +334,7 @@ impl Algorithm for NPOD { &self.data, &sigma, &pyl, - self.settings.random.names(), + self.settings.parameters.names(), ); let candidate_point = optimizer.optimize_point(spp.to_owned()).unwrap(); *spp = candidate_point; @@ -347,7 +345,12 @@ impl Algorithm for NPOD { // re-define a new optimization }); for cp in candididate_points { - prune(&mut self.theta, cp, &self.ranges, THETA_D); + prune( + &mut self.theta, + cp, + &self.settings.parameters.ranges(), + THETA_D, + ); } Ok(()) } diff --git a/src/algorithms/postprob.rs b/src/algorithms/postprob.rs index dcde79431..f7373deff 100644 --- a/src/algorithms/postprob.rs +++ b/src/algorithms/postprob.rs @@ -111,7 +111,7 @@ impl Algorithm for POSTPROB { } fn evaluation(&mut self) -> Result<()> { - let theta = Theta::new(self.theta.clone(), self.settings.random.names()); + let theta = Theta::new(self.theta.clone(), self.settings.parameters.names()); self.psi = psi( &self.equation, &self.data, diff --git a/src/algorithms/routines/initialization/mod.rs b/src/algorithms/routines/initialization/mod.rs index 67defd897..20a4ff63d 100644 --- a/src/algorithms/routines/initialization/mod.rs +++ b/src/algorithms/routines/initialization/mod.rs @@ -14,14 +14,14 @@ pub mod sobol; /// This function generates the grid of support points according to the sampler specified in the [Settings] pub fn sample_space(settings: &Settings, data: &Data, eqn: &impl Equation) -> Result> { // Get the ranges of the random parameters - let ranges = settings.random.ranges(); - let parameters = settings.random.names(); + let ranges = settings.parameters.ranges(); + let parameters = settings.parameters.names(); // If a prior file is provided, read it and return if settings.prior.file.is_some() { let prior = parse_prior( settings.prior.file.as_ref().unwrap(), - &settings.random.names(), + &settings.parameters.names(), )?; return Ok(prior); } diff --git a/src/algorithms/routines/output.rs b/src/algorithms/routines/output.rs index c5bcbdc79..83c250ac7 100644 --- a/src/algorithms/routines/output.rs +++ b/src/algorithms/routines/output.rs @@ -43,7 +43,7 @@ impl NPResult { ) -> Self { // TODO: Add support for fixed and constant parameters - let par_names = settings.random.names(); + let par_names = settings.parameters.names(); Self { equation, @@ -622,7 +622,7 @@ impl CycleLog { writer.write_field("gamlam")?; writer.write_field("nspp")?; - let parameter_names = settings.random.names(); + let parameter_names = settings.parameters.names(); for param_name in ¶meter_names { writer.write_field(format!("{}.mean", param_name))?; writer.write_field(format!("{}.median", param_name))?; diff --git a/src/algorithms/routines/settings.rs b/src/algorithms/routines/settings.rs index 9699045c1..b3c268451 100644 --- a/src/algorithms/routines/settings.rs +++ b/src/algorithms/routines/settings.rs @@ -7,8 +7,6 @@ use pharmsol::prelude::data::ErrorType; use serde::Deserialize; use serde_derive::Serialize; use serde_json; -use std::collections::HashMap; -use toml::Table; /// Contains all settings for PMcore #[derive(Debug, Deserialize, Clone, Serialize)] @@ -16,12 +14,8 @@ use toml::Table; pub struct Settings { /// General configuration settings pub config: Config, - /// Random parameters to be estimated - pub random: Random, - /// Parameters which are estimated, but fixed for the population - pub fixed: Option, - /// Parameters which are held constant - pub constant: Option, + /// Parameters to be estimated + pub parameters: Parameters, /// Defines the error model and polynomial to be used pub error: Error, /// Configuration for predictions @@ -44,9 +38,7 @@ impl Default for Settings { fn default() -> Self { Settings { config: Config::default(), - random: Random::default(), - fixed: None, - constant: None, + parameters: Parameters::new(), error: Error::default(), predictions: Predictions::default(), log: Log::default(), @@ -61,7 +53,6 @@ impl Default for Settings { impl Settings { /// Validate the settings pub fn validate(&self) -> Result<()> { - self.random.validate()?; self.error.validate()?; self.predictions.validate()?; Ok(()) @@ -100,106 +91,64 @@ impl Default for Config { } } -/// Random parameters to be estimated -/// -/// This struct contains the random parameters to be estimated. The parameters are specified as a hashmap, where the key is the name of the parameter, and the value is a tuple containing the upper and lower bounds of the parameter. -/// -/// # Example -/// -/// ```toml -/// [random] -/// alpha = [0.0, 1.0] -/// beta = [0.0, 1.0] -/// ``` -#[derive(Debug, Deserialize, Clone, Serialize)] -#[serde(default)] -pub struct Random { - #[serde(flatten)] - pub parameters: Table, +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Parameter { + name: String, + lower: f64, + upper: f64, + fixed: bool, } -impl Default for Random { - fn default() -> Self { - Random { - parameters: Table::new(), +impl Parameter { + pub fn new(name: impl Into, lower: f64, upper: f64, fixed: bool) -> Result { + if lower >= upper { + bail!(format!( + "In key '{}', lower bound ({}) is not less than upper bound ({})", + name.into(), + lower, + upper + )); } - } -} - -impl Random { - /// Get the upper and lower bounds of a random parameter from its key - pub fn get(&self, key: &str) -> Option<(f64, f64)> { - self.parameters - .get(key) - .and_then(|v| v.as_array()) - .map(|v| { - let lower = v[0].as_float().unwrap(); - let upper = v[1].as_float().unwrap(); - (lower, upper) - }) - } - - /// Returns a vector of the names of the random parameters - pub fn names(&self) -> Vec { - self.parameters.keys().cloned().collect() - } - - /// Returns a vector of the upper and lower bounds of the random parameters - pub fn ranges(&self) -> Vec<(f64, f64)> { - self.parameters - .values() - .map(|v| { - let lower = v.as_array().unwrap()[0].as_float().unwrap(); - let upper = v.as_array().unwrap()[1].as_float().unwrap(); - (lower, upper) - }) - .collect() - } - /// Validate the boundaries of the random parameters - pub fn validate(&self) -> Result<()> { - for (key, range) in &self.parameters { - let range = range.as_array().unwrap(); - let lower = range[0].as_float().unwrap(); - let upper = range[1].as_float().unwrap(); - if lower >= upper { - bail!(format!( - "In key '{}', lower bound ({}) is not less than upper bound ({})", - key, lower, upper - )); - } - } - Ok(()) + Ok(Self { + name: name.into(), + lower, + upper, + fixed, + }) } } -/// Parameters which are estimated, but fixed for the population -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct Fixed { - #[serde(flatten)] - pub parameters: HashMap, +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Parameters { + parameters: Vec, } -impl Default for Fixed { - fn default() -> Self { - Fixed { - parameters: HashMap::new(), +impl Parameters { + pub fn new() -> Self { + Parameters { + parameters: Vec::new(), } } -} -/// Parameters which are held constant -#[derive(Debug, Deserialize, Clone, Serialize)] -pub struct Constant { - #[serde(flatten)] - pub parameters: HashMap, -} + pub fn add( + &mut self, + name: impl Into, + lower: f64, + upper: f64, + fixed: bool, + ) -> Result<&mut Self> { + let parameter = Parameter::new(name, lower, upper, fixed)?; + self.parameters.push(parameter); + Ok(self) + } -impl Default for Constant { - fn default() -> Self { - Constant { - parameters: HashMap::new(), - } + pub fn names(&self) -> Vec { + self.parameters.iter().map(|p| p.name.clone()).collect() + } + + pub fn ranges(&self) -> Vec<(f64, f64)> { + self.parameters.iter().map(|p| (p.lower, p.upper)).collect() } } From fb8b28c433f6ce83587162666f99d52c11c5f1db Mon Sep 17 00:00:00 2001 From: Markus Date: Sun, 15 Dec 2024 15:47:45 +0100 Subject: [PATCH 02/30] Documentation --- src/algorithms/routines/settings.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/algorithms/routines/settings.rs b/src/algorithms/routines/settings.rs index b3c268451..c6d96dd02 100644 --- a/src/algorithms/routines/settings.rs +++ b/src/algorithms/routines/settings.rs @@ -91,6 +91,10 @@ impl Default for Config { } } +/// Defines a parameter to be estimated +/// +/// In non-parametric algorithms, parameters must be bounded. The lower and upper bounds are defined by the `lower` and `upper` fields, respectively. +/// Fixed parameters are unknown, but common among all subjects. #[derive(Debug, Clone, Deserialize, Serialize)] pub struct Parameter { name: String, @@ -100,6 +104,7 @@ pub struct Parameter { } impl Parameter { + /// Create a new parameter pub fn new(name: impl Into, lower: f64, upper: f64, fixed: bool) -> Result { if lower >= upper { bail!(format!( @@ -119,18 +124,21 @@ impl Parameter { } } -#[derive(Debug, Clone, Deserialize, Serialize)] +/// This structure contains information on all [Parameter]s to be estimated +#[derive(Debug, Clone, Deserialize, Serialize, Default)] pub struct Parameters { parameters: Vec, } impl Parameters { + /// Create a new set of parameters pub fn new() -> Self { Parameters { parameters: Vec::new(), } } + /// Add a parameter to the set pub fn add( &mut self, name: impl Into, @@ -143,10 +151,12 @@ impl Parameters { Ok(self) } + /// Get the names of the parameters pub fn names(&self) -> Vec { self.parameters.iter().map(|p| p.name.clone()).collect() } + /// Get the ranges of the parameters pub fn ranges(&self) -> Vec<(f64, f64)> { self.parameters.iter().map(|p| (p.lower, p.upper)).collect() } From 7166398c01370f3d23a1cfc308027fe68ef2d6d8 Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Mon, 16 Dec 2024 14:22:42 +0100 Subject: [PATCH 03/30] Improved builder pattern --- examples/bimodal_ke/main.rs | 3 +-- src/algorithms/routines/settings.rs | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/bimodal_ke/main.rs b/examples/bimodal_ke/main.rs index 45d52e27d..cc9b862ff 100644 --- a/examples/bimodal_ke/main.rs +++ b/examples/bimodal_ke/main.rs @@ -44,8 +44,7 @@ fn main() -> Result<()> { let mut settings = settings::read("examples/bimodal_ke/config.toml").unwrap(); let parameters = Parameters::new() .add("ke", 0.001, 3.0, false)? - .add("v", 25.0, 250.0, false)? - .to_owned(); + .add("v", 25.0, 250.0, false)?; settings.parameters = parameters; setup_log(&settings).unwrap(); let data = data::read_pmetrics("examples/bimodal_ke/bimodal_ke.csv").unwrap(); diff --git a/src/algorithms/routines/settings.rs b/src/algorithms/routines/settings.rs index c6d96dd02..a02228b68 100644 --- a/src/algorithms/routines/settings.rs +++ b/src/algorithms/routines/settings.rs @@ -140,12 +140,12 @@ impl Parameters { /// Add a parameter to the set pub fn add( - &mut self, + mut self, name: impl Into, lower: f64, upper: f64, fixed: bool, - ) -> Result<&mut Self> { + ) -> Result { let parameter = Parameter::new(name, lower, upper, fixed)?; self.parameters.push(parameter); Ok(self) From 0cc0dae3d5f51847206b77148679430754ce04dd Mon Sep 17 00:00:00 2001 From: Markus Date: Sat, 21 Dec 2024 15:03:29 +0100 Subject: [PATCH 04/30] Update BKE --- examples/bimodal_ke/main.rs | 18 +++++++++--------- src/algorithms/routines/settings.rs | 10 +++++++--- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/examples/bimodal_ke/main.rs b/examples/bimodal_ke/main.rs index cc9b862ff..bd5d1d437 100644 --- a/examples/bimodal_ke/main.rs +++ b/examples/bimodal_ke/main.rs @@ -1,7 +1,7 @@ use anyhow::Result; use logger::setup_log; use pmcore::prelude::*; -use settings::Parameters; +use settings::{Parameters, Settings}; fn main() -> Result<()> { let eq = equation::ODE::new( |x, p, _t, dx, rateiv, _cov| { @@ -41,20 +41,20 @@ fn main() -> Result<()> { // (1, 1), // ); - let mut settings = settings::read("examples/bimodal_ke/config.toml").unwrap(); - let parameters = Parameters::new() + let mut settings = Settings::new(); + settings.parameters = Parameters::new() .add("ke", 0.001, 3.0, false)? .add("v", 25.0, 250.0, false)?; - settings.parameters = parameters; + settings.config.cycles = 1024; + settings.error.poly = (0.0, 0.05, 0.0, 0.0); + settings.output.write = true; + settings.output.path = "examples/bimodal_ke/output".to_string(); + setup_log(&settings).unwrap(); let data = data::read_pmetrics("examples/bimodal_ke/bimodal_ke.csv").unwrap(); let mut algorithm = dispatch_algorithm(settings, eq, data).unwrap(); let result = algorithm.fit().unwrap(); - // algorithm.initialize().unwrap(); - // while !algorithm.next_cycle().unwrap() {} - // let result = algorithm.into_npresult(); result.write_outputs().unwrap(); - // println!("{:?}", result); - // let _result = fit(eq, data, settings); + Ok(()) } diff --git a/src/algorithms/routines/settings.rs b/src/algorithms/routines/settings.rs index a02228b68..b0794e44f 100644 --- a/src/algorithms/routines/settings.rs +++ b/src/algorithms/routines/settings.rs @@ -19,8 +19,6 @@ pub struct Settings { /// Defines the error model and polynomial to be used pub error: Error, /// Configuration for predictions - /// - /// This struct contains the interval at which to generate predictions, and the time after dose to generate predictions to pub predictions: Predictions, /// Configuration for logging pub log: Log, @@ -84,7 +82,7 @@ impl Default for Config { Config { cycles: 100, algorithm: "NPAG".to_string(), - cache: false, + cache: true, include: None, exclude: None, } @@ -151,6 +149,12 @@ impl Parameters { Ok(self) } + // Get a parameter by name + pub fn get(&self, name: impl Into) -> Option<&Parameter> { + let name = name.into(); + self.parameters.iter().find(|p| p.name == name) + } + /// Get the names of the parameters pub fn names(&self) -> Vec { self.parameters.iter().map(|p| p.name.clone()).collect() From d6d326808b5ce4fae9919aa80f9c9d6944db5a68 Mon Sep 17 00:00:00 2001 From: Markus Date: Sun, 22 Dec 2024 13:02:59 +0100 Subject: [PATCH 05/30] Refactoring Work in progress, TBC --- benches/compare.rs | 5 +++-- examples/two_eq_lag/config.toml | 7 ------- examples/two_eq_lag/main.rs | 14 +++++++++++++- src/algorithms/{postprob.rs => map.rs} | 4 ++-- src/algorithms/mod.rs | 23 +++++++++++++++-------- src/algorithms/routines/settings.rs | 6 ++++-- 6 files changed, 37 insertions(+), 22 deletions(-) rename src/algorithms/{postprob.rs => map.rs} (97%) diff --git a/benches/compare.rs b/benches/compare.rs index a3b993e9f..97fde3b15 100644 --- a/benches/compare.rs +++ b/benches/compare.rs @@ -1,3 +1,4 @@ +use algorithms::AlgorithmType; use pmcore::prelude::*; use diol::prelude::*; @@ -127,7 +128,7 @@ fn tel_settings() -> Settings { let settings = Settings { config: Config { cycles: 1000, - algorithm: "NPAG".to_string(), + algorithm: AlgorithmType::NPAG, cache: true, ..Default::default() }, @@ -175,7 +176,7 @@ fn bke_settings() -> Settings { let settings = Settings { config: Config { cycles: 1024, - algorithm: "NPAG".to_string(), + algorithm: AlgorithmType::NPAG, cache: true, include: None, exclude: None, diff --git a/examples/two_eq_lag/config.toml b/examples/two_eq_lag/config.toml index 60814bf1c..288b5ad77 100644 --- a/examples/two_eq_lag/config.toml +++ b/examples/two_eq_lag/config.toml @@ -1,15 +1,8 @@ [config] cycles = 1000 algorithm = "NPAG" - cache = true -[random] -Ka = [0.1, 0.9] -Ke = [0.001, 0.1] -Tlag1 = [0.0, 4.0] -V = [30.0, 120.0] - [error] value = 5 class = "proportional" diff --git a/examples/two_eq_lag/main.rs b/examples/two_eq_lag/main.rs index 6191f5492..0ddf92e1e 100644 --- a/examples/two_eq_lag/main.rs +++ b/examples/two_eq_lag/main.rs @@ -1,12 +1,14 @@ #![allow(dead_code)] #![allow(unused_variables)] #![allow(unused_imports)] +use core::panic; use std::path::Path; use data::read_pmetrics; use logger::setup_log; use ndarray::Array2; use pmcore::prelude::{models::one_compartment_with_absorption, simulator::Equation, *}; +use settings::Parameters; fn main() { let eq = equation::ODE::new( @@ -71,7 +73,17 @@ fn main() { // (2, 1), // ); - let settings = settings::read("examples/two_eq_lag/config.toml").unwrap(); + let mut settings = settings::read("examples/two_eq_lag/config.toml").unwrap(); + settings.parameters = Parameters::new() + .add("ka", 0.1, 0.9, false) + .unwrap() + .add("ke", 0.001, 0.1, false) + .unwrap() + .add("tlag", 0.0, 4.0, false) + .unwrap() + .add("v", 30.0, 120.0, false) + .unwrap(); + setup_log(&settings).unwrap(); let data = data::read_pmetrics("examples/two_eq_lag/two_eq_lag.csv").unwrap(); let mut algorithm = dispatch_algorithm(settings, eq, data).unwrap(); diff --git a/src/algorithms/postprob.rs b/src/algorithms/map.rs similarity index 97% rename from src/algorithms/postprob.rs rename to src/algorithms/map.rs index f7373deff..cecf16d33 100644 --- a/src/algorithms/postprob.rs +++ b/src/algorithms/map.rs @@ -14,7 +14,7 @@ use super::{initialization, output::CycleLog}; /// Posterior probability algorithm /// Reweights the prior probabilities to the observed data and error model -pub struct POSTPROB { +pub struct MAP { equation: E, psi: Array2, theta: Array2, @@ -31,7 +31,7 @@ pub struct POSTPROB { cyclelog: CycleLog, } -impl Algorithm for POSTPROB { +impl Algorithm for MAP { fn new(settings: Settings, equation: E, data: Data) -> Result, anyhow::Error> { Ok(Box::new(Self { equation, diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index 38df58900..0bcb736e2 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -3,22 +3,30 @@ use std::path::Path; use crate::prelude::{self, settings::Settings}; -use anyhow::{bail, Result}; +use anyhow::Result; use anyhow::{Context, Error}; +use map::MAP; use ndarray::Array2; use npag::*; use npod::NPOD; use output::NPResult; use pharmsol::prelude::{data::Data, simulator::Equation}; -use postprob::POSTPROB; use prelude::*; +use serde::{Deserialize, Serialize}; // use self::{data::Subject, simulator::Equation}; +pub mod map; pub mod npag; pub mod npod; -pub mod postprob; pub mod routines; +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum AlgorithmType { + NPAG, + NPOD, + MAP, +} + pub trait Algorithm { fn new(config: Settings, equation: E, data: Data) -> Result, Error> where @@ -102,10 +110,9 @@ pub fn dispatch_algorithm( equation: E, data: Data, ) -> Result>, Error> { - match settings.config.algorithm.as_str() { - "NPAG" => Ok(NPAG::new(settings, equation, data)?), - "NPOD" => Ok(NPOD::new(settings, equation, data)?), - "POSTPROB" => Ok(POSTPROB::new(settings, equation, data)?), - alg => bail!("Algorithm {} not implemented", alg), + match settings.config.algorithm { + AlgorithmType::NPAG => Ok(NPAG::new(settings, equation, data)?), + AlgorithmType::NPOD => Ok(NPOD::new(settings, equation, data)?), + AlgorithmType::MAP => Ok(MAP::new(settings, equation, data)?), } } diff --git a/src/algorithms/routines/settings.rs b/src/algorithms/routines/settings.rs index b0794e44f..5eaad844d 100644 --- a/src/algorithms/routines/settings.rs +++ b/src/algorithms/routines/settings.rs @@ -1,5 +1,7 @@ #![allow(dead_code)] +use crate::algorithms::AlgorithmType; + use super::output::OutputFile; use anyhow::{bail, Result}; use config::Config as eConfig; @@ -68,7 +70,7 @@ pub struct Config { /// Maximum number of cycles to run pub cycles: usize, /// Denotes the algorithm to use - pub algorithm: String, + pub algorithm: AlgorithmType, /// If true (default), cache predicted values pub cache: bool, /// Vector of IDs to include @@ -81,7 +83,7 @@ impl Default for Config { fn default() -> Self { Config { cycles: 100, - algorithm: "NPAG".to_string(), + algorithm: AlgorithmType::NPAG, cache: true, include: None, exclude: None, From bc5d4f2023b4bd2e029cd4368b9be6d9d1f0e18b Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Mon, 23 Dec 2024 08:54:39 +0100 Subject: [PATCH 06/30] Error handling in examples --- examples/bimodal_ke/main.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/bimodal_ke/main.rs b/examples/bimodal_ke/main.rs index bd5d1d437..73d853eb5 100644 --- a/examples/bimodal_ke/main.rs +++ b/examples/bimodal_ke/main.rs @@ -50,11 +50,11 @@ fn main() -> Result<()> { settings.output.write = true; settings.output.path = "examples/bimodal_ke/output".to_string(); - setup_log(&settings).unwrap(); - let data = data::read_pmetrics("examples/bimodal_ke/bimodal_ke.csv").unwrap(); - let mut algorithm = dispatch_algorithm(settings, eq, data).unwrap(); - let result = algorithm.fit().unwrap(); - result.write_outputs().unwrap(); + setup_log(&settings)?; + let data = data::read_pmetrics("examples/bimodal_ke/bimodal_ke.csv")?; + let mut algorithm = dispatch_algorithm(settings, eq, data)?; + let result = algorithm.fit()?; + result.write_outputs()?; Ok(()) } From 382d1783adf1153afb6cce080838de9058d57140 Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Mon, 23 Dec 2024 09:26:12 +0100 Subject: [PATCH 07/30] Setters and getters --- examples/bimodal_ke/main.rs | 22 +++-- examples/two_eq_lag/main.rs | 11 +-- src/algorithms/map.rs | 8 +- src/algorithms/mod.rs | 2 +- src/algorithms/npag.rs | 16 ++-- src/algorithms/npod.rs | 16 ++-- src/algorithms/routines/initialization/mod.rs | 18 ++-- src/algorithms/routines/output.rs | 24 ++--- src/algorithms/routines/settings.rs | 93 +++++++++++++++++-- src/logger.rs | 4 +- 10 files changed, 143 insertions(+), 71 deletions(-) diff --git a/examples/bimodal_ke/main.rs b/examples/bimodal_ke/main.rs index 73d853eb5..7524a458f 100644 --- a/examples/bimodal_ke/main.rs +++ b/examples/bimodal_ke/main.rs @@ -1,7 +1,8 @@ +use algorithms::AlgorithmType; use anyhow::Result; use logger::setup_log; use pmcore::prelude::*; -use settings::{Parameters, Settings}; +use settings::{Config, Parameters, Settings}; fn main() -> Result<()> { let eq = equation::ODE::new( |x, p, _t, dx, rateiv, _cov| { @@ -42,13 +43,18 @@ fn main() -> Result<()> { // ); let mut settings = Settings::new(); - settings.parameters = Parameters::new() - .add("ke", 0.001, 3.0, false)? - .add("v", 25.0, 250.0, false)?; - settings.config.cycles = 1024; - settings.error.poly = (0.0, 0.05, 0.0, 0.0); - settings.output.write = true; - settings.output.path = "examples/bimodal_ke/output".to_string(); + + let params = Parameters::new() + .add("ke", 0.001, 3.0, true)? + .add("v", 25.0, 250.0, true)?; + + settings.set_parameters(params); + settings.set_config(Config { + cycles: 1000, + algorithm: AlgorithmType::NPAG, + cache: true, + ..Default::default() + }); setup_log(&settings)?; let data = data::read_pmetrics("examples/bimodal_ke/bimodal_ke.csv")?; diff --git a/examples/two_eq_lag/main.rs b/examples/two_eq_lag/main.rs index 0ddf92e1e..e5dbc2cfb 100644 --- a/examples/two_eq_lag/main.rs +++ b/examples/two_eq_lag/main.rs @@ -73,16 +73,7 @@ fn main() { // (2, 1), // ); - let mut settings = settings::read("examples/two_eq_lag/config.toml").unwrap(); - settings.parameters = Parameters::new() - .add("ka", 0.1, 0.9, false) - .unwrap() - .add("ke", 0.001, 0.1, false) - .unwrap() - .add("tlag", 0.0, 4.0, false) - .unwrap() - .add("v", 30.0, 120.0, false) - .unwrap(); + let settings = settings::read("examples/two_eq_lag/config.toml").unwrap(); setup_log(&settings).unwrap(); let data = data::read_pmetrics("examples/two_eq_lag/two_eq_lag.csv").unwrap(); diff --git a/src/algorithms/map.rs b/src/algorithms/map.rs index cecf16d33..c68688679 100644 --- a/src/algorithms/map.rs +++ b/src/algorithms/map.rs @@ -41,13 +41,13 @@ impl Algorithm for MAP { objf: f64::INFINITY, cycle: 0, converged: false, - gamma: settings.error.value, - error_type: match settings.error.class.as_str() { + gamma: settings.error().value, + error_type: match settings.error().class.as_str() { "additive" => ErrorType::Add, "proportional" => ErrorType::Prop, _ => panic!("Error type not supported"), }, - c: settings.error.poly, + c: settings.error().poly, settings, data, @@ -111,7 +111,7 @@ impl Algorithm for MAP { } fn evaluation(&mut self) -> Result<()> { - let theta = Theta::new(self.theta.clone(), self.settings.parameters.names()); + let theta = Theta::new(self.theta.clone(), self.settings.parameters().names()); self.psi = psi( &self.equation, &self.data, diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index 0bcb736e2..f982f5370 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -110,7 +110,7 @@ pub fn dispatch_algorithm( equation: E, data: Data, ) -> Result>, Error> { - match settings.config.algorithm { + match settings.config().algorithm { AlgorithmType::NPAG => Ok(NPAG::new(settings, equation, data)?), AlgorithmType::NPOD => Ok(NPOD::new(settings, equation, data)?), AlgorithmType::MAP => Ok(MAP::new(settings, equation, data)?), diff --git a/src/algorithms/npag.rs b/src/algorithms/npag.rs index 4e83c88c9..74bf2d69b 100644 --- a/src/algorithms/npag.rs +++ b/src/algorithms/npag.rs @@ -63,11 +63,11 @@ impl Algorithm for NPAG { f1: f64::default(), cycle: 0, gamma_delta: 0.1, - gamma: settings.error.value, - error_type: settings.error.error_type(), + gamma: settings.error().value, + error_type: settings.error().error_type(), converged: false, cycle_log: CycleLog::new(), - c: settings.error.poly, + c: settings.error().poly, settings, data, })) @@ -141,7 +141,7 @@ impl Algorithm for NPAG { } // Stop if we have reached maximum number of cycles - if self.cycle >= self.settings.config.cycles { + if self.cycle >= self.settings.config().cycles { tracing::warn!("Maximum number of cycles reached"); self.converged = true; } @@ -173,14 +173,14 @@ impl Algorithm for NPAG { } fn evaluation(&mut self) -> Result<()> { - let theta = Theta::new(self.theta.clone(), self.settings.parameters.names()); + let theta = Theta::new(self.theta.clone(), self.settings.parameters().names()); self.psi = psi( &self.equation, &self.data, &theta, &ErrorModel::new(self.c, self.gamma, &self.error_type), - self.cycle == 1 && self.settings.log.write, + self.cycle == 1 && self.settings.log().write, self.cycle != 1, ); @@ -260,7 +260,7 @@ impl Algorithm for NPAG { let gamma_up = self.gamma * (1.0 + self.gamma_delta); let gamma_down = self.gamma / (1.0 + self.gamma_delta); - let theta = Theta::new(self.theta.clone(), self.settings.parameters.names()); + let theta = Theta::new(self.theta.clone(), self.settings.parameters().names()); let psi_up = psi( &self.equation, @@ -339,7 +339,7 @@ impl Algorithm for NPAG { adaptative_grid( &mut self.theta, self.eps, - &self.settings.parameters.ranges(), + &self.settings.parameters().ranges(), THETA_D, ); Ok(()) diff --git a/src/algorithms/npod.rs b/src/algorithms/npod.rs index b53671670..5d031a18d 100644 --- a/src/algorithms/npod.rs +++ b/src/algorithms/npod.rs @@ -57,11 +57,11 @@ impl Algorithm for NPOD { objf: f64::NEG_INFINITY, cycle: 0, gamma_delta: 0.1, - gamma: settings.error.value, - error_type: settings.error.error_type(), + gamma: settings.error().value, + error_type: settings.error().error_type(), converged: false, cycle_log: CycleLog::new(), - c: settings.error.poly, + c: settings.error().poly, settings, data, })) @@ -125,7 +125,7 @@ impl Algorithm for NPOD { } // Stop if we have reached maximum number of cycles - if self.cycle >= self.settings.config.cycles { + if self.cycle >= self.settings.config().cycles { tracing::warn!("Maximum number of cycles reached"); self.converged = true; } @@ -157,7 +157,7 @@ impl Algorithm for NPOD { } fn evaluation(&mut self) -> Result<()> { - let theta = Theta::new(self.theta.clone(), self.settings.parameters.names()); + let theta = Theta::new(self.theta.clone(), self.settings.parameters().names()); self.psi = psi( &self.equation, @@ -243,7 +243,7 @@ impl Algorithm for NPOD { // TODO: Move this to e.g. /evaluation/error.rs let gamma_up = self.gamma * (1.0 + self.gamma_delta); let gamma_down = self.gamma / (1.0 + self.gamma_delta); - let theta = Theta::new(self.theta.clone(), self.settings.parameters.names()); + let theta = Theta::new(self.theta.clone(), self.settings.parameters().names()); let psi_up = psi( &self.equation, @@ -334,7 +334,7 @@ impl Algorithm for NPOD { &self.data, &sigma, &pyl, - self.settings.parameters.names(), + self.settings.parameters().names(), ); let candidate_point = optimizer.optimize_point(spp.to_owned()).unwrap(); *spp = candidate_point; @@ -348,7 +348,7 @@ impl Algorithm for NPOD { prune( &mut self.theta, cp, - &self.settings.parameters.ranges(), + &self.settings.parameters().ranges(), THETA_D, ); } diff --git a/src/algorithms/routines/initialization/mod.rs b/src/algorithms/routines/initialization/mod.rs index 20a4ff63d..e8d959ad2 100644 --- a/src/algorithms/routines/initialization/mod.rs +++ b/src/algorithms/routines/initialization/mod.rs @@ -14,22 +14,22 @@ pub mod sobol; /// This function generates the grid of support points according to the sampler specified in the [Settings] pub fn sample_space(settings: &Settings, data: &Data, eqn: &impl Equation) -> Result> { // Get the ranges of the random parameters - let ranges = settings.parameters.ranges(); - let parameters = settings.parameters.names(); + let ranges = settings.parameters().ranges(); + let parameters = settings.parameters().names(); // If a prior file is provided, read it and return - if settings.prior.file.is_some() { + if settings.prior().file.is_some() { let prior = parse_prior( - settings.prior.file.as_ref().unwrap(), - &settings.parameters.names(), + settings.prior().file.as_ref().unwrap(), + &settings.parameters().names(), )?; return Ok(prior); } // Otherwise, parse the sampler type and generate the grid - let prior = match settings.prior.sampler.as_str() { - "sobol" => sobol::generate(settings.prior.points, &ranges, settings.prior.seed)?, - "latin" => latin::generate(settings.prior.points, &ranges, settings.prior.seed)?, + let prior = match settings.prior().sampler.as_str() { + "sobol" => sobol::generate(settings.prior().points, &ranges, settings.prior().seed)?, + "latin" => latin::generate(settings.prior().points, &ranges, settings.prior().seed)?, "osat" => { let mut point = vec![]; for range in ranges { @@ -41,7 +41,7 @@ pub fn sample_space(settings: &Settings, data: &Data, eqn: &impl Equation) -> Re _ => { bail!( "Unknown sampler specified in settings: {}", - settings.prior.sampler + settings.prior().sampler ); } }; diff --git a/src/algorithms/routines/output.rs b/src/algorithms/routines/output.rs index 83c250ac7..4b228a336 100644 --- a/src/algorithms/routines/output.rs +++ b/src/algorithms/routines/output.rs @@ -43,7 +43,7 @@ impl NPResult { ) -> Self { // TODO: Add support for fixed and constant parameters - let par_names = settings.parameters.names(); + let par_names = settings.parameters().names(); Self { equation, @@ -85,9 +85,9 @@ impl NPResult { } pub fn write_outputs(&self) -> Result<()> { - if self.settings.output.write { - let idelta: f64 = self.settings.predictions.idelta; - let tad = self.settings.predictions.tad; + if self.settings.output().write { + let idelta: f64 = self.settings.predictions().idelta; + let tad = self.settings.predictions().tad; self.cyclelog.write(&self.settings)?; self.write_obs().context("Failed to write observations")?; self.write_theta().context("Failed to write theta")?; @@ -141,7 +141,7 @@ impl NPResult { ); } - let outputfile = OutputFile::new(&self.settings.output.path, "op.csv")?; + let outputfile = OutputFile::new(&self.settings.output().path, "op.csv")?; let mut writer = WriterBuilder::new() .has_headers(true) .from_writer(&outputfile.file); @@ -253,7 +253,7 @@ impl NPResult { self.w.clone() }; - let outputfile = OutputFile::new(&self.settings.output.path, "theta.csv") + let outputfile = OutputFile::new(&self.settings.output().path, "theta.csv") .context("Failed to create output file for theta")?; let mut writer = WriterBuilder::new() @@ -297,7 +297,7 @@ impl NPResult { }; // Create the output folder if it doesn't exist - let outputfile = match OutputFile::new(&self.settings.output.path, "posterior.csv") { + let outputfile = match OutputFile::new(&self.settings.output().path, "posterior.csv") { Ok(of) => of, Err(e) => { tracing::error!("Failed to create output file: {}", e); @@ -343,7 +343,7 @@ impl NPResult { /// Write the observations, which is the reformatted input data pub fn write_obs(&self) -> Result<()> { tracing::debug!("Writing observations..."); - let outputfile = OutputFile::new(&self.settings.output.path, "obs.csv")?; + let outputfile = OutputFile::new(&self.settings.output().path, "obs.csv")?; write_pmetrics_observations(&self.data, &outputfile.file)?; tracing::info!( "Observations written to {:?}", @@ -372,7 +372,7 @@ impl NPResult { bail!("Number of subjects and number of posterior means do not match"); } - let outputfile = OutputFile::new(&self.settings.output.path, "pred.csv")?; + let outputfile = OutputFile::new(&self.settings.output().path, "pred.csv")?; let mut writer = WriterBuilder::new() .has_headers(true) .from_writer(&outputfile.file); @@ -468,7 +468,7 @@ impl NPResult { /// Writes the covariates pub fn write_covs(&self) -> Result<()> { tracing::debug!("Writing covariates..."); - let outputfile = OutputFile::new(&self.settings.output.path, "covs.csv")?; + let outputfile = OutputFile::new(&self.settings.output().path, "covs.csv")?; let mut writer = WriterBuilder::new() .has_headers(true) .from_writer(&outputfile.file); @@ -610,7 +610,7 @@ impl CycleLog { pub fn write(&self, settings: &Settings) -> Result<()> { tracing::debug!("Writing cycles..."); - let outputfile = OutputFile::new(&settings.output.path, "cycles.csv")?; + let outputfile = OutputFile::new(&settings.output().path, "cycles.csv")?; let mut writer = WriterBuilder::new() .has_headers(false) .from_writer(&outputfile.file); @@ -622,7 +622,7 @@ impl CycleLog { writer.write_field("gamlam")?; writer.write_field("nspp")?; - let parameter_names = settings.parameters.names(); + let parameter_names = settings.parameters().names(); for param_name in ¶meter_names { writer.write_field(format!("{}.mean", param_name))?; writer.write_field(format!("{}.median", param_name))?; diff --git a/src/algorithms/routines/settings.rs b/src/algorithms/routines/settings.rs index 5eaad844d..4584aa513 100644 --- a/src/algorithms/routines/settings.rs +++ b/src/algorithms/routines/settings.rs @@ -15,23 +15,23 @@ use serde_json; #[serde(deny_unknown_fields, default)] pub struct Settings { /// General configuration settings - pub config: Config, + config: Config, /// Parameters to be estimated - pub parameters: Parameters, + parameters: Parameters, /// Defines the error model and polynomial to be used - pub error: Error, + error: Error, /// Configuration for predictions - pub predictions: Predictions, + predictions: Predictions, /// Configuration for logging - pub log: Log, + log: Log, /// Configuration for (optional) prior - pub prior: Prior, + prior: Prior, /// Configuration for the output files - pub output: Output, + output: Output, /// Configuration for the convergence criteria - pub convergence: Convergence, + convergence: Convergence, /// Advanced options, mostly hyperparameters, for the algorithm(s) - pub advanced: Advanced, + advanced: Advanced, } impl Default for Settings { @@ -58,9 +58,84 @@ impl Settings { Ok(()) } + /// Create a new settings object with default values pub fn new() -> Self { Settings::default() } + + pub fn set_config(&mut self, config: Config) { + self.config = config; + } + + pub fn config(&self) -> &Config { + &self.config + } + + pub fn set_parameters(&mut self, parameters: Parameters) { + self.parameters = parameters; + } + + pub fn parameters(&self) -> &Parameters { + &self.parameters + } + + pub fn set_error(mut self, error: Error) { + self.error = error; + } + + pub fn error(&self) -> &Error { + &self.error + } + + pub fn set_predictions(mut self, predictions: Predictions) { + self.predictions = predictions; + } + + pub fn predictions(&self) -> &Predictions { + &self.predictions + } + + pub fn set_log(mut self, log: Log) { + self.log = log; + } + + pub fn log(&self) -> &Log { + &self.log + } + + pub fn set_prior(mut self, prior: Prior) { + self.prior = prior; + } + + pub fn prior(&self) -> &Prior { + &self.prior + } + + pub fn set_output(mut self, output: Output) { + self.output = output; + } + + pub fn output(&self) -> &Output { + &self.output + } + + pub fn set_convergence(mut self, convergence: Convergence) -> Self { + self.convergence = convergence; + self + } + + pub fn convergence(&self) -> &Convergence { + &self.convergence + } + + pub fn set_advanced(mut self, advanced: Advanced) -> Self { + self.advanced = advanced; + self + } + + pub fn advanced(&self) -> &Advanced { + &self.advanced + } } /// General configuration settings diff --git a/src/logger.rs b/src/logger.rs index 3c73f6555..735ffe256 100644 --- a/src/logger.rs +++ b/src/logger.rs @@ -23,7 +23,7 @@ use tracing_subscriber::EnvFilter; /// If not, the log messages are written to stdout. pub fn setup_log(settings: &Settings) -> Result<()> { // Use the log level defined in configuration file - let log_level = settings.log.level.as_str(); + let log_level = settings.log().level.as_str(); let env_filter = EnvFilter::new(log_level); let timestamper = CompactTimestamp { @@ -34,7 +34,7 @@ pub fn setup_log(settings: &Settings) -> Result<()> { let subscriber = Registry::default().with(env_filter); // Define outputfile - let outputfile = OutputFile::new(&settings.output.path, &settings.log.file)?; + let outputfile = OutputFile::new(&settings.output().path, &settings.log().file)?; // Define layer for file let file_layer = fmt::layer() From ee3c9edd8a3a0f89314d6e301e629471c154af67 Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Mon, 23 Dec 2024 09:27:11 +0100 Subject: [PATCH 08/30] Update settings.rs --- src/algorithms/routines/settings.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/algorithms/routines/settings.rs b/src/algorithms/routines/settings.rs index 4584aa513..2022c0a74 100644 --- a/src/algorithms/routines/settings.rs +++ b/src/algorithms/routines/settings.rs @@ -119,18 +119,16 @@ impl Settings { &self.output } - pub fn set_convergence(mut self, convergence: Convergence) -> Self { + pub fn set_convergence(mut self, convergence: Convergence) { self.convergence = convergence; - self } pub fn convergence(&self) -> &Convergence { &self.convergence } - pub fn set_advanced(mut self, advanced: Advanced) -> Self { + pub fn set_advanced(mut self, advanced: Advanced) { self.advanced = advanced; - self } pub fn advanced(&self) -> &Advanced { From a45e45ba4c6e73cab1415453a8c0a66fda7ee4a1 Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Mon, 23 Dec 2024 10:48:34 +0100 Subject: [PATCH 09/30] Setters takes reference to mut self --- benches/compare.rs | 140 +++++++++++----------------- examples/bimodal_ke/main.rs | 3 - src/algorithms/routines/settings.rs | 14 +-- 3 files changed, 62 insertions(+), 95 deletions(-) diff --git a/benches/compare.rs b/benches/compare.rs index 97fde3b15..9c265ebb2 100644 --- a/benches/compare.rs +++ b/benches/compare.rs @@ -1,8 +1,7 @@ -use algorithms::AlgorithmType; use pmcore::prelude::*; use diol::prelude::*; -use settings::{Log, *}; +use settings::*; fn main() -> std::io::Result<()> { let mut bench = Bench::new(BenchConfig::from_args()?); @@ -125,94 +124,65 @@ fn ode_tel(bencher: Bencher, len: usize) { } fn tel_settings() -> Settings { - let settings = Settings { - config: Config { - cycles: 1000, - algorithm: AlgorithmType::NPAG, - cache: true, - ..Default::default() - }, - predictions: settings::Predictions::default(), - log: Log { - level: "warn".to_string(), - file: "".to_string(), - write: false, - }, - prior: Prior { - file: None, - sampler: "sobol".to_string(), - points: 2129, - seed: 347, - }, - output: Output { - write: false, - ..Default::default() - }, - convergence: Default::default(), - advanced: Default::default(), - error: Error { - value: 5.0, - class: "proportional".to_string(), - poly: (0.02, 0.05, -2e-04, 0.0), - }, - parameters: { - Parameters::new() - .add("Ka".to_string(), 0.1, 0.3, false) - .unwrap() - .add("Ke".to_string(), 0.001, 0.1, false) - .unwrap() - .add("Tlag1".to_string(), 0.0, 4.00, false) - .unwrap() - .add("V".to_string(), 30.0, 120.0, false) - .unwrap() - .to_owned() - }, - }; + let mut settings = Settings::new(); + + let parameters = Parameters::new() + .add("Ka", 0.1, 0.3, false) + .unwrap() + .add("Ke", 0.001, 0.1, false) + .unwrap() + .add("Tlag1", 0.0, 4.00, false) + .unwrap() + .add("V", 30.0, 120.0, false) + .unwrap() + .to_owned(); + + settings.set_parameters(parameters); + settings.set_config(Config { + cycles: 1000, + ..Default::default() + }); + settings.set_prior(Prior { + file: None, + sampler: "sobol".to_string(), + points: 2129, + seed: 347, + }); + settings.set_error(Error { + value: 5.0, + class: "proportional".to_string(), + poly: (0.02, 0.05, -2e-04, 0.0), + }); + settings.validate().unwrap(); settings } fn bke_settings() -> Settings { - let settings = Settings { - config: Config { - cycles: 1024, - algorithm: AlgorithmType::NPAG, - cache: true, - include: None, - exclude: None, - }, - predictions: settings::Predictions::default(), - log: Log { - level: "warn".to_string(), - file: "".to_string(), - write: false, - }, - prior: Prior { - file: None, - points: settings::Prior::default().points, - sampler: "sobol".to_string(), - ..Default::default() - }, - output: Output { - write: false, - path: "output".to_string(), - }, - convergence: Convergence::default(), - advanced: Advanced::default(), - error: Error { - value: 0.0, - class: "additive".to_string(), - poly: (0.0, 0.05, 0.0, 0.0), - }, - parameters: { - Parameters::new() - .add("Ke".to_string(), 0.001, 0.1, false) - .unwrap() - .add("V".to_string(), 25.0, 250.0, false) - .unwrap() - .to_owned() - }, - }; + let mut settings = Settings::new(); + + let parameters = Parameters::new() + .add("ke", 0.001, 3.0, true) + .unwrap() + .add("v", 25.0, 250.0, true) + .unwrap() + .to_owned(); + + settings.set_parameters(parameters); + settings.set_config(Config { + cycles: 1000, + ..Default::default() + }); + settings.set_output(Output { + write: false, + path: "".to_string(), + }); + settings.set_error(Error { + value: 0.0, + class: "additive".to_string(), + poly: (0.00, 0.05, 0.0, 0.0), + }); + settings.validate().unwrap(); settings } diff --git a/examples/bimodal_ke/main.rs b/examples/bimodal_ke/main.rs index 7524a458f..d299708f0 100644 --- a/examples/bimodal_ke/main.rs +++ b/examples/bimodal_ke/main.rs @@ -1,4 +1,3 @@ -use algorithms::AlgorithmType; use anyhow::Result; use logger::setup_log; use pmcore::prelude::*; @@ -51,8 +50,6 @@ fn main() -> Result<()> { settings.set_parameters(params); settings.set_config(Config { cycles: 1000, - algorithm: AlgorithmType::NPAG, - cache: true, ..Default::default() }); diff --git a/src/algorithms/routines/settings.rs b/src/algorithms/routines/settings.rs index 2022c0a74..337dfe104 100644 --- a/src/algorithms/routines/settings.rs +++ b/src/algorithms/routines/settings.rs @@ -79,7 +79,7 @@ impl Settings { &self.parameters } - pub fn set_error(mut self, error: Error) { + pub fn set_error(&mut self, error: Error) { self.error = error; } @@ -87,7 +87,7 @@ impl Settings { &self.error } - pub fn set_predictions(mut self, predictions: Predictions) { + pub fn set_predictions(&mut self, predictions: Predictions) { self.predictions = predictions; } @@ -95,7 +95,7 @@ impl Settings { &self.predictions } - pub fn set_log(mut self, log: Log) { + pub fn set_log(&mut self, log: Log) { self.log = log; } @@ -103,7 +103,7 @@ impl Settings { &self.log } - pub fn set_prior(mut self, prior: Prior) { + pub fn set_prior(&mut self, prior: Prior) { self.prior = prior; } @@ -111,7 +111,7 @@ impl Settings { &self.prior } - pub fn set_output(mut self, output: Output) { + pub fn set_output(&mut self, output: Output) { self.output = output; } @@ -119,7 +119,7 @@ impl Settings { &self.output } - pub fn set_convergence(mut self, convergence: Convergence) { + pub fn set_convergence(&mut self, convergence: Convergence) { self.convergence = convergence; } @@ -127,7 +127,7 @@ impl Settings { &self.convergence } - pub fn set_advanced(mut self, advanced: Advanced) { + pub fn set_advanced(&mut self, advanced: Advanced) { self.advanced = advanced; } From 1a32d995b1f14583b4f2d5d1f89a8abc14e9c4e7 Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Mon, 23 Dec 2024 10:57:08 +0100 Subject: [PATCH 10/30] Use BTreeMap instead of vector for parameters --- src/algorithms/routines/settings.rs | 32 ++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/src/algorithms/routines/settings.rs b/src/algorithms/routines/settings.rs index 337dfe104..77a0a1c4e 100644 --- a/src/algorithms/routines/settings.rs +++ b/src/algorithms/routines/settings.rs @@ -1,14 +1,13 @@ #![allow(dead_code)] - -use crate::algorithms::AlgorithmType; - use super::output::OutputFile; +use crate::algorithms::AlgorithmType; use anyhow::{bail, Result}; use config::Config as eConfig; use pharmsol::prelude::data::ErrorType; use serde::Deserialize; use serde_derive::Serialize; use serde_json; +use std::collections::BTreeMap; /// Contains all settings for PMcore #[derive(Debug, Deserialize, Clone, Serialize)] @@ -200,14 +199,14 @@ impl Parameter { /// This structure contains information on all [Parameter]s to be estimated #[derive(Debug, Clone, Deserialize, Serialize, Default)] pub struct Parameters { - parameters: Vec, + parameters: BTreeMap, } impl Parameters { /// Create a new set of parameters pub fn new() -> Self { Parameters { - parameters: Vec::new(), + parameters: BTreeMap::new(), } } @@ -220,24 +219,37 @@ impl Parameters { fixed: bool, ) -> Result { let parameter = Parameter::new(name, lower, upper, fixed)?; - self.parameters.push(parameter); + self.parameters.insert(parameter.name.clone(), parameter); Ok(self) } // Get a parameter by name pub fn get(&self, name: impl Into) -> Option<&Parameter> { - let name = name.into(); - self.parameters.iter().find(|p| p.name == name) + self.parameters.get(name.into().as_str()) } /// Get the names of the parameters pub fn names(&self) -> Vec { - self.parameters.iter().map(|p| p.name.clone()).collect() + self.parameters.keys().cloned().collect() } /// Get the ranges of the parameters + /// + /// Returns a vector of tuples, where each tuple contains the lower and upper bounds of the parameter pub fn ranges(&self) -> Vec<(f64, f64)> { - self.parameters.iter().map(|p| (p.lower, p.upper)).collect() + self.parameters + .values() + .map(|p| (p.lower, p.upper)) + .collect() + } +} + +impl IntoIterator for Parameters { + type Item = (String, Parameter); + type IntoIter = std::collections::btree_map::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.parameters.into_iter() } } From 2545fcc422a36e33520a23bc8e0a5eca216cc3fd Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Mon, 23 Dec 2024 11:15:51 +0100 Subject: [PATCH 11/30] Minor typos --- src/algorithms/npag.rs | 4 ++-- .../{adaptative_grid.rs => adaptive_grid.rs} | 18 +++++++++--------- src/algorithms/routines/mod.rs | 2 +- 3 files changed, 12 insertions(+), 12 deletions(-) rename src/algorithms/routines/expansion/{adaptative_grid.rs => adaptive_grid.rs} (91%) diff --git a/src/algorithms/npag.rs b/src/algorithms/npag.rs index 74bf2d69b..635d4722a 100644 --- a/src/algorithms/npag.rs +++ b/src/algorithms/npag.rs @@ -18,7 +18,7 @@ use pharmsol::{ use ndarray::{Array, Array1, Array2, ArrayBase, Axis, Dim, OwnedRepr}; use ndarray_stats::{DeviationExt, QuantileExt}; -use super::{adaptative_grid::adaptative_grid, initialization}; +use super::{adaptive_grid::adaptive_grid, initialization}; const THETA_E: f64 = 1e-4; // Convergence criteria const THETA_G: f64 = 1e-4; // Objective function convergence criteria @@ -336,7 +336,7 @@ impl Algorithm for NPAG { } fn expansion(&mut self) -> Result<()> { - adaptative_grid( + adaptive_grid( &mut self.theta, self.eps, &self.settings.parameters().ranges(), diff --git a/src/algorithms/routines/expansion/adaptative_grid.rs b/src/algorithms/routines/expansion/adaptive_grid.rs similarity index 91% rename from src/algorithms/routines/expansion/adaptative_grid.rs rename to src/algorithms/routines/expansion/adaptive_grid.rs index 35ccb048c..15b2df897 100644 --- a/src/algorithms/routines/expansion/adaptative_grid.rs +++ b/src/algorithms/routines/expansion/adaptive_grid.rs @@ -19,7 +19,7 @@ use crate::algorithms::routines::condensation::prune::prune; /// /// A 2D array containing the updated support points after the adaptive grid expansion. /// -pub fn adaptative_grid( +pub fn adaptive_grid( theta: &mut Array2, eps: f64, ranges: &[(f64, f64)], @@ -82,8 +82,8 @@ mod tests { let ranges = [(0.0, 1.0)]; let min_dist = 0.05; - // Call adaptative_grid - let new_theta = adaptative_grid(&mut theta, eps, &ranges, min_dist); + // Call adaptive_grid + let new_theta = adaptive_grid(&mut theta, eps, &ranges, min_dist); // Expected theta: [[0.5], [0.6], [0.4]] let expected_theta = array![[0.5], [0.6], [0.4]]; @@ -100,8 +100,8 @@ mod tests { let ranges = [(0.0, 1.0), (0.0, 1.0)]; let min_dist = 0.05; - // Call adaptative_grid - let new_theta = adaptative_grid(&mut theta, eps, &ranges, min_dist); + // Call adaptive_grid + let new_theta = adaptive_grid(&mut theta, eps, &ranges, min_dist); // Expected new points are: // For dimension 0: [0.6, 0.5], [0.4, 0.5] @@ -121,8 +121,8 @@ mod tests { let ranges = [(0.0, 1.0)]; let min_dist = 0.2; - // Call adaptative_grid - let new_theta = adaptative_grid(&mut theta, eps, &ranges, min_dist); + // Call adaptive_grid + let new_theta = adaptive_grid(&mut theta, eps, &ranges, min_dist); // Since min_dist is 0.2, the new points at 0.6 and 0.4 are too close to 0.5 (distance 0.1) // So no new points should be added @@ -140,8 +140,8 @@ mod tests { let ranges = [(0.0, 1.0)]; let min_dist = 0.05; - // Call adaptative_grid - let new_theta = adaptative_grid(&mut theta, eps, &ranges, min_dist); + // Call adaptive_grid + let new_theta = adaptive_grid(&mut theta, eps, &ranges, min_dist); // val + l = 0.95 + 0.1 = 1.05 > 1.0, so point at 1.05 is out of bounds and should not be added // val - l = 0.95 - 0.1 = 0.85, which is within range diff --git a/src/algorithms/routines/mod.rs b/src/algorithms/routines/mod.rs index 138982aac..d649979fe 100644 --- a/src/algorithms/routines/mod.rs +++ b/src/algorithms/routines/mod.rs @@ -13,7 +13,7 @@ pub mod condensation { } /// Routines for expanding grids pub mod expansion { - pub mod adaptative_grid; + pub mod adaptive_grid; } /// Provides routines for reading and parsing settings From d56b501f53810d20150e5845fee7685eed865b85 Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Mon, 23 Dec 2024 13:43:24 +0100 Subject: [PATCH 12/30] Update config.toml --- examples/bimodal_ke/config.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/bimodal_ke/config.toml b/examples/bimodal_ke/config.toml index faffb49b2..8acc032d4 100644 --- a/examples/bimodal_ke/config.toml +++ b/examples/bimodal_ke/config.toml @@ -1,3 +1,5 @@ +# Currently not in use - settings are defined in main.rs + [config] cycles = 1024 algorithm = "NPAG" From 7c01ccfe8200f0e74978f1315b31ef182f7b3e5f Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Mon, 23 Dec 2024 14:26:04 +0100 Subject: [PATCH 13/30] More getters/setters --- src/algorithms/routines/settings.rs | 171 +++++++++++++++++++++++----- src/logger.rs | 2 +- 2 files changed, 146 insertions(+), 27 deletions(-) diff --git a/src/algorithms/routines/settings.rs b/src/algorithms/routines/settings.rs index 77a0a1c4e..f3d15ddbf 100644 --- a/src/algorithms/routines/settings.rs +++ b/src/algorithms/routines/settings.rs @@ -4,10 +4,10 @@ use crate::algorithms::AlgorithmType; use anyhow::{bail, Result}; use config::Config as eConfig; use pharmsol::prelude::data::ErrorType; -use serde::Deserialize; -use serde_derive::Serialize; +use serde::{Deserialize, Serialize}; use serde_json; use std::collections::BTreeMap; +use std::fmt::Display; /// Contains all settings for PMcore #[derive(Debug, Deserialize, Clone, Serialize)] @@ -133,6 +133,90 @@ impl Settings { pub fn advanced(&self) -> &Advanced { &self.advanced } + + pub fn set_cycles(&mut self, cycles: usize) { + self.config.cycles = cycles; + } + + pub fn set_algorithm(&mut self, algorithm: AlgorithmType) { + self.config.algorithm = algorithm; + } + + pub fn set_cache(&mut self, cache: bool) { + self.config.cache = cache; + } + + pub fn set_include(&mut self, include: Option>) { + self.config.include = include; + } + + pub fn set_exclude(&mut self, exclude: Option>) { + self.config.exclude = exclude; + } + + pub fn set_gamlam(&mut self, value: f64) { + self.error.value = value; + } + + pub fn set_error_type(&mut self, class: ErrorType) { + self.error.class = class; + } + + pub fn set_error_poly(&mut self, poly: (f64, f64, f64, f64)) { + self.error.poly = poly; + } + + pub fn set_idelta(&mut self, idelta: f64) { + self.predictions.idelta = idelta; + } + + pub fn set_tad(&mut self, tad: f64) { + self.predictions.tad = tad; + } + + pub fn set_log_level(&mut self, level: LogLevel) { + self.log.level = level; + } + + pub fn set_log_file(&mut self, file: String) { + self.log.file = file; + } + + pub fn set_prior_sampler(&mut self, sampler: String) { + self.prior.sampler = sampler; + } + + pub fn set_prior_points(&mut self, points: usize) { + self.prior.points = points; + } + + pub fn set_prior_seed(&mut self, seed: usize) { + self.prior.seed = seed; + } + + pub fn set_prior_file(&mut self, file: Option) { + self.prior.file = file; + } + + pub fn set_output_write(&mut self, write: bool) { + self.output.write = write; + } + + pub fn set_output_path(&mut self, path: String) { + self.output.path = path; + } + + /// Writes a copy of the parsed settings to file + /// The is written to output folder specified in the [Output] and is named `settings.json`. + pub fn write(&self) -> Result<()> { + let serialized = serde_json::to_string_pretty(self) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; + + let outputfile = OutputFile::new(self.output.path.as_str(), "settings.json")?; + let mut file = outputfile.file; + std::io::Write::write_all(&mut file, serialized.as_bytes())?; + Ok(()) + } } /// General configuration settings @@ -260,7 +344,7 @@ pub struct Error { /// The initial value of `gamma` or `lambda` pub value: f64, /// The error class, either `additive` or `proportional` - pub class: String, + pub class: ErrorType, /// The assay error polynomial pub poly: (f64, f64, f64, f64), } @@ -269,7 +353,7 @@ impl Default for Error { fn default() -> Self { Error { value: 0.0, - class: "additive".to_string(), + class: ErrorType::Additive, poly: (0.0, 0.1, 0.0, 0.0), } } @@ -287,11 +371,7 @@ impl Error { } pub fn error_type(&self) -> ErrorType { - match self.class.to_lowercase().as_str() { - "additive" | "l" | "lambda" => ErrorType::Add, - "proportional" | "g" | "gamma" => ErrorType::Prop, - _ => panic!("Error class '{}' not supported. Possible classes are 'gamma' (proportional) or 'lambda' (additive)", self.class), - } + self.class } } @@ -389,6 +469,59 @@ impl Predictions { } } +/// The log level, which can be one of the following: +/// - `TRACE` +/// - `DEBUG` +/// - `INFO` +/// - `WARN` +/// - `ERROR` +/// +/// The default log level is `info` +#[derive(Debug, Deserialize, Clone, Serialize)] +pub enum LogLevel { + TRACE, + DEBUG, + INFO, + WARN, + ERROR, +} + +impl Into for LogLevel { + fn into(self) -> tracing::Level { + match self { + LogLevel::TRACE => tracing::Level::TRACE, + LogLevel::DEBUG => tracing::Level::DEBUG, + LogLevel::INFO => tracing::Level::INFO, + LogLevel::WARN => tracing::Level::WARN, + LogLevel::ERROR => tracing::Level::ERROR, + } + } +} + +impl AsRef for LogLevel { + fn as_ref(&self) -> &str { + match self { + LogLevel::TRACE => "trace", + LogLevel::DEBUG => "debug", + LogLevel::INFO => "info", + LogLevel::WARN => "warn", + LogLevel::ERROR => "error", + } + } +} + +impl Display for LogLevel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.as_ref()) + } +} + +impl Default for LogLevel { + fn default() -> Self { + LogLevel::INFO + } +} + #[derive(Debug, Deserialize, Clone, Serialize)] #[serde(deny_unknown_fields, default)] pub struct Log { @@ -400,7 +533,7 @@ pub struct Log { /// - `info` /// - `warn` /// - `error` - pub level: String, + pub level: LogLevel, /// The file to write the log to pub file: String, /// Whether to write logs @@ -413,7 +546,7 @@ pub struct Log { impl Default for Log { fn default() -> Self { Log { - level: String::from("info"), + level: LogLevel::INFO, file: String::from("log.txt"), write: true, } @@ -524,24 +657,10 @@ pub fn read(path: impl Into) -> Result { // Write a copy of the settings to file if output is enabled if settings.output.write { - if let Err(error) = write_settings_to_file(&settings) { + if let Err(error) = settings.write() { bail!("Could not write settings to file: {}", error); } } Ok(settings) // Return the settings wrapped in Ok } - -/// Writes a copy of the parsed settings to file -/// -/// This function writes a copy of the parsed settings to file. -/// The file is written to output folder specified in the [settings](crate::routines::settings::Settings::paths), and is named `settings.json`. -pub fn write_settings_to_file(settings: &Settings) -> Result<()> { - let serialized = serde_json::to_string_pretty(settings) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?; - - let outputfile = OutputFile::new(settings.output.path.as_str(), "settings.json")?; - let mut file = outputfile.file; - std::io::Write::write_all(&mut file, serialized.as_bytes())?; - Ok(()) -} diff --git a/src/logger.rs b/src/logger.rs index 735ffe256..9ad11ae1d 100644 --- a/src/logger.rs +++ b/src/logger.rs @@ -23,7 +23,7 @@ use tracing_subscriber::EnvFilter; /// If not, the log messages are written to stdout. pub fn setup_log(settings: &Settings) -> Result<()> { // Use the log level defined in configuration file - let log_level = settings.log().level.as_str(); + let log_level: String = settings.log().level.to_string(); let env_filter = EnvFilter::new(log_level); let timestamper = CompactTimestamp { From 8b3dca10021c3d2ec726b705a1e399c2bb8c07e3 Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Mon, 23 Dec 2024 14:26:31 +0100 Subject: [PATCH 14/30] Rename POSTPROB to MAP Also updated documentation to match --- src/algorithms/map.rs | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/algorithms/map.rs b/src/algorithms/map.rs index c68688679..afff50ce7 100644 --- a/src/algorithms/map.rs +++ b/src/algorithms/map.rs @@ -12,8 +12,9 @@ use ndarray::{Array1, Array2}; use super::{initialization, output::CycleLog}; -/// Posterior probability algorithm -/// Reweights the prior probabilities to the observed data and error model +/// Maximim a posteriori (MAP) estimation +/// +/// Calculate the MAP estimate of the parameters of the model given the data. pub struct MAP { equation: E, psi: Array2, @@ -42,11 +43,7 @@ impl Algorithm for MAP { cycle: 0, converged: false, gamma: settings.error().value, - error_type: match settings.error().class.as_str() { - "additive" => ErrorType::Add, - "proportional" => ErrorType::Prop, - _ => panic!("Error type not supported"), - }, + error_type: settings.error().error_type(), c: settings.error().poly, settings, data, From b51bdd49cc4e913d6e15592137881b93b2cceff4 Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Mon, 23 Dec 2024 14:44:16 +0100 Subject: [PATCH 15/30] Update examples --- benches/compare.rs | 4 ++-- examples/bimodal_ke/main.rs | 10 +++++----- examples/two_eq_lag/main.rs | 26 ++++++++++++++++++++++---- src/algorithms/routines/settings.rs | 4 ++-- 4 files changed, 31 insertions(+), 13 deletions(-) diff --git a/benches/compare.rs b/benches/compare.rs index 9c265ebb2..838ef2acb 100644 --- a/benches/compare.rs +++ b/benches/compare.rs @@ -150,7 +150,7 @@ fn tel_settings() -> Settings { }); settings.set_error(Error { value: 5.0, - class: "proportional".to_string(), + class: ErrorType::Proportional, poly: (0.02, 0.05, -2e-04, 0.0), }); @@ -179,7 +179,7 @@ fn bke_settings() -> Settings { }); settings.set_error(Error { value: 0.0, - class: "additive".to_string(), + class: ErrorType::Additive, poly: (0.00, 0.05, 0.0, 0.0), }); diff --git a/examples/bimodal_ke/main.rs b/examples/bimodal_ke/main.rs index d299708f0..aadc24556 100644 --- a/examples/bimodal_ke/main.rs +++ b/examples/bimodal_ke/main.rs @@ -1,7 +1,7 @@ use anyhow::Result; use logger::setup_log; use pmcore::prelude::*; -use settings::{Config, Parameters, Settings}; +use settings::{Parameters, Settings}; fn main() -> Result<()> { let eq = equation::ODE::new( |x, p, _t, dx, rateiv, _cov| { @@ -48,10 +48,10 @@ fn main() -> Result<()> { .add("v", 25.0, 250.0, true)?; settings.set_parameters(params); - settings.set_config(Config { - cycles: 1000, - ..Default::default() - }); + settings.set_cycles(1000); + settings.set_error_poly((0.0, 0.5, 0.0, 0.0)); + settings.set_error_type(ErrorType::Additive); + settings.set_output_path("examples/bimodal_ke/output"); setup_log(&settings)?; let data = data::read_pmetrics("examples/bimodal_ke/bimodal_ke.csv")?; diff --git a/examples/two_eq_lag/main.rs b/examples/two_eq_lag/main.rs index e5dbc2cfb..eac1266d2 100644 --- a/examples/two_eq_lag/main.rs +++ b/examples/two_eq_lag/main.rs @@ -9,23 +9,24 @@ use logger::setup_log; use ndarray::Array2; use pmcore::prelude::{models::one_compartment_with_absorption, simulator::Equation, *}; use settings::Parameters; +use settings::Settings; fn main() { let eq = equation::ODE::new( |x, p, _t, dx, rateiv, _cov| { fetch_cov!(cov, t,); - fetch_params!(p, ka, ke, _tlag, _v); + fetch_params!(p, ka, ke, tlag, v); dx[0] = -ka * x[0]; dx[1] = ka * x[0] - ke * x[1]; }, |p| { - fetch_params!(p, _ka, _ke, tlag, _v); + fetch_params!(p, ka, ke, tlag, v); lag! {0=>tlag} }, |_p| fa! {}, |_p, _t, _cov, _x| {}, |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, _tlag, v); + fetch_params!(p, ka, ke, tlag, v); y[0] = x[1] / v; }, (2, 1), @@ -73,7 +74,24 @@ fn main() { // (2, 1), // ); - let settings = settings::read("examples/two_eq_lag/config.toml").unwrap(); + let mut settings = Settings::new(); + + let parameters = Parameters::new() + .add("ka", 0.1, 0.3, false) + .unwrap() + .add("ke", 0.001, 0.1, false) + .unwrap() + .add("tlag", 0.0, 4.00, false) + .unwrap() + .add("v", 30.0, 120.0, false) + .unwrap(); + + settings.set_parameters(parameters); + + settings.set_gamlam(5.0); + settings.set_error_poly((0.02, 0.05, -2e-04, 0.0)); + settings.set_prior_points(2129); + settings.set_cycles(1000); setup_log(&settings).unwrap(); let data = data::read_pmetrics("examples/two_eq_lag/two_eq_lag.csv").unwrap(); diff --git a/src/algorithms/routines/settings.rs b/src/algorithms/routines/settings.rs index f3d15ddbf..eeb8edf56 100644 --- a/src/algorithms/routines/settings.rs +++ b/src/algorithms/routines/settings.rs @@ -202,8 +202,8 @@ impl Settings { self.output.write = write; } - pub fn set_output_path(&mut self, path: String) { - self.output.path = path; + pub fn set_output_path(&mut self, path: impl Into) { + self.output.path = path.into(); } /// Writes a copy of the parsed settings to file From a89c1d1e68ef3b80b2bac4fa98b3034c509df242 Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Mon, 23 Dec 2024 16:17:02 +0100 Subject: [PATCH 16/30] Structure name change --- src/algorithms/map.rs | 6 +++--- src/algorithms/mod.rs | 18 ++++++++++++------ src/algorithms/npag.rs | 5 ++--- src/algorithms/npod.rs | 8 +++++--- src/algorithms/routines/settings.rs | 10 +++++----- 5 files changed, 27 insertions(+), 20 deletions(-) diff --git a/src/algorithms/map.rs b/src/algorithms/map.rs index afff50ce7..aa93a6a52 100644 --- a/src/algorithms/map.rs +++ b/src/algorithms/map.rs @@ -1,4 +1,4 @@ -use crate::prelude::{algorithms::Algorithm, ipm::burke, output::NPResult, settings::Settings}; +use crate::prelude::{ipm::burke, output::NPResult, settings::Settings}; use anyhow::Result; use pharmsol::{ prelude::{ @@ -10,7 +10,7 @@ use pharmsol::{ use ndarray::{Array1, Array2}; -use super::{initialization, output::CycleLog}; +use super::{initialization, output::CycleLog, NonParametric}; /// Maximim a posteriori (MAP) estimation /// @@ -32,7 +32,7 @@ pub struct MAP { cyclelog: CycleLog, } -impl Algorithm for MAP { +impl NonParametric for MAP { fn new(settings: Settings, equation: E, data: Data) -> Result, anyhow::Error> { Ok(Box::new(Self { equation, diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index f982f5370..30ab94021 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -20,14 +20,20 @@ pub mod npag; pub mod npod; pub mod routines; +/// Supported algorithms by `PMcore` +/// +/// - `NPAG`: Non-Parametric Adaptive Grid +/// - `NPOD`: Non-Parametric Optimal Design +/// - `MAP`: Maximum A Posteriori #[derive(Debug, Clone, Serialize, Deserialize)] -pub enum AlgorithmType { +pub enum Algorithm { NPAG, NPOD, MAP, } -pub trait Algorithm { +/// This traint defines the methods for non-parametric (NP) algorithms +pub trait NonParametric { fn new(config: Settings, equation: E, data: Data) -> Result, Error> where Self: Sized; @@ -109,10 +115,10 @@ pub fn dispatch_algorithm( settings: Settings, equation: E, data: Data, -) -> Result>, Error> { +) -> Result>, Error> { match settings.config().algorithm { - AlgorithmType::NPAG => Ok(NPAG::new(settings, equation, data)?), - AlgorithmType::NPOD => Ok(NPOD::new(settings, equation, data)?), - AlgorithmType::MAP => Ok(MAP::new(settings, equation, data)?), + Algorithm::NPAG => Ok(NPAG::new(settings, equation, data)?), + Algorithm::NPOD => Ok(NPOD::new(settings, equation, data)?), + Algorithm::MAP => Ok(MAP::new(settings, equation, data)?), } } diff --git a/src/algorithms/npag.rs b/src/algorithms/npag.rs index 635d4722a..6141e7740 100644 --- a/src/algorithms/npag.rs +++ b/src/algorithms/npag.rs @@ -1,5 +1,4 @@ use crate::prelude::{ - algorithms::Algorithm, ipm::burke, output::{CycleLog, NPCycle, NPResult}, qr, @@ -18,7 +17,7 @@ use pharmsol::{ use ndarray::{Array, Array1, Array2, ArrayBase, Axis, Dim, OwnedRepr}; use ndarray_stats::{DeviationExt, QuantileExt}; -use super::{adaptive_grid::adaptive_grid, initialization}; +use super::{adaptive_grid::adaptive_grid, initialization, NonParametric}; const THETA_E: f64 = 1e-4; // Convergence criteria const THETA_G: f64 = 1e-4; // Objective function convergence criteria @@ -48,7 +47,7 @@ pub struct NPAG { settings: Settings, } -impl Algorithm for NPAG { +impl NonParametric for NPAG { fn new(settings: Settings, equation: E, data: Data) -> Result, anyhow::Error> { Ok(Box::new(Self { equation, diff --git a/src/algorithms/npod.rs b/src/algorithms/npod.rs index 5d031a18d..123ebd0ae 100644 --- a/src/algorithms/npod.rs +++ b/src/algorithms/npod.rs @@ -1,5 +1,4 @@ use crate::prelude::{ - algorithms::Algorithm, ipm::burke, output::{CycleLog, NPCycle, NPResult}, qr, @@ -21,7 +20,10 @@ use ndarray::{ }; use ndarray_stats::{DeviationExt, QuantileExt}; -use super::{condensation::prune::prune, initialization, optimization::d_optimizer::SppOptimizer}; +use super::{ + condensation::prune::prune, initialization, optimization::d_optimizer::SppOptimizer, + NonParametric, +}; const THETA_F: f64 = 1e-2; const THETA_D: f64 = 1e-4; @@ -45,7 +47,7 @@ pub struct NPOD { settings: Settings, } -impl Algorithm for NPOD { +impl NonParametric for NPOD { fn new(settings: Settings, equation: E, data: Data) -> Result, anyhow::Error> { Ok(Box::new(Self { equation, diff --git a/src/algorithms/routines/settings.rs b/src/algorithms/routines/settings.rs index eeb8edf56..c06f3f207 100644 --- a/src/algorithms/routines/settings.rs +++ b/src/algorithms/routines/settings.rs @@ -1,6 +1,6 @@ #![allow(dead_code)] use super::output::OutputFile; -use crate::algorithms::AlgorithmType; +use crate::algorithms::Algorithm; use anyhow::{bail, Result}; use config::Config as eConfig; use pharmsol::prelude::data::ErrorType; @@ -138,7 +138,7 @@ impl Settings { self.config.cycles = cycles; } - pub fn set_algorithm(&mut self, algorithm: AlgorithmType) { + pub fn set_algorithm(&mut self, algorithm: Algorithm) { self.config.algorithm = algorithm; } @@ -226,7 +226,7 @@ pub struct Config { /// Maximum number of cycles to run pub cycles: usize, /// Denotes the algorithm to use - pub algorithm: AlgorithmType, + pub algorithm: Algorithm, /// If true (default), cache predicted values pub cache: bool, /// Vector of IDs to include @@ -239,7 +239,7 @@ impl Default for Config { fn default() -> Self { Config { cycles: 100, - algorithm: AlgorithmType::NPAG, + algorithm: Algorithm::NPAG, cache: true, include: None, exclude: None, @@ -476,7 +476,7 @@ impl Predictions { /// - `WARN` /// - `ERROR` /// -/// The default log level is `info` +/// The default log level is `INFO` #[derive(Debug, Deserialize, Clone, Serialize)] pub enum LogLevel { TRACE, From c578733c199acc64421c334a5c61c31dfa02a48f Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Mon, 23 Dec 2024 16:26:18 +0100 Subject: [PATCH 17/30] clippy --- examples/debug.rs | 2 +- src/algorithms/mod.rs | 1 + .../routines/evaluation/ipm_faer.rs | 2 +- .../routines/initialization/latin.rs | 2 +- src/algorithms/routines/initialization/mod.rs | 4 +-- .../routines/initialization/sobol.rs | 2 +- src/algorithms/routines/output.rs | 26 ++++++++++++------- src/algorithms/routines/settings.rs | 17 +++++------- 8 files changed, 29 insertions(+), 27 deletions(-) diff --git a/examples/debug.rs b/examples/debug.rs index 842e9146c..303856d56 100644 --- a/examples/debug.rs +++ b/examples/debug.rs @@ -1,4 +1,4 @@ -use algorithms::{npag::NPAG, Algorithm}; +use algorithms::{npag::NPAG, NonParametric}; use ipm::burke; use logger::setup_log; use pmcore::prelude::*; diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index 30ab94021..0a4adb789 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -108,6 +108,7 @@ pub trait NonParametric { } {} Ok(self.into_npresult()) } + #[allow(clippy::wrong_self_convention)] fn into_npresult(&self) -> NPResult; } diff --git a/src/algorithms/routines/evaluation/ipm_faer.rs b/src/algorithms/routines/evaluation/ipm_faer.rs index 04dce592e..d189d6119 100644 --- a/src/algorithms/routines/evaluation/ipm_faer.rs +++ b/src/algorithms/routines/evaluation/ipm_faer.rs @@ -94,7 +94,7 @@ pub fn burke( let inner = zipped!(lam.as_ref(), y.as_ref()).map(|unzipped!(lam_i, y_i)| *lam_i / *y_i); let w_plam = zipped!(plam.as_ref(), w.as_ref()).map(|unzipped!(plam_i, w_i)| *plam_i / *w_i); - let h = (&psi * inner.as_ref().col(0).column_vector_as_diagonal()) * &psi.transpose(); + let h = (&psi * inner.as_ref().col(0).column_vector_as_diagonal()) * psi.transpose(); let mut aux: Mat = Mat::zeros(row, row); for i in 0..row { let diag = aux.get_mut(i, i); diff --git a/src/algorithms/routines/initialization/latin.rs b/src/algorithms/routines/initialization/latin.rs index f19d0aa26..14fce1323 100644 --- a/src/algorithms/routines/initialization/latin.rs +++ b/src/algorithms/routines/initialization/latin.rs @@ -23,7 +23,7 @@ use rand::SeedableRng; /// pub fn generate( n_points: usize, - range_params: &Vec<(f64, f64)>, + range_params: &[(f64, f64)], seed: usize, ) -> Result, Dim<[usize; 2]>>> { let n_params = range_params.len(); diff --git a/src/algorithms/routines/initialization/mod.rs b/src/algorithms/routines/initialization/mod.rs index e8d959ad2..b168c37c8 100644 --- a/src/algorithms/routines/initialization/mod.rs +++ b/src/algorithms/routines/initialization/mod.rs @@ -49,7 +49,7 @@ pub fn sample_space(settings: &Settings, data: &Data, eqn: &impl Equation) -> Re } /// This function reads the prior distribution from a file -pub fn parse_prior(path: &String, names: &Vec) -> Result> { +pub fn parse_prior(path: &String, names: &[String]) -> Result> { tracing::info!("Reading prior from {}", path); let file = File::open(path).context(format!("Unable to open the prior file '{}'", path))?; let mut reader = csv::ReaderBuilder::new() @@ -69,7 +69,7 @@ pub fn parse_prior(path: &String, names: &Vec) -> Result> { } // Check and reorder parameters to match names in settings.parsed.random - let random_names: Vec = names.clone(); + let random_names: Vec = names.to_owned(); let mut reordered_indices: Vec = Vec::new(); for random_name in &random_names { diff --git a/src/algorithms/routines/initialization/sobol.rs b/src/algorithms/routines/initialization/sobol.rs index a3d2e47a0..97b238361 100644 --- a/src/algorithms/routines/initialization/sobol.rs +++ b/src/algorithms/routines/initialization/sobol.rs @@ -23,7 +23,7 @@ use sobol_burley::sample; /// pub fn generate( n_points: usize, - range_params: &Vec<(f64, f64)>, + range_params: &[(f64, f64)], seed: usize, ) -> Result, Dim<[usize; 2]>>> { let n_params = range_params.len(); diff --git a/src/algorithms/routines/output.rs b/src/algorithms/routines/output.rs index 4b228a336..5958694d0 100644 --- a/src/algorithms/routines/output.rs +++ b/src/algorithms/routines/output.rs @@ -96,7 +96,7 @@ impl NPResult { self.write_pred(idelta, tad) .context("Failed to write predictions")?; self.write_covs().context("Failed to write covariates")?; - if self.w.len() > 0 { + if !self.w.is_empty() { //TODO: find a better way to indicate that the run failed self.write_posterior() .context("Failed to write posterior")?; @@ -322,12 +322,12 @@ impl NPResult { let subjects = self.data.get_subjects(); for (sub, row) in posterior.axis_iter(Axis(0)).enumerate() { for (spp, elem) in row.axis_iter(Axis(0)).enumerate() { - writer.write_field(&subjects.get(sub).unwrap().id())?; + writer.write_field(subjects.get(sub).unwrap().id())?; writer.write_field(format!("{}", spp))?; for param in theta.row(spp) { - writer.write_field(&format!("{param}"))?; + writer.write_field(format!("{param}"))?; } - writer.write_field(&format!("{elem:.10}"))?; + writer.write_field(format!("{elem:.10}"))?; writer.write_record(None::<&[u8]>)?; } } @@ -479,7 +479,7 @@ impl NPResult { for occasion in subject.occasions() { if let Some(cov) = occasion.get_covariates() { let covmap = cov.covariates(); - for (cov_name, _) in &covmap { + for cov_name in covmap.keys() { covariate_names.insert(cov_name.clone()); } } @@ -657,6 +657,12 @@ impl CycleLog { } } +impl Default for CycleLog { + fn default() -> Self { + Self::new() + } +} + pub fn posterior(psi: &Array2, w: &Array1) -> Result> { let py = psi.dot(w); let mut post: Array2 = Array2::zeros((psi.nrows(), psi.ncols())); @@ -739,7 +745,7 @@ pub fn population_mean_median( theta: &Array2, w: &Array1, ) -> Result<(Array1, Array1)> { - let w = if w.len() == 0 { + let w = if w.is_empty() { tracing::warn!("w.len() == 0, setting all weights to 1/n"); Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64) } else { @@ -785,7 +791,7 @@ pub fn posterior_mean_median( let mut mean = Array2::zeros((0, theta.ncols())); let mut median = Array2::zeros((0, theta.ncols())); - let w = if w.len() == 0 { + let w = if w.is_empty() { tracing::warn!("w.len() == 0, setting all weights to 1/n"); Array1::from_elem(theta.nrows(), 1.0 / theta.nrows() as f64) } else { @@ -879,15 +885,15 @@ impl OutputFile { pub fn write_pmetrics_observations(data: &Data, file: &std::fs::File) -> Result<()> { let mut writer = WriterBuilder::new().has_headers(true).from_writer(file); - writer.write_record(&["id", "block", "time", "out", "outeq"])?; + writer.write_record(["id", "block", "time", "out", "outeq"])?; for subject in data.get_subjects() { for occasion in subject.occasions() { for event in occasion.get_events(&None, &None, false) { match event { Event::Observation(obs) => { // Write each field individually - writer.write_record(&[ - &subject.id(), + writer.write_record([ + subject.id(), &occasion.index().to_string(), &obs.time().to_string(), &obs.value().to_string(), diff --git a/src/algorithms/routines/settings.rs b/src/algorithms/routines/settings.rs index c06f3f207..1734bfdcd 100644 --- a/src/algorithms/routines/settings.rs +++ b/src/algorithms/routines/settings.rs @@ -477,18 +477,19 @@ impl Predictions { /// - `ERROR` /// /// The default log level is `INFO` -#[derive(Debug, Deserialize, Clone, Serialize)] +#[derive(Debug, Deserialize, Clone, Serialize, Default)] pub enum LogLevel { TRACE, DEBUG, + #[default] INFO, WARN, ERROR, } -impl Into for LogLevel { - fn into(self) -> tracing::Level { - match self { +impl From for tracing::Level { + fn from(log_level: LogLevel) -> tracing::Level { + match log_level { LogLevel::TRACE => tracing::Level::TRACE, LogLevel::DEBUG => tracing::Level::DEBUG, LogLevel::INFO => tracing::Level::INFO, @@ -516,12 +517,6 @@ impl Display for LogLevel { } } -impl Default for LogLevel { - fn default() -> Self { - LogLevel::INFO - } -} - #[derive(Debug, Deserialize, Clone, Serialize)] #[serde(deny_unknown_fields, default)] pub struct Log { @@ -605,7 +600,7 @@ impl Output { //// /// If a `#` symbol is found, it will automatically increment the number by one. pub fn parse_output_folder(&mut self) -> Result<()> { - if self.path.is_empty() || self.path == "" { + if self.path.is_empty() || self.path.is_empty() { // Set a default path if none is provided self.path = Output::default().path; } From 88b5aacad18398c030c0ae420988d91e545cc652 Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Mon, 23 Dec 2024 16:27:52 +0100 Subject: [PATCH 18/30] clippy --- examples/drusano/main.rs | 2 +- .../routines/initialization/latin.rs | 2 +- .../routines/initialization/sobol.rs | 4 ++-- src/algorithms/routines/output.rs | 22 +++++++++---------- 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/examples/drusano/main.rs b/examples/drusano/main.rs index aa3e68744..7e19f6488 100644 --- a/examples/drusano/main.rs +++ b/examples/drusano/main.rs @@ -157,7 +157,7 @@ fn find_m0(ufinal: f64, v: f64, alpha: f64, h1: f64, h2: f64) -> f64 { let b3 = alpha * v * u * hh / xm.powf(hh + 1.0); let xmp = top / (b1 + b2 + b3); - xm = xm + xmp * delu; + xm += xmp * delu; if xm <= 0.0 { return -1.0; // Greco equation is not solvable diff --git a/src/algorithms/routines/initialization/latin.rs b/src/algorithms/routines/initialization/latin.rs index 14fce1323..9a1a923e2 100644 --- a/src/algorithms/routines/initialization/latin.rs +++ b/src/algorithms/routines/initialization/latin.rs @@ -50,7 +50,7 @@ mod tests { #[test] fn test_generate_lhs() { - let result = generate(5, &vec![(0., 1.), (0., 100.), (0., 1000.)], 42).unwrap(); + let result = generate(5, &[(0., 1.), (0., 100.), (0., 1000.)], 42).unwrap(); assert_eq!(result.shape(), &[5, 3]); assert_eq!( result, diff --git a/src/algorithms/routines/initialization/sobol.rs b/src/algorithms/routines/initialization/sobol.rs index 97b238361..48ca574ab 100644 --- a/src/algorithms/routines/initialization/sobol.rs +++ b/src/algorithms/routines/initialization/sobol.rs @@ -50,7 +50,7 @@ use crate::prelude::*; #[test] fn basic_sobol() { assert_eq!( - initialization::sobol::generate(5, &vec![(0., 1.), (0., 1.), (0., 1.)], 347).unwrap(), + initialization::sobol::generate(5, &[(0., 1.), (0., 1.), (0., 1.)], 347).unwrap(), ndarray::array![ [0.10731887817382813, 0.14647412300109863, 0.5851038694381714], [0.9840304851531982, 0.7633365392684937, 0.19097506999969482], @@ -64,7 +64,7 @@ fn basic_sobol() { #[test] fn scaled_sobol() { assert_eq!( - initialization::sobol::generate(5, &vec![(0., 1.), (0., 2.), (-1., 1.)], 347).unwrap(), + initialization::sobol::generate(5, &[(0., 1.), (0., 2.), (-1., 1.)], 347).unwrap(), ndarray::array![ [ 0.10731887817382813, diff --git a/src/algorithms/routines/output.rs b/src/algorithms/routines/output.rs index 5958694d0..657fac760 100644 --- a/src/algorithms/routines/output.rs +++ b/src/algorithms/routines/output.rs @@ -27,6 +27,7 @@ pub struct NPResult { cyclelog: CycleLog, } +#[allow(clippy::too_many_arguments)] impl NPResult { /// Create a new NPResult object pub fn new( @@ -889,18 +890,15 @@ pub fn write_pmetrics_observations(data: &Data, file: &std::fs::File) -> Result< for subject in data.get_subjects() { for occasion in subject.occasions() { for event in occasion.get_events(&None, &None, false) { - match event { - Event::Observation(obs) => { - // Write each field individually - writer.write_record([ - subject.id(), - &occasion.index().to_string(), - &obs.time().to_string(), - &obs.value().to_string(), - &obs.outeq().to_string(), - ])?; - } - _ => {} + if let Event::Observation(obs) = event { + // Write each field individually + writer.write_record([ + subject.id(), + &occasion.index().to_string(), + &obs.time().to_string(), + &obs.value().to_string(), + &obs.outeq().to_string(), + ])?; } } } From 9abaa71d04416709d5cf40931851562f65d957d3 Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Mon, 23 Dec 2024 16:32:01 +0100 Subject: [PATCH 19/30] Clean up NPAG --- src/algorithms/npag.rs | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/algorithms/npag.rs b/src/algorithms/npag.rs index 6141e7740..96c687e3e 100644 --- a/src/algorithms/npag.rs +++ b/src/algorithms/npag.rs @@ -8,7 +8,7 @@ use anyhow::bail; use anyhow::Result; use pharmsol::{ prelude::{ - data::{Data, ErrorModel, ErrorType}, + data::{Data, ErrorModel}, simulator::{psi, Equation}, }, Subject, Theta, @@ -39,11 +39,9 @@ pub struct NPAG { cycle: usize, gamma_delta: f64, gamma: f64, - error_type: ErrorType, converged: bool, cycle_log: CycleLog, data: Data, - c: (f64, f64, f64, f64), settings: Settings, } @@ -63,10 +61,8 @@ impl NonParametric for NPAG { cycle: 0, gamma_delta: 0.1, gamma: settings.error().value, - error_type: settings.error().error_type(), converged: false, cycle_log: CycleLog::new(), - c: settings.error().poly, settings, data, })) @@ -178,7 +174,11 @@ impl NonParametric for NPAG { &self.equation, &self.data, &theta, - &ErrorModel::new(self.c, self.gamma, &self.error_type), + &ErrorModel::new( + self.settings.error().poly, + self.gamma, + &self.settings.error().class, + ), self.cycle == 1 && self.settings.log().write, self.cycle != 1, ); @@ -265,7 +265,11 @@ impl NonParametric for NPAG { &self.equation, &self.data, &theta, - &ErrorModel::new(self.c, gamma_up, &self.error_type), + &ErrorModel::new( + self.settings.error().poly, + gamma_up, + &self.settings.error().class, + ), false, true, ); @@ -273,7 +277,11 @@ impl NonParametric for NPAG { &self.equation, &self.data, &theta, - &ErrorModel::new(self.c, gamma_down, &self.error_type), + &ErrorModel::new( + self.settings.error().poly, + gamma_down, + &self.settings.error().class, + ), false, true, ); From aeff351c4f4e82257c65144dc62c1de0c4191312 Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Mon, 23 Dec 2024 16:34:21 +0100 Subject: [PATCH 20/30] Clean up NPOD --- src/algorithms/npod.rs | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/src/algorithms/npod.rs b/src/algorithms/npod.rs index 123ebd0ae..916bdac37 100644 --- a/src/algorithms/npod.rs +++ b/src/algorithms/npod.rs @@ -8,7 +8,7 @@ use anyhow::bail; use anyhow::Result; use pharmsol::{ prelude::{ - data::{Data, ErrorModel, ErrorType}, + data::{Data, ErrorModel}, simulator::{psi, Equation}, }, Subject, Theta, @@ -39,11 +39,9 @@ pub struct NPOD { cycle: usize, gamma_delta: f64, gamma: f64, - error_type: ErrorType, converged: bool, cycle_log: CycleLog, data: Data, - c: (f64, f64, f64, f64), settings: Settings, } @@ -60,10 +58,8 @@ impl NonParametric for NPOD { cycle: 0, gamma_delta: 0.1, gamma: settings.error().value, - error_type: settings.error().error_type(), converged: false, cycle_log: CycleLog::new(), - c: settings.error().poly, settings, data, })) @@ -165,7 +161,11 @@ impl NonParametric for NPOD { &self.equation, &self.data, &theta, - &ErrorModel::new(self.c, self.gamma, &self.error_type), + &ErrorModel::new( + self.settings.error().poly, + self.gamma, + &self.settings.error().class, + ), self.cycle == 1, self.cycle != 1, ); @@ -251,7 +251,11 @@ impl NonParametric for NPOD { &self.equation, &self.data, &theta, - &ErrorModel::new(self.c, gamma_up, &self.error_type), + &ErrorModel::new( + self.settings.error().poly, + self.gamma, + &self.settings.error().class, + ), false, true, ); @@ -259,7 +263,11 @@ impl NonParametric for NPOD { &self.equation, &self.data, &theta, - &ErrorModel::new(self.c, gamma_down, &self.error_type), + &ErrorModel::new( + self.settings.error().poly, + self.gamma, + &self.settings.error().class, + ), false, true, ); @@ -324,7 +332,11 @@ impl NonParametric for NPOD { let pyl = self.psi.dot(&self.w); // Add new point to theta based on the optimization of the D function - let sigma = ErrorModel::new(self.c, self.gamma, &self.error_type); + let sigma = ErrorModel::new( + self.settings.error().poly, + self.gamma, + &self.settings.error().class, + ); let mut candididate_points: Vec> = Vec::default(); for spp in self.theta.clone().rows() { From 419cad4195d2759bbff726b7d1ddcf0ac889573a Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Tue, 24 Dec 2024 12:49:52 +0100 Subject: [PATCH 21/30] Typo --- src/algorithms/map.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algorithms/map.rs b/src/algorithms/map.rs index aa93a6a52..c98010f58 100644 --- a/src/algorithms/map.rs +++ b/src/algorithms/map.rs @@ -12,7 +12,7 @@ use ndarray::{Array1, Array2}; use super::{initialization, output::CycleLog, NonParametric}; -/// Maximim a posteriori (MAP) estimation +/// Maximum a posteriori (MAP) estimation /// /// Calculate the MAP estimate of the parameters of the model given the data. pub struct MAP { From 2cabcf860bcac40a99ed2c3b47ec09fc187a7484 Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Tue, 24 Dec 2024 13:44:04 +0100 Subject: [PATCH 22/30] Support typestate builder for settings --- examples/debug.rs | 2 +- src/algorithms/map.rs | 4 +- src/algorithms/mod.rs | 33 ++++-- src/algorithms/npag.rs | 4 +- src/algorithms/npod.rs | 4 +- src/algorithms/routines/settings.rs | 177 +++++++++++++++++++++++++++- 6 files changed, 205 insertions(+), 19 deletions(-) diff --git a/examples/debug.rs b/examples/debug.rs index 303856d56..925e1acab 100644 --- a/examples/debug.rs +++ b/examples/debug.rs @@ -1,4 +1,4 @@ -use algorithms::{npag::NPAG, NonParametric}; +use algorithms::{npag::NPAG, NonParametricAlgorithm}; use ipm::burke; use logger::setup_log; use pmcore::prelude::*; diff --git a/src/algorithms/map.rs b/src/algorithms/map.rs index c98010f58..8a336cff1 100644 --- a/src/algorithms/map.rs +++ b/src/algorithms/map.rs @@ -10,7 +10,7 @@ use pharmsol::{ use ndarray::{Array1, Array2}; -use super::{initialization, output::CycleLog, NonParametric}; +use super::{initialization, output::CycleLog, NonParametricAlgorithm}; /// Maximum a posteriori (MAP) estimation /// @@ -32,7 +32,7 @@ pub struct MAP { cyclelog: CycleLog, } -impl NonParametric for MAP { +impl NonParametricAlgorithm for MAP { fn new(settings: Settings, equation: E, data: Data) -> Result, anyhow::Error> { Ok(Box::new(Self { equation, diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index 0a4adb789..12b41b336 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -3,7 +3,7 @@ use std::path::Path; use crate::prelude::{self, settings::Settings}; -use anyhow::Result; +use anyhow::{bail, Result}; use anyhow::{Context, Error}; use map::MAP; use ndarray::Array2; @@ -21,19 +21,29 @@ pub mod npod; pub mod routines; /// Supported algorithms by `PMcore` -/// -/// - `NPAG`: Non-Parametric Adaptive Grid -/// - `NPOD`: Non-Parametric Optimal Design -/// - `MAP`: Maximum A Posteriori -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)] pub enum Algorithm { + NonParametric(NonParametric), + Parametric(Parametric), +} + +/// Supported non-parametric algorithms +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)] +pub enum NonParametric { NPAG, NPOD, MAP, } +/// Supported parametric algorithms +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)] +pub enum Parametric { + FOCE, + NPSA, +} + /// This traint defines the methods for non-parametric (NP) algorithms -pub trait NonParametric { +pub trait NonParametricAlgorithm { fn new(config: Settings, equation: E, data: Data) -> Result, Error> where Self: Sized; @@ -116,10 +126,11 @@ pub fn dispatch_algorithm( settings: Settings, equation: E, data: Data, -) -> Result>, Error> { +) -> Result>> { match settings.config().algorithm { - Algorithm::NPAG => Ok(NPAG::new(settings, equation, data)?), - Algorithm::NPOD => Ok(NPOD::new(settings, equation, data)?), - Algorithm::MAP => Ok(MAP::new(settings, equation, data)?), + Algorithm::NonParametric(NonParametric::NPAG) => Ok(NPAG::new(settings, equation, data)?), + Algorithm::NonParametric(NonParametric::NPOD) => Ok(NPOD::new(settings, equation, data)?), + Algorithm::NonParametric(NonParametric::MAP) => Ok(MAP::new(settings, equation, data)?), + _ => bail!("Unsupported algorithm"), } } diff --git a/src/algorithms/npag.rs b/src/algorithms/npag.rs index 96c687e3e..95390fc1e 100644 --- a/src/algorithms/npag.rs +++ b/src/algorithms/npag.rs @@ -17,7 +17,7 @@ use pharmsol::{ use ndarray::{Array, Array1, Array2, ArrayBase, Axis, Dim, OwnedRepr}; use ndarray_stats::{DeviationExt, QuantileExt}; -use super::{adaptive_grid::adaptive_grid, initialization, NonParametric}; +use super::{adaptive_grid::adaptive_grid, initialization, NonParametricAlgorithm}; const THETA_E: f64 = 1e-4; // Convergence criteria const THETA_G: f64 = 1e-4; // Objective function convergence criteria @@ -45,7 +45,7 @@ pub struct NPAG { settings: Settings, } -impl NonParametric for NPAG { +impl NonParametricAlgorithm for NPAG { fn new(settings: Settings, equation: E, data: Data) -> Result, anyhow::Error> { Ok(Box::new(Self { equation, diff --git a/src/algorithms/npod.rs b/src/algorithms/npod.rs index 916bdac37..3f80a4df1 100644 --- a/src/algorithms/npod.rs +++ b/src/algorithms/npod.rs @@ -22,7 +22,7 @@ use ndarray_stats::{DeviationExt, QuantileExt}; use super::{ condensation::prune::prune, initialization, optimization::d_optimizer::SppOptimizer, - NonParametric, + NonParametricAlgorithm, }; const THETA_F: f64 = 1e-2; @@ -45,7 +45,7 @@ pub struct NPOD { settings: Settings, } -impl NonParametric for NPOD { +impl NonParametricAlgorithm for NPOD { fn new(settings: Settings, equation: E, data: Data) -> Result, anyhow::Error> { Ok(Box::new(Self { equation, diff --git a/src/algorithms/routines/settings.rs b/src/algorithms/routines/settings.rs index 1734bfdcd..5b2d6b2ce 100644 --- a/src/algorithms/routines/settings.rs +++ b/src/algorithms/routines/settings.rs @@ -239,7 +239,7 @@ impl Default for Config { fn default() -> Self { Config { cycles: 100, - algorithm: Algorithm::NPAG, + algorithm: Algorithm::NonParametric(crate::algorithms::NonParametric::NPAG), cache: true, include: None, exclude: None, @@ -659,3 +659,178 @@ pub fn read(path: impl Into) -> Result { Ok(settings) // Return the settings wrapped in Ok } + +pub struct SettingsBuilder { + config: Option, + parameters: Option, + error: Option, + predictions: Option, + log: Option, + prior: Option, + output: Option, + convergence: Option, + advanced: Option, + _marker: std::marker::PhantomData, +} + +// Marker traits for builder states +pub trait AlgorithmDefined {} +pub trait ParametersDefined {} +pub trait ErrorModelDefined {} + +// Implement marker traits for PhantomData states +pub struct InitialState; +pub struct AlgorithmSet; +pub struct ParametersSet; +pub struct ErrorSet; + +// Initial state: no algorithm set yet +impl SettingsBuilder { + pub fn new() -> Self { + SettingsBuilder { + config: None, + parameters: None, + error: None, + predictions: None, + log: None, + prior: None, + output: None, + convergence: None, + advanced: None, + _marker: std::marker::PhantomData, + } + } + + pub fn set_algorithm(self, algorithm: Algorithm) -> SettingsBuilder { + SettingsBuilder { + config: Some(Config { + algorithm, + ..Config::default() + }), + parameters: self.parameters, + error: self.error, + predictions: self.predictions, + log: self.log, + prior: self.prior, + output: self.output, + convergence: self.convergence, + advanced: self.advanced, + _marker: std::marker::PhantomData, + } + } +} + +// Algorithm is set, move to defining parameters +impl SettingsBuilder { + pub fn set_parameters(self, parameters: Parameters) -> SettingsBuilder { + SettingsBuilder { + config: self.config, + parameters: Some(parameters), + error: self.error, + predictions: self.predictions, + log: self.log, + prior: self.prior, + output: self.output, + convergence: self.convergence, + advanced: self.advanced, + _marker: std::marker::PhantomData, + } + } +} + +// Parameters are set, move to defining error model +impl SettingsBuilder { + pub fn set_error_model(self, error: Error) -> SettingsBuilder { + SettingsBuilder { + config: self.config, + parameters: self.parameters, + error: Some(error), + predictions: self.predictions, + log: self.log, + prior: self.prior, + output: self.output, + convergence: self.convergence, + advanced: self.advanced, + _marker: std::marker::PhantomData, + } + } +} + +// Error model is set, allow optional settings and final build +impl SettingsBuilder { + pub fn set_predictions(mut self, predictions: Predictions) -> Self { + self.predictions = Some(predictions); + self + } + + pub fn set_log(mut self, log: Log) -> Self { + self.log = Some(log); + self + } + + pub fn set_prior(mut self, prior: Prior) -> Self { + self.prior = Some(prior); + self + } + + pub fn set_output(mut self, output: Output) -> Self { + self.output = Some(output); + self + } + + pub fn set_convergence(mut self, convergence: Convergence) -> Self { + self.convergence = Some(convergence); + self + } + + pub fn set_advanced(mut self, advanced: Advanced) -> Self { + self.advanced = Some(advanced); + self + } + + pub fn build(self) -> Settings { + Settings { + config: self.config.unwrap(), + parameters: self.parameters.unwrap(), + error: self.error.unwrap(), + predictions: self.predictions.unwrap_or_default(), + log: self.log.unwrap_or_default(), + prior: self.prior.unwrap_or_default(), + output: self.output.unwrap_or_default(), + convergence: self.convergence.unwrap_or_default(), + advanced: self.advanced.unwrap_or_default(), + } + } +} + +#[cfg(test)] + +mod tests { + use super::*; + use crate::algorithms::{Algorithm, NonParametric}; + use pharmsol::prelude::data::ErrorType; + + #[test] + fn test_builder() { + let parameters = Parameters::new() + .add("Ke", 0.0, 5.0, false) + .unwrap() + .add("V", 10.0, 200.0, true) + .unwrap(); + + let settings = SettingsBuilder::new() + .set_algorithm(Algorithm::NonParametric(NonParametric::NPAG)) // Step 1: Define algorithm + .set_parameters(parameters) // Step 2: Define parameters + .set_error_model(Error { + value: 0.1, + class: ErrorType::Additive, + poly: (0.0, 0.1, 0.0, 0.0), + }) // Step 3: Define error model + .build(); // Final step + + assert_eq!( + settings.config.algorithm, + Algorithm::NonParametric(NonParametric::NPAG,) + ); + } +} From 9b0d809712e5a65ab04e980bb137d3be69f577a2 Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Tue, 24 Dec 2024 13:48:19 +0100 Subject: [PATCH 23/30] Remove include/exclude from settings These are already available directly on the pharmsol::data::Data structure --- src/algorithms/routines/settings.rs | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/algorithms/routines/settings.rs b/src/algorithms/routines/settings.rs index 5b2d6b2ce..500dac152 100644 --- a/src/algorithms/routines/settings.rs +++ b/src/algorithms/routines/settings.rs @@ -146,14 +146,6 @@ impl Settings { self.config.cache = cache; } - pub fn set_include(&mut self, include: Option>) { - self.config.include = include; - } - - pub fn set_exclude(&mut self, exclude: Option>) { - self.config.exclude = exclude; - } - pub fn set_gamlam(&mut self, value: f64) { self.error.value = value; } @@ -229,10 +221,6 @@ pub struct Config { pub algorithm: Algorithm, /// If true (default), cache predicted values pub cache: bool, - /// Vector of IDs to include - pub include: Option>, - /// Vector of IDs to exclude - pub exclude: Option>, } impl Default for Config { @@ -241,8 +229,6 @@ impl Default for Config { cycles: 100, algorithm: Algorithm::NonParametric(crate::algorithms::NonParametric::NPAG), cache: true, - include: None, - exclude: None, } } } From 96e84f732b76f45c61fa4ceecc0cba2d08bf341c Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Tue, 24 Dec 2024 14:50:13 +0100 Subject: [PATCH 24/30] Experimenting with algorithm choice and dispatch --- src/algorithms/map.rs | 1 + src/algorithms/mod.rs | 33 ++++++++++++----------------- src/algorithms/npag.rs | 2 +- src/algorithms/npod.rs | 1 + src/algorithms/routines/settings.rs | 11 ++++------ 5 files changed, 20 insertions(+), 28 deletions(-) diff --git a/src/algorithms/map.rs b/src/algorithms/map.rs index 8a336cff1..77463b096 100644 --- a/src/algorithms/map.rs +++ b/src/algorithms/map.rs @@ -15,6 +15,7 @@ use super::{initialization, output::CycleLog, NonParametricAlgorithm}; /// Maximum a posteriori (MAP) estimation /// /// Calculate the MAP estimate of the parameters of the model given the data. +#[derive(Debug, Clone)] pub struct MAP { equation: E, psi: Array2, diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index 12b41b336..f1fbb0cf1 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -3,7 +3,7 @@ use std::path::Path; use crate::prelude::{self, settings::Settings}; -use anyhow::{bail, Result}; +use anyhow::Result; use anyhow::{Context, Error}; use map::MAP; use ndarray::Array2; @@ -23,26 +23,14 @@ pub mod routines; /// Supported algorithms by `PMcore` #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)] pub enum Algorithm { - NonParametric(NonParametric), - Parametric(Parametric), -} - -/// Supported non-parametric algorithms -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)] -pub enum NonParametric { + // Non-parametric algorithms NPAG, NPOD, MAP, + // Parametric algorithms } -/// Supported parametric algorithms -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)] -pub enum Parametric { - FOCE, - NPSA, -} - -/// This traint defines the methods for non-parametric (NP) algorithms +/// This trait defines the methods for non-parametric (NP) algorithms pub trait NonParametricAlgorithm { fn new(config: Settings, equation: E, data: Data) -> Result, Error> where @@ -122,15 +110,20 @@ pub trait NonParametricAlgorithm { fn into_npresult(&self) -> NPResult; } +pub trait ParametricAlgorithm { + fn fit(&mut self) -> Result<()> { + unimplemented!() + } +} + pub fn dispatch_algorithm( settings: Settings, equation: E, data: Data, ) -> Result>> { match settings.config().algorithm { - Algorithm::NonParametric(NonParametric::NPAG) => Ok(NPAG::new(settings, equation, data)?), - Algorithm::NonParametric(NonParametric::NPOD) => Ok(NPOD::new(settings, equation, data)?), - Algorithm::NonParametric(NonParametric::MAP) => Ok(MAP::new(settings, equation, data)?), - _ => bail!("Unsupported algorithm"), + Algorithm::NPAG => Ok(NPAG::new(settings, equation, data)?), + Algorithm::NPOD => Ok(NPOD::new(settings, equation, data)?), + Algorithm::MAP => Ok(MAP::new(settings, equation, data)?), } } diff --git a/src/algorithms/npag.rs b/src/algorithms/npag.rs index 95390fc1e..8b0f71cfb 100644 --- a/src/algorithms/npag.rs +++ b/src/algorithms/npag.rs @@ -24,7 +24,7 @@ const THETA_G: f64 = 1e-4; // Objective function convergence criteria const THETA_F: f64 = 1e-2; const THETA_D: f64 = 1e-4; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct NPAG { equation: E, psi: Array2, diff --git a/src/algorithms/npod.rs b/src/algorithms/npod.rs index 3f80a4df1..f3817c4df 100644 --- a/src/algorithms/npod.rs +++ b/src/algorithms/npod.rs @@ -28,6 +28,7 @@ use super::{ const THETA_F: f64 = 1e-2; const THETA_D: f64 = 1e-4; +#[derive(Debug, Clone)] pub struct NPOD { equation: E, psi: Array2, diff --git a/src/algorithms/routines/settings.rs b/src/algorithms/routines/settings.rs index 500dac152..d6ec5fdf4 100644 --- a/src/algorithms/routines/settings.rs +++ b/src/algorithms/routines/settings.rs @@ -227,7 +227,7 @@ impl Default for Config { fn default() -> Self { Config { cycles: 100, - algorithm: Algorithm::NonParametric(crate::algorithms::NonParametric::NPAG), + algorithm: Algorithm::NPAG, cache: true, } } @@ -793,7 +793,7 @@ impl SettingsBuilder { mod tests { use super::*; - use crate::algorithms::{Algorithm, NonParametric}; + use crate::algorithms::Algorithm; use pharmsol::prelude::data::ErrorType; #[test] @@ -805,7 +805,7 @@ mod tests { .unwrap(); let settings = SettingsBuilder::new() - .set_algorithm(Algorithm::NonParametric(NonParametric::NPAG)) // Step 1: Define algorithm + .set_algorithm(Algorithm::NPAG) // Step 1: Define algorithm .set_parameters(parameters) // Step 2: Define parameters .set_error_model(Error { value: 0.1, @@ -814,9 +814,6 @@ mod tests { }) // Step 3: Define error model .build(); // Final step - assert_eq!( - settings.config.algorithm, - Algorithm::NonParametric(NonParametric::NPAG,) - ); + assert_eq!(settings.config.algorithm, Algorithm::NPAG); } } From dbfe41cdce9d9bd9fc1d28c296f1a92993e318dd Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Wed, 25 Dec 2024 16:12:45 +0100 Subject: [PATCH 25/30] Clean up arguments to MAP --- src/algorithms/map.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/algorithms/map.rs b/src/algorithms/map.rs index 77463b096..e34e76c4b 100644 --- a/src/algorithms/map.rs +++ b/src/algorithms/map.rs @@ -2,7 +2,7 @@ use crate::prelude::{ipm::burke, output::NPResult, settings::Settings}; use anyhow::Result; use pharmsol::{ prelude::{ - data::{Data, ErrorModel, ErrorType}, + data::{Data, ErrorModel}, simulator::{psi, Equation}, }, Theta, @@ -25,10 +25,7 @@ pub struct MAP { cycle: usize, converged: bool, gamma: f64, - error_type: ErrorType, data: Data, - c: (f64, f64, f64, f64), - #[allow(dead_code)] settings: Settings, cyclelog: CycleLog, } @@ -44,8 +41,6 @@ impl NonParametricAlgorithm for MAP { cycle: 0, converged: false, gamma: settings.error().value, - error_type: settings.error().error_type(), - c: settings.error().poly, settings, data, @@ -114,7 +109,11 @@ impl NonParametricAlgorithm for MAP { &self.equation, &self.data, &theta, - &ErrorModel::new(self.c, self.gamma, &self.error_type), + &ErrorModel::new( + self.settings.error().poly, + self.gamma, + &self.settings.error().error_type(), + ), false, false, ); From 09b9b342bf1dc6eb754d423e6fed7ecc3c6e58ca Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Wed, 25 Dec 2024 16:13:03 +0100 Subject: [PATCH 26/30] Minor changes to logger helpers --- src/logger.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/logger.rs b/src/logger.rs index 9ad11ae1d..4605658ed 100644 --- a/src/logger.rs +++ b/src/logger.rs @@ -26,9 +26,7 @@ pub fn setup_log(settings: &Settings) -> Result<()> { let log_level: String = settings.log().level.to_string(); let env_filter = EnvFilter::new(log_level); - let timestamper = CompactTimestamp { - start: Instant::now(), - }; + let timestamper = CompactTimestamp::new(); // Define a registry with that level as an environment filter let subscriber = Registry::default().with(env_filter); @@ -60,6 +58,14 @@ struct CompactTimestamp { start: Instant, } +impl CompactTimestamp { + fn new() -> Self { + Self { + start: Instant::now(), + } + } +} + impl FormatTime for CompactTimestamp { fn format_time( &self, From b683b723434555f50f3f0cb56764d8c1ed03c537 Mon Sep 17 00:00:00 2001 From: Markus Date: Sat, 28 Dec 2024 12:43:00 +0100 Subject: [PATCH 27/30] Minor docs update --- src/algorithms/mod.rs | 5 +++-- src/algorithms/npag.rs | 1 + src/algorithms/npod.rs | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index f1fbb0cf1..454945a3e 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -23,11 +23,12 @@ pub mod routines; /// Supported algorithms by `PMcore` #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)] pub enum Algorithm { - // Non-parametric algorithms + /// Non-parametric adaptive grid (NPAG), see [NPAG] NPAG, + /// Non-parametric adaptive grid (NPAG), see [NPOD] NPOD, + /// Maximum a posteriori estimation, see [MAP] MAP, - // Parametric algorithms } /// This trait defines the methods for non-parametric (NP) algorithms diff --git a/src/algorithms/npag.rs b/src/algorithms/npag.rs index 8b0f71cfb..f5df6673f 100644 --- a/src/algorithms/npag.rs +++ b/src/algorithms/npag.rs @@ -24,6 +24,7 @@ const THETA_G: f64 = 1e-4; // Objective function convergence criteria const THETA_F: f64 = 1e-2; const THETA_D: f64 = 1e-4; +/// Non-parametric adaptive grid (NPAG) algorithm #[derive(Debug, Clone)] pub struct NPAG { equation: E, diff --git a/src/algorithms/npod.rs b/src/algorithms/npod.rs index f3817c4df..a039e8beb 100644 --- a/src/algorithms/npod.rs +++ b/src/algorithms/npod.rs @@ -28,6 +28,7 @@ use super::{ const THETA_F: f64 = 1e-2; const THETA_D: f64 = 1e-4; +/// Non-parametric optimal design (NPOD) algorithm #[derive(Debug, Clone)] pub struct NPOD { equation: E, From 712cb5b10cdf6d8db0b737acb6afda6044a0b713 Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Wed, 8 Jan 2025 13:47:06 +0100 Subject: [PATCH 28/30] Add more methods --- src/algorithms/routines/settings.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/algorithms/routines/settings.rs b/src/algorithms/routines/settings.rs index d6ec5fdf4..679ca1d81 100644 --- a/src/algorithms/routines/settings.rs +++ b/src/algorithms/routines/settings.rs @@ -744,6 +744,16 @@ impl SettingsBuilder { // Error model is set, allow optional settings and final build impl SettingsBuilder { + pub fn set_cycles(mut self, cycles: usize) -> Self { + self.config.as_mut().unwrap().cycles = cycles; + self + } + + pub fn set_cache(mut self, cache: bool) -> Self { + self.config.as_mut().unwrap().cache = cache; + self + } + pub fn set_predictions(mut self, predictions: Predictions) -> Self { self.predictions = Some(predictions); self From cc10fc8389a6b542988ae5dc8c501f0c5e76a497 Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Wed, 8 Jan 2025 13:48:42 +0100 Subject: [PATCH 29/30] Clippy --- src/algorithms/routines/settings.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/algorithms/routines/settings.rs b/src/algorithms/routines/settings.rs index 679ca1d81..b8a335e76 100644 --- a/src/algorithms/routines/settings.rs +++ b/src/algorithms/routines/settings.rs @@ -706,6 +706,12 @@ impl SettingsBuilder { } } +impl Default for SettingsBuilder { + fn default() -> Self { + SettingsBuilder::new() + } +} + // Algorithm is set, move to defining parameters impl SettingsBuilder { pub fn set_parameters(self, parameters: Parameters) -> SettingsBuilder { From e294ae8e82aa9b76c81a7ef32b65eb70a14c977a Mon Sep 17 00:00:00 2001 From: Markus <66058642+mhovd@users.noreply.github.com> Date: Wed, 8 Jan 2025 13:51:10 +0100 Subject: [PATCH 30/30] Expand test --- src/algorithms/routines/settings.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/algorithms/routines/settings.rs b/src/algorithms/routines/settings.rs index b8a335e76..6ea2363e6 100644 --- a/src/algorithms/routines/settings.rs +++ b/src/algorithms/routines/settings.rs @@ -828,8 +828,13 @@ mod tests { class: ErrorType::Additive, poly: (0.0, 0.1, 0.0, 0.0), }) // Step 3: Define error model + .set_cycles(100) // Optional: Set cycles + .set_cache(true) // Optional: Set cache .build(); // Final step assert_eq!(settings.config.algorithm, Algorithm::NPAG); + assert_eq!(settings.config.cycles, 100); + assert_eq!(settings.config.cache, true); + assert_eq!(settings.parameters().names(), vec!["Ke", "V"]); } }