Skip to content

Commit

Permalink
fix xlsx export
Browse files Browse the repository at this point in the history
  • Loading branch information
pa1ch committed Dec 3, 2024
1 parent 266f81f commit 83ba6e5
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 66 deletions.
6 changes: 1 addition & 5 deletions superset/charts/data/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def _send_chart_response(

if not result["queries"]:
return self.response_400(_("Empty query result"))

exportAsTime = form_data.get('exportAsTime')
column_config = form_data.get('column_config')
table_order_by = form_data.get('table_order_by')
Expand All @@ -426,9 +427,6 @@ def _send_chart_response(
g.user)
return self.response_403()

if not result["queries"]:
return self.response_400(_("Empty query result"))

if list_of_data := result["queries"]:
df = pd.DataFrame()
for data in list_of_data:
Expand Down Expand Up @@ -487,8 +485,6 @@ def _send_chart_response(
g.user)
return self.response_403()

if not result["queries"]:
return self.response_400(_("Empty query result"))
if list_of_data := result["queries"]:
df = pd.DataFrame()
for data in list_of_data:
Expand Down
2 changes: 1 addition & 1 deletion superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def _try_json_readsha(filepath: str, length: int) -> str | None:
"ru": {"flag": "ru", "name": "Russian"},
}

XLSX_EXPORT = {"encoding": "utf-8", "index": False}
EXCEL_EXPORT = {"encoding": "utf-8", "index": False}
# Override the default d3 locale format
# Default values are equivalent to
# D3_FORMAT = {
Expand Down
48 changes: 30 additions & 18 deletions superset/sqllab/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from superset.sqllab.validators import CanAccessQueryValidatorImpl
from superset.superset_typing import FlaskResponse
from superset.utils import core as utils
from superset.views.base import CsvResponse, generate_download_headers, json_success
from superset.views.base import CsvResponse, XlsxResponse, generate_download_headers, json_success
from superset.views.base_api import BaseSupersetApi, requires_json, statsd_metrics

