diff --git a/src/file_reader.rs b/src/file_reader.rs index 914a2d9..df6410c 100644 --- a/src/file_reader.rs +++ b/src/file_reader.rs @@ -6,7 +6,7 @@ use object_store::{ObjectStore, PutPayload}; use opendal::{services::Http, services::S3, Operator}; use web_sys::{js_sys, Url}; -use crate::{ParquetTable, INMEMORY_STORE}; +use crate::{ParquetTable, INMEMORY_STORE, SESSION_CTX}; const S3_ENDPOINT_KEY: &str = "s3_endpoint"; const S3_ACCESS_KEY_ID_KEY: &str = "s3_access_key_id"; @@ -36,10 +36,18 @@ async fn update_file( parquet_table: ParquetTable, parquet_table_setter: WriteSignal>, ) { + let ctx = SESSION_CTX.as_ref(); let object_store = &*INMEMORY_STORE; let path = Path::parse(&parquet_table.table_name).unwrap(); let payload = PutPayload::from_bytes(parquet_table.bytes.clone()); object_store.put(&path, payload).await.unwrap(); + ctx.register_parquet( + &parquet_table.table_name, + &format!("mem:///{}", parquet_table.table_name), + Default::default(), + ) + .await + .unwrap(); parquet_table_setter.set(Some(parquet_table)); } diff --git a/src/main.rs b/src/main.rs index 3cb1dd7..e5200bb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,11 @@ mod schema; -use datafusion::physical_plan::ExecutionPlan; -use file_reader::{get_stored_value, FileReader}; +use datafusion::{ + datasource::MemTable, + execution::object_store::ObjectStoreUrl, + physical_plan::ExecutionPlan, + prelude::{SessionConfig, SessionContext}, +}; +use file_reader::FileReader; use leptos_router::{ components::Router, hooks::{query_signal, use_query_map}, @@ -32,11 +37,21 @@ mod query_input; use query_input::{execute_query_inner, QueryInput}; mod settings; -use settings::{Settings, ANTHROPIC_API_KEY}; +use settings::Settings; pub(crate) static INMEMORY_STORE: LazyLock> = LazyLock::new(|| Arc::new(InMemory::new())); +pub(crate) static SESSION_CTX: LazyLock> = LazyLock::new(|| { + let mut config = SessionConfig::new(); + config.options_mut().sql_parser.dialect = "PostgreSQL".to_string(); + let ctx = Arc::new(SessionContext::new_with_config(config)); + let object_store_url = ObjectStoreUrl::parse("mem://").unwrap(); + let object_store = INMEMORY_STORE.clone(); + ctx.register_object_store(object_store_url.as_ref(), object_store); + ctx +}); + #[derive(Debug, Clone, PartialEq)] pub(crate) struct ParquetReader { parquet_table: ParquetTable, @@ -190,10 +205,9 @@ impl std::fmt::Display for ParquetInfo { } async fn execute_query_async( - query: String, - table_name: String, + query: &str, ) -> Result<(Vec, Arc), String> { - let (results, physical_plan) = execute_query_inner(&table_name, &query) + let (results, physical_plan) = execute_query_inner(query) .await .map_err(|e| format!("Failed to execute query: {}", e))?; @@ -269,15 +283,7 @@ fn App() -> impl IntoView { let Some(parquet_reader) = parquet_reader.get() else { return; }; - let api_key = get_stored_value(ANTHROPIC_API_KEY, ""); - let sql = match query_input::user_input_to_sql( - &user_input, - &parquet_reader.info().schema, - parquet_reader.table_name(), - &api_key, - ) - .await - { + let sql = match query_input::user_input_to_sql(&user_input, &parquet_reader).await { Ok(response) => response, Err(e) => { set_error_message.set(Some(e)); @@ -301,13 +307,12 @@ fn App() -> impl IntoView { return; } - if let Some(parquet_table) = bytes_opt { + if let Some(_parquet_table) = bytes_opt { let query = query.clone(); let export_to = export_to.clone(); - let table_name = parquet_table.table_name; leptos::task::spawn_local(async move { - match execute_query_async(query.clone(), table_name).await { + match execute_query_async(&query).await { Ok((results, physical_plan)) => { if let Some(export_to) = export_to { if export_to == "csv" { @@ -316,8 +321,18 @@ fn App() -> impl IntoView { export_to_parquet_inner(&results); } } + set_query_results.update(|r| { let id = r.len(); + if let Some(first_batch) = results.first() { + let schema = first_batch.schema(); + let mem_table = + MemTable::try_new(schema, vec![results.clone()]).unwrap(); + SESSION_CTX + .as_ref() + .register_table(format!("view_{}", id), Arc::new(mem_table)) + .unwrap(); + } r.push(QueryResult::new( id, query, diff --git a/src/query_input.rs b/src/query_input.rs index 38bf501..a8c9dc9 100644 --- a/src/query_input.rs +++ b/src/query_input.rs @@ -4,9 +4,7 @@ use arrow_array::RecordBatch; use arrow_schema::SchemaRef; use datafusion::{ error::DataFusionError, - execution::object_store::ObjectStoreUrl, physical_plan::{collect, ExecutionPlan}, - prelude::{ParquetReadOptions, SessionConfig}, }; use leptos::{logging, prelude::*}; use leptos::{ @@ -17,27 +15,15 @@ use serde_json::json; use wasm_bindgen_futures::JsFuture; use web_sys::{js_sys, Headers, Request, RequestInit, RequestMode, Response}; -use crate::INMEMORY_STORE; +use crate::{ + settings::{get_stored_value, ANTHROPIC_API_KEY}, + ParquetReader, SESSION_CTX, +}; pub(crate) async fn execute_query_inner( - table_name: &str, query: &str, ) -> Result<(Vec, Arc), DataFusionError> { - let mut config = SessionConfig::new(); - config.options_mut().sql_parser.dialect = "PostgreSQL".to_string(); - - let ctx = datafusion::prelude::SessionContext::new_with_config(config); - - let object_store_url = ObjectStoreUrl::parse("mem://").unwrap(); - let object_store = INMEMORY_STORE.clone(); - ctx.register_object_store(object_store_url.as_ref(), object_store); - ctx.register_parquet( - table_name, - &format!("mem:///{}", table_name), - ParquetReadOptions::default(), - ) - .await?; - + let ctx = SESSION_CTX.as_ref(); let plan = ctx.sql(query).await?; let (state, plan) = plan.into_parts(); @@ -137,9 +123,7 @@ pub fn QueryInput( pub(crate) async fn user_input_to_sql( input: &str, - schema: &SchemaRef, - file_name: &str, - api_key: &str, + parquet_reader: &ParquetReader, ) -> Result { // if the input seems to be a SQL query, return it as is if input.starts_with("select") || input.starts_with("SELECT") { @@ -148,6 +132,9 @@ pub(crate) async fn user_input_to_sql( // otherwise, treat it as some natural language + let schema = &parquet_reader.info().schema; + let file_name = parquet_reader.table_name(); + let api_key = get_stored_value(ANTHROPIC_API_KEY, ""); let schema_str = schema_to_brief_str(schema); logging::log!("Processing user input: {}", input); @@ -157,7 +144,7 @@ pub(crate) async fn user_input_to_sql( ); logging::log!("{}", prompt); - let sql = match generate_sql_via_claude(&prompt, api_key).await { + let sql = match generate_sql_via_claude(&prompt, &api_key).await { Ok(response) => response, Err(e) => { logging::log!("{}", e); diff --git a/src/query_results.rs b/src/query_results.rs index 6b6a643..5db9a83 100644 --- a/src/query_results.rs +++ b/src/query_results.rs @@ -144,6 +144,12 @@ pub fn QueryResultView(result: QueryResult) -> impl IntoView {
+
+ + {format!("SELECT * FROM view_{}", result.id())} + + {format!("view_{}", result.id())} +