diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 6856fac..4dec0a9 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -43,9 +43,9 @@ jobs: pip install . - name: "Tests" run: | - flake8 tests - mypy --strict tests # indirect type annotation checking - black --check tests + flake8 pysrc tests + mypy --strict pysrc tests + black --check pysrc tests pytest tests - name: "Enable universal2 on Python >= 3.9 on macOS" if: ${{ startsWith(matrix.os, 'macos') && matrix.python-version != '3.8' }} diff --git a/.gitignore b/.gitignore index 6f7c05b..81055e0 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,10 @@ # Python bytecode files: *.pyc +*.pyd + +# Generated by Pytest +/.pytest_cache/ # Emacs junk: *~ \ No newline at end of file diff --git a/.hypothesis/examples/2841afa294c040ea/68e09c5bfbe3cdbe b/.hypothesis/examples/2841afa294c040ea/68e09c5bfbe3cdbe new file mode 100644 index 0000000..51d72fb Binary files /dev/null and b/.hypothesis/examples/2841afa294c040ea/68e09c5bfbe3cdbe differ diff --git a/.hypothesis/examples/4d7f7d596155940a/9a3c1667a531b1ac b/.hypothesis/examples/4d7f7d596155940a/9a3c1667a531b1ac new file mode 100644 index 0000000..654be01 Binary files /dev/null and b/.hypothesis/examples/4d7f7d596155940a/9a3c1667a531b1ac differ diff --git a/.hypothesis/examples/cc12e50281195de6/b9fc09fff40c6832 b/.hypothesis/examples/cc12e50281195de6/b9fc09fff40c6832 new file mode 100644 index 0000000..b1a3b69 Binary files /dev/null and b/.hypothesis/examples/cc12e50281195de6/b9fc09fff40c6832 differ diff --git a/.hypothesis/unicode_data/13.0.0/codec-utf-8.json.gz b/.hypothesis/unicode_data/13.0.0/codec-utf-8.json.gz index 35fdea6..54dc879 100644 Binary files a/.hypothesis/unicode_data/13.0.0/codec-utf-8.json.gz and b/.hypothesis/unicode_data/13.0.0/codec-utf-8.json.gz differ diff --git a/CHANGELOG.md b/CHANGELOG.md index f3f5c8a..011c6aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Changelog +## 0.21.0 + +* Added support for searching `bytes`, `bytearray`, `memoryview`, and similar objects using the `BytesAhoCorasick` class. + ## 0.20.0 * Added support for Python 3.12. diff --git a/Cargo.lock b/Cargo.lock index 3cb46c8..369627d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "aho-corasick" -version = "1.1.1" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea5d730647d4fadd988536d06fecce94b7b4f2a7efdae548f1cf4b63205518ab" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" dependencies = [ "memchr", ] @@ -67,15 +67,15 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.148" +version = "0.2.151" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cdc71e17332e86d2e1d38c1f99edcb6288ee11b815fb1a4b049eaa2114d369b" +checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" [[package]] name = "lock_api" -version = "0.4.10" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1cc9717a20b1bb222f333e6a92fd32f7d8a18ddc5a3191a11af45dcbf4dcd16" +checksum = "3c168f8615b12bc01f9c17e2eb0cc07dcae1940121185446edc3744920e8ef45" dependencies = [ "autocfg", "scopeguard", @@ -83,9 +83,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.6.3" +version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f232d6ef707e1956a43342693d2a31e72989554d58299d7a88738cc95b0d35c" +checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" [[package]] name = "memoffset" @@ -98,9 +98,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.18.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "parking_lot" @@ -114,9 +114,9 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.8" +version = "0.9.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447" +checksum = "4c42a9226546d68acdd9c0a280d17ce19bfe27a46bf68784e4066115788d008e" dependencies = [ "cfg-if", "libc", @@ -127,18 +127,18 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.67" +version = "1.0.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d433d9f1a3e8c1263d9456598b16fec66f4acc9a74dacffd35c7bb09b3a1328" +checksum = "907a61bd0f64c2f29cd1cf1dc34d05176426a3f504a78010f08416ddb7b13708" dependencies = [ "unicode-ident", ] [[package]] name = "pyo3" -version = "0.20.0" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04e8453b658fe480c3e70c8ed4e3d3ec33eb74988bd186561b0cc66b85c3bc4b" +checksum = "9a89dc7a5850d0e983be1ec2a463a171d20990487c3cfcd68b5363f1ee3d6fe0" dependencies = [ "cfg-if", "indoc", @@ -153,9 +153,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.20.0" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a96fe70b176a89cff78f2fa7b3c930081e163d5379b4dcdf993e3ae29ca662e5" +checksum = "07426f0d8fe5a601f26293f300afd1a7b1ed5e78b2a705870c5f30893c5163be" dependencies = [ "once_cell", "target-lexicon", @@ -163,9 +163,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.20.0" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "214929900fd25e6604661ed9cf349727c8920d47deff196c4e28165a6ef2a96b" +checksum = "dbb7dec17e17766b46bca4f1a4215a85006b4c2ecde122076c562dd058da6cf1" dependencies = [ "libc", "pyo3-build-config", @@ -173,9 +173,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.20.0" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dac53072f717aa1bfa4db832b39de8c875b7c7af4f4a6fe93cdbf9264cf8383b" +checksum = "05f738b4e40d50b5711957f142878cfa0f28e054aa0ebdfc3fd137a843f74ed3" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -185,9 +185,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.20.0" +version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7774b5a8282bd4f25f803b1f0d945120be959a36c72e08e7cd031c792fdfd424" +checksum = "0fc910d4851847827daf9d6cdd4a823fbdaab5b8818325c5e97a86da79e8881f" dependencies = [ "heck", "proc-macro2", @@ -197,18 +197,18 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.33" +version = "1.0.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" dependencies = [ "proc-macro2", ] [[package]] name = "redox_syscall" -version = "0.3.5" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" dependencies = [ "bitflags", ] @@ -221,15 +221,15 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "smallvec" -version = "1.11.1" +version = "1.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "942b4a808e05215192e39f4ab80813e599068285906cc91aa64f923db842bd5a" +checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" [[package]] name = "syn" -version = "2.0.38" +version = "2.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e96b79aaa137db8f61e26363a0c9b47d8b4ec75da28b7d1d614c2303e232408b" +checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" dependencies = [ "proc-macro2", "quote", @@ -238,9 +238,9 @@ dependencies = [ [[package]] name = "target-lexicon" -version = "0.12.11" +version = "0.12.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d0e916b1148c8e263850e1ebcbd046f333e0683c724876bb0da63ea4373dc8a" +checksum = "69758bda2e78f098e4ccb393021a0963bb3442eac05f135c30f61b7370bbafae" [[package]] name = "unicode-ident" diff --git a/README.md b/README.md index c8f56e6..64c4d82 100644 --- a/README.md +++ b/README.md @@ -14,14 +14,16 @@ Found any problems or have any questions? [File an issue on the GitHub project]( ## Quickstart -The `ahocorasick_rs` library allows you to search for multiple strings ("patterns") within a haystack. +The `ahocorasick_rs` library allows you to search for multiple strings ("patterns") within a haystack, or alternatively search multiple bytes. For example, let's install the library: ```shell-session $ pip install ahocorasick-rs ``` -Then, we can construct a `AhoCorasick` object: +### Searching strings + +We can construct a `AhoCorasick` object: ```python >>> import ahocorasick_rs @@ -58,11 +60,29 @@ You can construct a `AhoCorasick` object from any iterable (including generators ['hello', 'world', 'hello'] ``` +### Searching `bytes` and other similar objects + +You can also search `bytes`, `bytearray`, `memoryview`, and other objects supporting the Python buffer API. + +```python +>>> patterns = [b"hello", b"world"] +>>> ac = ahocorasick_rs.BytesAhoCorasick(patterns) +>>> haystack = b"hello world" +>>> ac.find_matches_as_indexes(b"hello world") +[(0, 0, 5), (1, 6, 11)] +>>> patterns[0], patterns[1] +(b'hello', b'world') +>>> haystack[0:5], haystack[6:11] +(b'hello', b'world') +``` + +The `find_matches_as_strings()` API is not supported by `BytesAhoCorasick`. + ## Choosing the matching algorithm ### Match kind -There are three ways you can configure matching in cases where multiple patterns overlap. +There are three ways you can configure matching in cases where multiple patterns overlap, supported by both `AhoCorasick` and `BytesAhoCorasick` objects. For a more in-depth explanation, see the [underlying Rust library's documentation of matching](https://docs.rs/aho-corasick/latest/aho_corasick/enum.MatchKind.html). Assume we have this starting point: @@ -127,7 +147,8 @@ This returns the leftmost-in-the-haystack matching pattern that is longest: ### Overlapping matches -You can get all overlapping matches, instead of just one of them, but only if you stick to the default matchkind, `MatchKind.Standard`: +You can get all overlapping matches, instead of just one of them, but only if you stick to the default matchkind, `MatchKind.Standard`. +Again, this is supported by both `AhoCorasick` and `BytesAhoCorasick`. ```python >>> from ahocorasick_rs import AhoCorasick @@ -139,7 +160,7 @@ You can get all overlapping matches, instead of just one of them, but only if yo ## Additional configuration: speed and memory usage tradeoffs -### Algorithm implementations: trading construction speed, memory, and performance +### Algorithm implementations: trading construction speed, memory, and performance (`AhoCorasick` and `BytesAhoCorasick`) You can choose the type of underlying automaton to use, with different performance tradeoffs. The short version: if you want maximum matching speed, and you don't have too many patterns, try the `Implementation.DFA` implementation and see if it helps. @@ -157,7 +178,7 @@ The underlying Rust library supports [four choices](https://docs.rs/aho-corasick >>> ac = AhoCorasick(["disco", "disc"], implementation=Implementation.DFA) ``` -### Trading memory for speed +### Trading memory for speed (`AhoCorasick` only) If you use ``find_matches_as_strings()``, there are two ways strings can be constructed: from the haystack, or by caching the patterns on the object. The former takes more work, the latter uses more memory if the patterns would otherwise have been garbage-collected. @@ -171,7 +192,8 @@ You can control the behavior by using the `store_patterns` keyword argument to ` ## Implementation details -* Matching releases the GIL, to enable concurrency. +* Matching on strings releases the GIL, to enable concurrency. + Matching on bytes does not currently release the GIL, but see https://github.com/G-Research/ahocorasick_rs/issues/94 for a case where it could. * Not all features from the underlying library are exposed; if you would like additional features, please [file an issue](https://github.com/g-research/ahocorasick_rs/issues/new) or submit a PR. ## Benchmarks diff --git a/justfile b/justfile index b3714e6..82738de 100644 --- a/justfile +++ b/justfile @@ -20,9 +20,9 @@ install-dev-dependencies: setup: venv install-dev-dependencies lint: - flake8 tests/ - black --check tests/ - mypy --strict tests + flake8 pysrc tests/ + black --check pysrc tests/ + mypy --strict pysrc tests test: pytest tests/ diff --git a/pyproject.toml b/pyproject.toml index 58826f9..9028e8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,14 @@ [build-system] -requires = ["maturin>=0.14,<0.15"] +requires = ["maturin>=1.0,<2.0"] build-backend = "maturin" [project] name = "ahocorasick_rs" -requires-python = ">=3.7" +requires-python = ">=3.8" +dependencies = [ + # Technically not necessary to run, only needed for type checking... + "typing_extensions >= 4.6.0 ; python_version < '3.12'" +] [tool.maturin] python-source = "pysrc/" \ No newline at end of file diff --git a/pysrc/ahocorasick_rs/__init__.py b/pysrc/ahocorasick_rs/__init__.py index 811077f..814534f 100644 --- a/pysrc/ahocorasick_rs/__init__.py +++ b/pysrc/ahocorasick_rs/__init__.py @@ -1,6 +1,7 @@ # Expose the Rust code: from .ahocorasick_rs import ( AhoCorasick, + BytesAhoCorasick, MatchKind, Implementation, ) @@ -12,6 +13,7 @@ __all__ = [ "AhoCorasick", + "BytesAhoCorasick", "MatchKind", "Implementation", # Deprecated: diff --git a/pysrc/ahocorasick_rs/ahocorasick_rs.pyi b/pysrc/ahocorasick_rs/ahocorasick_rs.pyi index c8c3760..48aa083 100644 --- a/pysrc/ahocorasick_rs/ahocorasick_rs.pyi +++ b/pysrc/ahocorasick_rs/ahocorasick_rs.pyi @@ -1,4 +1,12 @@ +from __future__ import annotations + from typing import Optional, Iterable +import sys + +if sys.version_info >= (3, 12): + from collections.abc import Buffer +else: + from typing_extensions import Buffer class Implementation: NoncontiguousNFA: Implementation @@ -24,3 +32,14 @@ class AhoCorasick: def find_matches_as_strings( self, haystack: str, overlapping: bool = False ) -> list[str]: ... + +class BytesAhoCorasick: + def __init__( + self, + patterns: Iterable[Buffer], + matchkind: MatchKind = MatchKind.Standard, + implementation: Optional[Implementation] = None, + ) -> None: ... + def find_matches_as_indexes( + self, haystack: Buffer, overlapping: bool = False + ) -> list[tuple[int, int, int]]: ... diff --git a/rust-toolchain.toml b/rust-toolchain.toml index b90b798..e77a43c 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,3 +1,3 @@ [toolchain] -channel = "1.73" +channel = "1.75" components = ["rustfmt", "clippy"] \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 7251982..a8b67e1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,8 @@ use aho_corasick::{ }; use itertools::Itertools; use pyo3::{ - exceptions::PyValueError, + buffer::{PyBuffer, ReadOnlyCell}, + exceptions::{PyTypeError, PyValueError}, prelude::*, types::{PyList, PyUnicode}, }; @@ -36,27 +37,36 @@ fn match_error_to_pyerror(e: MatchError) -> PyErr { PyValueError::new_err(e.to_string()) } -impl PyAhoCorasick { - /// Return matches for a given haystack. - fn get_matches( - &self, - py: Python<'_>, - haystack: &str, - overlapping: bool, - ) -> PyResult> { - let ac_impl = &self.ac_impl; - py.allow_threads(|| { - if overlapping { - ac_impl - .try_find_overlapping_iter(haystack) - .map_err(match_error_to_pyerror) - .map(|it| it.collect()) - } else { - Ok(ac_impl.find_iter(haystack).collect()) - } - }) +/// Return matches for a given haystack. +fn get_matches<'a>( + ac_impl: &'a AhoCorasick, + haystack: &'a [u8], + overlapping: bool, +) -> PyResult + 'a> { + let mut overlapping_it = None; + let mut non_overlapping_it = None; + + if overlapping { + overlapping_it = Some( + ac_impl + .try_find_overlapping_iter(haystack) + .map_err(match_error_to_pyerror)?, + ); + } else { + non_overlapping_it = Some( + ac_impl + .try_find_iter(haystack) + .map_err(match_error_to_pyerror)?, + ); } + Ok(overlapping_it + .into_iter() + .flatten() + .chain(non_overlapping_it.into_iter().flatten())) +} + +impl PyAhoCorasick { /// Create mapping from byte index to Unicode code point (character) index /// in the haystack. fn get_byte_to_code_point(&self, haystack: &str) -> Vec { @@ -220,17 +230,18 @@ impl PyAhoCorasick { ) -> PyResult> { let byte_to_code_point = self_.get_byte_to_code_point(haystack); let py = self_.py(); - let matches = self_.get_matches(py, haystack, overlapping)?; - Ok(matches - .into_iter() - .map(|m| { - ( - m.pattern().as_u64(), - byte_to_code_point[m.start()], - byte_to_code_point[m.end()], - ) - }) - .collect()) + let matches = get_matches(&self_.ac_impl, haystack.as_bytes(), overlapping)?; + py.allow_threads(|| { + Ok(matches + .map(|m| { + ( + m.pattern().as_u64(), + byte_to_code_point[m.start()], + byte_to_code_point[m.end()], + ) + }) + .collect()) + }) } /// Return matches as list of patterns (i.e. strings). If ``overlapping`` is @@ -242,7 +253,8 @@ impl PyAhoCorasick { overlapping: bool, ) -> PyResult> { let py = self_.py(); - let matches = self_.get_matches(py, haystack, overlapping)?.into_iter(); + let matches = get_matches(&self_.ac_impl, haystack.as_bytes(), overlapping)?; + let matches = py.allow_threads(|| matches.collect::>().into_iter()); let result = if let Some(ref patterns) = self_.patterns { PyList::new(py, matches.map(|m| patterns[m.pattern()].clone_ref(py))) } else { @@ -255,11 +267,166 @@ impl PyAhoCorasick { } } +/// A wrapper around PyBuffer that can be passed directly to AhoCorasickBuilder. +struct PyBufferBytes<'a> { + py: Python<'a>, + buffer: PyBuffer, +} + +impl<'a> TryFrom<&'a PyAny> for PyBufferBytes<'a> { + type Error = PyErr; + + // Get a PyBufferBytes from a Python object + fn try_from(obj: &'a PyAny) -> PyResult { + let buffer = PyBuffer::::get(obj).map_err(PyErr::from)?; + + if buffer.dimensions() > 1 { + return Err(PyTypeError::new_err( + "Only one-dimensional sequences are supported", + )); + } + + // Make sure we can get a slice from the buffer + let py = obj.py(); + let _ = buffer + .as_slice(py) + .ok_or_else(|| PyTypeError::new_err("Must be a contiguous sequence of bytes"))?; + + Ok(PyBufferBytes { py, buffer }) + } +} + +impl<'a> AsRef<[u8]> for PyBufferBytes<'a> { + fn as_ref(&self) -> &[u8] { + // This already succeeded when we created PyBufferBytes, so just expect() + let slice = self + .buffer + .as_slice(self.py) + .expect("Failed to get a slice from a valid buffer?"); + + const _: () = assert!( + std::mem::size_of::>() == std::mem::size_of::(), + "ReadOnlyCell has a different size than u8" + ); + + // Safety: the slice is &[ReadOnlyCell] which has the same memory + // representation as &[u8] due to it being a #[repr(transparent)] newtype + // around the standard library UnsafeCell, which the documentation + // claims has the same representation as the underlying type. + // + // However, holding this reference while Python code is executing might + // result in the buffer getting mutated from under us, which is a violation + // of the &[u8] invariants (and having the .readonly() flag set on the + // PyBuffer unfortunately doesn't actually guarantee immutability). + // + // Because &[u8] is `Ungil`, we can't prevent a release of the GIL while + // this reference is being held (though even if it wasn't `Ungil`, we + // wouldn't be able to prevent calling back into Python while holding + // this reference, which might also result in a mutation). + // + // This effectively means that it's only safe to hold onto the reference + // returned from this function as long as we don't release the GIL and + // don't call back into Python code while the reference is alive. + // See also https://github.com/PyO3/pyo3/issues/2824 + unsafe { std::mem::transmute(slice) } + } +} + +/// Search for multiple pattern bytes against a single bytes haystack. In +/// addition to ``bytes``, you can use other objects that support the Python +/// buffer API, like ``memoryview`` and ``bytearray``. +/// +/// Takes three arguments: +/// +/// * ``patterns``: A list of bytes, the patterns to match against. Empty +/// patterns are not supported and will result in a ``ValueError`` exception +/// being raised. No references are kept to the patterns once construction is +/// finished. +/// * ``matchkind``: Defaults to ``"MATCHKING_STANDARD"``. +/// * ``implementation``: The underlying type of automaton to use for Aho-Corasick. +#[pyclass(name = "BytesAhoCorasick")] +struct PyBytesAhoCorasick { + ac_impl: AhoCorasick, +} + +/// Methods for PyBytesAhoCorasick. +#[pymethods] +impl PyBytesAhoCorasick { + /// __new__() implementation. + #[new] + #[pyo3(signature = (patterns, matchkind = PyMatchKind::Standard, implementation = None))] + fn new( + _py: Python, + patterns: &PyAny, + matchkind: PyMatchKind, + implementation: Option, + ) -> PyResult { + // If set, this means we had an error while parsing byte buffers from `patterns` + let patterns_error: Cell> = Cell::new(None); + + // Convert the `patterns` iterable into an Iterator over PyBufferBytes + let patterns_iter = + patterns + .iter()? + .map_while(|pat| match pat.and_then(PyBufferBytes::try_from) { + Ok(pat) => { + if pat.as_ref().is_empty() { + patterns_error.set(Some(PyValueError::new_err( + "You passed in an empty pattern", + ))); + None + } else { + Some(pat) + } + } + Err(e) => { + patterns_error.set(Some(e)); + None + } + }); + + let ac_impl = AhoCorasickBuilder::new() + .kind(implementation.map(|i| i.into())) + .match_kind(matchkind.into()) + .build(patterns_iter) + // TODO make sure this error is meaningful to Python users + .map_err(|e| PyValueError::new_err(e.to_string()))?; + + if let Some(err) = patterns_error.take() { + return Err(err); + } + + Ok(Self { ac_impl }) + } + + /// Return matches as tuple of (index_into_patterns, + /// start_index_in_haystack, end_index_in_haystack). If ``overlapping`` is + /// ``False`` (the default), don't include overlapping results. + #[pyo3(signature = (haystack, overlapping = false))] + fn find_matches_as_indexes( + self_: PyRef, + haystack: &PyAny, + overlapping: bool, + ) -> PyResult> { + let haystack = PyBufferBytes::try_from(haystack)?; + let matches = get_matches(&self_.ac_impl, haystack.as_ref(), overlapping)?; + + // Note: we must collect here and not release the GIL or return an iterator + // from this function due to the safety caveat in the implementation of + // AsRef<[u8]> for PyBufferBytes, which is relevant here since the matches + // iterator is holding an AsRef reference on the haystack. + Ok(matches + .map(|m| (m.pattern().as_u64(), m.start(), m.end())) + .collect()) + } +} + /// The main Python module. #[pymodule] fn ahocorasick_rs(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/tests/test_ac.py b/tests/test_ac.py index 9a73845..2643d38 100644 --- a/tests/test_ac.py +++ b/tests/test_ac.py @@ -94,7 +94,7 @@ def test_construction_extensive( Exercise the construction code paths, ensuring we end up using all patterns. """ - patterns = [f"@{p}@" for p in patterns] + patterns = [f"{p}_{i}_" for (i, p) in enumerate(patterns)] ac = AhoCorasick(patterns, store_patterns=store_patterns) for p in patterns: assert ac.find_matches_as_strings(p) == [p] diff --git a/tests/test_ac_bytes.py b/tests/test_ac_bytes.py new file mode 100644 index 0000000..f88063e --- /dev/null +++ b/tests/test_ac_bytes.py @@ -0,0 +1,268 @@ +"""Tests for ahocorasick_rs's bytes support.""" + +from __future__ import annotations + +from typing import Optional + +import pytest + +from hypothesis import strategies as st +from hypothesis import given, assume + +from ahocorasick_rs import ( + BytesAhoCorasick, + MATCHKIND_STANDARD, + MATCHKIND_LEFTMOST_FIRST, + MATCHKIND_LEFTMOST_LONGEST, + MatchKind, + Implementation, +) + + +@pytest.mark.parametrize( + "implementation", + [ + None, + Implementation.NoncontiguousNFA, + Implementation.ContiguousNFA, + Implementation.DFA, + ], +) +def test_basic_matching(implementation: Optional[Implementation]) -> None: + """ + find_matches_as_indexes() returns matching patterns in the given byte string. + """ + haystack = b"hello, world, hello again" + patterns = [b"hello", b"world"] + ac = BytesAhoCorasick(patterns, implementation=implementation) + + expected = [b"hello", b"world", b"hello"] + + # find_matches_as_indexes() + index_matches = ac.find_matches_as_indexes(haystack) + assert [patterns[i] for (i, _, _) in index_matches] == expected + assert [haystack[s:e] for (_, s, e) in index_matches] == expected + + +@pytest.mark.parametrize( + "implementation", + [ + None, + Implementation.NoncontiguousNFA, + Implementation.ContiguousNFA, + Implementation.DFA, + ], +) +def test_different_byte_objects_matching( + implementation: Optional[Implementation], +) -> None: + """ + find_matches_as_indexes() returns matching patterns in the given byte string. + """ + haystack = b"hello, world, hello again" + patterns = [memoryview(b"hello"), bytearray(b"world")] + ac = BytesAhoCorasick(patterns, implementation=implementation) + + expected = [b"hello", b"world", b"hello"] + + # find_matches_as_indexes() + index_matches = ac.find_matches_as_indexes(haystack) + assert [patterns[i] for (i, _, _) in index_matches] == expected + assert [haystack[s:e] for (_, s, e) in index_matches] == expected + + +def test_iterator_of_patterns() -> None: + """ + It's possible to construct ``BytesAhoCorasick()`` with an iterator. + """ + haystack = b"hello, world, hello again" + patterns = [b"hello", b"world"] + ac = BytesAhoCorasick(iter(patterns)) + + expected = [b"hello", b"world", b"hello"] + + index_matches = ac.find_matches_as_indexes(haystack) + assert [patterns[i] for (i, _, _) in index_matches] == expected + assert [haystack[s:e] for (_, s, e) in index_matches] == expected + + +def test_bad_iterators() -> None: + """ + When constructed with a bad iterator, the underlying Python error is raised. + """ + with pytest.raises(TypeError): + BytesAhoCorasick(None) # type: ignore + + with pytest.raises(TypeError): + BytesAhoCorasick([b"x", 12]) # type: ignore[list-item] + + # str doesn't implement the buffer API and can't be converted to bytes + with pytest.raises(TypeError): + BytesAhoCorasick([b"x", "y"]) # type: ignore[list-item] + + +@given( + st.lists(st.binary(min_size=3), min_size=1, max_size=30_000), +) +def test_construction_extensive(patterns: list[bytes]) -> None: + """ + Exercise the construction code paths, ensuring we end up using all + patterns. + """ + patterns = [b"%b_%i_" % (p, i) for (i, p) in enumerate(patterns)] + ac = BytesAhoCorasick(patterns) + for haystack in patterns: + assert [ + haystack[s:e] for (_, s, e) in ac.find_matches_as_indexes(haystack) + ] == [haystack] + + +@given(st.binary(), st.binary(min_size=1), st.binary()) +def test_random_bytes_extensive(prefix: bytes, pattern: bytes, suffix: bytes) -> None: + """ + Random bytes patterns still give correct results for + find_matches_as_indexes(), with property-testing. + """ + assume(pattern not in prefix) + assume(pattern not in suffix) + haystack = prefix + pattern + suffix + ac = BytesAhoCorasick([pattern]) + + index_matches = ac.find_matches_as_indexes(haystack) + expected = [pattern] + assert [i for (i, _, _) in index_matches] == [0] + assert [haystack[s:e] for (_, s, e) in index_matches] == expected + + +@pytest.mark.parametrize("bad_patterns", [[b""], [b"", b"xx"], [b"xx", b""]]) +def test_empty_patterns_are_not_legal(bad_patterns: list[bytes]) -> None: + """ + Passing in an empty pattern suggests a bug in user code, and the outputs + are bad when you do have that, so raise an error. + """ + with pytest.raises(ValueError) as e: + BytesAhoCorasick(bad_patterns) + assert "You passed in an empty pattern" in str(e.value) + + +@given(st.binary(min_size=1), st.binary()) +def test_bytes_totally_random(pattern: bytes, haystack: bytes) -> None: + """ + Catch more edge cases of patterns and haystacks. + """ + ac = BytesAhoCorasick([pattern]) + + index_matches = ac.find_matches_as_indexes(haystack) + + expected_index = haystack.find(pattern) + if expected_index == -1: + assert index_matches == [] + else: + assert index_matches[0][1] == expected_index + assert [haystack[s:e] for (_, s, e) in index_matches][0] == pattern + + +def test_matchkind() -> None: + """ + Different matchkinds give different results. + + The default, MATCHKIND_STANDARD finds overlapping matches. + + MATCHKIND_LEFTMOST_FIRST finds the leftmost match if there are overlapping + matches, choosing the earlier provided pattern. + + MATCHKIND_LEFTMOST_LONGEST finds the leftmost match if there are overlapping + matches, picking the longest one if there are multiple ones. + """ + haystack = b"This is the winter of my discontent" + patterns = [b"content", b"disco", b"disc", b"discontent", b"winter"] + + def get_strings(ac: BytesAhoCorasick) -> list[bytes]: + return [haystack[s:e] for (_, s, e) in ac.find_matches_as_indexes(haystack)] + + # Default is MATCHKIND_STANDARD: + assert get_strings(BytesAhoCorasick(patterns)) == [ + b"winter", + b"disc", + ] + + # Explicit MATCHKIND_STANDARD: + assert get_strings(BytesAhoCorasick(patterns, matchkind=MATCHKIND_STANDARD)) == [ + b"winter", + b"disc", + ] + assert get_strings(BytesAhoCorasick(patterns, matchkind=MatchKind.Standard)) == [ + b"winter", + b"disc", + ] + + # MATCHKIND_LEFTMOST_FIRST: + assert get_strings( + BytesAhoCorasick(patterns, matchkind=MATCHKIND_LEFTMOST_FIRST) + ) == [ + b"winter", + b"disco", + ] + assert get_strings( + BytesAhoCorasick(patterns, matchkind=MatchKind.LeftmostFirst) + ) == [ + b"winter", + b"disco", + ] + + # MATCHKIND_LEFTMOST_LONGEST: + assert get_strings( + BytesAhoCorasick(patterns, matchkind=MATCHKIND_LEFTMOST_LONGEST) + ) == [ + b"winter", + b"discontent", + ] + assert get_strings( + BytesAhoCorasick(patterns, matchkind=MatchKind.LeftmostLongest) + ) == [ + b"winter", + b"discontent", + ] + + +def test_overlapping() -> None: + """ + It's possible to get overlapping matches, but only with MATCHKIND_STANDARD. + """ + haystack = b"This is the winter of my discontent" + patterns = [b"content", b"disco", b"disc", b"discontent", b"winter"] + + def get_strings(ac: BytesAhoCorasick) -> list[bytes]: + assert ac.find_matches_as_indexes(haystack) == ac.find_matches_as_indexes( + haystack, overlapping=False + ) + return [ + haystack[s:e] + for (_, s, e) in ac.find_matches_as_indexes(haystack, overlapping=True) + ] + + def assert_no_overlapping(ac: BytesAhoCorasick) -> None: + with pytest.raises(ValueError): + ac.find_matches_as_indexes(haystack, overlapping=True) + + # Default is MatchKind.Standard: + expected = [ + b"winter", + b"disc", + b"disco", + b"discontent", + b"content", + ] + assert get_strings(BytesAhoCorasick(patterns)) == expected + + # Explicit MATCHKIND_STANDARD: + assert ( + get_strings(BytesAhoCorasick(patterns, matchkind=MatchKind.Standard)) + == expected + ) + + # Other matchkinds don't support overlapping. + assert_no_overlapping(BytesAhoCorasick(patterns, matchkind=MatchKind.LeftmostFirst)) + assert_no_overlapping( + BytesAhoCorasick(patterns, matchkind=MatchKind.LeftmostLongest) + )