Skip to content

Commit

Permalink
Add read_files_as_table and make bulk_upload_files atomic (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
azorej authored Oct 10, 2024
1 parent 3347240 commit 6f34f4d
Show file tree
Hide file tree
Showing 9 changed files with 801 additions and 611 deletions.
2 changes: 1 addition & 1 deletion dbxio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
from dbxio.utils import * # noqa: F403
from dbxio.volume import * # noqa: F403

__version__ = '0.4.6' # single source of truth
__version__ = '0.5.0' # single source of truth
4 changes: 2 additions & 2 deletions dbxio/core/cloud/azure/object_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def download_blob_to_file(self, blob_name: str, file_path: Union[str, Path]) ->

def break_lease(self, blob_name: str) -> None:
blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name)
BlobLeaseClient(client=blob_client).break_lease() # type: ignore
BlobLeaseClient(client=blob_client).break_lease()

def lock_blob(self, blob_name: str, force: bool = False):
blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name)
Expand All @@ -90,7 +90,7 @@ def upload_blob(self, blob_name: str, data: Union[bytes, IOBase, BinaryIO], over

def try_delete_blob(self, blob_name: str) -> None:
blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name)
lease_client = BlobLeaseClient(client=blob_client) # type: ignore
lease_client = BlobLeaseClient(client=blob_client)

