Skip to content

Commit

Permalink
add python test and cpp sql_router_test
Browse files Browse the repository at this point in the history
  • Loading branch information
vagetablechicken committed Jul 6, 2022
1 parent cb2f0bd commit 9aa6122
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 26 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/cicd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ jobs:
run: |
make test
- name: run sql_router_test
id: sql_router_test
run: |
- name: run sql_sdk_test
id: sql_sdk_test
run: |
Expand Down
56 changes: 32 additions & 24 deletions python/openmldb/dbapi/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# fmt:off
import sys
from pathlib import Path

# add parent directory
sys.path.append(Path(__file__).parent.parent.as_posix())
from sdk import sdk as sdk_module
Expand All @@ -26,6 +27,7 @@
from typing import List
from typing import Union
import re

# fmt:on

# Globals
Expand Down Expand Up @@ -218,48 +220,53 @@ def callproc(self, procname, parameters=()):
return self

@classmethod
def __add_row_to_builder(cls, row, hold_idxs, schema, builder, appendMap):
for i in range(len(hold_idxs)):
idx = hold_idxs[i]
name = schema.GetColumnName(idx)
colType = schema.GetColumnType(idx)
def __add_row_to_builder(cls, row, hole_idxes, schema, builder, append_map):
# hole idxes is in stmt order, we should use schema order to append
hole_pairs = [(hole_idxes[i], i) for i in range(len(hole_idxes))]
hole_pairs.sort(key=lambda elem: elem[0])

for i in range(len(hole_pairs)):
idx = hole_pairs[i][0]
name = schema.GetColumnName(idx)
col_type = schema.GetColumnType(idx)
row_idx = hole_pairs[i][1]
if isinstance(row, tuple):
ok = appendMap[colType](row[i])
ok = append_map[col_type](row[row_idx])
if not ok:
raise DatabaseError(
"error at append data seq {}".format(i))
"error at append data seq {}".format(row_idx))
elif isinstance(row, dict):
if row[name] is None:
builder.AppendNULL()
continue
ok = appendMap[colType](row[name])
ok = append_map[col_type](row[name])
if not ok:
raise DatabaseError(
"error at append data seq {}".format(i))
"error at append data seq {}".format(row_idx))
else:
raise DatabaseError(
"error at append data seq {} for unsupported type".format(i))
"error at append data seq {} for unsupported row type".format(row_idx))

def execute(self, operation, parameters=()):
command = operation.strip(' \t\n\r') if operation else None
if command is None:
raise Exception("None operation")
if insertRE.match(command):
questionMarkCount = command.count('?')
if questionMarkCount > 0:
if len(parameters) != questionMarkCount:
question_mark_count = command.count('?')
if question_mark_count > 0:
if len(parameters) != question_mark_count:
raise DatabaseError("parameters is not enough")
ok, builder = self.connection._sdk.getInsertBuilder(
self.db, command)
if not ok:
raise DatabaseError("get insert builder fail")
schema = builder.GetSchema()
holdIdxs = builder.GetHoleIdx()
appendMap = self.__get_append_map(
builder, parameters, holdIdxs, schema)
# holeIdxes is in stmt column order
hole_idxes = builder.GetHoleIdx()
append_map = self.__get_append_map(
builder, parameters, hole_idxes, schema)
self.__add_row_to_builder(
parameters, holdIdxs, schema, builder, appendMap)
parameters, hole_idxes, schema, builder, append_map)
ok, error = self.connection._sdk.executeInsert(
self.db, command, builder)
else:
Expand Down Expand Up @@ -289,10 +296,11 @@ def execute(self, operation, parameters=()):
return self

@classmethod
def __get_append_map(cls, builder, row, hold_idxs, schema):
def __get_append_map(cls, builder, row, hole_idxes, schema):
# calc str total length
str_size = 0
for i in range(len(hold_idxs)):
idx = hold_idxs[i]
for i in range(len(hole_idxes)):
idx = hole_idxes[i]
name = schema.GetColumnName(idx)
if isinstance(row, tuple):
if isinstance(row[i], str):
Expand All @@ -316,7 +324,7 @@ def __get_append_map(cls, builder, row, hold_idxs, schema):
raise DatabaseError(
"parameters type {} does not support: {}, should be tuple or dict".format(type(row), row))
builder.Init(str_size)
appendMap = {
append_map = {
sql_router_sdk.kTypeBool: builder.AppendBool,
sql_router_sdk.kTypeInt16: builder.AppendInt16,
sql_router_sdk.kTypeInt32: builder.AppendInt32,
Expand All @@ -329,15 +337,15 @@ def __get_append_map(cls, builder, row, hold_idxs, schema):
int(x.split("-")[0]), int(x.split("-")[1]), int(x.split("-")[2])),
sql_router_sdk.kTypeTimestamp: builder.AppendTimestamp
}
return appendMap
return append_map

def __insert_rows(self, rows: List[Union[tuple, dict]], hold_idxs, schema, rows_builder, command):
for row in rows:
tmp_builder = rows_builder.NewRow()
appendMap = self.__get_append_map(
append_map = self.__get_append_map(
tmp_builder, row, hold_idxs, schema)
self.__add_row_to_builder(
row, hold_idxs, schema, tmp_builder, appendMap)
row, hold_idxs, schema, tmp_builder, append_map)
ok, error = self.connection._sdk.executeInsert(
self.db, command, rows_builder)
if not ok:
Expand Down
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
'Programming Language :: Python :: 3',
],
install_requires=[
"sqlalchemy < 1.4.0",
"sqlalchemy <= 1.4.9",
"IPython",
"prettytable",
"pytest"
Expand Down
6 changes: 5 additions & 1 deletion python/test/dbapi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@


class TestOpenmldbDBAPI:

cursor = None

@classmethod
Expand Down Expand Up @@ -70,6 +69,11 @@ def test_select_conditioned(self):
assert 'second' in result
assert 200 in result

def test_custom_order_insert(self):
self.cursor.execute("insert into new_table (y, x) values(300, 'third');")
self.cursor.execute("insert into new_table (y, x) values(?, ?);", (300, 'third'))
self.cursor.execute("insert into new_table (y, x) values(?, ?);", {'x': 'third', 'y': 300})


if __name__ == "__main__":
sys.exit(pytest.main(["-vv", "dbapi_test.py"]))

0 comments on commit 9aa6122

Please sign in to comment.