Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat) Next iteration of all xml-support #929

Merged
merged 18 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions tests/unit/test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def test_generate_billion_scale_services(self):
),
version="1.0",
)

# print(type(generated_services))
generated_xml = generated_services.to_xml()
# Validate against relaxng
self.assertTrue(validate_services(etree.fromstring(str(generated_xml))))
Expand All @@ -466,10 +466,9 @@ def test_generate_billion_scale_services(self):
# print(f"Generated: {generated.tag}, {generated.attrib}, {generated.text}")
self.assertEqual(original.tag, generated.tag)
self.assertEqual(original.attrib, generated.attrib)
self.assertEqual(
original.text.strip() if original.text else None,
generated.text.strip() if generated.text else None,
)
orig_text = original.text or ""
gen_text = generated.text or ""
self.assertEqual(orig_text.strip(), gen_text.strip())
Copy link
Contributor

@glebashnik glebashnik Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intention is for "" to equal None, why?



if __name__ == "__main__":
Expand Down
107 changes: 107 additions & 0 deletions tests/unit/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
DeploymentConfiguration,
Struct,
StructField,
ServicesConfiguration,
ApplicationConfiguration,
)
from vespa.configuration.vt import compare_xml


class TestField(unittest.TestCase):
Expand Down Expand Up @@ -768,6 +771,9 @@ def test_services_to_text(self):
)

self.assertEqual(self.app_package.services_to_text, expected_result)
self.assertTrue(
compare_xml(self.app_package.services_to_text_vt, expected_result)
)

def test_query_profile_to_text(self):
expected_result = (
Expand Down Expand Up @@ -851,6 +857,9 @@ def test_generated_services_uses_mode_streaming(self):
"</services>"
)
self.assertEqual(self.app_package.services_to_text, expected_result)
self.assertTrue(
compare_xml(self.app_package.services_to_text_vt, expected_result)
)


class TestSchemaInheritance(unittest.TestCase):
Expand Down Expand Up @@ -1030,6 +1039,9 @@ def test_services_to_text(self):
)

self.assertEqual(self.app_package.services_to_text, expected_result)
self.assertTrue(
compare_xml(self.app_package.services_to_text_vt, expected_result),
)


class TestSimplifiedApplicationPackage(unittest.TestCase):
Expand Down Expand Up @@ -1184,6 +1196,9 @@ def test_services_to_text(self):
)

self.assertEqual(self.app_package.services_to_text, expected_result)
self.assertTrue(
compare_xml(self.app_package.services_to_text_vt, expected_result),
)

def test_query_profile_to_text(self):
expected_result = (
Expand Down Expand Up @@ -1272,6 +1287,9 @@ def test_services_to_text(self):
"</services>"
)
self.assertEqual(self.app_package.services_to_text, expected_result)
self.assertTrue(
compare_xml(self.app_package.services_to_text_vt, expected_result),
)


class TestComponentSetup(unittest.TestCase):
Expand Down Expand Up @@ -1337,6 +1355,9 @@ def test_services_to_text(self):
"</services>"
)
self.assertEqual(self.app_package.services_to_text, expected_result)
self.assertTrue(
compare_xml(self.app_package.services_to_text_vt, expected_result),
)


class TestClientTokenSetup(unittest.TestCase):
Expand Down Expand Up @@ -1386,6 +1407,9 @@ def test_services_to_text(self):
)

self.assertEqual(self.app_package.services_to_text, expected_result)
self.assertTrue(
compare_xml(self.app_package.services_to_text_vt, expected_result),
)


class TestClientsWithCluster(unittest.TestCase):
Expand Down Expand Up @@ -1450,6 +1474,9 @@ def test_services_to_text(self):
"</services>"
)
self.assertEqual(self.app_package.services_to_text, expected_result)
self.assertTrue(
compare_xml(self.app_package.services_to_text_vt, expected_result),
)


class TestValidAppName(unittest.TestCase):
Expand Down Expand Up @@ -1581,6 +1608,9 @@ def test_services_to_text(self):
"</services>"
)
self.assertEqual(self.app_package.services_to_text, expected_result)
self.assertTrue(
compare_xml(self.app_package.services_to_text_vt, expected_result),
)


class TestAuthClientEquality(unittest.TestCase):
Expand Down Expand Up @@ -1698,3 +1728,80 @@ def test_schema_to_text(self):
"}"
)
self.assertEqual(self.app_package.schema.schema_to_text, expected_result)