try:
lease_client.break_lease()
Expand Down
2 changes: 2 additions & 0 deletions dbxio/delta/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
get_comment_on_table,
get_tags_on_table,
merge_table,
read_files_as_table,
read_table,
save_table_to_files,
set_comment_on_table,
Expand All @@ -33,6 +34,7 @@
'exists_table',
'infer_schema',
'merge_table',
'read_files_as_table',
'read_table',
'write_table',
'save_table_to_files',
Expand Down
2 changes: 1 addition & 1 deletion dbxio/delta/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class Table:
default=Materialization.Table,
validator=attrs.validators.instance_of(Materialization),
)
schema: Optional[Union[dict[str, BaseType], list[dict[str, BaseType]], TableSchema]] = attrs.field(
schema: Optional[TableSchema] = attrs.field(
default=None,
converter=_table_schema_converter,
)
Expand Down
147 changes: 111 additions & 36 deletions dbxio/delta/table_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,39 @@ def exists_table(table: Union[str, Table], client: 'DbxIOClient') -> bool:
return False


def create_table(table: Union[str, Table], client: 'DbxIOClient') -> _FutureBaseResult:
"""
Creates a table in the catalog.
If a table already exists, it does nothing.
Query pattern:
CREATE TABLE IF NOT EXISTS <table_identifier> (col1 type1, col2 type2, ...)
[USING <table_format> LOCATION <location>]
[PARTITIONED BY (col1, col2, ...)]
"""
def _create_table_query(table: Union[str, Table], replace: bool, if_not_exists: bool, include_schema: bool) -> str:
dbxio_table = Table.from_obj(table)

schema_sql = ','.join([f'`{col_name}` {col_type}' for col_name, col_type in dbxio_table.schema.as_dict().items()])
query = f'CREATE TABLE IF NOT EXISTS {dbxio_table.safe_table_identifier} ({schema_sql})'
if include_schema and dbxio_table.schema:
schema_sql = f'({dbxio_table.schema.as_sql()})'
else:
schema_sql = ''
if replace:
query = 'CREATE OR REPLACE TABLE'
else:
query = 'CREATE TABLE'
if if_not_exists:
query += ' IF NOT EXISTS'
query += f' {dbxio_table.safe_table_identifier} {schema_sql}'
if loc := dbxio_table.attributes.location:
query += f" USING {dbxio_table.table_format.name} LOCATION '{loc}'"
if part := dbxio_table.attributes.partitioned_by:
query += f" PARTITIONED BY ({','.join(part)})"

return query


def create_table(table: Union[str, Table], client: 'DbxIOClient', replace: bool = False) -> _FutureBaseResult:
"""
Creates a table in the catalog.
If replace == False: if a table already exists, it does nothing.
If replace == True: create or replace table.
Query pattern:
CREATE [OR REPLACE] TABLE [IF NOT EXISTS] <table_identifier> (col1 type1, col2 type2, ...)
[USING <table_format> LOCATION <location>]
[PARTITIONED BY (col1, col2, ...)]
"""
query = _create_table_query(table, replace, if_not_exists=True, include_schema=True)
return client.sql(query)


Expand Down Expand Up @@ -200,6 +215,50 @@ def copy_into_table(
client.sql(sql_copy_into_query).wait()


def read_files_as_table(
client: 'DbxIOClient',
table: Table,
blob_path: str,
table_format: TableFormat,
abs_name: str,
abs_container_name: str,
include_files_pattern: bool = False,
replace: bool = False,
force_schema: bool = True,
) -> None:
"""
Copy data from blob storage as a table. All files that match the pattern *.{table_format} will be copied.
If force_schema == False it will use schemaHints instead of schema option
"""
create_query = _create_table_query(table, replace, if_not_exists=False, include_schema=False)
options = {
'format': f"'{table_format.value.lower()}'",
}
if include_files_pattern:
options['fileNamePattern'] = f"'*.{table_format.value.lower()}'"
if table.schema:
sql_schema = f"'{table.schema.as_sql()}'"
columns_exp = ', '.join(table.schema.columns)
if force_schema:
options['schema'] = sql_schema
else:
options['schemaHints'] = sql_schema
else:
columns_exp = '*'
options['mergeSchema'] = 'true'

options_query = ',\n'.join([f'{k} => {v}' for k, v in options.items()])
select_query = dedent(f"""
AS SELECT {columns_exp}
FROM read_files(
'abfss://{abs_container_name}@{abs_name}.dfs.core.windows.net/{blob_path}',
{options_query}
)
""")
query = ConstDatabricksQuery(f'{create_query} {select_query}')
client.sql(query).wait()


def bulk_write_table(
table: Union[str, Table],
new_records: Union[Iterator[Dict], List[Dict]],
Expand Down Expand Up @@ -241,17 +300,24 @@ def bulk_write_table(
retrying=client.retrying,
) as tmp_path:
if not append:
drop_table(dbxio_table, client=client, force=True).wait()
create_table(dbxio_table, client=client).wait()

copy_into_table(
client=client,
table=dbxio_table,
table_format=TableFormat.PARQUET,
blob_path=tmp_path,
abs_name=abs_name,
abs_container_name=abs_container_name,
)
read_files_as_table(
client=client,
table=dbxio_table,
table_format=TableFormat.PARQUET,
blob_path=tmp_path,
abs_name=abs_name,
abs_container_name=abs_container_name,
replace=True,
)
else:
copy_into_table(
client=client,
table=dbxio_table,
table_format=TableFormat.PARQUET,
blob_path=tmp_path,
abs_name=abs_name,
abs_container_name=abs_container_name,
)


def bulk_write_local_files(
Expand All @@ -270,7 +336,7 @@ def bulk_write_local_files(
"""
assert table.schema, 'Table schema is required for bulk_write_local_files function'

p = Path(path)
p = Path(path).expanduser()
files = p.glob(f'*.{table_format.value.lower()}') if p.is_dir() else [path]

operation_uuid = str(uuid.uuid4())
Expand All @@ -294,21 +360,30 @@ def bulk_write_local_files(
force=force,
)

if not append:
drop_table(table, client=client, force=True).wait()
create_table(table, client=client).wait()

common_blob_path = str(os.path.commonpath(blobs))
include_files_pattern = len(blobs) > 1
copy_into_table(
client=client,
table=table,
table_format=table_format,
blob_path=common_blob_path,
include_files_pattern=include_files_pattern,
abs_name=abs_name,
abs_container_name=abs_container_name,
)

if not append:
read_files_as_table(
client=client,
table=table,
table_format=table_format,
blob_path=common_blob_path,
include_files_pattern=include_files_pattern,
abs_name=abs_name,
abs_container_name=abs_container_name,
replace=True,
)
else:
copy_into_table(
client=client,
table=table,
table_format=table_format,
blob_path=common_blob_path,
include_files_pattern=include_files_pattern,
abs_name=abs_name,
abs_container_name=abs_container_name,
)


def merge_table(
Expand Down
6 changes: 6 additions & 0 deletions dbxio/delta/table_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,11 @@ def columns(self):
def as_dict(self) -> dict[str, BaseType]:
return {col_spec.name: col_spec.type for col_spec in self._columns}

@cache
def as_sql(self) -> str:
return ', '.join([
f'`{name}` {type_}' for name, type_ in self.as_dict().items()
])

def apply(self, record: dict[str, Any]) -> dict[str, Any]:
return {key: self.as_dict()[key].deserialize(val) for key, val in record.items()}
Loading

0 comments on commit 6f34f4d

Please sign in to comment.