Skip to content

Commit

Permalink
Merge pull request #89 from iscgar/cleanup
Browse files Browse the repository at this point in the history
Break early in case of a pattern error or an empty pattern
  • Loading branch information
itamarst authored Oct 6, 2023
2 parents 8d58c4b + b5fa51d commit f6fdac1
Showing 1 changed file with 57 additions and 70 deletions.
127 changes: 57 additions & 70 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use std::{sync::{
Arc, Mutex,
}, cell::Cell};
use std::cell::Cell;

use aho_corasick::{
AhoCorasick, AhoCorasickBuilder, AhoCorasickKind, Match, MatchError, MatchKind,
Expand All @@ -14,7 +12,7 @@ use pyo3::{

/// Search for multiple pattern strings against a single haystack string.
///
/// Takes three arguments:
/// Takes four arguments:
///
/// * ``patterns``: A list of strings, the patterns to match against. Empty
/// patterns are not supported and will result in a ``ValueError`` exception
Expand All @@ -25,6 +23,7 @@ use pyo3::{
/// ``False``, patterns will not be stored. By default uses a heuristic where
/// a short list of small strings (up to 4KB) results in ``True``, and
/// anything else results in ``False``.
/// * ``implementation``: The underlying type of automaton to use for Aho-Corasick.
#[pyclass(name = "AhoCorasick")]
struct PyAhoCorasick {
ac_impl: AhoCorasick,
Expand Down Expand Up @@ -80,7 +79,6 @@ impl PyAhoCorasick {

/// Python equivalent of MatchKind.
#[derive(Clone, Copy, Debug)]
#[allow(clippy::upper_case_acronyms)]
#[pyclass(name = "MatchKind")]
enum PyMatchKind {
Standard,
Expand Down Expand Up @@ -128,98 +126,87 @@ impl PyAhoCorasick {
py: Python,
patterns: &PyAny,
matchkind: PyMatchKind,
mut store_patterns: Option<bool>,
store_patterns: Option<bool>,
implementation: Option<Implementation>,
) -> PyResult<Self> {
// If set, this means we had an error while parsing the strings from the patterns iterable.
let patterns_error: Arc<Mutex<Option<PyErr>>> = Arc::new(Mutex::new(None));
let patterns_error: Cell<Option<PyErr>> = Cell::new(None);

// Convert the `patterns` iterable into an Iterator over &PyUnicode:
let mut patterns_iter = patterns
.iter()?
.map_while(|i_result| {
i_result
.and_then(|i| i.downcast::<PyUnicode>().map_err(PyErr::from))
.map_or_else(
|e| {
if let Ok(mut guard) = patterns_error.lock() {
*guard = Some(e);
}
None
},
Some,
)
})
.fuse();
let mut patterns_iter = patterns.iter()?.map_while(|pat| {
pat.and_then(|i| i.downcast::<PyUnicode>().map_err(PyErr::from))
.map_or_else(
|e| {
patterns_error.set(Some(e));
None
},
Some,
)
});

// If store_patterns is None (the default), use a heuristic to decide
// whether to store patterns.
let mut first_few_patterns: Vec<&PyUnicode> = vec![];
if store_patterns.is_none() {
let mut total = 0;
store_patterns = Some(true);
for s in patterns_iter.by_ref() {
// Highly unlikely that strings will fail to return length, so just expect().
total += s.len().expect("String doesn't have length?");
first_few_patterns.push(s);
if total > 4096 {
store_patterns = Some(false);
break;
let mut patterns: Vec<Py<PyUnicode>> = vec![];
let store_patterns = store_patterns
.unwrap_or_else(|| {
let mut total = 0;
let mut store_patterns = true;
for s in patterns_iter.by_ref() {
// Highly unlikely that strings will fail to return length, so just expect().
total += s.len().expect("String doesn't have length?");
patterns.push(s.into());
if total > 4096 {
store_patterns = false;
break;
}
}
}
}
let patterns = if matches!(store_patterns, Some(true)) {
let mut patterns = vec![];
store_patterns
});

if store_patterns {
for s in patterns_iter.by_ref() {
first_few_patterns.push(s);
}
for s in &first_few_patterns {
patterns.push((*s).into());
patterns.push(s.into());
}
Some(patterns)
} else {
None
};
}

let has_empty_patterns = Cell::new(false);
let ac_impl = AhoCorasickBuilder::new()
.kind(implementation.map(|i| i.into()))
.match_kind(matchkind.into())
.build(
first_few_patterns
.into_iter()
patterns
.iter()
.map(|i| i.as_ref(py))
.chain(patterns_iter)
.chunks(10 * 1024)
.into_iter()
.flat_map(|chunk| {
let result =
chunk
.filter_map(|s| s.extract::<String>().ok())
.inspect(|s| {
if s.is_empty() {
has_empty_patterns.set(true);
}
});
// Release the GIL in case some other thread wants to do work:
py.allow_threads(|| ());
result

chunk.map(|s| s.extract::<String>().ok())
})
.map_while(|s| {
s.and_then(|s| {
if s.is_empty() {
patterns_error.set(Some(PyValueError::new_err(
"You passed in an empty string as a pattern",
)));
None
} else {
Some(s)
}
})
}),
) // TODO make sure this error is menaingful to Python users
) // TODO make sure this error is meaningful to Python users
.map_err(|e| PyValueError::new_err(e.to_string()))?;

if has_empty_patterns.get() {
return Err(PyValueError::new_err(
"You passed in an empty string as a pattern",
));
if let Some(err) = patterns_error.take() {
return Err(err);
}

let result = Ok(Self { ac_impl, patterns });
if let Ok(mut guard) = patterns_error.lock() {
if let Some(err) = guard.take() {
return Err(err);
}
}
result
let patterns = if store_patterns { Some(patterns) } else { None };

Ok(Self { ac_impl, patterns })
}

/// Return matches as tuple of (index_into_patterns,
Expand Down

0 comments on commit f6fdac1

Please sign in to comment.