diff --git a/registry/sqlx-data.json b/registry/sqlx-data.json index e4d6b9b9..898aadb5 100644 --- a/registry/sqlx-data.json +++ b/registry/sqlx-data.json @@ -626,6 +626,26 @@ }, "query": "SELECT id FROM extensions WHERE name = $1" }, + "6991d9aae9d130b3d209ea5816847f661979e07c07f754c15e9b05eee3d61cb6": { + "describe": { + "columns": [ + { + "name": "extension_id", + "ordinal": 0, + "type_info": "Int4" + } + ], + "nullable": [ + true + ], + "parameters": { + "Left": [ + "Text" + ] + } + }, + "query": "SELECT extension_id FROM versions WHERE extension_name = $1" + }, "69ef7c7c79e69f31731a41417a0047562f0806a7c73eb4bb98c9ed554fff3b7c": { "describe": { "columns": [], diff --git a/registry/src/routes/download.rs b/registry/src/routes/download.rs index 14b08eb4..516c2a0b 100644 --- a/registry/src/routes/download.rs +++ b/registry/src/routes/download.rs @@ -1,8 +1,8 @@ //! Functionality for downloading extensions and maintaining download counts +use crate::config::Config; use crate::download::{check_version, latest_version}; use crate::errors::Result; use crate::uploader::extension_location; -use crate::{config::Config, extensions::get_extension_id}; use actix_web::{get, web, HttpResponse}; use sqlx::{Pool, Postgres}; use tracing::info; @@ -16,7 +16,9 @@ pub async fn download( path: web::Path<(String, String)>, ) -> Result { let (name, mut version) = path.into_inner(); - let extension_id = get_extension_id(&name, conn.as_ref()).await?; + let Ok(extension_id) = get_extension_id_fallback(&name, &conn).await else { + return Ok(HttpResponse::NotFound().body("No extension with the given name was found")); + }; // Use latest version if 'latest' provided as version if version == "latest" { @@ -50,3 +52,25 @@ async fn increase_download_count(pool: &Pool, extension_id: i32) -> Re Ok(()) } + +/// Given an extension name, try to find it in the `extensions` table (more common scenario). +/// +/// If it's not found, try to find it in `versions` under `extension_name`. +pub async fn get_extension_id_fallback(extension_name: &str, conn: &Pool) -> Result { + if let Ok(record) = sqlx::query!("SELECT id FROM extensions WHERE name = $1", extension_name) + .fetch_one(conn) + .await + { + return Ok(record.id); + } + + let record = sqlx::query!( + "SELECT extension_id FROM versions WHERE extension_name = $1", + extension_name + ) + .fetch_one(conn) + .await?; + + // Safe unwrap: if `extension_name` is in versions, `extension_id` must be as well + Ok(record.extension_id.unwrap() as i64) +}