config = app.config
Expand Down Expand Up @@ -141,28 +141,37 @@ def estimate_query_cost(self) -> Response:
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
f".export_csv",
f".export_data",
log_to_statsd=False,
)
def export_csv(self, client_id: str) -> Union[CsvResponse, Response]:
"""Exports the SQL query results to a CSV
def export_data(self, client_id: str) -> Union[CsvResponse, XlsxResponse]:
"""Exports the SQL query results to a CSV or Excel file
---
get:
summary: >-
Exports the SQL query results to a CSV
Exports the SQL query results to a CSV or Excel file
parameters:
- in: path
schema:
type: integer
name: client_id
description: The SQL query result identifier
- in: query
name: result_format
schema:
type: string
enum: [csv, xlsx]
description: The output format (csv or xlsx)
responses:
200:
description: SQL query results
content:
text/csv:
schema:
type: string
application/vnd.openxmlformats-officedocument.spreadsheetml.sheet:
schema:
type: string
400:
$ref: '#/components/responses/400'
401:
Expand All @@ -176,31 +185,34 @@ def export_csv(self, client_id: str) -> Union[CsvResponse, Response]:
"""
result_format = request.args.get('result_format')
result = SqlResultExportCommand(client_id=client_id, result_format=result_format).run()
if result_format == ChartDataResultFormat.XLSX:
return send_file(path_or_file=result,
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
as_attachment=True,
download_name="data.xlsx"
)

query, data, row_count = result["query"], result["data"], result["count"]
quoted_name = parse.quote(query.name)

if result_format == ChartDataResultFormat.XLSX:
response = XlsxResponse(
data,
headers=generate_download_headers("xlsx", quoted_name)
)
event_format = "xlsx"
else:
response = CsvResponse(
data,
headers=generate_download_headers("csv", quoted_name)
)
event_format = "csv"

quoted_csv_name = parse.quote(query.name)
response = CsvResponse(
data, headers=generate_download_headers("csv", quoted_csv_name)
)
event_info = {
"event_type": "data_export",
"client_id": client_id,
"row_count": row_count,
"database": query.database.name,
"schema": query.schema,
"sql": query.sql,
"exported_format": "csv",
"exported_format": event_format,
}
event_rep = repr(event_info)
logger.debug(
"CSV exported: %s", event_rep, extra={"superset_event": event_info}
"Data exported: %s", event_rep, extra={"superset_event": event_info}
)
return response

Expand Down
19 changes: 11 additions & 8 deletions superset/sqllab/commands/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery
from superset.sqllab.limiting_factor import LimitingFactor
from superset.utils import core as utils, csv
from superset.utils import core as utils, csv, excel
from superset.views.utils import _deserialize_results_payload

config = app.config
Expand Down Expand Up @@ -94,7 +94,7 @@ def run(
blob = None
if results_backend and self._query.results_key:
logger.info(
"Fetching CSV from results backend [%s]", self._query.results_key
"Fetching %s from results backend [%s]", self.result_format, self._query.results_key
)
blob = results_backend.get(self._query.results_key)
if blob:
Expand All @@ -112,9 +112,9 @@ def run(
columns=[c["name"] for c in obj["columns"]],
)

logger.info("Using pandas to convert to CSV")
logger.info("Using pandas to convert to %s format", self.result_format)
else:
logger.info("Running a query to turn into CSV")
logger.info("Running a query to turn into %s", self.result_format)
if self._query.select_sql:
sql = self._query.select_sql
limit = None
Expand All @@ -129,13 +129,16 @@ def run(
# remove extra row from `increased_limit`
limit -= 1
df = self._query.database.get_df(sql, self._query.schema)[:limit]

if self.result_format == ChartDataResultFormat.XLSX:
xlsx_data = csv.df_to_escaped_xlsx(df)
return xlsx_data
csv_data = csv.df_to_escaped_csv(df, **config["CSV_EXPORT"], from_sqllab=True)
data = excel.df_to_excel(df, **config["EXCEL_EXPORT"])
elif self.result_format == ChartDataResultFormat.CSV:
data = csv.df_to_escaped_csv(df, **config["CSV_EXPORT"], from_sqllab=True)
else:
raise ValueError(f"Unsupported result format: {self.result_format}")

return {
"query": self._query,
"count": len(df.index),
"data": csv_data,
"data": data,
}
31 changes: 5 additions & 26 deletions superset/utils/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import io
import logging
import re
import urllib.request
from typing import Any, Optional
from typing import Any, Optional, Union
from urllib.error import URLError

import numpy as np
Expand Down Expand Up @@ -66,39 +65,19 @@ def escape_value(value: str) -> str:


def df_to_escaped_csv(df: pd.DataFrame, **kwargs: Any) -> Any:
escape_values = lambda v: escape_value(v) if isinstance(v, str) else v
def escape_values(v: Any) -> Union[str, Any]:
return escape_value(v) if isinstance(v, str) else v

# Escape csv headers
df = df.rename(columns=escape_values)

# Escape csv values
if kwargs.get("from_sqllab"):
kwargs.pop("from_sqllab")
return df.to_csv(**kwargs)
return df.to_dict(orient="records")


def df_to_escaped_xlsx(df: pd.DataFrame, **kwargs: Any) -> io.BytesIO:
escape_values = lambda v: escape_value(v) if isinstance(v, str) else v

# Escape xslx headers
df = df.rename(columns=escape_values)

# Convert timezone-aware timestamps to timezone-naive
for col in df.select_dtypes(include=['datetime64[ns, UTC]', 'datetimetz']).columns:
df[col] = df[col].dt.tz_localize(None)

excel_writer = io.BytesIO()
# Escape xlsx values
for name, column in df.items():
if column.dtype == np.dtype(object):
for idx, value in enumerate(column.values):
if isinstance(value, str):
df.at[idx, name] = escape_value(value)
df.to_excel(excel_writer, startrow=0, merge_cells=False,
sheet_name="Sheet_1", index_label=None, index=False)
excel_writer.seek(0)
return excel_writer


def get_chart_csv_data(
chart_url: str, auth_cookies: Optional[dict[str, str]] = None
) -> Optional[bytes]:
Expand Down
44 changes: 41 additions & 3 deletions superset/utils/excel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,54 @@

import pandas as pd

from superset.utils.core import GenericDataType


def quote_formulas(df: pd.DataFrame) -> pd.DataFrame:
"""
Make sure to quote any formulas for security reasons.
"""
formula_prefixes = {"=", "+", "-", "@"}

for col in df.select_dtypes(include="object").columns:
df[col] = df[col].apply(
lambda x: (
f"'{x}"
if isinstance(x, str) and len(x) and x[0] in formula_prefixes
else x
)
)

return df


def df_to_excel(df: pd.DataFrame, **kwargs: Any) -> Any:
output = io.BytesIO()

# timezones are not supported
for column in df.select_dtypes(include=["datetimetz"]).columns:
df[column] = df[column].astype(str)
# make sure formulas are quoted, to prevent malicious injections
df = quote_formulas(df)

# remove timezones from datetime columns
for col in df.select_dtypes(include=['datetime64[ns, UTC]', 'datetime64[ns, tzinfo]']).columns:
df[col] = df[col].dt.tz_localize(None)

# pylint: disable=abstract-class-instantiated
with pd.ExcelWriter(output, engine="xlsxwriter") as writer:
df.to_excel(writer, **kwargs)

return output.getvalue()


def apply_column_types(
df: pd.DataFrame, column_types: list[GenericDataType]
) -> pd.DataFrame:
for column, column_type in zip(df.columns, column_types):
if column_type == GenericDataType.NUMERIC:
try:
df[column] = pd.to_numeric(df[column])
except ValueError:
df[column] = df[column].astype(str)
elif pd.api.types.is_datetime64tz_dtype(df[column]):
# timezones are not supported
df[column] = df[column].dt.tz_localize(None)
return df
4 changes: 2 additions & 2 deletions superset/views/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,10 +687,10 @@ class CsvResponse(Response):

class XlsxResponse(Response):
"""
Override Response to use xlsx mimetype
Override Response to take into account xlsx encoding from config.py
"""

charset = "utf-8"
charset = conf["EXCEL_EXPORT"].get("encoding", "utf-8")
default_mimetype = (
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
)
Expand Down
4 changes: 2 additions & 2 deletions superset/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
VizData,
VizPayload,
)
from superset.utils import core as utils, csv
from superset.utils import core as utils, csv, excel
from superset.utils.cache import set_and_log_cache
from superset.utils.core import (
apply_max_row_limit,
Expand Down Expand Up @@ -695,7 +695,7 @@ def get_xlsx(self, mt_cl: dict = None) -> BytesIO:
for column_df in df.columns:
df.rename(columns={column_df: mt_cl.get(column_df) or column_df},
inplace=True)
return csv.df_to_escaped_xlsx(df)
return excel.df_to_excel(df)

@deprecated(deprecated_in="3.0")
def get_data(self, df: pd.DataFrame) -> VizData: # pylint: disable=no-self-use
Expand Down
Loading

0 comments on commit 83ba6e5

Please sign in to comment.