Skip to content

Commit

Permalink
Release the GIL for BytesAhoCorasick when the haystack is bytes
Browse files Browse the repository at this point in the history
In theory we should do that for every type that is guaranteed to be
immutable, in order to improve performance, but there's no way to tell
whether a type is immutable or not through the Buffer protocol, so we
only do this as a special case for `bytes`.

Fixes #94.
  • Loading branch information
Isaac Garzon committed Jan 6, 2024
1 parent ab9d0b3 commit 8994049
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 12 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ You can control the behavior by using the `store_patterns` keyword argument to `
## Implementation details <a name="implementation"></a>

* Matching on strings releases the GIL, to enable concurrency.
Matching on bytes does not currently release the GIL, but see https://github.com/G-Research/ahocorasick_rs/issues/94 for a case where it could.
Matching on bytes does not currently release the GIL for memory-safety reasons, unless the haystack type is `bytes`.
* Not all features from the underlying library are exposed; if you would like additional features, please [file an issue](https://github.com/g-research/ahocorasick_rs/issues/new) or submit a PR.

## Benchmarks <a name="benchmarks"></a>
Expand Down
28 changes: 17 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use pyo3::{
buffer::{PyBuffer, ReadOnlyCell},
exceptions::{PyTypeError, PyValueError},
prelude::*,
types::{PyList, PyUnicode},
types::{PyBytes, PyList, PyUnicode},
};

/// Search for multiple pattern strings against a single haystack string.
Expand Down Expand Up @@ -408,16 +408,22 @@ impl PyBytesAhoCorasick {
haystack: &PyAny,
overlapping: bool,
) -> PyResult<Vec<(u64, usize, usize)>> {
let haystack = PyBufferBytes::try_from(haystack)?;
let matches = get_matches(&self_.ac_impl, haystack.as_ref(), overlapping)?;

// Note: we must collect here and not release the GIL or return an iterator
// from this function due to the safety caveat in the implementation of
// AsRef<[u8]> for PyBufferBytes, which is relevant here since the matches
// iterator is holding an AsRef reference on the haystack.
Ok(matches
.map(|m| (m.pattern().as_u64(), m.start(), m.end()))
.collect())
let haystack_buffer = PyBufferBytes::try_from(haystack)?;
let matches = get_matches(&self_.ac_impl, haystack_buffer.as_ref(), overlapping)?
.map(|m| (m.pattern().as_u64(), m.start(), m.end()));

if !haystack.is_instance_of::<PyBytes>() {
// Note: we must collect here and not release the GIL or return an iterator
// from this function due to the safety caveat in the implementation of
// AsRef<[u8]> for PyBufferBytes, which is relevant here since the matches
// iterator is holding an AsRef reference to the haystack.
Ok(matches.collect())
} else {
// However, if the haystack is a PyBytes, it's guaranteed to be immutable,
// so the safety caveat doesn't apply, and we can safely release the GIL
// while the matches iterator is holding a reference to the haystack.
haystack.py().allow_threads(|| Ok(matches.collect()))
}
}
}

Expand Down
29 changes: 29 additions & 0 deletions tests/test_ac_bytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,35 @@ def test_different_byte_objects_matching(
assert [haystack[s:e] for (_, s, e) in index_matches] == expected


@pytest.mark.parametrize(
"implementation",
[
None,
Implementation.NoncontiguousNFA,
Implementation.ContiguousNFA,
Implementation.DFA,
],
)
@pytest.mark.parametrize("haystack_type", [bytes, bytearray, memoryview])
def test_different_byte_haystacks_matching(
implementation: Optional[Implementation],
haystack_type: type[bytes | bytearray | memoryview],
) -> None:
"""
find_matches_as_indexes() returns matching patterns in the given byte string.
"""
haystack = haystack_type(b"hello, world, hello again")
patterns = [b"hello", b"world"]
ac = BytesAhoCorasick(patterns, implementation=implementation)

expected = [b"hello", b"world", b"hello"]

# find_matches_as_indexes()
index_matches = ac.find_matches_as_indexes(haystack)
assert [patterns[i] for (i, _, _) in index_matches] == expected
assert [haystack[s:e] for (_, s, e) in index_matches] == expected


def test_iterator_of_patterns() -> None:
"""
It's possible to construct ``BytesAhoCorasick()`` with an iterator.
Expand Down

0 comments on commit 8994049

Please sign in to comment.