Skip to content

Commit

Permalink
mvp sync deadpool with copy_many
Browse files Browse the repository at this point in the history
  • Loading branch information
jr1221 committed Jan 7, 2025
1 parent 86cf64c commit 53aea54
Show file tree
Hide file tree
Showing 15 changed files with 235 additions and 402 deletions.
280 changes: 55 additions & 225 deletions scylla-server/Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion scylla-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ serde_json = "1.0.128"
diesel_migrations = { version = "2.2.0", features = ["postgres"] }
rangemap = "1.5.1"
axum-macros = "0.5.0"
diesel-async = { version = "0.5.2", features = ["postgres", "bb8", "async-connection-wrapper", "sync-connection-wrapper", "tokio"] }
rustc-hash = "2.1.0"
deadpool-diesel = { version = "0.6.1", features = ["rt_tokio_1", "postgres", "tracing"] }
[target.'cfg(not(target_env = "msvc"))'.dependencies]
tikv-jemallocator = "0.6"

Expand Down
4 changes: 2 additions & 2 deletions scylla-server/src/controllers/data_controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ pub async fn get_data(
State(pool): State<PoolHandle>,
Path((data_type_name, run_id)): Path<(String, i32)>,
) -> Result<Json<Vec<PublicData>>, ScyllaError> {
let mut db = pool.get().await?;
let data = data_service::get_data(&mut db, data_type_name, run_id).await?;
let db = pool.get().await?;
let data = data_service::get_data(db, data_type_name, run_id).await?;

// map data to frontend data types according to the From func of the client struct
let mut transformed_data: Vec<PublicData> = data.into_iter().map(PublicData::from).collect();
Expand Down
4 changes: 2 additions & 2 deletions scylla-server/src/controllers/data_type_controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use crate::{
pub async fn get_all_data_types(
State(pool): State<PoolHandle>,
) -> Result<Json<Vec<PublicDataType>>, ScyllaError> {
let mut db = pool.get().await?;
let data_types = data_type_service::get_all_data_types(&mut db).await?;
let db = pool.get().await?;
let data_types = data_type_service::get_all_data_types(db).await?;

let transformed_data_types: Vec<PublicDataType> =
data_types.into_iter().map(PublicDataType::from).collect();
Expand Down
4 changes: 2 additions & 2 deletions scylla-server/src/controllers/file_insertion_controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ pub async fn insert_file(
mut multipart: Multipart,
) -> Result<String, ScyllaError> {
// create a run ID cache
let mut db = pool.get().await?;
let db = pool.get().await?;
debug!("Warming up run ID map!");
let mut run_iter = run_service::get_all_runs(&mut db)
let mut run_iter = run_service::get_all_runs(db)
.await?
.into_iter()
.map(|f| (f.id, f.time.timestamp_micros() as u64))
Expand Down
12 changes: 6 additions & 6 deletions scylla-server/src/controllers/run_controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ use crate::{
pub async fn get_all_runs(
State(pool): State<PoolHandle>,
) -> Result<Json<Vec<PublicRun>>, ScyllaError> {
let mut db = pool.get().await?;
let run_data = run_service::get_all_runs(&mut db).await?;
let db = pool.get().await?;
let run_data = run_service::get_all_runs(db).await?;

let transformed_run_data: Vec<PublicRun> = run_data.into_iter().map(PublicRun::from).collect();

Expand All @@ -26,8 +26,8 @@ pub async fn get_run_by_id(
State(pool): State<PoolHandle>,
Path(run_id): Path<i32>,
) -> Result<Json<PublicRun>, ScyllaError> {
let mut db = pool.get().await?;
let run_data = run_service::get_run_by_id(&mut db, run_id).await?;
let db = pool.get().await?;
let run_data = run_service::get_run_by_id(db, run_id).await?;

if run_data.is_none() {
return Err(ScyllaError::EmptyResult);
Expand All @@ -43,8 +43,8 @@ pub async fn get_run_by_id(
/// create a new run with an auto-incremented ID
/// note the new run must be updated so the channel passed in notifies the data processor to use the new run
pub async fn new_run(State(pool): State<PoolHandle>) -> Result<Json<PublicRun>, ScyllaError> {
let mut db = pool.get().await?;
let run_data = run_service::create_run(&mut db, chrono::offset::Utc::now()).await?;
let db = pool.get().await?;
let run_data = run_service::create_run(db, chrono::offset::Utc::now()).await?;

crate::RUN_ID.store(run_data.id, Ordering::Relaxed);
tracing::info!(
Expand Down
64 changes: 13 additions & 51 deletions scylla-server/src/db_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,35 +24,6 @@ pub struct DbHandler {
upload_interval: u64,
}

/// Chunks a vec into roughly equal vectors all under size `max_chunk_size`
/// This precomputes vec capacity but does however call to_vec(), reallocating the slices
fn chunk_vec<T: Clone>(input: Vec<T>, max_chunk_size: usize) -> Vec<Vec<T>> {
if max_chunk_size == 0 {
panic!("Maximum chunk size must be greater than zero");
}

let len = input.len();
if len == 0 {
return Vec::new();
}

// Calculate the number of chunks
let num_chunks = len.div_ceil(max_chunk_size);

// Recompute a balanced chunk size
let chunk_size = usize::max(1, len.div_ceil(num_chunks));

let mut result = Vec::with_capacity(num_chunks);
let mut start = 0;

while start < len {
let end = usize::min(start + chunk_size, len);
result.push(input[start..end].to_vec());
start = end;
}
result
}

impl DbHandler {
/// Make a new db handler
/// * `recv` - the broadcast reciver of which clientdata will be sent
Expand Down Expand Up @@ -81,10 +52,6 @@ impl DbHandler {
loop {
tokio::select! {
_ = cancel_token.cancelled() => {
let Ok(mut database) = pool.get().await else {
warn!("Could not get connection for cleanup");
break;
};
// cleanup all remaining messages if batches start backing up
while let Some(final_msgs) = batch_queue.recv().await {
info!("{} batches remaining!", batch_queue.len()+1);
Expand All @@ -93,15 +60,14 @@ impl DbHandler {
debug!("A batch of zero messages was sent!");
continue;
}
let chunk_size = final_msgs.len() / ((final_msgs.len() / 8190) + 1);
let chunks = chunk_vec(final_msgs, chunk_size);
debug!("Batch uploading {} chunks in sequence", chunks.len());
for chunk in chunks {
let Ok(database) = pool.get().await else {
warn!("Could not get connection for cleanup");
break;
};
info!(
"A cleanup chunk uploaded: {:?}",
data_service::add_many(&mut database, chunk).await
data_service::copy_many(database, final_msgs).await.map_err(|_| "Error!")
);
}
}
info!("No more messages to cleanup.");
break;
Expand All @@ -115,12 +81,8 @@ impl DbHandler {
continue;
}
let msg_len = msgs.len();
let chunk_size = msg_len / ((msg_len / 8190) + 1);
let chunks = chunk_vec(msgs, chunk_size);
info!("Batch uploading {} chunks in parrallel, {} messages.", chunks.len(), msg_len);
for chunk in chunks {
tokio::spawn(DbHandler::batch_upload(chunk, pool.clone()));
}
info!("Batch uploading {} messages.", msg_len);
tokio::spawn(DbHandler::batch_upload(msgs, pool.clone()));
debug!(
"DB send: {} of {}",
batch_queue.len(),
Expand Down Expand Up @@ -151,11 +113,11 @@ impl DbHandler {

#[instrument(level = Level::DEBUG, skip(msg, pool))]
async fn batch_upload(msg: Vec<ClientData>, pool: PoolHandle) {
let Ok(mut database) = pool.get().await else {
let Ok(database) = pool.get().await else {
warn!("Could not get connection for batch upload!");
return;
};
match data_service::add_many(&mut database, msg).await {
match data_service::copy_many(database, msg).await {
Ok(count) => info!("Batch uploaded: {:?}", count),
Err(err) => warn!("Error in batch upload: {:?}", err),
}
Expand Down Expand Up @@ -209,13 +171,13 @@ impl DbHandler {
);

if !self.datatype_list.contains(&msg.name) {
let Ok(mut database) = self.pool.get().await else {
let Ok(database) = self.pool.get().await else {
warn!("Could not get connection for dataType upsert");
return;
};
info!("Upserting data type: {}", msg.name);
if let Err(msg) = data_type_service::upsert_data_type(
&mut database,
database,
msg.name.clone(),
msg.unit.clone(),
msg.node.clone(),
Expand All @@ -230,15 +192,15 @@ impl DbHandler {
// Check for GPS points, insert them into current run if available
if msg.name == "TPU/GPS/Location" {
debug!("Upserting run with location points!");
let Ok(mut database) = self.pool.get().await else {
let Ok(database) = self.pool.get().await else {
warn!("Could not get connection for db points update");
return;
};
// ensure lat AND long present in message, just a sanity check
if msg.values.len() < 2 {
warn!("GPS message found without both lat and long!");
} else if let Err(err) = run_service::update_run_with_coords(
&mut database,
database,
RUN_ID.load(std::sync::atomic::Ordering::Relaxed),
msg.values[0].into(),
msg.values[1].into(),
Expand Down
16 changes: 9 additions & 7 deletions scylla-server/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ use axum::{
};
use tracing::warn;

use crate::services;

pub enum ScyllaError {
/// Deseil error
DbError(diesel::result::Error),
DbError(services::DbError),
/// Diesel db connection error,
ConnError(diesel_async::pooled_connection::bb8::RunError),
ConnError(deadpool_diesel::PoolError),
/// An instruction was not encodable
InvalidEncoding(String),
/// Could not communicate to car
Expand All @@ -17,14 +19,14 @@ pub enum ScyllaError {
EmptyResult,
}

impl From<diesel::result::Error> for ScyllaError {
fn from(error: diesel::result::Error) -> Self {
impl From<services::DbError> for ScyllaError {
fn from(error: services::DbError) -> Self {
ScyllaError::DbError(error)
}
}

impl From<diesel_async::pooled_connection::bb8::RunError> for ScyllaError {
fn from(error: diesel_async::pooled_connection::bb8::RunError) -> Self {
impl From<deadpool_diesel::PoolError> for ScyllaError {
fn from(error: deadpool_diesel::PoolError) -> Self {
ScyllaError::ConnError(error)
}
}
Expand All @@ -39,7 +41,7 @@ impl IntoResponse for ScyllaError {
),
ScyllaError::DbError(error) => (
StatusCode::BAD_REQUEST,
format!("Misc query error: {}", error),
format!("Misc query error: {:?}", error),
),
ScyllaError::InvalidEncoding(reason) => (StatusCode::UNPROCESSABLE_ENTITY, reason),
ScyllaError::CommFailure(reason) => (StatusCode::BAD_GATEWAY, reason),
Expand Down
5 changes: 2 additions & 3 deletions scylla-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@ pub mod serverdata;
pub mod transformers;

/// The type descriptor of the database passed to the middlelayer through axum state
pub type Database<'a> =
diesel_async::pooled_connection::bb8::PooledConnection<'a, diesel_async::AsyncPgConnection>;
pub type Database = deadpool_diesel::postgres::Connection;

pub type PoolHandle = diesel_async::pooled_connection::bb8::Pool<diesel_async::AsyncPgConnection>;
pub type PoolHandle = deadpool_diesel::Pool<deadpool_diesel::Manager<diesel::PgConnection>>;

#[derive(clap::ValueEnum, Debug, PartialEq, Copy, Clone, Default)]
#[clap(rename_all = "kebab_case")]
Expand Down
52 changes: 21 additions & 31 deletions scylla-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@ use axum::{
Extension, Router,
};
use clap::Parser;
use diesel_async::async_connection_wrapper::AsyncConnectionWrapper;
use diesel_async::{
pooled_connection::{bb8::Pool, AsyncDieselConnectionManager},
AsyncConnection, AsyncPgConnection,
};
use deadpool_diesel::postgres::{Manager, Pool};
use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness};
use dotenvy::dotenv;
use rumqttc::v5::AsyncClient;
Expand All @@ -30,7 +26,7 @@ use scylla_server::{
use scylla_server::{
db_handler,
mqtt_processor::{MqttProcessor, MqttProcessorOptions},
ClientData, RUN_ID,
ClientData, PoolHandle, RUN_ID,
};
use socketioxide::{extract::SocketRef, SocketIo};
use tokio::{signal, sync::mpsc};
Expand Down Expand Up @@ -143,30 +139,25 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}

dotenv().ok();
let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be specified");
info!("Beginning DB migration...");
let manager: deadpool_diesel::Manager<diesel::PgConnection> = Manager::new(
std::env::var("DATABASE_URL").unwrap(),
deadpool_diesel::Runtime::Tokio1,
);
let pool: PoolHandle = Pool::builder(manager)
.build()
.expect("Could not build pool");

info!("Beginning DB migration w/ temporary connection...");
// it is best to create a temporary unmanaged connection to run the migrations
// a completely new set of connections is created by the pool manager because it cannot understand an already established connection
let conn: AsyncPgConnection = AsyncPgConnection::establish(&db_url).await?;
let mut async_wrapper: AsyncConnectionWrapper<AsyncPgConnection> =
AsyncConnectionWrapper::from(conn);
tokio::task::spawn_blocking(move || {
async_wrapper.run_pending_migrations(MIGRATIONS).unwrap();
})
.await?;
let conn = pool.get().await.unwrap();
let res = conn
.interact(|conn| conn.run_pending_migrations(MIGRATIONS).err())
.await
.expect("Could not migrate DB!");
if res.is_some() {
panic!("Could not migrate DB!")
}
info!("Successfully migrated DB!");

info!("Initializing database connections...");
let manager = AsyncDieselConnectionManager::<AsyncPgConnection>::new(db_url);
let pool: Pool<AsyncPgConnection> = Pool::builder()
.max_size(10)
.min_idle(Some(2))
.max_lifetime(Some(Duration::from_secs(60 * 60 * 24)))
.idle_timeout(Some(Duration::from_secs(60 * 2)))
.build(manager)
.await?;

// create the socket stuff
let (socket_layer, io) = SocketIo::builder()
.max_buffer_size(4096) // TODO tune values
Expand Down Expand Up @@ -210,10 +201,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}

// creates the initial run
let curr_run =
run_service::create_run(&mut pool.get().await.unwrap(), chrono::offset::Utc::now())
.await
.expect("Could not create initial run!");
let curr_run = run_service::create_run(pool.get().await.unwrap(), chrono::offset::Utc::now())
.await
.unwrap();
debug!("Configuring current run: {:?}", curr_run);

RUN_ID.store(curr_run.id, Ordering::Relaxed);
Expand Down
1 change: 1 addition & 0 deletions scylla-server/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub struct Data {
#[derive(Insertable)]
#[diesel(table_name = crate::schema::data)]
#[diesel(belongs_to(DataType, foreign_key = dataTypeName))]
#[diesel(treat_none_as_default_value = false)]
#[diesel(check_for_backend(diesel::pg::Pg))]
#[diesel(primary_key(dataTypeName, time))]
pub struct DataInsert {
Expand Down
Loading

0 comments on commit 53aea54

Please sign in to comment.