Skip to content

Commit

Permalink
Bump version to 0.9.2; enhance storage options handling and add profi…
Browse files Browse the repository at this point in the history
…le support for AWS
  • Loading branch information
legout committed Jan 20, 2025
1 parent 1e0cb41 commit cee12be
Show file tree
Hide file tree
Showing 5 changed files with 360 additions and 288 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ keywords = [
name = "FlowerPower"
readme = "README.md"
requires-python = ">= 3.11"
version = "0.9.1.1"
version = "0.9.2"

[project.scripts]
flowerpower = "flowerpower.cli:app"
Expand Down
178 changes: 98 additions & 80 deletions src/flowerpower/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,18 @@
import pyarrow.dataset as pds
from fsspec import AbstractFileSystem
from fsspec.utils import get_protocol
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict

from ..utils.filesystem import get_filesystem
from ..utils.polars import pl
from ..utils.sql import sql2polars_filter, sql2pyarrow_filter
from ..utils.storage_options import (
AwsStorageOptions,
AzureStorageOptions,
GcsStorageOptions,
GitHubStorageOptions,
GitLabStorageOptions,
# AwsStorageOptions,
# AzureStorageOptions,
# GcsStorageOptions,
# GitHubStorageOptions,
# GitLabStorageOptions,
StorageOptions,
)
import importlib

Expand Down Expand Up @@ -61,94 +62,111 @@ class BaseFileIO(BaseModel):

model_config = ConfigDict(arbitrary_types_allowed=True)
path: str | list[str]
storage_options: (
AwsStorageOptions
| GcsStorageOptions
| AzureStorageOptions
| GitHubStorageOptions
| GitLabStorageOptions
| dict[str, Any]
| None
) = None
fs: AbstractFileSystem | None = Field(default=None)
storage_options: StorageOptions | dict[str, Any] | None = None
fs: AbstractFileSystem | None = None
format: str | None = None

def model_post_init(self, __context):
# self._update_storage_options_from_env()
if isinstance(self.storage_options, dict):
if "protocol" not in self.storage_options:
self.storage_options["protocol"] = get_protocol(self.path)
self.storage_options = StorageOptions(
**self.storage_options
).storage_options

if self.fs is None:
self.fs = get_filesystem(
path=self.path if isinstance(self.path, str) else self.path[0],
storage_options=self.storage_options,
fs=self.fs,
dirfs=True,
)

if isinstance(self.path, str):
self.path = (
self.path.replace(f"**/*.{self.format}", "")
self.path.replace(self.storage_options.protocol, "")
.lstrip("://")
.replace(f"**/*.{self.format}", "")
.replace("**", "")
.replace("*", "")
.rstrip("//")
.rstrip("/")
)

self._update_fs()
# def _update_storage_options_from_aws_credentials(
# self,
# profile: str = "default",
# allow_invalid_certificates: bool = False,
# allow_http: bool = False,
# ) -> AwsStorageOptions:
# if isinstance(self.storage_options, AwsStorageOptions):
# self.storage_options = self.storage_options.to_aws_credentials(
# profile=profile,
# allow_invalid_certificates=allow_invalid_certificates,
# allow_http=allow_http,
# )
# self._update_fs()

# def _update_storage_options_from_env(
# self,
# ):
# protocol = get_protocol(self.path)
# if protocol == "s3":
# self.storage_options = AwsStorageOptions.from_env()
# elif protocol == "gs" or protocol == "gcs":
# self.storage_options = GcsStorageOptions.from_env()
# elif protocol == "az" or protocol == "abfs":
# self.storage_options = AzureStorageOptions.from_env()
# elif protocol == "github":
# self.storage_options = GitHubStorageOptions.from_env()
# elif protocol == "gitlab":
# self.storage_options = GitLabStorageOptions.from_env()
# self._update_fs()

# def _update_storage_options(self, **kwargs):
# self.storage_options = self.storage_options.model_copy(update=kwargs)
# self._update_fs()

# def _update_fs(self):
# if self.fs is None:
# self.fs = get_filesystem(
# path=self.path if isinstance(self.path, str) else self.path[0],
# storage_options=self.storage_options,
# fs=self.fs,
# dirfs=False,
# )

def _update_storage_options_from_aws_credentials(
self,
profile: str = "default",
allow_invalid_certificates: bool = False,
allow_http: bool = False,
) -> AwsStorageOptions:
if isinstance(self.storage_options, AwsStorageOptions):
self.storage_options = self.storage_options.to_aws_credentials(
profile=profile,
allow_invalid_certificates=allow_invalid_certificates,
allow_http=allow_http,
)
self._update_fs()

def _update_storage_options_from_env(
self,
):
protocol = get_protocol(self.path)
if protocol == "s3":
self.storage_options = AwsStorageOptions.from_env()
elif protocol == "gs" or protocol == "gcs":
self.storage_options = GcsStorageOptions.from_env()
elif protocol == "az" or protocol == "abfs":
self.storage_options = AzureStorageOptions.from_env()
elif protocol == "github":
self.storage_options = GitHubStorageOptions.from_env()
elif protocol == "gitlab":
self.storage_options = GitLabStorageOptions.from_env()
self._update_fs()

def _update_storage_options(self, **kwargs):
self.storage_options = self.storage_options.model_copy(update=kwargs)
self._update_fs()

def _update_fs(self):
self.fs = get_filesystem(
self.path if isinstance(self.path, str) else self.path[0],
self.storage_options,
dirfs=False,
)
@property
def _path(self):
if self.fs.protocol == "dir":
if isinstance(self.path, list):
return [p.lstrip(self.fs.path).lstrip("/") for p in self.path]
else:
return self.path.lstrip(self.fs.path).lstrip("/")
return self.path

@property
def _glob_path(self):
if isinstance(self.path, list):
if isinstance(self._path, list):
return
else:
if self.format is not None:
if f"**/*.{self.format}" in self.path:
return self.path
if f"**/*.{self.format}" in self._path:
return self._path
else:
return f"{self.path}/**/*.{self.format}"
return f"{self._path}/**/*.{self.format}".lstrip("/")
else:
if "**" in self.path:
return self.path
if "**" in self._path:
return self._path
else:
return f"{self.path}/**"
return f"{self._path}/**".lstrip("/")

def list_files(self):
if isinstance(self.path, list):
return self.path
if isinstance(self._path, list):
return self._path

glob_path = self._glob_path
return self.fs.glob(glob_path)
return self.fs.glob(self._glob_path)


class BaseFileLoader(BaseFileIO):
Expand Down Expand Up @@ -261,7 +279,7 @@ def _to_polars_dataframe(self, **kwargs) -> pl.DataFrame | list[pl.DataFrame]:
Returns:
pl.DataFrame | list[pl.DataFrame]: Polars DataFrame or list of DataFrames.
"""
self._load(**kwargs)
self.load(**kwargs)
if isinstance(self._data, list):
df = [
df if isinstance(self._data, pl.DataFrame) else pl.from_arrow(df)
Expand Down Expand Up @@ -299,7 +317,7 @@ def _to_polars_lazyframe(self, **kwargs) -> pl.LazyFrame | list[pl.LazyFrame]:
Returns:
pl.LazyFrame | list[pl.LazyFrame]: Polars LazyFrame or list of LazyFrames.
"""
self._load(**kwargs)
self.load(**kwargs)
if not self.concat:
return [df.lazy() for df in self._to_polars_dataframe()]
return self._to_polars_dataframe.lazy()
Expand Down Expand Up @@ -345,7 +363,7 @@ def to_polars(
return self._to_polars_lazyframe(**kwargs)
return self._to_polars_dataframe(**kwargs)

def iter_polers(
def iter_polars(
self,
lazy: bool = False,
batch_size: int = 1,
Expand All @@ -362,7 +380,7 @@ def to_pyarrow_table(self, **kwargs) -> pa.Table | list[pa.Table]:
Returns:
pa.Table | list[pa.Table]: PyArrow Table or list of Tables.
"""
self._load(**kwargs)
self.load(**kwargs)
if isinstance(self._data, list):
df = [
df.to_arrow(**kwargs) if isinstance(df, pl.DataFrame) else df
Expand Down Expand Up @@ -587,23 +605,23 @@ def to_pyarrow_dataset(
"""
if self.format == ["csv", "arrow", "ipc"]:
self._dataset = self.fs.pyarrow_dataset(
self.path,
self._path,
format=self.format,
schema=self.schema_,
partitioning=self.partitioning,
**kwargs,
)
elif self.format == "parquet":
if self.fs.exists(os.path.join(self.path, "_metadata")):
if self.fs.exists(os.path.join(self._path, "_metadata")):
self._dataset = self.fs.parquet_dataset(
self.path,
self._path,
schema=self.schema_,
partitioning=self.partitioning,
**kwargs,
)
else:
self._dataset = self.fs.pyarrow_dataset(
self.path,
self._path,
format=self.format,
schema=self.schema_,
partitioning=self.partitioning,
Expand Down Expand Up @@ -669,7 +687,7 @@ def to_pydala_dataset(self, **kwargs) -> ParquetDataset:
if not hasattr(self, "conn"):
self.conn = duckdb.connect()
self._pydala_dataset = self.fs.pydala_dataset(
self.path,
self._path,
partitioning=self.partitioning,
ddb_con=self.conn,
**kwargs,
Expand Down Expand Up @@ -883,7 +901,7 @@ def write(
if not self.is_pydala_dataset:
self.fs.write_pyarrow_dataset(
data=data or self.data,
path=self.path,
path=self._path,
basename=basename or self.basename,
schema=schema or self.schema_,
partition_by=partition_by or self.partition_by,
Expand All @@ -899,7 +917,7 @@ def write(
else:
self.fs.write_pydala_dataset(
data=data or self.data,
path=self.path,
path=self._path,
mode=mode,
basename=basename or self.basename,
schema=schema or self.schema_,
Expand Down
10 changes: 8 additions & 2 deletions src/flowerpower/utils/filesystem/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def mscf_ls_p(self, path, detail=False, **kwargs):

def get_filesystem(
path: str | Path | None = None,
storage_options: (BaseStorageOptions | dict[str, str] | None) = None,
storage_options: BaseStorageOptions | dict[str, str] | None = None,
dirfs: bool = True,
cached: bool = False,
cache_storage: str | None = None,
Expand Down Expand Up @@ -282,7 +282,10 @@ def get_filesystem(

pp = infer_storage_options(str(path) if isinstance(path, Path) else path)
protocol = pp.get("protocol")
path = pp.get("host", "") + pp.get("path", "")
host = pp.get("host", "")
path = pp.get("path", "")
if host and host not in path:
path = os.path.join(host, path)

if protocol == "file" or protocol == "local":
fs = filesystem(protocol)
Expand All @@ -292,6 +295,9 @@ def get_filesystem(
fs.is_cache_fs = False
return fs

if isinstance(storage_options, dict):
storage_options = storage_options_from_dict(protocol, storage_options)

if storage_options is None:
storage_options = storage_options_from_dict(protocol, storage_options_kwargs)

Expand Down
Loading

0 comments on commit cee12be

Please sign in to comment.