From 77847f2d5f111158c9e0197bd86a92a4fb0aab89 Mon Sep 17 00:00:00 2001 From: Lian Hu Date: Tue, 3 Oct 2023 16:31:52 +0200 Subject: [PATCH] Updates based on PR review --- test/test_rohmu.py | 120 +++++++++++++++++++++++---------------------- 1 file changed, 61 insertions(+), 59 deletions(-) diff --git a/test/test_rohmu.py b/test/test_rohmu.py index 1aa0808b..f1795af1 100644 --- a/test/test_rohmu.py +++ b/test/test_rohmu.py @@ -1,25 +1,24 @@ import hashlib import logging import os +import pathlib from tempfile import NamedTemporaryFile import pytest from rohmu import get_transfer, rohmufile +from rohmu.object_storage.base import BaseTransfer from rohmu.rohmufile import create_sink_pipeline +from rohmu.typing import Metadata from .base import CONSTANT_TEST_RSA_PRIVATE_KEY, CONSTANT_TEST_RSA_PUBLIC_KEY +EMPTY_FILE_SHA1 = "da39a3ee5e6b4b0d3255bfef95601890afd80709" + log = logging.getLogger(__name__) -@pytest.mark.parametrize( - "compress_algorithm, file_size", - [("lzma", 0), ("snappy", 0), ("zstd", 0), ("lzma", 1), ("snappy", 1), ("zstd", 1)], - ids=[ - "test_lzma_0byte_file", "test_snappy_0byte_file", "test_zstd_0byte_file", "test_lzma_1byte_file", - "test_snappy_1byte_file", "test_zstd_1byte_file" - ], -) +@pytest.mark.parametrize("compress_algorithm", ["lzma", "snappy", "zstd"], ids=["test_lzma", "test_snappy", "test_zstd"]) +@pytest.mark.parametrize("file_size", [0, 1], ids=["0byte_file", "1byte_file"]) def test_rohmu_with_local_storage(compress_algorithm: str, file_size: int, tmp_path): hash_algorithm = "sha1" compression_level = 0 @@ -34,30 +33,20 @@ def test_rohmu_with_local_storage(compress_algorithm: str, file_size: int, tmp_p with open(orig_file, "rb") as file_in: assert file_in.read() == content - # 1 - Compressed the file + original_file_size = os.path.getsize(orig_file) + assert original_file_size == len(content) + + # 1 - Compress the file compressed_filepath = work_dir / "compressed" / "hello_compressed" compressed_filepath.parent.mkdir(exist_ok=True) hasher = hashlib.new(hash_algorithm) - input_obj = open(orig_file, "rb") - output_obj = NamedTemporaryFile( - dir=os.path.dirname(compressed_filepath), prefix=os.path.basename(compressed_filepath), suffix=".tmp-compress" - ) - with input_obj, output_obj: - original_file_size, compressed_file_size = rohmufile.write_file( - data_callback=hasher.update, - input_obj=input_obj, - output_obj=output_obj, - compression_algorithm=compress_algorithm, - compression_level=compression_level, - rsa_public_key=CONSTANT_TEST_RSA_PUBLIC_KEY, - log_func=log.debug, - ) - os.link(output_obj.name, compressed_filepath) - - log.info("original_file_size: %s, compressed_file_size: %s", original_file_size, compressed_file_size) - assert original_file_size == len(content) + compressed_file_size = _compress_file(orig_file, compressed_filepath, compress_algorithm, compression_level, hasher) file_hash = hasher.hexdigest() - log.info("original_file_hash: %s", file_hash) + + log.info( + "original_file_size: %s, original_file_hash: %s, compressed_file_size: %s", original_file_size, file_hash, + compressed_file_size + ) # 2 - Upload the compressed file upload_dir = work_dir / "uploaded" @@ -66,22 +55,16 @@ def test_rohmu_with_local_storage(compress_algorithm: str, file_size: int, tmp_p "directory": str(upload_dir), "storage_type": "local", } + storage = get_transfer(storage_config) + metadata = { "encryption-key-id": "No matter", "compression-algorithm": compress_algorithm, "compression-level": compression_level, + "Content-Length": str(compressed_file_size) } - storage = get_transfer(storage_config) - - metadata_copy = metadata.copy() - metadata_copy["Content-Length"] = str(compressed_file_size) file_key = "compressed/hello_compressed" - - def upload_progress_callback(n_bytes: int) -> None: - log.debug("File: '%s', uploaded %d bytes", file_key, n_bytes) - - with open(compressed_filepath, "rb") as f: - storage.store_file_object(file_key, f, metadata=metadata_copy, upload_progress_fn=upload_progress_callback) + _upload_compressed_file(storage=storage, file_to_upload=str(compressed_filepath), file_key=file_key, metadata=metadata) # 3 - Decrypt and decompress # 3.1 Use file downloading rohmu API @@ -90,12 +73,11 @@ def upload_progress_callback(n_bytes: int) -> None: decompressed_size = _download_and_decompress_with_file(storage, str(decompressed_filepath), file_key, metadata) assert len(content) == decompressed_size # Compare content - with open(decompressed_filepath, "rb") as file_in: - content_decrypted = file_in.read() - hasher = hashlib.new(hash_algorithm) - hasher.update(content_decrypted) - assert hasher.hexdigest() == file_hash - assert content_decrypted == content + content_decrypted = decompressed_filepath.read_bytes() + hasher = hashlib.new(hash_algorithm) + hasher.update(content_decrypted) + assert hasher.hexdigest() == file_hash + assert content_decrypted == content # 3.2 Use rohmu SinkIO API decompressed_filepath = work_dir / "hello_decompressed_2" @@ -103,24 +85,46 @@ def upload_progress_callback(n_bytes: int) -> None: assert len(content) == decompressed_size # Compare content - hasher.hexdigest() - with open(decompressed_filepath, "rb") as file_in: - content_decrypted = file_in.read() - hasher = hashlib.new(hash_algorithm) - hasher.update(content_decrypted) - assert hasher.hexdigest() == file_hash - assert content_decrypted == content + content_decrypted = decompressed_filepath.read_bytes() + hasher = hashlib.new(hash_algorithm) + hasher.update(content_decrypted) + assert hasher.hexdigest() == file_hash + assert content_decrypted == content if file_size == 0: - empty_file_sha1 = "da39a3ee5e6b4b0d3255bfef95601890afd80709" - assert empty_file_sha1 == hasher.hexdigest() + assert EMPTY_FILE_SHA1 == hasher.hexdigest() -def _key_lookup(key_id: str): # pylint: disable=unused-argument +def _key_lookup(key_id: str) -> str: # pylint: disable=unused-argument return CONSTANT_TEST_RSA_PRIVATE_KEY -def _download_and_decompress_with_sink(storage, output_path: str, file_key: str, metadata: dict): +def _compress_file(input_file: pathlib.Path, output_file: pathlib.Path, algorithm: str, compress_level: int, hasher) -> int: + with open(input_file, "rb") as input_obj, NamedTemporaryFile( + dir=output_file.parent, prefix=output_file.name, suffix=".tmp-compress" + ) as output_obj: + _, compressed_file_size = rohmufile.write_file( + data_callback=hasher.update, + input_obj=input_obj, + output_obj=output_obj, + compression_algorithm=algorithm, + compression_level=compress_level, + rsa_public_key=CONSTANT_TEST_RSA_PUBLIC_KEY, + log_func=log.debug, + ) + os.link(output_obj.name, output_file) + return compressed_file_size + + +def _upload_compressed_file(storage: BaseTransfer, file_to_upload: str, file_key: str, metadata: Metadata) -> None: + def upload_progress_callback(n_bytes: int) -> None: + log.debug("File: '%s', uploaded %d bytes", file_key, n_bytes) + + with open(file_to_upload, "rb") as f: + storage.store_file_object(file_key, f, metadata=metadata, upload_progress_fn=upload_progress_callback) + + +def _download_and_decompress_with_sink(storage: BaseTransfer, output_path: str, file_key: str, metadata: Metadata) -> int: data, _ = storage.get_contents_to_string(file_key) if isinstance(data, str): data = data.encode("latin1") @@ -135,7 +139,7 @@ def _download_and_decompress_with_sink(storage, output_path: str, file_key: str, return decompressed_size -def _download_and_decompress_with_file(storage, output_path: str, file_key: str, metadata: dict): +def _download_and_decompress_with_file(storage: BaseTransfer, output_path: str, file_key: str, metadata: Metadata) -> int: # Download the compressed file file_download_path = output_path + ".tmp" @@ -146,9 +150,7 @@ def download_progress_callback(bytes_written: int, input_size: int) -> None: storage.get_contents_to_fileobj(file_key, f, progress_callback=download_progress_callback) # Decrypt and decompress - input_obj = open(file_download_path, "rb") - output_obj = open(output_path, "wb") - with input_obj, output_obj: + with open(file_download_path, "rb") as input_obj, open(output_path, "wb") as output_obj: _, decompressed_size = rohmufile.read_file( input_obj=input_obj, output_obj=output_obj,