Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

completer | dbmetadata: All calls to ODBC/SQL migrated away from completer #27

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion odbcli/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def __init__(
# Loop over side-bar when moving past the element on the bottom
self.obj_list[len(self.obj_list) - 1].next_object = self.obj_list[0]
self._selected_object = self.obj_list[0]
self.completer = MssqlCompleter(smart_completion = True, get_conn = lambda: self.active_conn)
self.completer = MssqlCompleter(smart_completion = True,
get_conn = lambda: self.active_conn)

self.application = self._create_application()

Expand Down
171 changes: 46 additions & 125 deletions odbcli/completion/mssqlcompleter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from prompt_toolkit.completion import Completer, Completion, PathCompleter
from prompt_toolkit.document import Document
from ..conn import sqlConnection
from ..dbmetadata import DbMetadata
from .sqlcompletion import (Blank, FromClauseItem, suggest_type, Special, NamedQuery,
Database, Schema, Table, Function, Column, View,
Keyword, Datatype, Alias, Path, JoinCondition, Join)
Expand Down Expand Up @@ -129,12 +130,12 @@ def __init__(
self.logger.debug("Completer instantiated")

@property
def active_conn(self) -> sqlConnection:
return self._get_conn()
def dbmetadata(self) -> DbMetadata:
return self._get_conn().dbmetadata

def escape_name(self, name):
if self.active_conn is not None:
name = self.active_conn.escape_name(name)
if self.dbmetadata is not None:
name = self.dbmetadata.escape_name(name)

return name

Expand All @@ -143,21 +144,17 @@ def escape_schema(self, name):

def unescape_name(self, name):
""" Unquote a string."""
if self.active_conn is not None:
name = self.active_conn.unescape_name(name)
if self.dbmetadata is not None:
name = self.dbmetadata.unescape_name(name)

return name

def escape_names(self, names):
if self.active_conn is not None:
names = self.active_conn.escape_names(names)
if self.dbmetadata is not None:
names = self.dbmetadata.escape_names(names)

return names

def extend_database_names(self, databases):
databases = self.escape_names(databases)
self.databases.extend(databases)

def extend_keywords(self, additional_keywords):
self.keywords = self.keywords + additional_keywords
# OG: Unclear what the roll of all_completions is
Expand All @@ -179,8 +176,7 @@ def extend_functions(self, func_data):

# dbmetadata['schema_name']['functions']['function_name'] should return
# the function metadata namedtuple for the corresponding function
conn = self.active_conn
metadata = conn.dbmetadata.data
metadata = self.dbmetadata.data
submeta = metadata['function']

for f in func_data:
Expand All @@ -203,8 +199,7 @@ def _refresh_arg_list_cache(self):
# This is used when suggesting functions, to avoid the latency that would result
# if we'd recalculate the arg lists each time we suggest functions (in
# large DBs)
conn = self.active_conn
metadata = conn.dbmetadata.data
metadata = self.dbmetadata.data
self._arg_list_cache = {
usage: {
meta: self._arg_list(meta, usage)
Expand All @@ -226,8 +221,7 @@ def extend_foreignkeys(self, fk_data):
# These are added as a list of ForeignKey namedtuples to the
# ColumnMetadata namedtuple for both the child and parent
# OG: This needs catalog facelift
conn = self.active_conn
metadata = conn.dbmetadata.data
metadata = self.dbmetadata.data
submeta = metadata['table']

for fk in fk_data:
Expand All @@ -249,8 +243,7 @@ def extend_datatypes(self, type_data):
# dbmetadata['datatypes'][schema_name][type_name] should store type
# metadata, such as composite type field names. Currently, we're not
# storing any metadata beyond typename, so just store None
conn = self.active_conn
metadata = conn.dbmetadata.data
metadata = self.dbmetadata.data

for t in type_data:
schema, type_name = self.escape_names(t)
Expand All @@ -276,8 +269,7 @@ def reset_completions(self):
self.special_commands = []
# search_path at this point is not used
self.search_path = []
conn = self.active_conn
conn.dbmetadata.reset_metadata()
self.dbmetadata.reset_metadata()
# OG: Unclear what the roll of all_completions is
#self.all_completions = set(self.keywords + self.functions)

Expand Down Expand Up @@ -676,47 +668,20 @@ def filt(_):
return matches

def get_schema_matches(self, suggestion, word_before_cursor):
conn = self.active_conn
if suggestion.parent:
catalog_u = self.unescape_name(suggestion.parent)
else:
catalog_u = conn.current_catalog()
catalog_u = self.dbmetadata.current_catalog()

catalog_e = self.escape_name(catalog_u)
self.logger.debug("get_schema_matches: parent %s", suggestion.parent)
# OG: Note here, if there is even a single schema in [catalog_e].keys()
# we'll happily return a potentially incomplete result set.
schema_names_e = conn.dbmetadata.get_schemas(catalog = catalog_e)
schema_names_e = self.dbmetadata.get_schemas(catalog = catalog_e)

if schema_names_e is None:
# Asking for schema in a non-existant catalog
return []

if len(schema_names_e) == 0:
# Catalog exists in dbmetadata but is empty
if suggestion.parent:
# Looking for schemas in a specified catalog
schema_names = []
# Attempt list_schemas
schema_names = conn.list_schemas(
catalog = conn.sanitize_search_string(catalog_u))

if len(schema_names) < 1:
res = conn.find_tables(
catalog = conn.sanitize_search_string(catalog_u),
schema = "",
table = "",
type = "")
schema_names = [r.schema for r in res]
else:
# Looking for schemas in current catalog
schema_names = conn.list_schemas()

schema_names = set(schema_names)

schema_names_e = self.escape_names(schema_names)
conn.dbmetadata.extend_schemas(catalog = catalog_e, names = schema_names_e)

return self.find_matches(
word_before_cursor, schema_names_e, meta='schema')

Expand Down Expand Up @@ -836,11 +801,7 @@ def get_alias_matches(self, suggestion, word_before_cursor):
meta='table alias')

def get_database_matches(self, _, word_before_cursor):
conn = self.active_conn
catalogs_e = conn.dbmetadata.get_catalogs()
if catalogs_e is None and (conn.connected()):
catalogs_e = self.escape_names(conn.list_catalogs())
conn.dbmetadata.extend_catalogs(catalogs_e)
catalogs_e = self.dbmetadata.get_catalogs()

return self.find_matches(word_before_cursor, catalogs_e,
meta='catalog')
Expand Down Expand Up @@ -921,10 +882,9 @@ def populate_scoped_cols(self, scoped_tbls, local_tbls=()):
:return: {TableReference:{colname:ColumnMetaData}}

"""
conn = self.active_conn
ctes = dict((normalize_ref(t.name), t.columns) for t in local_tbls)
columns = OrderedDict()
metadata = conn.dbmetadata.data
metadata = self.dbmetadata.data

def addcols(schema, rel, alias, reltype, cols):
tbl = TableReference(schema, rel, alias, reltype == 'functions')
Expand Down Expand Up @@ -978,7 +938,7 @@ def addcols(catalog, schema, rel, alias, reltype, cols):
if tbl.catalog:
catalog_u = self.unescape_name(tbl.catalog)
else:
catalog_u = self.active_conn.current_catalog()
catalog_u = self.dbmetadata.current_catalog()

# TODO: What if no schema? Possible in some DBMS
if tbl.schema:
Expand Down Expand Up @@ -1007,21 +967,21 @@ def addcols(catalog, schema, rel, alias, reltype, cols):
# cols = func.fields()
# addcols(schema, relname, tbl.alias, 'functions', cols)
else:
conn = self.active_conn
# Per SQLColumns spec: CatalogName cannot contain a string search pattern
res = conn.find_columns(
catalog = catalog_u,
schema = conn.sanitize_search_string(schema_u),
table = conn.sanitize_search_string(relname_u),
column = "%")
if len(res):
cols = [ColumnMetadata(
name = col.column,
datatype = col.data_type,
has_default = col.default,
default = col.default
) for col in res]
addcols(catalog, schema, relname, tbl.alias, "table", cols)
for reltype in ("table", "view"):
res = self.dbmetadata.get_columns(
catalog = catalog,
schema = schema,
name = relname,
obj_type = reltype)

if res is not None and len(res):
cols = [ColumnMetadata(
name = col.column,
datatype = col.data_type,
has_default = col.default,
default = col.default
) for col in res]
addcols(catalog, schema, relname, tbl.alias, reltype, cols)

return columns

Expand All @@ -1032,8 +992,7 @@ def _get_schemas(self, obj_typ, schema):
:param schema is the schema qualification input by the user (if any)

"""
conn = self.active_conn
metadata = conn.dbmetadata.data
metadata = self.dbmetadata.data
submeta = metadata[obj_typ]
if schema:
schema = self.escape_name(schema)
Expand All @@ -1050,8 +1009,7 @@ def populate_schema_objects(self, schema, obj_type):
:param schema is the schema qualification input by the user (if any)

"""
conn = self.active_conn
metadata = conn.dbmetadata.data
metadata = self.dbmetadata.data

return [
SchemaObject(
Expand All @@ -1070,15 +1028,14 @@ def populate_objects(self, catalog, schema, obj_type):
"""
ret = []
obj_names = []
conn = self.active_conn
self.logger.debug("populate_objects(%s): Called for %s.%s",
obj_type, catalog, schema)
if catalog is None and schema is None:
catalog = ""
schema = ""
if catalog is None:
# Set to current catalog
catalog = conn.current_catalog()
catalog = self.dbmetadata.current_catalog()
if catalog is None or catalog == "":
# Don't allow "".[schema]
# Interpret this to mean [schema].""
Expand All @@ -1091,55 +1048,20 @@ def populate_objects(self, catalog, schema, obj_type):
# dbmetadata always escaped?
catalog_e = self.escape_name(self.unescape_name(catalog))
schema_e = self.escape_name(self.unescape_name(schema))
obj_names = conn.dbmetadata.get_objects(catalog = catalog_e, schema = schema_e, obj_type = obj_type)
obj_names = self.dbmetadata.get_objects(catalog = catalog_e, schema = schema_e, obj_type = obj_type)
if obj_names is None:
self.logger.debug("populate_objects(%s): Called for %s.%s, catalog/schema not found",
obj_type, catalog, schema)
return []

if len(obj_names) == 0:
# catalog.schema were found but dbmetadata had no information as to
# content. So let's attempt to query
obj_names = []
self.logger.debug("populate_objects(%s): Did not find %s.%s metadata. Will query.", obj_type, catalog_e, schema_e)
# Special case: Look for tables without catalog/schema
if catalog == "" and schema == "":
res = conn.find_tables(
catalog = "\x00",
schema = "\x00",
table = "",
type = obj_type)
else:
res = conn.find_tables(
catalog = conn.sanitize_search_string(
self.unescape_name(catalog)),
schema = conn.sanitize_search_string(
self.unescape_name(schema)),
table = "",
type = obj_type)
for r in res:
name_e = self.escape_name(r.name)
ret.append(
SchemaObject(
name=name_e,
schema=schema_e,
catalog=catalog_e
)
)
obj_names.append(name_e)
self.logger.debug("populate_objects(%s): Query complete %s.%s", obj_type, catalog_e, schema_e)
conn.dbmetadata.extend_objects(
catalog = catalog_e, schema = schema_e,
names = obj_names, obj_type = obj_type)
else:
for name_e in obj_names:
ret.append(
SchemaObject(
name=name_e,
schema=schema_e, #should this be r.schema
catalog=catalog_e #should this be r.catalog
)
for name_e in obj_names:
ret.append(
SchemaObject(
name=name_e,
schema=schema_e, #should this be r.schema
catalog=catalog_e #should this be r.catalog
)
)
return ret

def populate_functions(self, schema, filter_func):
Expand All @@ -1152,8 +1074,7 @@ def populate_functions(self, schema, filter_func):

"""

conn = self.active_conn
metadata = conn.dbmetadata.data
metadata = self.dbmetadata.data
# Because of multiple dispatch, we can have multiple functions
# with the same name, which is why `for meta in metas` is necessary
# in the comprehensions below
Expand Down
10 changes: 7 additions & 3 deletions odbcli/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
self.username = username
self.password = password
self.logger = getLogger(__name__)
self.dbmetadata = DbMetadata()
self.dbmetadata = DbMetadata(self)
self._quotechar = None
self._search_escapechar = None
self._search_escapepattern = None
Expand Down Expand Up @@ -450,8 +450,12 @@ def find_procedure_columns(

def current_catalog(self) -> str:
if self.conn.connected():
return self.conn.catalog_name
return None
with self._lock:
res = self.conn.catalog_name
else:
res = None

return res

def connected(self) -> bool:
return self.conn.connected()
Expand Down
Loading