From f18eaec2cc27bdf0535ce14b920fd7f6cb965642 Mon Sep 17 00:00:00 2001 From: "min.tian" Date: Mon, 28 Oct 2024 09:38:46 +0800 Subject: [PATCH] add key for plotly_chart Signed-off-by: min.tian --- .../components/check_results/charts.py | 15 +++++++++------ .../frontend/pages/quries_per_dollar.py | 18 +++++++++++++----- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/vectordb_bench/frontend/components/check_results/charts.py b/vectordb_bench/frontend/components/check_results/charts.py index c2b2813b8..9e869b479 100644 --- a/vectordb_bench/frontend/components/check_results/charts.py +++ b/vectordb_bench/frontend/components/check_results/charts.py @@ -1,5 +1,7 @@ from vectordb_bench.backend.cases import Case -from vectordb_bench.frontend.components.check_results.expanderStyle import initMainExpanderStyle +from vectordb_bench.frontend.components.check_results.expanderStyle import ( + initMainExpanderStyle, +) from vectordb_bench.metric import metricOrder, isLowerIsBetterMetric, metricUnitMap from vectordb_bench.frontend.config.styles import * from vectordb_bench.models import ResultLabel @@ -11,7 +13,7 @@ def drawCharts(st, allData, failedTasks, caseNames: list[str]): for caseName in caseNames: chartContainer = st.expander(caseName, True) data = [data for data in allData if data["case_name"] == caseName] - drawChart(data, chartContainer) + drawChart(data, chartContainer, key_prefix=caseName) errorDBs = failedTasks[caseName] showFailedDBs(chartContainer, errorDBs) @@ -35,7 +37,7 @@ def showFailedText(st, text, dbs): ) -def drawChart(data, st): +def drawChart(data, st, key_prefix: str): metricsSet = set() for d in data: metricsSet = metricsSet.union(d["metricsSet"]) @@ -43,7 +45,8 @@ def drawChart(data, st): for i, metric in enumerate(showMetrics): container = st.container() - drawMetricChart(data, metric, container) + key = f"{key_prefix}-{metric}" + drawMetricChart(data, metric, container, key=key) def getLabelToShapeMap(data): @@ -75,7 +78,7 @@ def getLabelToShapeMap(data): return labelToShapeMap -def drawMetricChart(data, metric, st): +def drawMetricChart(data, metric, st, key: str): dataWithMetric = [d for d in data if d.get(metric, 0) > 1e-7] # dataWithMetric = data if len(dataWithMetric) == 0: @@ -161,4 +164,4 @@ def drawMetricChart(data, metric, st): ), ) - chart.plotly_chart(fig, use_container_width=True) + chart.plotly_chart(fig, use_container_width=True, key=key) diff --git a/vectordb_bench/frontend/pages/quries_per_dollar.py b/vectordb_bench/frontend/pages/quries_per_dollar.py index 0bb05294b..4a45181de 100644 --- a/vectordb_bench/frontend/pages/quries_per_dollar.py +++ b/vectordb_bench/frontend/pages/quries_per_dollar.py @@ -1,10 +1,17 @@ import streamlit as st from vectordb_bench.frontend.components.check_results.footer import footer -from vectordb_bench.frontend.components.check_results.expanderStyle import initMainExpanderStyle +from vectordb_bench.frontend.components.check_results.expanderStyle import ( + initMainExpanderStyle, +) from vectordb_bench.frontend.components.check_results.priceTable import priceTable -from vectordb_bench.frontend.components.check_results.stPageConfig import initResultsPageConfig +from vectordb_bench.frontend.components.check_results.stPageConfig import ( + initResultsPageConfig, +) from vectordb_bench.frontend.components.check_results.headerIcon import drawHeaderIcon -from vectordb_bench.frontend.components.check_results.nav import NavToResults, NavToRunTest +from vectordb_bench.frontend.components.check_results.nav import ( + NavToResults, + NavToRunTest, +) from vectordb_bench.frontend.components.check_results.charts import drawMetricChart from vectordb_bench.frontend.components.check_results.filters import getshownData from vectordb_bench.frontend.components.get_results.saveAsImage import getResults @@ -16,7 +23,7 @@ def main(): # set page config initResultsPageConfig(st) - + # header drawHeaderIcon(st) @@ -57,7 +64,8 @@ def main(): dataWithMetric.append(d) if len(dataWithMetric) > 0: chartContainer = st.expander(caseName, True) - drawMetricChart(data, metric, chartContainer) + key = f"{caseName}-{metric}" + drawMetricChart(data, metric, chartContainer, key=key) # footer footer(st.container())