Skip to content

Commit

Permalink
feat(db): Implement db-layer user count cache (#1567)
Browse files Browse the repository at this point in the history
* Draft cache design

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes (attempt 2/3)

* Remove AnyCache, simplify Cache

* Move cache

* feat(db): Implement db-layer user count cache

* Fix errors from resolving conflicts

* Invalidate cache instead of updating

* Remove Cache::update

* Remove Cache::set

* Add test for user cache, fix bug

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
boxbeam and autofix-ci[bot] authored Feb 28, 2024
1 parent ebd2937 commit ebea511
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 49 deletions.
10 changes: 1 addition & 9 deletions ee/tabby-db/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::future::Future;

use tokio::sync::RwLock;

#[derive(Default)]
pub struct Cache<T> {
value: RwLock<Option<T>>,
}
Expand Down Expand Up @@ -33,13 +34,4 @@ impl<T> Cache<T> {
Ok(generated)
}
}

pub async fn update(&self, f: impl FnOnce(&mut T)) {
let mut lock = self.value.write().await;
lock.as_mut().map(f);
}

pub async fn set(&self, value: T) {
*self.value.write().await = Some(value);
}
}
15 changes: 13 additions & 2 deletions ee/tabby-db/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::ops::Deref;
use std::{ops::Deref, sync::Arc};

use anyhow::anyhow;
use cache::Cache;
use chrono::{DateTime, NaiveDateTime, Utc};
pub use email_setting::EmailSettingDAO;
pub use github_oauth_credential::GithubOAuthCredentialDAO;
Expand Down Expand Up @@ -34,9 +35,16 @@ pub trait DbEnum: Sized {
fn from_enum_str(s: &str) -> anyhow::Result<Self>;
}

#[derive(Default)]
pub struct DbCache {
pub active_user_count: Cache<usize>,
pub active_admin_count: Cache<usize>,
}

#[derive(Clone)]
pub struct DbConn {
pool: Pool<Sqlite>,
cache: Arc<DbCache>,
}

impl DbConn {
Expand Down Expand Up @@ -75,7 +83,10 @@ impl DbConn {
.execute(&pool)
.await?;

let conn = Self { pool };
let conn = Self {
pool,
cache: Default::default(),
};
conn.manual_users_active_migration().await?;
Ok(conn)
}
Expand Down
75 changes: 62 additions & 13 deletions ee/tabby-db/src/users.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ impl DbConn {
let res = res.unique_error("User already exists")?;
transaction.commit().await?;

self.cache.active_user_count.invalidate().await;
if is_admin {
self.cache.active_admin_count.invalidate().await;
}

Ok(res.last_insert_rowid() as i32)
}

Expand Down Expand Up @@ -179,10 +184,11 @@ impl DbConn {
.await?
.rows_affected();
if changed != 1 {
Err(anyhow!("user active status was not changed"))
} else {
Ok(())
return Err(anyhow!("user active status was not changed"));
}
self.cache.active_admin_count.invalidate().await;
self.cache.active_user_count.invalidate().await;
Ok(())
}

pub async fn update_user_role(&self, id: i32, is_admin: bool) -> Result<()> {
Expand All @@ -199,6 +205,7 @@ impl DbConn {
if changed != 1 {
Err(anyhow!("user admin status was not changed"))
} else {
self.cache.active_admin_count.invalidate().await;
Ok(())
}
}
Expand All @@ -214,19 +221,28 @@ impl DbConn {
Ok(())
}

// FIXME(boxbeam): Revisit if a caching layer should be put into DbConn for this query in future.
pub async fn count_active_users(&self) -> Result<usize> {
let users = query_scalar!("SELECT COUNT(1) FROM users WHERE active;")
.fetch_one(&self.pool)
.await?;
Ok(users as usize)
self.cache
.active_user_count
.get_or_refresh(|| async {
let users = query_scalar!("SELECT COUNT(1) FROM users WHERE active;")
.fetch_one(&self.pool)
.await?;
Ok(users as usize)
})
.await
}

pub async fn count_active_admin_users(&self) -> Result<usize> {
let users = query_scalar!("SELECT COUNT(1) FROM users WHERE active and is_admin;")
.fetch_one(&self.pool)
.await?;
Ok(users as usize)
self.cache
.active_admin_count
.get_or_refresh(|| async {
let users = query_scalar!("SELECT COUNT(1) FROM users WHERE active and is_admin;")
.fetch_one(&self.pool)
.await?;
Ok(users as usize)
})
.await
}
}

Expand Down Expand Up @@ -528,5 +544,38 @@ mod tests {
)
);
}

