Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature add sql formatter #84

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ __pycache__
/cli_helpers_dev
.idea/
.cache/
.vscode/
**/.ropeproject/
*.swp

1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ This project receives help from these awesome contributors:
- Mel Dafert
- Andrii Kohut
- Roland Walker
- Liu Zhao (astroshot)

Thanks
------
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
Changelog
=========

Features
-------------
* New formatter is added to export query result to sql format (such as sql-insert, sql-update).

TBD
-------------
* don't escape newlines, etc. in ascii tables, and add ascii_escaped table format
Expand Down
108 changes: 108 additions & 0 deletions cli_helpers/tabular_output/sql_output_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# -*- coding: utf-8 -*-

supported_formats = (
"sql-insert",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These formats will need documenting.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll add some later

"sql-update",
"sql-update-1",
"sql-update-2",
)

preprocessors = ()


def escape_for_sql_statement(value):
if isinstance(value, bytes):
return f"X'{value.hex()}'"
else:
return "'{}'".format(value)


def adapter(data, headers, table_format=None, **kwargs):
"""
This function registers supported_formats to default TabularOutputFormatter

Parameters:
data: query result
headers: columns
table_format: values from supported_formats
kwargs:
extract_tables: extract_tables function. For example, in pgcli.packages.parseutils.tables there is a function extract_tables
delimiter: Character surrounds table name or column name when it conflicts with sql keywords.
For example, mysql uses ` and postgres uses "
"""
extract_table_func = kwargs.get("extract_tables")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is one way to do it. Alternatively, the CLI could pass in tables as part of kwargs.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try it

if not extract_table_func:
raise ValueError("extract_tables function should be registered first")

tables = extract_table_func(formatter.query)
delimiter = kwargs.get("delimiter")
if not isinstance(delimiter, str):
delimiter = '"'

if tables is not None and len(tables) > 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adapter doesn't really make sense for more than one table, correct? Perhaps there should be an error if we have more than one table.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this formatter is convenient is generate INSERTION SQL or UPDATING SQL relates to one table. In those cases when SQL with multiple tables is run, usually queried results earn more concern than considering to importing them. So I thought it was ok to use a fake table name "DUAL" to indicate this situation, and you?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should raise an exception in this case.

table = tables[0]
if table[0]:
table_name = "{}.{}".format(*table[:2])
else:
table_name = table[1]
else:
table_name = "DUAL".format(delimiter=delimiter)

header_joiner = "{delimiter}, {delimiter}".format(delimiter=delimiter)
if table_format == "sql-insert":
h = header_joiner.join(headers)
yield "INSERT INTO {delimiter}{table_name}{delimiter} ({delimiter}{header}{delimiter}) VALUES".format(
table_name=table_name, header=h, delimiter=delimiter
)
prefix = " "
for d in data:
values = ", ".join(escape_for_sql_statement(v) for i, v in enumerate(d))
yield "{}({})".format(prefix, values)
if prefix == " ":
prefix = ", "
yield ";"
if table_format.startswith("sql-update"):
s = table_format.split("-")
keys = 1
if len(s) > 2:
keys = int(s[-1])
for d in data:
yield "UPDATE {delimiter}{table_name}{delimiter} SET".format(
table_name=table_name, delimiter=delimiter
)
prefix = " "
for i, v in enumerate(d[keys:], keys):
yield "{prefix}{delimiter}{column}{delimiter} = {value}".format(
prefix=prefix,
delimiter=delimiter,
column=headers[i],
value=escape_for_sql_statement(v),
)
if prefix == " ":
prefix = ", "
f = "{delimiter}{column}{delimiter} = {value}"
where = (
f.format(
delimiter=delimiter,
column=headers[i],
value=escape_for_sql_statement(d[i]),
)
for i in range(keys)
)
yield "WHERE {};".format(" AND ".join(where))


def register_new_formatter(TabularOutputFormatter, **kwargs):
"""
Parameters:
TabularOutputFormatter: default TabularOutputFormatter imported from cli_helpers
kwargs: dict required, with key delimiter and tables required.
For example {"delimiter": "`", "extact_tables": extract_tables}
"""
global formatter
formatter = TabularOutputFormatter
for sql_format in supported_formats:
kwargs["table_format"] = sql_format
TabularOutputFormatter.register_new_formatter(
sql_format, adapter, preprocessors, kwargs
)
159 changes: 159 additions & 0 deletions tests/tabular_output/test_sql_output_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# -*- coding: utf-8 -*-

from collections import namedtuple

from cli_helpers.tabular_output import TabularOutputFormatter
from cli_helpers.tabular_output.sql_output_adapter import escape_for_sql_statement, adapter, register_new_formatter

