Skip to content

Commit

Permalink
editoast: fix batch simulation
Browse files Browse the repository at this point in the history
We reduce the number of queries executed by the batch simulation function.
Previously, for `N` trains, we executed `Nx4` postgres queries.
Now we run ~4 batch queries.
  • Loading branch information
flomonster committed Jul 4, 2024
1 parent 2c70a65 commit dc11fc8
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 42 deletions.
13 changes: 11 additions & 2 deletions editoast/src/modelsv2/timetable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ use crate::error::Result;
use crate::modelsv2::Retrieve;
use editoast_models::DbConnection;

#[derive(Debug, Default, Clone, ModelV2)]
#[derive(Debug, Default, Clone, ModelV2, PartialEq)]
#[model(table = crate::tables::timetable_v2)]
#[cfg_attr(test, derive(serde::Deserialize, PartialEq))]
#[cfg_attr(test, derive(serde::Deserialize))]
pub struct Timetable {
pub id: i64,
pub electrical_profile_set_id: Option<i64>,
Expand Down Expand Up @@ -49,3 +49,12 @@ impl Retrieve<i64> for TimetableWithTrains {
}
}
}

impl From<TimetableWithTrains> for Timetable {
fn from(timetable_with_trains: TimetableWithTrains) -> Self {
Self {
id: timetable_with_trains.id,
electrical_profile_set_id: timetable_with_trains.electrical_profile_set_id,
}
}
}
26 changes: 22 additions & 4 deletions editoast/src/views/v2/path/pathfinding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ use crate::modelsv2::OperationalPointModel;
use crate::modelsv2::Retrieve;
use crate::modelsv2::RetrieveBatch;
use crate::modelsv2::RetrieveBatchUnchecked;
use crate::modelsv2::RollingStockModel;
use crate::modelsv2::TrackSectionModel;
use crate::redis_utils::RedisClient;
use crate::redis_utils::RedisConnection;
use crate::views::get_app_version;
use crate::views::v2::path::PathfindingError;
use crate::views::v2::train_schedule::TrainScheduleProxy;
use editoast_models::DbConnection;
use editoast_models::DbConnectionPoolV2;
use editoast_schemas::infra::OperationalPoint;
Expand Down Expand Up @@ -159,17 +159,29 @@ async fn pathfinding_blocks(
Ok(pathfinding_result)
}

/// Compute a path given a batch of trainschedule and an infrastructure
/// Compute a path given a batch of trainschedule and an infrastructure.
///
/// ## Important
///
/// If this function was called with the same train schedule, the result will be cached.
/// If you call this function multiple times with the same train schedule but with another infra, then you must provide a fresh `cache`.
pub async fn pathfinding_from_train(
conn: &mut DbConnection,
redis: &mut RedisConnection,
core: Arc<CoreClient>,
infra: &Infra,
train_schedule: TrainSchedule,
proxy: Arc<TrainScheduleProxy>,
) -> Result<PathfindingResult> {
if let Some(res) = proxy.get_pathfinding_result(train_schedule.id) {
return Ok(res);
}

// Retrieve rolling stock
let rolling_stock_name = train_schedule.rolling_stock_name.clone();
let Some(rolling_stock) = RollingStockModel::retrieve(conn, rolling_stock_name.clone()).await?
let Some(rolling_stock) = proxy
.get_rolling_stock(rolling_stock_name.clone(), conn)
.await?
else {
return Ok(PathfindingResult::RollingStockNotFound { rolling_stock_name });
};
Expand All @@ -187,7 +199,13 @@ pub async fn pathfinding_from_train(
.collect(),
};

pathfinding_blocks(conn, redis, core, infra, &path_input).await
match pathfinding_blocks(conn, redis, core, infra, &path_input).await {
Ok(res) => {
proxy.set_pathfinding_result(train_schedule.id, res.clone());
Ok(res)
}
err => err,
}
}

/// Generates a unique hash based on the pathfinding entries.
Expand Down
20 changes: 18 additions & 2 deletions editoast/src/views/v2/timetable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod stdcm;

use std::collections::HashMap;
use std::ops::DerefMut as _;
use std::sync::Arc;