class TestVTequality(unittest.TestCase):
def test_application_configuration(self):
app_config = ApplicationConfiguration(
name="container.handler.observability.application-userdata",
value={"version": "my-version"},
)
app_config_vt = app_config.to_vt()
vt_str = str(app_config_vt.to_xml())
app_config_str = app_config.to_text
self.assertTrue(compare_xml(app_config_str, vt_str))

def test_cluster_configuration(self):
clusters = [
ContainerCluster(
id="test_container",
nodes=Nodes(
count="1",
parameters=[
Parameter(
"resources",
{"vcpu": "4.0", "memory": "16Gb", "disk": "125Gb"},
[Parameter("gpu", {"count": "1", "memory": "16Gb"})],
),
],
),
components=[
Component(
id="e5",
type="hugging-face-embedder",
parameters=[
Parameter(
"transformer-model", {"path": "model/model.onnx"}
),
Parameter(
"tokenizer-model", {"path": "model/tokenizer.json"}
),
],
)
],
auth_clients=[
AuthClient(
id="mtls",
permissions=["read", "write"],
parameters=[
Parameter("certificate", {"file": "security/clients.pem"})
],
),
AuthClient(
id="token",
permissions=["read"],
parameters=[Parameter("token", {"id": "accessToken"})],
),
],
),
ContentCluster(id="test_content", document_name="test"),
]
for cluster_config in clusters:
vt_str = str(cluster_config.to_vt().to_xml())
cluster_config_str = cluster_config.to_xml_string()
self.assertTrue(compare_xml(cluster_config_str, vt_str))


class TestServiceConfig(unittest.TestCase):
def test_default_service_config_to_text(self):
self.maxDiff = None
application_name = "test"
service_config = ServicesConfiguration(application_name=application_name)
app_package = ApplicationPackage(
name=application_name, services_config=service_config
)
expected_result = '<?xml version="1.0" encoding="UTF-8" ?>\n<services version="1.0">\n <container id="test_container" version="1.0"></container>\n</services>\n'
self.assertEqual(expected_result, app_package.services_to_text)
self.assertTrue(
compare_xml(app_package.services_to_text_vt, expected_result),
)
81 changes: 81 additions & 0 deletions tests/unit/test_vt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import unittest
from vespa.configuration.vt import compare_xml


class TestXMLComparison(unittest.TestCase):
def test_equal_simple(self):
xml1 = "<root><child>Text</child></root>"
xml2 = "<root><child>Text</child></root>"
self.assertTrue(compare_xml(xml1, xml2))

def test_whitespace_differences(self):
xml1 = "<root><child>Text</child></root>"
xml2 = "<root>\n <child>Text</child>\n</root>"
self.assertTrue(compare_xml(xml1, xml2))

def test_attribute_order(self):
xml1 = '<root><child b="2" a="1">Text</child></root>'
xml2 = '<root><child a="1" b="2">Text</child></root>'
self.assertTrue(compare_xml(xml1, xml2))

def test_text_whitespace(self):
xml1 = "<root><child> Text </child></root>"
xml2 = "<root><child>Text</child></root>"
self.assertTrue(compare_xml(xml1, xml2))

def test_different_text(self):
xml1 = "<root><child>Text1</child></root>"
xml2 = "<root><child>Text2</child></root>"
self.assertFalse(compare_xml(xml1, xml2))

def test_different_structure(self):
xml1 = "<root><child>Text</child></root>"
xml2 = "<root><child><subchild>Text</subchild></child></root>"
self.assertFalse(compare_xml(xml1, xml2))

def test_namespace_handling(self):
xml1 = '<root xmlns="namespace"><child>Text</child></root>'
xml2 = "<root><child>Text</child></root>"
# Namespaces are considered in the tag comparison
self.assertFalse(compare_xml(xml1, xml2))

def test_comments_ignored(self):
xml1 = "<root><!-- A comment --><child>Text</child></root>"
xml2 = "<root><child>Text</child></root>"
# Comments are not part of the element tree; they are ignored
self.assertTrue(compare_xml(xml1, xml2))

def test_processing_instructions(self):
xml1 = "<?xml version='1.0'?><root><child>Text</child></root>"
xml2 = "<root><child>Text</child></root>"
self.assertTrue(compare_xml(xml1, xml2))

def test_different_attributes(self):
xml1 = '<root><child a="1">Text</child></root>'
xml2 = '<root><child a="2">Text</child></root>'
self.assertFalse(compare_xml(xml1, xml2))

def test_additional_attributes(self):
xml1 = '<root><child a="1" b="2">Text</child></root>'
xml2 = '<root><child a="1">Text</child></root>'
self.assertFalse(compare_xml(xml1, xml2))