TableReference = namedtuple(
"TableReference", ["schema", "name", "alias", "is_function"]
)

TableReference.ref = property(
lambda self: self.alias
or (
self.name
if self.name.islower() or self.name[0] == '"'
else '"' + self.name + '"'
)
)


def test_escape_for_sql_statement_bytes():
bts = b"837124ab3e8dc0f"
escaped_bytes = escape_for_sql_statement(bts)
assert escaped_bytes == "X'383337313234616233653864633066'"


def __mock_extract_tables(sql):
"""
mock function for extract tables
in mycli, pass `mycli.packages.parseutils.extract_tables`
in pgcli, pass `pgcli.packages.parseutils.extract_tables`

:param sql: sql query
:return:
"""
table_refs = (TableReference(schema=None, name='user', alias='"user"', is_function=False),)
return table_refs


def test_output_sql_insert():
global formatter
formatter = TabularOutputFormatter
register_new_formatter(formatter)
data = [
[
1,
"Jackson",
"[email protected]",
"132454789",
"",
"2022-09-09 19:44:32.712343+08",
"2022-09-09 19:44:32.712343+08",
]
]
header = ["id", "name", "email", "phone", "description", "created_at", "updated_at"]
table_format = "sql-insert"
kwargs = {
"column_types": [int, str, str, str, str, str, str],
"sep_title": "RECORD {n}",
"sep_character": "-",
"sep_length": (1, 25),
"missing_value": "<null>",
"integer_format": "",
"float_format": "",
"disable_numparse": True,
"preserve_whitespace": True,
"max_field_width": 500,
"extract_tables": __mock_extract_tables,
}

formatter.query = 'SELECT * FROM "user";'
# For postgresql
kwargs["delimiter"] = '"'
output = adapter(data, header, table_format=table_format, **kwargs)
output_list = [l for l in output]
expected = [
'INSERT INTO "user" ("id", "name", "email", "phone", "description", "created_at", "updated_at") VALUES',
" ('1', 'Jackson', '[email protected]', '132454789', '', "
+ "'2022-09-09 19:44:32.712343+08', '2022-09-09 19:44:32.712343+08')",
";",
]
assert expected == output_list

# For mysql
kwargs["delimiter"] = "`"
output = adapter(data, header, table_format=table_format, **kwargs)
output_list = [l for l in output]
expected = [
'INSERT INTO `user` (`id`, `name`, `email`, `phone`, `description`, `created_at`, `updated_at`) VALUES',
" ('1', 'Jackson', '[email protected]', '132454789', '', "
+ "'2022-09-09 19:44:32.712343+08', '2022-09-09 19:44:32.712343+08')",
";",
]
assert expected == output_list


def test_output_sql_update_pg():
global formatter
formatter = TabularOutputFormatter
register_new_formatter(formatter)
data = [
[
1,
"Jackson",
"[email protected]",
"132454789",
"",
"2022-09-09 19:44:32.712343+08",
"2022-09-09 19:44:32.712343+08",
]
]
header = ["id", "name", "email", "phone", "description", "created_at", "updated_at"]
table_format = "sql-update"
table_refs = (TableReference(schema=None, name='user', alias='"user"', is_function=False),)
kwargs = {
"column_types": [int, str, str, str, str, str, str],
"sep_title": "RECORD {n}",
"sep_character": "-",
"sep_length": (1, 25),
"missing_value": "<null>",
"integer_format": "",
"float_format": "",
"disable_numparse": True,
"preserve_whitespace": True,
"max_field_width": 500,
"extract_tables": __mock_extract_tables,
}
formatter.query = 'SELECT * FROM "user";'
# For postgresql
kwargs["delimiter"] = '"'
output = adapter(data, header, table_format=table_format, **kwargs)
output_list = [l for l in output]
expected = [
'UPDATE "user" SET',
' "name" = \'Jackson\'',
', "email" = \'[email protected]\'',
', "phone" = \'132454789\'',
', "description" = \'\'',
', "created_at" = \'2022-09-09 19:44:32.712343+08\'',
', "updated_at" = \'2022-09-09 19:44:32.712343+08\'',
'WHERE "id" = \'1\';']
assert expected == output_list

# For mysql
kwargs["delimiter"] = "`"
output = adapter(data, header, table_format=table_format, **kwargs)
output_list = [l for l in output]
print(output_list)
expected = [
'UPDATE `user` SET',
" `name` = 'Jackson'",
", `email` = '[email protected]'",
", `phone` = '132454789'",
", `description` = ''",
", `created_at` = '2022-09-09 19:44:32.712343+08'",
", `updated_at` = '2022-09-09 19:44:32.712343+08'",
"WHERE `id` = '1';"]
assert expected == output_list