diff --git a/flake.lock b/flake.lock index 0ec5d933a6..99ba354fed 100644 --- a/flake.lock +++ b/flake.lock @@ -965,11 +965,11 @@ ] }, "locked": { - "lastModified": 1728836758, - "narHash": "sha256-Cw9V9AAwgYnrqkh9TdVB0VqHSSfuCvOsDuD3XTDbKjI=", + "lastModified": 1729095447, + "narHash": "sha256-VT6wHzGsonqB0brKczONncm+ezmKrCE/E7uJO+FuvmY=", "owner": "onekey-sec", "repo": "unblob-native", - "rev": "de90b4d1c77831f9e4a38efb4e2f0d56b7d07095", + "rev": "59ce177733cee85eb3f42e011c2139b0f003dfb2", "type": "github" }, "original": { diff --git a/poetry.lock b/poetry.lock index bfe686c848..1614046022 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1920,18 +1920,18 @@ lzallright = ">=0.2.1,<0.3.0" [[package]] name = "unblob-native" -version = "0.1.3" +version = "0.1.4" description = "Performance-critical functionality for Unblob" optional = false python-versions = ">=3.8" files = [ - {file = "unblob_native-0.1.3-cp38-abi3-macosx_10_7_x86_64.whl", hash = "sha256:11e27ee80c8c52f0c1d315d0c05695da3814a4b36da47cbb1c6a1dd7a0f81a12"}, - {file = "unblob_native-0.1.3-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:6c4adf1c8609f620449a6aa362548443569d683f84e73ef20615935d7a0ee403"}, - {file = "unblob_native-0.1.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b268206100cb656d86f747ffe444e0808c1983d4efabe4e6d43a735738f0ecfe"}, - {file = "unblob_native-0.1.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f455172f7025b088c3af0931614109c634f881f7e497c5343ea14849c0d3f4b"}, - {file = "unblob_native-0.1.3-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:4d7b8256cba2ab4d2afed40b4a57a560082009d58ce2a7e4927c9e957185f365"}, - {file = "unblob_native-0.1.3-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:cdba5a8b50fef3067f83aac44d809bab6a5f65037778c23b7abc04be0a995473"}, - {file = "unblob_native-0.1.3.tar.gz", hash = "sha256:92f300677d3f4328682a2a2bf90dd498c826652752c9f2ab2731a82e31903008"}, + {file = "unblob_native-0.1.4-cp38-abi3-macosx_10_7_x86_64.whl", hash = "sha256:3e7b8e976b3363fcf2c40328c963d3aab15a3233af78507ef846669735822010"}, + {file = "unblob_native-0.1.4-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:cce0bcae3fb28a91d2c8af3e3c407962f4a7eb442d9c47a8e84dec84b340b7d7"}, + {file = "unblob_native-0.1.4-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de416b6481534fe4ad329801ea0e8fd7afcecdbd7e73a8854d95497c7e1f3dc7"}, + {file = "unblob_native-0.1.4-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b29b7dbd319d27ccdab48fc4dfd3fc3b320343f62f794fd36ae8e9ca755fc67"}, + {file = "unblob_native-0.1.4-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:be11c409a6868646a28edfaf28f89ac7cba836975d0870694df83ecb906df0fe"}, + {file = "unblob_native-0.1.4-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:2ba8af660b81ec8ea823390406800c994548dc38f46bf96a466af3db6f592e47"}, + {file = "unblob_native-0.1.4.tar.gz", hash = "sha256:fdcc963090358617850ab0d0e2c9f69551212f3267d2b3e8ce8e2a25cf513af9"}, ] [[package]] @@ -2065,4 +2065,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "8ef93e585816ad6eb6641b456b1a945d7b3301e3cf3752c35412e97946a9e54b" +content-hash = "2b8414d8c00f7b73376eae4067c447fde2a91fb06c61ec2d1f9a2ad0f6033362" diff --git a/pyproject.toml b/pyproject.toml index 7edc3cbd6f..77dd2d8a91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ lz4 = "^4.3.2" lief = "^0.15.1" cryptography = ">=41.0,<44.0" treelib = "^1.7.0" -unblob-native = "^0.1.1" +unblob-native = "^0.1.4" jefferson = "^0.4.5" rich = "^13.3.5" pyfatfs = "^1.0.5" diff --git a/tests/test_cli.py b/tests/test_cli.py index 720015859a..fc62243de7 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Iterable, List, Optional, Type +from typing import Iterable, List, Optional, Tuple, Type from unittest import mock import pytest @@ -13,9 +13,11 @@ from unblob.processing import ( DEFAULT_DEPTH, DEFAULT_PROCESS_NUM, + DEFAULT_SKIP_EXTENSION, DEFAULT_SKIP_MAGIC, ExtractionConfig, ) +from unblob.testing import is_sandbox_available from unblob.ui import ( NullProgressReporter, ProgressReporter, @@ -310,16 +312,16 @@ def test_keep_extracted_chunks( @pytest.mark.parametrize( - "skip_extension, extracted_files_count", + "skip_extension, expected_skip_extensions", [ - pytest.param([], 5, id="skip-extension-empty"), - pytest.param([""], 5, id="skip-zip-extension-empty-suffix"), - pytest.param([".zip"], 1, id="skip-extension-zip"), - pytest.param([".rlib"], 5, id="skip-extension-rlib"), + pytest.param((), DEFAULT_SKIP_EXTENSION, id="skip-extension-empty"), + pytest.param(("",), ("",), id="skip-zip-extension-empty-suffix"), + pytest.param((".zip",), (".zip",), id="skip-extension-zip"), + pytest.param((".rlib",), (".rlib",), id="skip-extension-rlib"), ], ) def test_skip_extension( - skip_extension: List[str], extracted_files_count: int, tmp_path: Path + skip_extension: List[str], expected_skip_extensions: Tuple[str, ...], tmp_path: Path ): runner = CliRunner() in_path = ( @@ -335,8 +337,12 @@ def test_skip_extension( for suffix in skip_extension: args += ["--skip-extension", suffix] params = [*args, "--extract-dir", str(tmp_path), str(in_path)] - result = runner.invoke(unblob.cli.cli, params) - assert extracted_files_count == len(list(tmp_path.rglob("*"))) + process_file_mock = mock.MagicMock() + with mock.patch.object(unblob.cli, "process_file", process_file_mock): + result = runner.invoke(unblob.cli.cli, params) + assert ( + process_file_mock.call_args.args[0].skip_extension == expected_skip_extensions + ) assert result.exit_code == 0 @@ -420,3 +426,31 @@ def test_clear_skip_magics( assert sorted(process_file_mock.call_args.args[0].skip_magic) == sorted( skip_magic ), fail_message + + +@pytest.mark.skipif( + not is_sandbox_available(), reason="Sandboxing is only available on Linux" +) +def test_sandbox_escape(tmp_path: Path): + runner = CliRunner() + + in_path = tmp_path / "input" + in_path.touch() + extract_dir = tmp_path / "extract-dir" + params = ["--extract-dir", str(extract_dir), str(in_path)] + + unrelated_file = tmp_path / "unrelated" + assert not unrelated_file.exists() + + process_file_mock = mock.MagicMock( + side_effect=lambda *_args, **_kwargs: unrelated_file.write_text( + "sandbox escape" + ) + ) + with mock.patch.object(unblob.cli, "process_file", process_file_mock): + result = runner.invoke(unblob.cli.cli, params) + + assert result.exit_code != 0 + assert isinstance(result.exception, PermissionError) + assert not unrelated_file.exists() + process_file_mock.assert_called_once() diff --git a/tests/test_sandbox.py b/tests/test_sandbox.py new file mode 100644 index 0000000000..972ba8caa8 --- /dev/null +++ b/tests/test_sandbox.py @@ -0,0 +1,85 @@ +from pathlib import Path + +import pytest + +from unblob.processing import ExtractionConfig +from unblob.sandbox import Sandbox +from unblob.testing import is_sandbox_available + +pytestmark = pytest.mark.skipif( + not is_sandbox_available(), reason="Sandboxing only works on Linux" +) + + +@pytest.fixture +def log_path(tmp_path): + return tmp_path / "unblob.log" + + +@pytest.fixture +def extraction_config(extraction_config, tmp_path): + extraction_config.extract_root = tmp_path / "extract" / "root" + # parent has to exist + extraction_config.extract_root.parent.mkdir() + return extraction_config + + +@pytest.fixture +def sandbox(extraction_config: ExtractionConfig, log_path: Path): + return Sandbox(extraction_config, log_path, None) + + +def test_necessary_resources_can_be_created_in_sandbox( + sandbox: Sandbox, extraction_config: ExtractionConfig, log_path: Path +): + directory_in_extract_root = extraction_config.extract_root / "path" / "to" / "dir" + file_in_extract_root = directory_in_extract_root / "file" + + assert not extraction_config.extract_root.exists() + sandbox.run(extraction_config.extract_root.mkdir, parents=True) + assert extraction_config.extract_root.exists() + + assert not directory_in_extract_root.exists() + sandbox.run(directory_in_extract_root.mkdir, parents=True) + assert directory_in_extract_root.exists() + + assert not file_in_extract_root.exists() + sandbox.run(file_in_extract_root.touch) + assert file_in_extract_root.exists() + + sandbox.run(file_in_extract_root.write_text, "file content") + assert file_in_extract_root.read_text() == "file content" + + # log-file is already opened + log_path.touch() + sandbox.run(log_path.write_text, "log line") + assert log_path.read_text() == "log line" + + +def test_access_outside_sandbox_is_not_possible(sandbox: Sandbox, tmp_path: Path): + unrelated_dir = tmp_path / "unrelated" / "path" + unrelated_file = tmp_path / "unrelated-file" + + assert not unrelated_dir.exists() + with pytest.raises(PermissionError): + sandbox.run(unrelated_dir.mkdir, parents=True) + assert not unrelated_dir.exists() + + unrelated_dir.mkdir(parents=True) + with pytest.raises(PermissionError): + sandbox.run(unrelated_dir.rmdir) + assert unrelated_dir.exists() + + assert not unrelated_file.exists() + with pytest.raises(PermissionError): + sandbox.run(unrelated_file.touch) + assert not unrelated_file.exists() + + unrelated_file.write_text("file content") + with pytest.raises(PermissionError): + sandbox.run(unrelated_file.write_text, "overwrite attempt") + assert unrelated_file.read_text() == "file content" + + with pytest.raises(PermissionError): + sandbox.run(unrelated_file.unlink) + assert unrelated_file.exists() diff --git a/unblob/cli.py b/unblob/cli.py index cb275e809c..498644e078 100755 --- a/unblob/cli.py +++ b/unblob/cli.py @@ -33,6 +33,7 @@ ExtractionConfig, process_file, ) +from .sandbox import Sandbox from .ui import NullProgressReporter, RichConsoleProgressReporter logger = get_logger() @@ -301,7 +302,8 @@ def cli( ) logger.info("Start processing file", file=file) - process_results = process_file(config, file, report_file) + sandbox = Sandbox(config, log_path, report_file) + process_results = sandbox.run(process_file, config, file, report_file) if verbose == 0: if skip_extraction: print_scan_report(process_results) diff --git a/unblob/pool.py b/unblob/pool.py index 810011a209..4b06ea3e85 100644 --- a/unblob/pool.py +++ b/unblob/pool.py @@ -1,11 +1,13 @@ import abc +import contextlib import multiprocessing as mp import os import queue +import signal import sys import threading from multiprocessing.queues import JoinableQueue -from typing import Any, Callable, Union +from typing import Any, Callable, Set, Union from .logging import multiprocessing_breakpoint @@ -13,6 +15,10 @@ class PoolBase(abc.ABC): + def __init__(self): + with pools_lock: + pools.add(self) + @abc.abstractmethod def submit(self, args): pass @@ -24,15 +30,20 @@ def process_until_done(self): def start(self): pass - def close(self): - pass + def close(self, *, immediate=False): # noqa: ARG002 + with pools_lock: + pools.remove(self) def __enter__(self): self.start() return self - def __exit__(self, *args): - self.close() + def __exit__(self, exc_type, _exc_value, _tb): + self.close(immediate=exc_type is not None) + + +pools_lock = threading.Lock() +pools: Set[PoolBase] = set() class Queue(JoinableQueue): @@ -53,9 +64,15 @@ class _Sentinel: def _worker_process(handler, input_, output): - # Creates a new process group, making sure no signals are propagated from the main process to the worker processes. + # Creates a new process group, making sure no signals are + # propagated from the main process to the worker processes. os.setpgrp() + # Restore default signal handlers, otherwise workers would inherit + # them from main process + signal.signal(signal.SIGTERM, signal.SIG_DFL) + signal.signal(signal.SIGINT, signal.SIG_DFL) + sys.breakpointhook = multiprocessing_breakpoint while (args := input_.get()) is not _SENTINEL: result = handler(args) @@ -71,11 +88,14 @@ def __init__( *, result_callback: Callable[["MultiPool", Any], Any], ): + super().__init__() if process_num <= 0: raise ValueError("At process_num must be greater than 0") + self._running = False self._result_callback = result_callback self._input = Queue(ctx=mp.get_context()) + self._input.cancel_join_thread() self._output = mp.SimpleQueue() self._procs = [ mp.Process( @@ -87,14 +107,32 @@ def __init__( self._tid = threading.get_native_id() def start(self): + self._running = True for p in self._procs: p.start() - def close(self): - self._clear_input_queue() - self._request_workers_to_quit() - self._clear_output_queue() + def close(self, *, immediate=False): + if not self._running: + return + self._running = False + + if immediate: + self._terminate_workers() + else: + self._clear_input_queue() + self._request_workers_to_quit() + self._clear_output_queue() + self._wait_for_workers_to_quit() + super().close(immediate=immediate) + + def _terminate_workers(self): + for proc in self._procs: + proc.terminate() + + self._input.close() + if sys.version_info >= (3, 9): + self._output.close() def _clear_input_queue(self): try: @@ -129,14 +167,16 @@ def submit(self, args): self._input.put(args) def process_until_done(self): - while not self._input.is_empty(): - result = self._output.get() - self._result_callback(self, result) - self._input.task_done() + with contextlib.suppress(EOFError): + while not self._input.is_empty(): + result = self._output.get() + self._result_callback(self, result) + self._input.task_done() class SinglePool(PoolBase): def __init__(self, handler, *, result_callback): + super().__init__() self._handler = handler self._result_callback = result_callback @@ -157,3 +197,19 @@ def make_pool(process_num, handler, result_callback) -> Union[SinglePool, MultiP handler=handler, result_callback=result_callback, ) + + +orig_signal_handlers = {} + + +def _on_terminate(signum, frame): + pools_snapshot = list(pools) + for pool in pools_snapshot: + pool.close(immediate=True) + + if callable(orig_signal_handlers[signum]): + orig_signal_handlers[signum](signum, frame) + + +orig_signal_handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, _on_terminate) +orig_signal_handlers[signal.SIGINT] = signal.signal(signal.SIGINT, _on_terminate) diff --git a/unblob/processing.py b/unblob/processing.py index 4b95ea7651..ac93c20a1c 100644 --- a/unblob/processing.py +++ b/unblob/processing.py @@ -43,7 +43,6 @@ StatReport, UnknownError, ) -from .signals import terminate_gracefully from .ui import NullProgressReporter, ProgressReporter logger = get_logger() @@ -111,7 +110,6 @@ def get_extract_dir_for(self, path: Path) -> Path: return extract_dir.expanduser().resolve() -@terminate_gracefully def process_file( config: ExtractionConfig, input_path: Path, report_file: Optional[Path] = None ) -> ProcessResult: diff --git a/unblob/sandbox.py b/unblob/sandbox.py new file mode 100644 index 0000000000..3608b5bf7f --- /dev/null +++ b/unblob/sandbox.py @@ -0,0 +1,118 @@ +import ctypes +import sys +import threading +from pathlib import Path +from typing import Callable, Iterable, Optional, Type, TypeVar + +from structlog import get_logger +from unblob_native.sandbox import ( + AccessFS, + SandboxError, + restrict_access, +) + +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + +from unblob.processing import ExtractionConfig + +logger = get_logger() + +P = ParamSpec("P") +R = TypeVar("R") + + +class Sandbox: + """Configures restricted file-systems to run functions in. + + When calling ``run()``, a separate thread will be configured with + minimum required file-system permissions. All subprocesses spawned + from that thread will honor the restrictions. + """ + + def __init__( + self, + config: ExtractionConfig, + log_path: Path, + report_file: Optional[Path], + extra_restrictions: Iterable[AccessFS] = (), + ): + self.restrictions = [ + # Python, shared libraries, extractor binaries and so on + AccessFS.read("/"), + # Multiprocessing + AccessFS.read_write("/dev/shm"), # noqa: S108 + # Extracted contents + AccessFS.read_write(config.extract_root), + AccessFS.make_dir(config.extract_root.parent), + AccessFS.read_write(log_path), + *extra_restrictions, + ] + + if report_file: + self.restrictions += [ + AccessFS.read_write(report_file), + AccessFS.make_reg(report_file.parent), + ] + + def run(self, callback: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: + """Run callback with restricted filesystem access.""" + exception = None + result = None + + def _run_in_thread(callback, *args, **kwargs): + nonlocal exception, result + + self._try_enter_sandbox() + try: + result = callback(*args, **kwargs) + except BaseException as e: + exception = e + + thread = threading.Thread( + target=_run_in_thread, args=(callback, *args), kwargs=kwargs + ) + thread.start() + + try: + thread.join() + except KeyboardInterrupt: + raise_in_thread(thread, KeyboardInterrupt) + thread.join() + + if exception: + raise exception # pyright: ignore[reportGeneralTypeIssues] + return result # pyright: ignore[reportReturnType] + + def _try_enter_sandbox(self): + try: + restrict_access(*self.restrictions) + except SandboxError: + logger.warning( + "Sandboxing FS access is unavailable on this system, skipping." + ) + + +def raise_in_thread(thread: threading.Thread, exctype: Type) -> None: + if thread.ident is None: + raise RuntimeError("Thread is not started") + + res = ctypes.pythonapi.PyThreadState_SetAsyncExc( + ctypes.c_ulong(thread.ident), ctypes.py_object(exctype) + ) + + # success + if res == 1: + return + + # Need to revert the call to restore interpreter state + ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_ulong(thread.ident), None) + + # Thread could have exited since + if res == 0: + return + + # Something bad have happened + raise RuntimeError("Could not raise exception in thread", thread.ident) diff --git a/unblob/signals.py b/unblob/signals.py deleted file mode 100644 index 76b70a4dbe..0000000000 --- a/unblob/signals.py +++ /dev/null @@ -1,51 +0,0 @@ -import functools -import signal - -from structlog import get_logger - -logger = get_logger() - - -class ShutDownRequired(BaseException): - def __init__(self, signal: str): - super().__init__() - self.signal = signal - - -def terminate_gracefully(func): - @functools.wraps(func) - def decorator(*args, **kwargs): - signals_fired = [] - - def _handle_signal(signum: int, frame): - nonlocal signals_fired - signals_fired.append((signum, frame)) - raise ShutDownRequired(signal=signal.Signals(signum).name) - - original_signal_handlers = { - signal.SIGINT: signal.signal(signal.SIGINT, _handle_signal), - signal.SIGTERM: signal.signal(signal.SIGTERM, _handle_signal), - } - - logger.debug( - "Setting up signal handlers", - original_signal_handlers=original_signal_handlers, - _verbosity=2, - ) - - try: - return func(*args, **kwargs) - except ShutDownRequired as exc: - logger.warning("Shutting down", signal=exc.signal) - finally: - # Set back the original signal handlers - for sig, handler in original_signal_handlers.items(): - signal.signal(sig, handler) - - # Call the original signal handler with the fired and catched signal(s) - for sig, frame in signals_fired: - handler = original_signal_handlers.get(sig) - if callable(handler): - handler(sig, frame) - - return decorator diff --git a/unblob/testing.py b/unblob/testing.py index 75a786df34..b56c0f175a 100644 --- a/unblob/testing.py +++ b/unblob/testing.py @@ -1,6 +1,7 @@ import binascii import glob import io +import platform import shlex import subprocess from pathlib import Path @@ -10,6 +11,7 @@ from lark.lark import Lark from lark.visitors import Discard, Transformer from pytest_cov.embed import cleanup_on_sigterm +from unblob_native.sandbox import AccessFS, SandboxError, restrict_access from unblob.finder import build_hyperscan_database from unblob.logging import configure_logger @@ -217,3 +219,17 @@ def start(self, s): rv.write(line.data) return rv.getvalue() + + +def is_sandbox_available(): + is_sandbox_available = True + + try: + restrict_access(AccessFS.read_write("/")) + except SandboxError: + is_sandbox_available = False + + if platform.architecture == "x86_64" and platform.system == "linux": + assert is_sandbox_available, "Sandboxing should work at least on Linux-x86_64" + + return is_sandbox_available