Skip to content

Commit

Permalink
Improved spectator examples
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan-k committed Jan 19, 2024
1 parent 26705dd commit 8ec9317
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 132 deletions.
2 changes: 1 addition & 1 deletion examples/simulatedannealing/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct Rosenbrock {
lower_bound: Vec<f64>,
/// upper bound
upper_bound: Vec<f64>,
/// Random number generator. We use a `Arc<Mutex<_>>` here because `ArgminOperator` requires
/// Random number generator. We use a `Arc<Mutex<_>>` here because `Anneal` requires
/// `self` to be passed as an immutable reference. This gives us thread safe interior
/// mutability.
rng: Arc<Mutex<Xoshiro256PlusPlus>>,
Expand Down
43 changes: 4 additions & 39 deletions examples/spectator_basic/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct Rosenbrock {
lower_bound: Vec<f64>,
/// upper bound
upper_bound: Vec<f64>,
/// Random number generator. We use a `Arc<Mutex<_>>` here because `ArgminOperator` requires
/// Random number generator. We use a `Arc<Mutex<_>>` here because `Anneal` requires
/// `self` to be passed as an immutable reference. This gives us thread safe interior
/// mutability.
rng: Arc<Mutex<Xoshiro256PlusPlus>>,
Expand All @@ -48,7 +48,6 @@ impl CostFunction for Rosenbrock {
type Output = f64;

fn cost(&self, param: &Self::Param) -> Result<Self::Output, Error> {
// std::thread::sleep(std::time::Duration::from_millis(5));
Ok(rosenbrock(param, self.a, self.b))
}
}
Expand All @@ -63,19 +62,10 @@ impl Anneal for Rosenbrock {
let mut param_n = param.clone();
let mut rng = self.rng.lock().unwrap();
let distr = Uniform::from(0..param.len());
// Perform modifications to a degree proportional to the current temperature `temp`.
for _ in 0..(temp.floor() as u64 + 1) {
// Compute random index of the parameter vector using the supplied random number
// generator.
let idx = rng.sample(distr);

// Compute random number in [0.1, 0.1].
let val = rng.sample(Uniform::new_inclusive(-0.0001, 0.0001));

// modify previous parameter value at random position `idx` by `val`
param_n[idx] += val;

// check if bounds are violated. If yes, project onto bound.
param_n[idx] = param_n[idx].clamp(self.lower_bound[idx], self.upper_bound[idx]);
}
Ok(param_n)
Expand All @@ -95,34 +85,13 @@ fn run() -> Result<(), Error> {
let init_param: Vec<f64> = vec![-0.9; num];

// Define initial temperature
// let temp = 0.0001;
let temp = 1000.0;

// Set up simulated annealing solver
// An alternative random number generator (RNG) can be provided to `new_with_rng`:
// SimulatedAnnealing::new_with_rng(temp, Xoshiro256PlusPlus::from_entropy())?
let solver = SimulatedAnnealing::new(temp)?
// Optional: Define temperature function (defaults to `SATempFunc::TemperatureFast`)
.with_temp_func(SATempFunc::Boltzmann);
/////////////////////////
// Stopping criteria //
/////////////////////////
// Optional: stop if there was no new best solution after 1000 iterations
// .with_stall_best(1000)
// Optional: stop if there was no accepted solution after 1000 iterations
// .with_stall_accepted(1000)
/////////////////////////
// Reannealing //
/////////////////////////
// Optional: Reanneal after 1000 iterations (resets temperature to initial temperature)
// .with_reannealing_fixed(1000)
// Optional: Reanneal after no accepted solution has been found for `iter` iterations
// .with_reannealing_accepted(1000)
// Optional: Start reannealing after no new best solution has been found for 800 iterations
// .with_reannealing_best(1000);
let solver = SimulatedAnnealing::new(temp)?.with_temp_func(SATempFunc::Boltzmann);

let spectator = SpectatorBuilder::new()
// .with_name("something")
// .with_name("name_your_run")
.select(&["cost", "best_cost", "t"])
.build();

Expand All @@ -133,17 +102,13 @@ fn run() -> Result<(), Error> {
.configure(|state| {
state
.param(init_param)
// Optional: Set maximum number of iterations (defaults to `std::u64::MAX`)
.max_iters(1_000_000)
// Optional: Set target cost function value (defaults to `std::f64::NEG_INFINITY`)
.target_cost(0.0)
})
// Add spectator observer
.add_observer(spectator, ObserverMode::Always)
.run()?;

// Wait a second (lets the logger flush everything before printing again)
std::thread::sleep(std::time::Duration::from_secs(1));

// Print result
println!("{res}");
Ok(())
Expand Down
165 changes: 77 additions & 88 deletions examples/spectator_multiple/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct Rosenbrock {
lower_bound: Vec<f64>,
/// upper bound
upper_bound: Vec<f64>,
/// Random number generator. We use a `Arc<Mutex<_>>` here because `ArgminOperator` requires
/// Random number generator. We use a `Arc<Mutex<_>>` here because `Anneal` requires
/// `self` to be passed as an immutable reference. This gives us thread safe interior
/// mutability.
rng: Arc<Mutex<Xoshiro256PlusPlus>>,
Expand All @@ -49,6 +49,8 @@ impl CostFunction for Rosenbrock {
type Output = f64;

fn cost(&self, param: &Self::Param) -> Result<Self::Output, Error> {
// Artificially slow down computation of cost function
std::thread::sleep(std::time::Duration::from_millis(1));
Ok(rosenbrock(param, self.a, self.b))
}
}
Expand All @@ -63,104 +65,91 @@ impl Anneal for Rosenbrock {
let mut param_n = param.clone();
let mut rng = self.rng.lock().unwrap();
let distr = Uniform::from(0..param.len());
// Perform modifications to a degree proportional to the current temperature `temp`.
for _ in 0..(temp.floor() as u64 + 1) {
// Compute random index of the parameter vector using the supplied random number
// generator.
let idx = rng.sample(distr);

// Compute random number in [0.1, 0.1].
let val = rng.sample(Uniform::new_inclusive(-0.1, 0.1));

// modify previous parameter value at random position `idx` by `val`
param_n[idx] += val;

// check if bounds are violated. If yes, project onto bound.
param_n[idx] = param_n[idx].clamp(self.lower_bound[idx], self.upper_bound[idx]);
}
Ok(param_n)
}
}

fn run() -> Result<(), Error> {
// Define bounds
// let lower_bound: Vec<f64> = vec![-50.0, -50.0];
// let upper_bound: Vec<f64> = vec![50.0, 50.0];
let lower_bound: Vec<f64> = vec![-50.0; 20];
let upper_bound: Vec<f64> = vec![50.0; 20];

// Define cost function
let operator = Rosenbrock::new(1.0, 100.0, lower_bound, upper_bound);

// Define initial parameter vector
let init_param: Vec<f64> = vec![3.0; 20];

// Define initial temperature
let temp = 1500.0;

// Set up simulated annealing solver
// An alternative random number generator (RNG) can be provided to `new_with_rng`:
// SimulatedAnnealing::new_with_rng(temp, Xoshiro256PlusPlus::from_entropy())?
let solver = SimulatedAnnealing::new(temp)?
// Optional: Define temperature function (defaults to `SATempFunc::TemperatureFast`)
.with_temp_func(SATempFunc::Boltzmann)
/////////////////////////
// Stopping criteria //
/////////////////////////
// Optional: stop if there was no new best solution after 1000 iterations
.with_stall_best(1000)
// Optional: stop if there was no accepted solution after 1000 iterations
.with_stall_accepted(1000)
/////////////////////////
// Reannealing //
/////////////////////////
// Optional: Reanneal after 1000 iterations (resets temperature to initial temperature)
.with_reannealing_fixed(1000)
// Optional: Reanneal after no accepted solution has been found for `iter` iterations
.with_reannealing_accepted(500)
// Optional: Start reannealing after no new best solution has been found for 800 iterations
.with_reannealing_best(800);

/////////////////////////
// Run solver //
/////////////////////////
let observer = SpectatorBuilder::new().build();
let res = Executor::new(operator.clone(), solver.clone())
.configure(|state| {
state
.param(init_param.clone())
// Optional: Set maximum number of iterations (defaults to `std::u64::MAX`)
.max_iters(10_000)
// Optional: Set target cost function value (defaults to `std::f64::NEG_INFINITY`)
.target_cost(0.0)
})
.add_observer(observer, ObserverMode::Always)
.run()?;

// Wait a second (lets the logger flush everything before printing again)
std::thread::sleep(std::time::Duration::from_secs(1));

// Print result
println!("{res}");

let observer = SpectatorBuilder::new().build();
let res = Executor::new(operator, solver)
.configure(|state| {
state
.param(init_param)
// Optional: Set maximum number of iterations (defaults to `std::u64::MAX`)
.max_iters(10_000)
// Optional: Set target cost function value (defaults to `std::f64::NEG_INFINITY`)
.target_cost(0.0)
})
.add_observer(observer, ObserverMode::Always)
.run()?;

// Wait a second (lets the logger flush everything before printing again)
std::thread::sleep(std::time::Duration::from_secs(1));

// Print result
println!("{res}");
std::thread::scope(move |s| {
s.spawn(move || {
/////////////////////////
// Run solver 1 //
/////////////////////////

let lower_bound: Vec<f64> = vec![-50.0; 5];
let upper_bound: Vec<f64> = vec![50.0; 5];

let cost = Rosenbrock::new(1.0, 100.0, lower_bound, upper_bound);

// Define initial parameter vector
let init_param: Vec<f64> = vec![3.0; 5];

// Define initial temperature
let temp = 20.0;

// Set up simulated annealing solver
let solver = SimulatedAnnealing::new(temp)
.unwrap()
.with_temp_func(SATempFunc::Boltzmann);

let observer = SpectatorBuilder::new().build();
let res = Executor::new(cost, solver)
.configure(|state| {
state
.param(init_param.clone())
.max_iters(100_000)
.target_cost(0.0)
})
.add_observer(observer, ObserverMode::Always)
.run()
.unwrap();

// Print result
println!("{res}");
});

s.spawn(|| {
/////////////////////////
// Run solver 2 //
/////////////////////////
let lower_bound: Vec<f64> = vec![-50.0; 5];
let upper_bound: Vec<f64> = vec![50.0; 5];

let cost = Rosenbrock::new(1.0, 100.0, lower_bound, upper_bound);

// Define initial parameter vector
let init_param: Vec<f64> = vec![3.0; 5];

// Define initial temperature
let temp = 2.0;

// Set up simulated annealing solver
let solver = SimulatedAnnealing::new(temp)
.unwrap()
.with_temp_func(SATempFunc::Boltzmann);

let observer = SpectatorBuilder::new().build();
let res = Executor::new(cost, solver)
.configure(|state| {
state
.param(init_param.clone())
.max_iters(100_000)
.target_cost(0.0)
})
.add_observer(observer, ObserverMode::Always)
.run()
.unwrap();

// Print result
println!("{res}");
});
});
Ok(())
}

Expand Down
18 changes: 16 additions & 2 deletions observers/spectator/src/observer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use std::collections::HashSet;
use std::{collections::HashSet, thread::JoinHandle};

use anyhow::Error;
use argmin::core::{observers::Observe, ArgminFloat, State, KV};
Expand Down Expand Up @@ -234,13 +234,14 @@ impl SpectatorBuilder {
/// ```
pub fn build(self) -> Spectator {
let (tx, rx) = tokio::sync::mpsc::channel(self.capacity);
std::thread::spawn(move || sender(rx, self.host, self.port));
let thread_handle = std::thread::spawn(move || sender(rx, self.host, self.port));

Spectator {
tx,
name: self.name,
sending: true,
selected: self.selected,
thread_handle: Some(thread_handle),
}
}
}
Expand All @@ -253,6 +254,7 @@ pub struct Spectator {
name: String,
sending: bool,
selected: HashSet<String>,
thread_handle: Option<JoinHandle<Result<(), Error>>>,
}

impl Spectator {
Expand Down Expand Up @@ -376,6 +378,7 @@ where
Ok(())
}

/// Forwards termination reason to spectator
fn observe_final(&mut self, state: &I) -> Result<(), Error> {
let message = Message::Termination {
name: self.name.clone(),
Expand All @@ -385,3 +388,14 @@ where
Ok(())
}
}

impl Drop for Spectator {
fn drop(&mut self) {
self.thread_handle
.take()
.map(JoinHandle::join)
.unwrap()
.unwrap()
.unwrap();
}
}
6 changes: 4 additions & 2 deletions observers/spectator/src/sender.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ pub(crate) async fn sender(
if let Ok(stream) = TcpStream::connect(format!("{host}:{port}")).await {
let mut stream = Framed::new(stream, codec);
while let Some(msg) = rx.recv().await {
let msg = msg.pack()?;
stream.send(msg).await?;
stream.send(msg.pack()?).await?;
if let Message::Termination { .. } = msg {
return Ok(());
}
}
} else {
eprintln!("Can't connect to spectator on {host}:{port}");
Expand Down

0 comments on commit 8ec9317

Please sign in to comment.