use actix_web::delete;
use actix_web::get;
Expand Down Expand Up @@ -34,11 +35,13 @@ use crate::modelsv2::timetable::TimetableWithTrains;
use crate::modelsv2::train_schedule::TrainSchedule;
use crate::modelsv2::train_schedule::TrainScheduleChangeset;
use crate::modelsv2::Infra;
use crate::modelsv2::RollingStockModel;
use crate::views::pagination::PaginatedList;
use crate::views::pagination::PaginationQueryParam;
use crate::views::pagination::PaginationStats;
use crate::views::v2::train_schedule::train_simulation_batch;
use crate::views::v2::train_schedule::TrainScheduleForm;
use crate::views::v2::train_schedule::TrainScheduleProxy;
use crate::views::v2::train_schedule::TrainScheduleResult;
use crate::CoreClient;
use crate::RedisClient;
Expand Down Expand Up @@ -333,24 +336,37 @@ pub async fn conflicts(
let infra_id = query.into_inner().infra_id;

// 1. Retrieve Timetable / Infra / Trains / Simultion
let timetable = TimetableWithTrains::retrieve_or_fail(conn, timetable_id, || {
let timetable_trains = TimetableWithTrains::retrieve_or_fail(conn, timetable_id, || {
TimetableError::NotFound { timetable_id }
})
.await?;
let timetable: Timetable = timetable_trains.clone().into();

let infra = Infra::retrieve_or_fail(conn, infra_id, || TimetableError::InfraNotFound {
infra_id,
})
.await?;

let (trains, _): (Vec<_>, _) = TrainSchedule::retrieve_batch(conn, timetable.train_ids).await?;
let (trains, _): (Vec<_>, _) =
TrainSchedule::retrieve_batch(conn, timetable_trains.train_ids).await?;

let (rolling_stocks, _): (Vec<_>, _) = RollingStockModel::retrieve_batch(
db_pool.get().await?.deref_mut(),
trains
.iter()
.map::<String, _>(|t| t.rolling_stock_name.clone()),
)
.await?;

let proxy = Arc::new(TrainScheduleProxy::new(&rolling_stocks, &[timetable]));

let simulations = train_simulation_batch(
db_pool.clone(),
redis_client.clone(),
core_client.clone(),
&trains,
&infra,
proxy,
)
.await?;

Expand Down
31 changes: 27 additions & 4 deletions editoast/src/views/v2/timetable/stdcm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use serde::Deserialize;
use serde::Serialize;
use std::cmp::max;
use std::collections::HashMap;
use std::ops::DerefMut;
use std::sync::Arc;
use thiserror::Error;
use utoipa::IntoParams;
Expand All @@ -26,13 +27,15 @@ use crate::core::v2::stdcm::{STDCMRequest, STDCMStepTimingData};
use crate::core::AsCoreRequest;
use crate::core::CoreClient;
use crate::error::Result;
use crate::modelsv2::timetable::Timetable;
use crate::modelsv2::timetable::TimetableWithTrains;
use crate::modelsv2::train_schedule::TrainSchedule;
use crate::modelsv2::work_schedules::WorkSchedule;
use crate::modelsv2::RollingStockModel;
use crate::modelsv2::{Infra, List};
use crate::views::v2::path::pathfinding::extract_location_from_path_items;
use crate::views::v2::path::pathfinding::TrackOffsetExtractionError;
use crate::views::v2::train_schedule::TrainScheduleProxy;
use crate::views::v2::train_schedule::{train_simulation, train_simulation_batch};
use crate::RedisClient;
use crate::Retrieve;
Expand Down Expand Up @@ -163,22 +166,35 @@ async fn stdcm(
let redis_client_inner = redis_client.into_inner();

// 1. Retrieve Timetable / Infra / Trains / Simulation / Rolling Stock
let timetable = TimetableWithTrains::retrieve_or_fail(conn, timetable_id, || {
let timetable_trains = TimetableWithTrains::retrieve_or_fail(conn, timetable_id, || {
STDCMError::TimetableNotFound { timetable_id }
})
.await?;
let timetable: Timetable = timetable_trains.clone().into();

let infra =
Infra::retrieve_or_fail(conn, infra_id, || STDCMError::InfraNotFound { infra_id }).await?;

let (trains, _): (Vec<_>, _) = TrainSchedule::retrieve_batch(conn, timetable.train_ids).await?;
let (trains, _): (Vec<_>, _) =
TrainSchedule::retrieve_batch(conn, timetable_trains.train_ids).await?;

let (rolling_stocks, _): (Vec<_>, _) = RollingStockModel::retrieve_batch(
db_pool.get().await?.deref_mut(),
trains
.iter()
.map::<String, _>(|t| t.rolling_stock_name.clone()),
)
.await?;

let proxy = Arc::new(TrainScheduleProxy::new(&rolling_stocks, &[timetable]));

let simulations = train_simulation_batch(
db_pool.clone(),
redis_client_inner.clone(),
core_client.clone(),
&trains,
&infra,
proxy,
)
.await?;

Expand Down Expand Up @@ -331,8 +347,15 @@ async fn get_maximum_run_time(
};

let conn = &mut db_pool.clone().get().await?;
let sim_result =
train_simulation(conn, redis_client, core_client, &train_schedule, infra).await?;
let sim_result = train_simulation(
conn,
redis_client,
core_client,
&train_schedule,
infra,
Arc::default(),
)
.await?;

let total_stop_time: u64 = data
.steps
Expand Down
74 changes: 66 additions & 8 deletions editoast/src/views/v2/train_schedule.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod projection;
mod proxy;

use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
Expand Down Expand Up @@ -49,6 +50,8 @@ use editoast_models::DbConnection;
use editoast_models::DbConnectionPool;
use editoast_models::DbConnectionPoolV2;

pub use proxy::TrainScheduleProxy;

const CACHE_SIMULATION_EXPIRATION: u64 = 604800; // 1 week

crate::routes! {
Expand Down Expand Up @@ -306,7 +309,15 @@ pub async fn simulation(
.await?;

Ok(Json(
train_simulation(conn, redis_client, core_client, &train_schedule, &infra).await?,
train_simulation(
conn,
redis_client,
core_client,
&train_schedule,
&infra,
Arc::default(),
)
.await?,
))
}

Expand All @@ -317,6 +328,7 @@ pub async fn train_simulation(
core: Arc<CoreClient>,
train_schedule: &TrainSchedule,
infra: &Infra,
proxy: Arc<TrainScheduleProxy>,
) -> Result<SimulationResponse> {
let mut redis_conn = redis_client.get_connection().await?;
// Compute path
Expand All @@ -326,6 +338,7 @@ pub async fn train_simulation(
core.clone(),
infra,
train_schedule.clone(),
proxy.clone(),
)
.await?;

Expand All @@ -350,8 +363,15 @@ pub async fn train_simulation(
};

// Build simulation request
let simulation_request =
build_simulation_request(conn, infra, train_schedule, &path_items_positions, path).await?;
let simulation_request = build_simulation_request(
conn,
infra,
train_schedule,
&path_items_positions,
path,
proxy,
)
.await?;

// Compute unique hash of the simulation input
let hash = train_simulation_input_hash(infra.id, &infra.version, &simulation_request);
Expand Down Expand Up @@ -382,15 +402,18 @@ async fn build_simulation_request(
train_schedule: &TrainSchedule,
path_items_position: &[u64],
path: SimulationPath,
proxy: Arc<TrainScheduleProxy>,
) -> Result<SimulationRequest> {
// Get rolling stock
let rolling_stock_name = train_schedule.rolling_stock_name.clone();
let rolling_stock = RollingStockModel::retrieve(conn, rolling_stock_name.clone())
let rolling_stock = proxy
.get_rolling_stock(rolling_stock_name, conn)
.await?
.expect("Rolling stock should exist since the pathfinding succeeded");
// Get electrical_profile_set_id
let timetable_id = train_schedule.timetable_id;
let timetable = Timetable::retrieve(conn, timetable_id)
let timetable = proxy
.get_timetable(timetable_id, conn)
.await?
.expect("Timetable should exist since it's a foreign key");

Expand Down Expand Up @@ -537,9 +560,33 @@ pub async fn simulation_summary(
},
)
.await?;
let (rolling_stocks, _): (Vec<_>, _) = RollingStockModel::retrieve_batch(
db_pool.get().await?.deref_mut(),
trains
.iter()
.map::<String, _>(|t| t.rolling_stock_name.clone()),
)
.await?;
let (timetables, _): (Vec<_>, _) = Timetable::retrieve_batch(
db_pool.get().await?.deref_mut(),
trains
.iter()
.map(|t| t.timetable_id)
.collect::<HashSet<_>>(),
)
.await?;

let simulations =
train_simulation_batch(db_pool.clone(), redis_client, core, &trains, &infra).await?;
let proxy = Arc::new(TrainScheduleProxy::new(&rolling_stocks, &timetables));

let simulations = train_simulation_batch(
db_pool.clone(),
redis_client,
core,
&trains,
&infra,
proxy.clone(),
)
.await?;

// Transform simulations to simulation summary
let mut simulation_summaries = HashMap::new();
Expand Down Expand Up @@ -590,6 +637,7 @@ pub async fn train_simulation_batch(
core_client: Arc<CoreClient>,
train_schedules: &[TrainSchedule],
infra: &Infra,
proxy: Arc<TrainScheduleProxy>,
) -> Result<Vec<SimulationResponse>> {
let pending_simulations =
train_schedules
Expand All @@ -598,6 +646,7 @@ pub async fn train_simulation_batch(
.map(|(train_schedule, conn)| {
let redis_client = redis_client.clone();
let core_client = core_client.clone();
let cache = proxy.clone();
async move {
train_simulation(
conn.await
Expand All @@ -607,6 +656,7 @@ pub async fn train_simulation_batch(
core_client,
train_schedule,
infra,
cache,
)
.await
}
Expand Down Expand Up @@ -654,7 +704,15 @@ async fn get_path(
})
.await?;
Ok(Json(
pathfinding_from_train(conn, &mut redis_conn, core, &infra, train_schedule).await?,
pathfinding_from_train(
conn,
&mut redis_conn,
core,
&infra,
train_schedule,
Arc::default(),
)
.await?,
))
}

Expand Down
Loading

0 comments on commit dc11fc8

Please sign in to comment.