Skip to content

Commit

Permalink
Capture compression and open binary if present. (#419)
Browse files Browse the repository at this point in the history
* Capture compression and open binary if present.

* Add non-utf8 encoded test
  • Loading branch information
delucchi-cmu authored Nov 14, 2024
1 parent e47eab1 commit c27133d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/hats/io/file_io/file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,19 @@ def load_csv_to_pandas(file_pointer: str | Path | UPath, **kwargs) -> pd.DataFra


def load_csv_to_pandas_generator(
file_pointer: str | Path | UPath, chunksize=10_000, **kwargs
file_pointer: str | Path | UPath, *, chunksize=10_000, compression=None, **kwargs
) -> Generator[pd.DataFrame]:
"""Load a csv file to a pandas dataframe
Args:
file_pointer: location of csv file to load
file_system: fsspec or pyarrow filesystem, default None
chunksize (int): number of rows to load per chunk
compression (str): for compressed CSVs, the manner of compression. e.g. 'gz', 'bzip'.
**kwargs: arguments to pass to pandas `read_csv` loading method
Returns:
pandas dataframe loaded from CSV
"""
file_pointer = get_upath(file_pointer)
with file_pointer.open("r", **kwargs) as csv_file:
with file_pointer.open(mode="rb", compression=compression, **kwargs) as csv_file:
with pd.read_csv(csv_file, chunksize=chunksize, **kwargs) as reader:
yield from reader

Expand Down
11 changes: 11 additions & 0 deletions tests/hats/io/file_io/test_file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,17 @@ def test_load_csv_to_pandas_generator(small_sky_source_dir):
assert num_reads == 2


def test_load_csv_to_pandas_generator_encoding(tmp_path):
path = tmp_path / "koi8-r.csv"
with path.open(encoding="koi8-r", mode="w") as fh:
fh.write("col1,col2\nыыы,яяя\n")
num_reads = 0
for frame in load_csv_to_pandas_generator(path, chunksize=7, encoding="koi8-r"):
assert len(frame) == 1
num_reads += 1
assert num_reads == 1


def test_write_df_to_csv(tmp_path):
random_df = pd.DataFrame(np.random.randint(0, 100, size=(100, 4)), columns=list("ABCD"))
test_file_path = tmp_path / "test.csv"
Expand Down

0 comments on commit c27133d

Please sign in to comment.