Skip to content

Commit

Permalink
initial rate limit
Browse files Browse the repository at this point in the history
  • Loading branch information
jr1221 committed Sep 20, 2024
1 parent 20a408f commit 82c46e0
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 27 deletions.
7 changes: 7 additions & 0 deletions compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ services:
- SOURCE_DATABASE_URL=postgresql://postgres:password@odyssey-timescale:5432/postgres
# - PROD_SIREN_HOST_URL=siren:1883
- SCYLLA_PROD=true
#- SCYLLA_SATURATE_BATCH=false
#-SCYLLA_DATA_UPLOAD_DISABLE=false
#-SCYLLA_SIREN_HOST_URL=localhost:1883
#-SCYLLA_BATCH_UPSERT_TIME=10
#-SCYLLA_RATE_LIMIT_MODE=none
#-SCYLLA_STATIC_RATE_LIMIT_VALUE=100
#-SCYLLA_SOCKET_DISCARD_PERCENT=0
- RUST_LOG=warn,scylla_server=debug
cpu_shares: 1024
stop_grace_period: 2m
Expand Down
10 changes: 10 additions & 0 deletions scylla-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,13 @@ pub mod serverdata;

/// The type descriptor of the database passed to the middlelayer through axum state
pub type Database = std::sync::Arc<prisma::PrismaClient>;

#[derive(clap::ValueEnum, Debug, PartialEq, Copy, Clone, Default)]
#[clap(rename_all = "kebab_case")]
pub enum RateLimitMode {
/// static rate limiting based on a set value
Static,
/// no rate limiting
#[default]
None,
}
23 changes: 22 additions & 1 deletion scylla-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use scylla_server::{
db_handler, mock_processor::MockProcessor, mqtt_processor::MqttProcessor, ClientData,
},
services::run_service::{self, public_run},
Database,
Database, RateLimitMode,
};
use socketioxide::{extract::SocketRef, SocketIo};
use tokio::{signal, sync::mpsc};
Expand Down Expand Up @@ -70,6 +70,25 @@ struct ScyllaArgs {
)]
batch_upsert_time: u64,

/// The rate limit mode to use
#[arg(
short = 'm',
long,
env = "SCYLLA_RATE_LIMIT_MODE",
default_value_t = RateLimitMode::None,
value_enum,
)]
rate_limit_mode: RateLimitMode,

/// The static rate limit number to use in ms
#[arg(
short = 'v',
long,
env = "SCYLLA_STATIC_RATE_LIMIT_VALUE",
default_value = "100"
)]
static_rate_limit_value: u64,

/// The percent of messages discarded when sent from the socket
#[arg(
short = 'd',
Expand Down Expand Up @@ -189,6 +208,8 @@ async fn main() {
curr_run.id,
io,
token.clone(),
cli.static_rate_limit_value,
cli.rate_limit_mode,
((cli.socketio_discard_percent as f32 / 100.0) * 255.0) as u8,
);
let (client, eventloop) = AsyncClient::new(opts, 600);
Expand Down
80 changes: 54 additions & 26 deletions scylla-server/src/processors/mqtt_processor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use core::fmt;
use std::{sync::Arc, time::Duration};
use std::{collections::HashMap, sync::Arc, time::Duration};

use prisma_client_rust::{bigdecimal::ToPrimitive, chrono, serde_json};
use protobuf::Message;
Expand All @@ -9,13 +8,16 @@ use rumqttc::v5::{
AsyncClient, Event, EventLoop, MqttOptions,
};
use socketioxide::SocketIo;
use tokio::sync::mpsc::{Receiver, Sender};
use tokio::{
sync::mpsc::{Receiver, Sender},
time::Instant,
};
use tokio_util::sync::CancellationToken;
use tracing::{debug, instrument, trace, warn, Level};

use crate::{
controllers::car_command_controller::CALYPSO_BIDIR_CMD_PREFIX, serverdata,
services::run_service,
services::run_service, RateLimitMode,
};

use super::ClientData;
Expand All @@ -27,8 +29,14 @@ pub struct MqttProcessor {
curr_run: i32,
io: SocketIo,
cancel_token: CancellationToken,
/// Upload ratio, below is not uploaded above is uploaded
/// Upload ratio, below is not socket sent above is socket sent
upload_ratio: u8,
/// static rate limiter
rate_limiter: HashMap<String, Instant>,
/// time to rate limit in ms
rate_limit_time: u64,
/// rate limit mode
rate_limit_mode: RateLimitMode,
}

