Skip to content

Commit

Permalink
Support all tar archives in create and extract
Browse files Browse the repository at this point in the history
  • Loading branch information
hagenw committed Jan 7, 2025
1 parent 4592c91 commit 6c16737
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 51 deletions.
109 changes: 58 additions & 51 deletions audeer/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def create_archive(
*,
verbose: bool = False,
):
r"""Create ZIP or TAR.GZ archive.
r"""Create ZIP or TAR archive.
If a list with ``files`` is provided,
only those files will be included to the archive.
Expand Down Expand Up @@ -129,8 +129,9 @@ def create_archive(
Raises:
FileNotFoundError: if ``root`` or a file in ``files`` is not found
NotADirectoryError: if ``root`` is not a directory
RuntimeError: if archive does not end with ``zip`` or ``tar.gz``
or a file in ``files`` is not below ``root``
RuntimeError: if archive does not end with ``zip``,
``tar``, ``tar.gz``, ``tar.bz2``, ``tar.xz``
RuntimeError: if a file in ``files`` is not below ``root``
Examples:
>>> file_a = audeer.touch("a.txt")
Expand Down Expand Up @@ -211,21 +212,31 @@ def create_archive(
)
disable = not verbose

if archive.endswith("zip"):
with zipfile.ZipFile(archive, "w", zipfile.ZIP_DEFLATED) as zf:
for file in progress_bar(files, desc=desc, disable=disable):
full_file = safe_path(root, file)
zf.write(full_file, arcname=file)
elif archive.endswith("tar.gz"):
with tarfile.open(archive, "w:gz") as tf:
for file in progress_bar(files, desc=desc, disable=disable):
full_file = safe_path(root, file)
tf.add(full_file, file)
else:
archive_handlers = {
"zip": lambda path: zipfile.ZipFile(path, "w", zipfile.ZIP_DEFLATED),
"tar": lambda path: tarfile.open(path, "w:"),
"tar.gz": lambda path: tarfile.open(path, "w:gz"),
"tar.bz2": lambda path: tarfile.open(path, "w:bz2"),
"tar.xz": lambda path: tarfile.open(path, "w:xz"),
}

# Get the archive extension
extension = next((ext for ext in archive_handlers if archive.endswith(ext)), None)
if extension is None:
supported = ", ".join(archive_handlers.keys())
raise RuntimeError(
f"You can only create a ZIP or TAR.GZ archive, " f"not {archive}"
f"Unsupported archive format. Supported formats: {supported}"
)

# Create and populate the archive
with archive_handlers[extension](archive) as archive_file:
for file in progress_bar(files, desc=desc, disable=disable):
full_file = safe_path(root, file)
if extension == "zip":
archive_file.write(full_file, arcname=file)
else:
archive_file.add(full_file, file)


def download_url(
url: str,
Expand Down Expand Up @@ -282,10 +293,10 @@ def extract_archive(
keep_archive: bool = True,
verbose: bool = False,
) -> typing.List[str]:
r"""Extract ZIP or TAR.GZ file.
r"""Extract ZIP or TAR file.
Args:
archive: path to ZIP or TAR.GZ file
archive: path to ZIP or TAR file
destination: folder where the files will be extracted.
If the folder does not exists,
it will be created
Expand All @@ -300,7 +311,7 @@ def extract_archive(
FileNotFoundError: if ``archive`` is not found
IsADirectoryError: if ``archive`` is a directory
NotADirectoryError: if ``destination`` is not a directory
RuntimeError: if ``archive`` is not a ZIP or TAR.GZ file
RuntimeError: if ``archive`` is not a ZIP or TAR file
RuntimeError: if ``archive`` is malformed
Examples:
Expand Down Expand Up @@ -351,39 +362,35 @@ def extract_archive(
)
disable = not verbose

def extract_zip(archive: str) -> list:
with zipfile.ZipFile(archive, "r") as zf:
members = zf.infolist()
for member in progress_bar(members, desc=desc, disable=disable):
zf.extract(member, destination)
return [m.filename for m in members]

def extract_tar(archive: str) -> list:
with tarfile.open(archive, "r") as tf:
members = tf.getmembers()
for member in progress_bar(members, desc=desc, disable=disable):
# In Python 3.12 the `filter` argument was introduced,
# and it will be set automatically in Python 3.14,
# see
# https://docs.python.org/3.12/library/tarfile.html#tarfile-extraction-filter
# noqa: E501
kwargs = {"numeric_owner": True}
if sys.version_info >= (3, 12): # pragma: no cover
kwargs = kwargs | {"filter": "tar"}
tf.extract(member, destination, **kwargs)
return [m.name for m in members]

try:
if archive.endswith("zip"):
with zipfile.ZipFile(archive, "r") as zf:
members = zf.infolist()
for member in progress_bar(
members,
desc=desc,
disable=disable,
):
zf.extract(member, destination)
files = [m.filename for m in members]
elif archive.endswith("tar.gz"):
with tarfile.open(archive, "r") as tf:
members = tf.getmembers()
for member in progress_bar(
members,
desc=desc,
disable=disable,
):
# In Python 3.12 the `filter` argument was introduced,
# and it will be set automatically in Python 3.14,
# see
# https://docs.python.org/3.12/library/tarfile.html#tarfile-extraction-filter
# noqa: E501
kwargs = {"numeric_owner": True}
if sys.version_info >= (3, 12): # pragma: no cover
kwargs = kwargs | {"filter": "tar"}
tf.extract(member, destination, **kwargs)
files = [m.name for m in members]
files = extract_zip(archive)
elif tarfile.is_tarfile(archive):
files = extract_tar(archive)
else:
raise RuntimeError(
f"You can only extract ZIP and TAR.GZ files, " f"not {archive}"
)
raise RuntimeError(f"You can only extract ZIP and TAR files, not {archive}")
except (EOFError, zipfile.BadZipFile, tarfile.ReadError):
raise RuntimeError(f"Broken archive: {archive}")
except (KeyboardInterrupt, Exception): # pragma: no cover
Expand All @@ -410,10 +417,10 @@ def extract_archives(
keep_archive: bool = True,
verbose: bool = False,
) -> typing.List[str]:
r"""Extract multiple ZIP or TAR.GZ archives at once.
r"""Extract multiple ZIP or TAR archives at once.
Args:
archives: paths of ZIP or TAR.GZ files
archives: paths of ZIP or TAR files
destination: folder where the files will be extracted.
If the folder does not exists,
it will be created
Expand All @@ -428,7 +435,7 @@ def extract_archives(
FileNotFoundError: if an archive is not found
IsADirectoryError: if an archive is a directory
NotADirectoryError: if ``destination`` is not a directory
RuntimeError: if an archive is not a ZIP or TAR.GZ file
RuntimeError: if an archive is not a ZIP or TAR file
RuntimeError: if an archive file is malformed
Examples:
Expand Down
27 changes: 27 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,15 @@ def tree(tmpdir, request):
".",
[".hidden", "file.txt", "sub/a/b/file.txt"],
),
( # tar
["file.txt", "sub/a/b/file.txt"],
".",
["sub/a/b/file.txt", "file.txt"],
"archive.tar",
"archive.tar",
".",
["sub/a/b/file.txt", "file.txt"],
),
( # tar.gz
["file.txt", "sub/a/b/file.txt"],
".",
Expand All @@ -125,6 +134,24 @@ def tree(tmpdir, request):
".",
["sub/a/b/file.txt", "file.txt"],
),
( # tar.bz2
["file.txt", "sub/a/b/file.txt"],
".",
["sub/a/b/file.txt", "file.txt"],
"archive.tar.bz2",
"archive.tar.bz2",
".",
["sub/a/b/file.txt", "file.txt"],
),
( # tar.xz
["file.txt", "sub/a/b/file.txt"],
".",
["sub/a/b/file.txt", "file.txt"],
"archive.tar.xz",
"archive.tar.xz",
".",
["sub/a/b/file.txt", "file.txt"],
),
( # root is sub folder
["sub/file.txt"],
"./sub",
Expand Down

0 comments on commit 6c16737

Please sign in to comment.