diff --git a/nh3.pyi b/nh3.pyi index a815f32..a272512 100644 --- a/nh3.pyi +++ b/nh3.pyi @@ -1,9 +1,10 @@ -from typing import Dict, Optional, Set +from typing import Callable, Dict, Optional, Set def clean( html: str, tags: Optional[Set[str]] = None, attributes: Optional[Dict[str, Set[str]]] = None, + attribute_filter: Optional[Callable[[str, str, str]], Optional[str]] = None, strip_comments: bool = True, link_rel: Optional[str] = "noopener noreferrer", ) -> str: ... diff --git a/src/lib.rs b/src/lib.rs index 589df7c..69413f3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,16 @@ +use std::borrow::Cow; use std::collections::{HashMap, HashSet}; +use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; +use pyo3::types::{PyString, PyTuple}; /// Clean HTML with a conservative set of defaults #[pyfunction(signature = ( html, tags = None, attributes = None, + attribute_filter = None, strip_comments = true, link_rel = "noopener noreferrer", ))] @@ -15,12 +19,20 @@ fn clean( html: &str, tags: Option>, attributes: Option>>, + attribute_filter: Option, strip_comments: bool, link_rel: Option<&str>, -) -> String { - py.allow_threads(|| { +) -> PyResult { + if let Some(callback) = attribute_filter.as_ref() { + if !callback.as_ref(py).is_callable() { + return Err(PyTypeError::new_err("attribute_filter must be callable")); + } + } + + let cleaned = py.allow_threads(|| { if tags.is_some() || attributes.is_some() + || attribute_filter.is_some() || !strip_comments || link_rel != Some("noopener noreferrer") { @@ -34,13 +46,52 @@ fn clean( } cleaner.tag_attributes(attrs); } + if let Some(callback) = attribute_filter { + cleaner.attribute_filter(move |element, attribute, value| { + Python::with_gil(|py| { + let res = callback.call( + py, + PyTuple::new( + py, + [ + PyString::new(py, element), + PyString::new(py, attribute), + PyString::new(py, value), + ], + ), + None, + ); + let err = match res { + Ok(val) => { + if val.is_none(py) { + return None; + } else if let Ok(s) = val.downcast::(py) { + match s.to_str() { + Ok(s) => return Some(Cow::::Owned(s.to_string())), + Err(err) => err, + } + } else { + PyTypeError::new_err( + "expected attribute_filter to return str or None", + ) + } + } + Err(err) => err, + }; + err.restore(py); + Some(value.into()) + }) + }); + } cleaner.strip_comments(strip_comments); cleaner.link_rel(link_rel); cleaner.clean(html).to_string() } else { ammonia::clean(html) } - }) + }); + + Ok(cleaned) } /// Turn an arbitrary string into unformatted HTML diff --git a/tests/test_nh3.py b/tests/test_nh3.py index 55603c1..96a5362 100644 --- a/tests/test_nh3.py +++ b/tests/test_nh3.py @@ -1,4 +1,5 @@ import nh3 +import pytest def test_clean(): @@ -19,6 +20,28 @@ def test_clean(): ) +def test_clean_with_attribute_filter(): + html = "Home" + + def attribute_filter(element, attribute, value): + if element == "img" and attribute == "src": + return None + return value + + assert ( + nh3.clean(html, attribute_filter=attribute_filter, link_rel=None) + == 'Home' + ) + + with pytest.raises(TypeError): + nh3.clean(html, attribute_filter="not a callable") + + with pytest.raises(SystemError): + # FIXME: attribute_filter may raise exception, but it's an infallible API + # which causes Python to raise SystemError instead of the intended TypeError + nh3.clean(html, attribute_filter=lambda _element, _attribute, _value: True) + + def test_clean_text(): res = nh3.clean_text('Robert"); abuse();//') assert res == "Robert"); abuse();//"