impl MqttProcessor {
Expand All @@ -46,11 +54,13 @@ impl MqttProcessor {
initial_run: i32,
io: SocketIo,
cancel_token: CancellationToken,
static_rate_limit_time: u64,
rate_limit_mode: RateLimitMode,
upload_ratio: u8,
) -> (MqttProcessor, MqttOptions) {
// create the mqtt client and configure it
let mut mqtt_opts = MqttOptions::new(
"ScyllaServer",
format!("ScyllaServer-{:?}", Instant::now()),
mqtt_path.split_once(':').expect("Invalid Siren URL").0,
mqtt_path
.split_once(':')
Expand All @@ -66,6 +76,8 @@ impl MqttProcessor {
.set_session_expiry_interval(Some(u32::MAX))
.set_topic_alias_max(Some(600));

let rate_map: HashMap<String, Instant> = HashMap::new();

// TODO mess with incoming message cap if db, etc. cannot keep up

(
Expand All @@ -76,6 +88,9 @@ impl MqttProcessor {
io,
cancel_token,
upload_ratio,
rate_limiter: rate_map,
rate_limit_time: static_rate_limit_time,
rate_limit_mode,
},
mqtt_opts,
)
Expand Down Expand Up @@ -110,11 +125,8 @@ impl MqttProcessor {
trace!("Received mqtt message: {:?}", msg);
// parse the message into the data and the node name it falls under
let msg = match self.parse_msg(msg) {
Ok(msg) => msg,
Err(err) => {
warn!("Message parse error: {:?}", err);
continue;
}
Some(msg) => msg,
None => continue
};
latency_ringbuffer.push(chrono::offset::Utc::now().timestamp_millis() - msg.timestamp);
self.send_db_msg(msg.clone()).await;
Expand Down Expand Up @@ -170,26 +182,41 @@ impl MqttProcessor {
/// * `msg` - The mqtt message to parse
/// returns the ClientData, or the Err of something that can be debug printed
#[instrument(skip(self), level = Level::TRACE)]
fn parse_msg(&self, msg: Publish) -> Result<ClientData, impl fmt::Debug> {
let topic = std::str::from_utf8(&msg.topic)
.map_err(|f| format!("Could not parse topic: {}, topic: {:?}", f, msg.topic))?;
fn parse_msg(&mut self, msg: Publish) -> Option<ClientData> {
let Ok(topic) = std::str::from_utf8(&msg.topic) else {
warn!("Could not parse topic, topic: {:?}", msg.topic);
return None;
};

// ignore command messages, less confusing in logs than just failing to decode protobuf
if topic.starts_with(CALYPSO_BIDIR_CMD_PREFIX) {
return Err(format!("Skipping command message: {}", topic));
debug!("Skipping command message: {}", topic);
return None;
}

let split = topic
.split_once('/')
.ok_or(&format!("Could not parse nesting: {:?}", msg.topic))?;
if self.rate_limit_mode == RateLimitMode::Static {
if let Some(old) = self.rate_limiter.get(topic) {
if old.elapsed() < Duration::from_millis(self.rate_limit_time) {
trace!("Static rate limit skipping message with topic {}", topic);
return None;
} else {
self.rate_limiter.insert(topic.to_string(), Instant::now());
}
} else {
self.rate_limiter.insert(topic.to_string(), Instant::now());
}
}

let Some(split) = topic.split_once('/') else {
warn!("Could not parse nesting: {:?}", msg.topic);
return None;
};

// look at data after topic as if we dont have a topic the protobuf is useless anyways
let data = serverdata::ServerData::parse_from_bytes(&msg.payload).map_err(|f| {
format!(
"Could not parse message payload:{:?} error: {}",
msg.topic, f
)
})?;
let Ok(data) = serverdata::ServerData::parse_from_bytes(&msg.payload) else {
warn!("Could not parse message payload:{:?}", msg.topic);
return None;
};

// get the node and datatype from the topic extracted at the beginning
let node = split.0;
Expand Down Expand Up @@ -225,14 +252,15 @@ impl MqttProcessor {
debug!("Timestamp before year 2000: {}", unix_time);
let sys_time = chrono::offset::Utc::now().timestamp_millis();
if sys_time < 963014966000 {
return Err("System has no good time, discarding message!".to_string());
warn!("System has no good time, discarding message!");
return None;
}
sys_time
} else {
unix_time
};

Ok(ClientData {
Some(ClientData {
run_id: self.curr_run,
name: data_type,
unit: data.unit,
Expand Down

0 comments on commit 82c46e0

Please sign in to comment.