Skip to content

Commit

Permalink
Refactor foreign key parsing completely into the _parse.py file
Browse files Browse the repository at this point in the history
Signed-off-by: Jesse Whitehouse <[email protected]>
  • Loading branch information
Jesse Whitehouse committed Oct 13, 2023
1 parent 7a4234a commit 3a73daf
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 82 deletions.
67 changes: 11 additions & 56 deletions src/databricks/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import Any, Optional, List
from typing import Any, Optional, List, Tuple

import sqlalchemy
from sqlalchemy import event, DDL
Expand All @@ -16,10 +16,10 @@
# This import is required to process our @compiles decorators
import databricks.sqlalchemy._types as dialect_type_impl
from databricks import sql
from databricks.sqlalchemy.utils import (
extract_identifier_groups_from_string,
from databricks.sqlalchemy._parse import (
build_fk_dict,
extract_identifiers_from_string,
extract_three_level_identifier_from_constraint_string
extract_three_level_identifier_from_constraint_string,
)

try:
Expand Down Expand Up @@ -109,56 +109,11 @@ def _extract_pk_from_dte_result(result: dict) -> ReflectedPrimaryKeyConstraint:
return {"constrained_columns": column_list, "name": name}


def _extract_single_fk_dict_from_dte_result_row(
table_name: str, schema_name: Optional[str], fk_name: str, fk_constraint_string: str
) -> dict:
"""
"""

# SQLAlchemy's ComponentReflectionTest::test_get_foreign_keys is strange in that it
# expects the `referred_schema` member of the outputted dictionary to be None if
# a `schema` argument was not passed to the dialect's `get_foreign_keys` method
referred_table_dict = extract_three_level_identifier_from_constraint_string(fk_constraint_string)
referred_table = referred_table_dict["table"]
if schema_name:
referred_schema = referred_table_dict["schema"]
else:
referred_schema = None

_extracted = extract_identifier_groups_from_string(fk_constraint_string)
constrained_columns_str, referred_columns_str = (
_extracted[0],
_extracted[1],
)

constrainted_columns = extract_identifiers_from_string(constrained_columns_str)
referred_columns = extract_identifiers_from_string(referred_columns_str)

return {
"constrained_columns": constrainted_columns,
"name": fk_name,
"referred_table": referred_table,
"referred_columns": referred_columns,
"referred_schema": referred_schema,
}


def _extract_fk_from_dte_result(
table_name: str, schema_name: Optional[str], result: dict
result: dict, schema_name: Optional[str]
) -> ReflectedForeignKeyConstraint:
"""Return a list of dictionaries with the keys:
constrained_columns
a list of column names that make up the foreign key
name
the name of the foreign key constraint
referred_table
the name of the table that the foreign key references
referred_columns
a list of column names that are referenced by the foreign key
"""Extract a list of foreign key information dictionaries from the result
of a DESCRIBE TABLE EXTENDED call.
Returns an empty list if no foreign key is defined.
Expand All @@ -173,7 +128,7 @@ def _extract_fk_from_dte_result(
"""

# find any rows that contain "FOREIGN_KEY" as the `data_type`
filtered_rows = [(k, v) for k, v in result.items() if "FOREIGN KEY" in v]
filtered_rows: List[Tuple] = [(k, v) for k, v in result.items() if "FOREIGN KEY" in v]

# bail if no foreign key was found
if not filtered_rows:
Expand All @@ -184,8 +139,8 @@ def _extract_fk_from_dte_result(
# target is a tuple of (constraint_name, constraint_string)
for target in filtered_rows:
_constraint_name, _constraint_string = target
this_constraint_dict = _extract_single_fk_dict_from_dte_result_row(
table_name, schema_name, _constraint_name, _constraint_string
this_constraint_dict = build_fk_dict(
_constraint_name, _constraint_string, schema_name
)
constraint_list.append(this_constraint_dict)

Expand Down Expand Up @@ -386,7 +341,7 @@ def get_foreign_keys(
schema_name=schema,
)

return _extract_fk_from_dte_result(table_name, schema, result)
return _extract_fk_from_dte_result(result, schema)

def get_indexes(self, connection, table_name, schema=None, **kw):
"""Return information about indexes in `table_name`.
Expand Down
145 changes: 145 additions & 0 deletions src/databricks/sqlalchemy/_parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from typing import List, Optional
import re

"""
This module contains helper functions that can parse the contents
of DESCRIBE TABLE EXTENDED calls. Mostly wrappers around regexes.
"""

def extract_identifiers_from_string(input_str: str) -> List[str]:
"""For a string input resembling (`a`, `b`, `c`) return a list of identifiers ['a', 'b', 'c']"""

# This matches the valid character list contained in DatabricksIdentifierPreparer
pattern = re.compile(r"`([A-Za-z0-9_]+)`")
matches = pattern.findall(input_str)
return [i for i in matches]


def extract_identifier_groups_from_string(input_str: str) -> List[str]:
"""For a string input resembling :
FOREIGN KEY (`pname`, `pid`, `pattr`) REFERENCES `main`.`pysql_sqlalchemy`.`tb1` (`name`, `id`, `attr`)
Return ['(`pname`, `pid`, `pattr`)', '(`name`, `id`, `attr`)']
"""
pattern = re.compile(r"\([`A-Za-z0-9_,\s]*\)")
matches = pattern.findall(input_str)
return [i for i in matches]


def extract_three_level_identifier_from_constraint_string(input_str: str) -> dict:
"""For a string input resembling :
FOREIGN KEY (`parent_user_id`) REFERENCES `main`.`pysql_dialect_compliance`.`users` (`user_id`)
Return a dict like
{
"catalog": "main",
"schema": "pysql_dialect_compliance",
"table": "users"
}
"""
pat = re.compile(r"REFERENCES\s+(.*?)\s*\(")
matches = pat.findall(input_str)

if not matches:
return None

first_match = matches[0]
parts = first_match.split(".")

def strip_backticks(input:str):
return input.replace("`", "")

return {
"catalog": strip_backticks(parts[0]),
"schema": strip_backticks(parts[1]),
"table": strip_backticks(parts[2])
}

def _parse_fk_from_constraint_string(constraint_str: str) -> dict:
"""Build a dictionary of foreign key constraint information from a constraint string.
For example:
```
FOREIGN KEY (`pname`, `pid`, `pattr`) REFERENCES `main`.`pysql_dialect_compliance`.`tb1` (`name`, `id`, `attr`)
```
Return a dictionary like:
```
{
"constrained_columns": ["pname", "pid", "pattr"],
"referred_table": "tb1",
"referred_schema": "pysql_dialect_compliance",
"referred_columns": ["name", "id", "attr"]
}
```
Note that the constraint name doesn't appear in the constraint string so it will not
be present in the output of this function.
"""

referred_table_dict = extract_three_level_identifier_from_constraint_string(
constraint_str
)
referred_table = referred_table_dict["table"]
referred_schema = referred_table_dict["schema"]

# _extracted is a tuple of two lists of identifiers
# we assume the first immediately follows "FOREIGN KEY" and the second
# immediately follows REFERENCES $tableName
_extracted = extract_identifier_groups_from_string(constraint_str)
constrained_columns_str, referred_columns_str = (
_extracted[0],
_extracted[1],
)

constrained_columns = extract_identifiers_from_string(constrained_columns_str)
referred_columns = extract_identifiers_from_string(referred_columns_str)

return {
"constrained_columns": constrained_columns,
"referred_table": referred_table,
"referred_columns": referred_columns,
"referred_schema": referred_schema,
}

def build_fk_dict(
fk_name: str, fk_constraint_string: str, schema_name: Optional[str]
) -> dict:
"""
Given a foriegn key name and a foreign key constraint string, return a dictionary
with the following keys:
name
the name of the foreign key constraint
constrained_columns
a list of column names that make up the foreign key
referred_table
the name of the table that the foreign key references
referred_columns
a list of column names that are referenced by the foreign key
referred_schema
the name of the schema that the foreign key references.
referred schema will be None if the schema_name argument is None.
This is required by SQLAlchey's ComponentReflectionTest::test_get_foreign_keys
"""

# The foreign key name is not contained in the constraint string so we
# need to add it manually
base_fk_dict = _parse_fk_from_constraint_string(fk_constraint_string)

if not schema_name:
schema_override_dict = dict(referred_schema=None)
else:
schema_override_dict = {}

complete_foreign_key_dict = {
"name": fk_name,
**base_fk_dict,
**schema_override_dict,
}

return complete_foreign_key_dict
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import pytest
from databricks.sqlalchemy.utils import (
from databricks.sqlalchemy._parse import (
extract_identifiers_from_string,
extract_identifier_groups_from_string,
extract_three_level_identifier_from_constraint_string
extract_three_level_identifier_from_constraint_string,
build_fk_dict
)


Expand Down Expand Up @@ -47,4 +48,19 @@ def test_extract_3l_namespace_from_constraint_string():
"table": "users"
}

assert extract_three_level_identifier_from_constraint_string(input) == expected, "Failed to extract 3L namespace from constraint string"
assert extract_three_level_identifier_from_constraint_string(input) == expected, "Failed to extract 3L namespace from constraint string"

@pytest.mark.parametrize("schema", [None, "some_schema"])
def test_build_fk_dict(schema):
fk_constraint_string = "FOREIGN KEY (`parent_user_id`) REFERENCES `main`.`some_schema`.`users` (`user_id`)"

result = build_fk_dict("some_fk_name", fk_constraint_string, schema_name=schema)

assert result == {
"name": "some_fk_name",
"constrained_columns": ["parent_user_id"],
"referred_schema": schema,
"referred_table": "users",
"referred_columns": ["user_id"],
}

23 changes: 0 additions & 23 deletions src/databricks/sqlalchemy/utils.py

This file was deleted.

0 comments on commit 3a73daf

Please sign in to comment.