#[tokio::test]
async fn test_caching() {
let db = DbConn::new_in_memory().await.unwrap();

db.create_user("[email protected]".into(), "".into(), true)
.await
.unwrap();

assert_eq!(db.count_active_users().await.unwrap(), 1);
assert_eq!(db.count_active_admin_users().await.unwrap(), 1);

let user2_id = db
.create_user("[email protected]".into(), "".into(), false)
.await
.unwrap();
assert_eq!(db.count_active_users().await.unwrap(), 2);
assert_eq!(db.count_active_admin_users().await.unwrap(), 1);

db.update_user_active(user2_id, false).await.unwrap();
assert_eq!(db.count_active_users().await.unwrap(), 1);
assert_eq!(db.count_active_admin_users().await.unwrap(), 1);

let user3_id = db
.create_user("[email protected]".into(), "".into(), true)
.await
.unwrap();
assert_eq!(db.count_active_users().await.unwrap(), 2);
assert_eq!(db.count_active_admin_users().await.unwrap(), 2);

db.update_user_active(user3_id, false).await.unwrap();
assert_eq!(db.count_active_users().await.unwrap(), 1);
assert_eq!(db.count_active_admin_users().await.unwrap(), 1);
}
}
// FIXME(boxbeam): Revisit if a caching layer should be put into DbConn for this query in future.
30 changes: 5 additions & 25 deletions ee/tabby-webserver/src/service/license.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use anyhow::{anyhow, Context};
use async_trait::async_trait;
use chrono::{DateTime, Duration, NaiveDateTime, Utc};
use chrono::{DateTime, NaiveDateTime, Utc};
use jsonwebtoken as jwt;
use lazy_static::lazy_static;
use serde::Deserialize;
use tabby_db::DbConn;
use tokio::sync::RwLock;

use crate::schema::{
license::{LicenseInfo, LicenseService, LicenseStatus, LicenseType},
Expand Down Expand Up @@ -62,26 +61,11 @@ fn jwt_timestamp_to_utc(secs: i64) -> Result<DateTime<Utc>> {

struct LicenseServiceImpl {
db: DbConn,
seats: RwLock<(DateTime<Utc>, usize)>,
}

impl LicenseServiceImpl {
async fn read_used_seats(&self, force_refresh: bool) -> Result<usize> {
let now = Utc::now();
let (refreshed, mut seats) = {
let lock = self.seats.read().await;
*lock
};
if force_refresh || now - refreshed > Duration::seconds(15) {
let mut lock = self.seats.write().await;
seats = self.db.count_active_users().await?;
*lock = (now, seats);
}
Ok(seats)
}

async fn make_community_license(&self) -> Result<LicenseInfo> {
let seats_used = self.read_used_seats(false).await?;
let seats_used = self.db.count_active_users().await?;
let status = if seats_used > LicenseInfo::seat_limits_for_community_license() {
LicenseStatus::SeatsExceeded
} else {
Expand All @@ -101,11 +85,7 @@ impl LicenseServiceImpl {
}

pub async fn new_license_service(db: DbConn) -> Result<impl LicenseService> {
let seats = db.count_active_users().await?;
Ok(LicenseServiceImpl {
db,
seats: (Utc::now(), seats).into(),
})
Ok(LicenseServiceImpl { db })
}

fn license_info_from_raw(raw: LicenseJWTPayload, seats_used: usize) -> Result<LicenseInfo> {
Expand Down Expand Up @@ -140,15 +120,15 @@ impl LicenseService for LicenseServiceImpl {
};
let license =
validate_license(&license).map_err(|e| anyhow!("License is corrupt: {e:?}"))?;
let seats = self.read_used_seats(false).await?;
let seats = self.db.count_active_users().await?;
let license = license_info_from_raw(license, seats)?;

Ok(license)
}

async fn update_license(&self, license: String) -> Result<()> {
let raw = validate_license(&license).map_err(|_e| anyhow!("License is not valid"))?;
let seats = self.read_used_seats(true).await?;
let seats = self.db.count_active_users().await?;
match license_info_from_raw(raw, seats)?.status {
LicenseStatus::Ok => self.db.update_enterprise_license(Some(license)).await?,
LicenseStatus::Expired => return Err(anyhow!("License is expired").into()),
Expand Down

0 comments on commit ebea511

Please sign in to comment.