def test_multiple_children_order(self):
xml1 = "<root><child>1</child><child>2</child></root>"
xml2 = "<root><child>2</child><child>1</child></root>"
self.assertTrue(compare_xml(xml1, xml2))

def test_namespace_prefixes(self):
xml1 = '<root xmlns:ns="namespace"><ns:child>Text</ns:child></root>'
xml2 = "<root><child>Text</child></root>"
# Different namespaces make the tags different
self.assertFalse(compare_xml(xml1, xml2))

def test_cdata_handling(self):
xml1 = "<root><child><![CDATA[Text]]></child></root>"
xml2 = "<root><child>Text</child></root>"
self.assertTrue(compare_xml(xml1, xml2))


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions vespa/configuration/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@
"removed-db",
"max-wait-after-coverage-factor",
"maxsize",
"clients",
"client",
]

# Fail if any tag is duplicated. Provide feedback of which tags are duplicated.
Expand Down
61 changes: 58 additions & 3 deletions vespa/configuration/vt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import types
from xml.sax.saxutils import escape
from fastcore.utils import patch
import xml.etree.ElementTree as ET

# If the vespa tags correspond to reserved Python keywords, they are replaced with the following:
replace_reserved = {
Expand Down Expand Up @@ -64,6 +65,7 @@ def __iter__(self):


def attrmap(o):
o = dict(_global="global").get(o, o)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this?

return o.lstrip("_").replace("_", "-")


Expand Down Expand Up @@ -98,16 +100,21 @@ def vt(
**kw,
):
"Create an `VT` structure for `to_xml()`"
return VT(
tag.lower(), *_preproc(c, kw, attrmap=attrmap, valmap=valmap), void_=void_
)
# NB! fastcore.xml uses tag.lower() for tag names. This is not done here.
return VT(tag, *_preproc(c, kw, attrmap=attrmap, valmap=valmap), void_=void_)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does preproc do?
A docstring for preproc would be nice.



# XML void tags (self-closing)
# TODO: Add self-closing tags for Vespa configuration
voids = set("".split())


def Xml(*c, version="1.0", encoding="UTF-8", **kwargs) -> VT:
"An top level XML tag, with `encoding` and children `c`"
res = vt("?xml", *c, version=version, encoding=encoding, void_="?")
return res


# Replace the 'partial' based tag creation
def create_tag_function(tag, void_):
def tag_function(*c, **kwargs):
Expand Down Expand Up @@ -157,6 +164,8 @@ def _to_xml(elm, lvl, indent, do_escape):

# Handle void (self-closing) tags
if elm.void_:
if isinstance(elm.void_, str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For contxt, void_ is a bool?
Why check for str?

return f"{sp}<{stag}{attr_str} {elm.void_}>{nl}"
return f"{sp}<{stag}{attr_str} />{nl}"

# Handle non-void tags with children or no children
Expand Down Expand Up @@ -214,3 +223,49 @@ def __call__(self: VT, *c, **kw):
if kw:
self.attrs = {**self.attrs, **kw}
return self


def canonicalize(element):
"""Recursively sort attributes and children to canonicalize the element."""
# Sort attributes
if element.attrib:
element.attrib = dict(sorted(element.attrib.items()))
# Sort children by tag and text
children = list(element)
for child in children:
canonicalize(child)
element[:] = sorted(children, key=lambda e: (e.tag, (e.text or "").strip()))
# Strip whitespace from text and tail
if element.text:
element.text = element.text.strip()
if element.tail:
element.tail = element.tail.strip()


def elements_equal(e1, e2):
"""Compare two elements for equality."""
if e1.tag != e2.tag:
return False
if sorted(e1.attrib.items()) != sorted(e2.attrib.items()):
return False
if (e1.text or "").strip() != (e2.text or "").strip():
return False
if (e1.tail or "").strip() != (e2.tail or "").strip():
return False
if len(e1) != len(e2):
return False
return all(elements_equal(c1, c2) for c1, c2 in zip(e1, e2))


def compare_xml(xml_str1, xml_str2):
"""Compare two XML strings for equality."""
try:
tree1 = ET.ElementTree(ET.fromstring(xml_str1))
tree2 = ET.ElementTree(ET.fromstring(xml_str2))
except ET.ParseError:
return False
root1 = tree1.getroot()
root2 = tree2.getroot()
canonicalize(root1)
canonicalize(root2)
return elements_equal(root1, root2)
Loading
Loading