diff --git a/.gitignore b/.gitignore index 69ed2c6c..ccbcdab1 100644 --- a/.gitignore +++ b/.gitignore @@ -152,5 +152,5 @@ markdown_before/ *.pem *.crt -# Exclude markdown files from vespa/resources +# Exclude markdown files in vespa/resources - directory vespa/resources/*.md \ No newline at end of file diff --git a/tests/integration/test_integration_queries.py b/tests/integration/test_integration_queries.py new file mode 100644 index 00000000..80107f1c --- /dev/null +++ b/tests/integration/test_integration_queries.py @@ -0,0 +1,1041 @@ +import unittest +import requests +from vespa.deployment import VespaDocker +from vespa.package import ( + ApplicationPackage, + Schema, + Document, + Field, + FieldSet, + StructField, + Struct, + RankProfile, +) +from tests.unit.test_q import TestQueryBuilder + +qb = TestQueryBuilder() + + +class TestQueriesIntegration(unittest.TestCase): + @classmethod + def setUpClass(cls): + application_name = "querybuilder" + cls.application_name = application_name + schema_name = "sd1" + cls.schema_name = schema_name + # Define all fields used in the unit tests + # Schema 1 + fields = [ + Field( + name="weightedset_field", + type="weightedset", + indexing=["attribute"], + ), + Field(name="location_field", type="position", indexing=["attribute"]), + Field(name="f1", type="string", indexing=["attribute", "summary"]), + Field(name="f2", type="string", indexing=["attribute", "summary"]), + Field(name="f3", type="string", indexing=["attribute", "summary"]), + Field(name="f4", type="string", indexing=["attribute", "summary"]), + Field(name="age", type="int", indexing=["attribute", "summary"]), + Field(name="duration", type="int", indexing=["attribute", "summary"]), + Field(name="id", type="string", indexing=["attribute", "summary"]), + Field(name="text", type="string", indexing=["index", "summary"]), + Field(name="title", type="string", indexing=["index", "summary"]), + Field(name="description", type="string", indexing=["index", "summary"]), + Field(name="date", type="string", indexing=["attribute", "summary"]), + Field(name="status", type="string", indexing=["attribute", "summary"]), + Field(name="comments", type="string", indexing=["attribute", "summary"]), + Field( + name="embedding", + type="tensor(x[128])", + indexing=["attribute"], + ), + Field(name="tags", type="array", indexing=["attribute", "summary"]), + Field( + name="timestamp", + type="long", + indexing=["attribute", "summary"], + ), + Field(name="integer_field", type="int", indexing=["attribute", "summary"]), + Field( + name="predicate_field", + type="predicate", + indexing=["attribute", "summary"], + index="arity: 2", # This is required for predicate fields + ), + Field( + name="myStringAttribute", type="string", indexing=["index", "summary"] + ), + Field(name="myUrlField", type="string", indexing=["index", "summary"]), + Field(name="fieldName", type="string", indexing=["index", "summary"]), + Field( + name="dense_rep", + type="tensor(x[128])", + indexing=["attribute"], + ), + Field(name="artist", type="string", indexing=["attribute", "summary"]), + Field(name="subject", type="string", indexing=["attribute", "summary"]), + Field( + name="display_date", type="string", indexing=["attribute", "summary"] + ), + Field(name="price", type="double", indexing=["attribute", "summary"]), + Field(name="keywords", type="string", indexing=["index", "summary"]), + ] + email_struct = Struct( + name="email", + fields=[ + Field(name="sender", type="string"), + Field(name="recipient", type="string"), + Field(name="subject", type="string"), + Field(name="content", type="string"), + ], + ) + emails_field = Field( + name="emails", + type="array", + indexing=["summary"], + struct_fields=[ + StructField( + name="content", indexing=["attribute"], attribute=["fast-search"] + ) + ], + ) + person_struct = Struct( + name="person", + fields=[ + Field(name="first_name", type="string"), + Field(name="last_name", type="string"), + Field(name="year_of_birth", type="int"), + ], + ) + persons_field = Field( + name="persons", + type="array", + indexing=["summary"], + struct_fields=[ + StructField(name="first_name", indexing=["attribute"]), + StructField(name="last_name", indexing=["attribute"]), + StructField(name="year_of_birth", indexing=["attribute"]), + ], + ) + rank_profiles = [ + RankProfile( + name="dotproduct", + first_phase="rawScore(weightedset_field)", + summary_features=["rawScore(weightedset_field)"], + ), + RankProfile( + name="geolocation", + first_phase="distance(location_field)", + summary_features=["distance(location_field).km"], + ), + RankProfile( + name="bm25", first_phase="bm25(text)", summary_features=["bm25(text)"] + ), + ] + fieldset = FieldSet(name="default", fields=["text", "title", "description"]) + document = Document(fields=fields, structs=[email_struct, person_struct]) + schema = Schema( + name=schema_name, + document=document, + rank_profiles=rank_profiles, + fieldsets=[fieldset], + ) + schema.add_fields(emails_field, persons_field) + # Add purchase schema for grouping test + purchase_schema = Schema( + name="purchase", + document=Document( + fields=[ + Field(name="date", type="long", indexing=["summary", "attribute"]), + Field(name="price", type="int", indexing=["summary", "attribute"]), + Field(name="tax", type="double", indexing=["summary", "attribute"]), + Field( + name="item", type="string", indexing=["summary", "attribute"] + ), + Field( + name="customer", + type="string", + indexing=["summary", "attribute"], + ), + ] + ), + ) + # Create the application package + application_package = ApplicationPackage( + name=application_name, schema=[schema, purchase_schema] + ) + print(application_package.get_schema(schema_name).schema_to_text) + print(application_package.get_schema("purchase").schema_to_text) + # Deploy the application + cls.vespa_docker = VespaDocker(port=8089) + cls.app = cls.vespa_docker.deploy(application_package=application_package) + cls.app.wait_for_application_up() + + @classmethod + def tearDownClass(cls): + cls.vespa_docker.container.stop(timeout=5) + cls.vespa_docker.container.remove() + + @property + def sample_grouping_data(self): + sample_data = [ + { + "fields": { + "customer": "Smith", + "date": 1157526000, + "item": "Intake valve", + "price": "1000", + "tax": "0.24", + }, + "put": "id:purchase:purchase::0", + }, + { + "fields": { + "customer": "Smith", + "date": 1157616000, + "item": "Rocker arm", + "price": "1000", + "tax": "0.12", + }, + "put": "id:purchase:purchase::1", + }, + { + "fields": { + "customer": "Smith", + "date": 1157619600, + "item": "Spring", + "price": "2000", + "tax": "0.24", + }, + "put": "id:purchase:purchase::2", + }, + { + "fields": { + "customer": "Jones", + "date": 1157709600, + "item": "Valve cover", + "price": "3000", + "tax": "0.12", + }, + "put": "id:purchase:purchase::3", + }, + { + "fields": { + "customer": "Jones", + "date": 1157702400, + "item": "Intake port", + "price": "5000", + "tax": "0.24", + }, + "put": "id:purchase:purchase::4", + }, + { + "fields": { + "customer": "Brown", + "date": 1157706000, + "item": "Head", + "price": "8000", + "tax": "0.12", + }, + "put": "id:purchase:purchase::5", + }, + { + "fields": { + "customer": "Smith", + "date": 1157796000, + "item": "Coolant", + "price": "1300", + "tax": "0.24", + }, + "put": "id:purchase:purchase::6", + }, + { + "fields": { + "customer": "Jones", + "date": 1157788800, + "item": "Engine block", + "price": "2100", + "tax": "0.12", + }, + "put": "id:purchase:purchase::7", + }, + { + "fields": { + "customer": "Brown", + "date": 1157792400, + "item": "Oil pan", + "price": "3400", + "tax": "0.24", + }, + "put": "id:purchase:purchase::8", + }, + { + "fields": { + "customer": "Smith", + "date": 1157796000, + "item": "Oil sump", + "price": "5500", + "tax": "0.12", + }, + "put": "id:purchase:purchase::9", + }, + { + "fields": { + "customer": "Jones", + "date": 1157875200, + "item": "Camshaft", + "price": "8900", + "tax": "0.24", + }, + "put": "id:purchase:purchase::10", + }, + { + "fields": { + "customer": "Brown", + "date": 1157878800, + "item": "Exhaust valve", + "price": "1440", + "tax": "0.12", + }, + "put": "id:purchase:purchase::11", + }, + { + "fields": { + "customer": "Brown", + "date": 1157882400, + "item": "Rocker arm", + "price": "2330", + "tax": "0.24", + }, + "put": "id:purchase:purchase::12", + }, + { + "fields": { + "customer": "Brown", + "date": 1157875200, + "item": "Spring", + "price": "3770", + "tax": "0.12", + }, + "put": "id:purchase:purchase::13", + }, + { + "fields": { + "customer": "Smith", + "date": 1157878800, + "item": "Spark plug", + "price": "6100", + "tax": "0.24", + }, + "put": "id:purchase:purchase::14", + }, + { + "fields": { + "customer": "Jones", + "date": 1157968800, + "item": "Exhaust port", + "price": "9870", + "tax": "0.12", + }, + "put": "id:purchase:purchase::15", + }, + { + "fields": { + "customer": "Brown", + "date": 1157961600, + "item": "Piston", + "price": "1597", + "tax": "0.24", + }, + "put": "id:purchase:purchase::16", + }, + { + "fields": { + "customer": "Smith", + "date": 1157965200, + "item": "Connection rod", + "price": "2584", + "tax": "0.12", + }, + "put": "id:purchase:purchase::17", + }, + { + "fields": { + "customer": "Jones", + "date": 1157968800, + "item": "Rod bearing", + "price": "4181", + "tax": "0.24", + }, + "put": "id:purchase:purchase::18", + }, + { + "fields": { + "customer": "Jones", + "date": 1157972400, + "item": "Crankshaft", + "price": "6765", + "tax": "0.12", + }, + "put": "id:purchase:purchase::19", + }, + ] + docs = [ + {"fields": doc["fields"], "id": doc["put"].split("::")[-1]} + for doc in sample_data + ] + return docs + + def feed_grouping_data(self): + # Feed documents + self.app.feed_iterable(iter=self.sample_grouping_data, schema="purchase") + return + + def test_dotProduct_with_annotations(self): + # Feed a document with 'weightedset_field' + field = "weightedset_field" + fields = {field: {"feature1": 2, "feature2": 4}} + data_id = 1 + self.app.feed_data_point( + schema=self.schema_name, data_id=data_id, fields=fields + ) + q = qb.test_dotProduct_with_annotations() + with self.app.syncio() as sess: + result = sess.query(yql=q, ranking="dotproduct") + print(result.json) + self.assertEqual(len(result.hits), 1) + self.assertEqual( + result.hits[0]["id"], + f"id:{self.schema_name}:{self.schema_name}::{data_id}", + ) + self.assertEqual( + result.hits[0]["fields"]["summaryfeatures"]["rawScore(weightedset_field)"], + 10, + ) + + def test_geolocation_with_annotations(self): + # Feed a document with 'location_field' + field_name = "location_field" + fields = { + field_name: { + "lat": 37.77491, + "lng": -122.41941, + }, # 0.00001 degrees more than the query + } + data_id = 2 + self.app.feed_data_point( + schema=self.schema_name, data_id=data_id, fields=fields + ) + # Build and send the query + q = qb.test_geolocation_with_annotations() + with self.app.syncio() as sess: + result = sess.query(yql=q, ranking="geolocation") + # Check the result + self.assertEqual(len(result.hits), 1) + self.assertEqual( + result.hits[0]["id"], + f"id:{self.schema_name}:{self.schema_name}::{data_id}", + ) + self.assertAlmostEqual( + result.hits[0]["fields"]["summaryfeatures"]["distance(location_field).km"], + 0.001417364012462494, + ) + print(result.json) + + def test_basic_and_andnot_or_offset_limit_param_order_by_and_contains(self): + docs = [ + { # Should not match - f3 doesn't contain "v3" + "f1": "v1", + "f2": "v2", + "f3": "asdf", + "f4": "d", + "age": 10, + "duration": 100, + }, + { # Should match + "f1": "v1", + "f2": "v2", + "f3": "v3", + "f4": "d", + "age": 20, + "duration": 200, + }, + { # Should match + "f1": "v1", + "f2": "v2", + "f3": "v3", + "f4": "d", + "age": 30, + "duration": 300, + }, + { # Should not match - contains f4="v4" + "f1": "v1", + "f2": "v2", + "f3": "v3", + "f4": "v4", + "age": 40, + "duration": 400, + }, + ] + + # Feed documents + docs = [{"id": data_id, "fields": doc} for data_id, doc in enumerate(docs, 1)] + self.app.feed_iterable(iter=docs, schema=self.schema_name) + + # Build and send query + q = qb.test_basic_and_andnot_or_offset_limit_param_order_by_and_contains() + print(f"Executing query: {q}") + + with self.app.syncio() as sess: + result = sess.query(yql=q) + + # Verify results + self.assertEqual( + len(result.hits), 1 + ) # Should get 1 hit due to offset=1, limit=2 + + # The query orders by age desc, duration asc with offset 1 + # So we should get doc ID 2 (since doc ID 3 is skipped due to offset) + hit = result.hits[0] + self.assertEqual(hit["id"], f"id:{self.schema_name}:{self.schema_name}::2") + + # Verify the matching document has expected field values + self.assertEqual(hit["fields"]["age"], 20) + self.assertEqual(hit["fields"]["duration"], 200) + self.assertEqual(hit["fields"]["f1"], "v1") + self.assertEqual(hit["fields"]["f2"], "v2") + self.assertEqual(hit["fields"]["f3"], "v3") + self.assertEqual(hit["fields"]["f4"], "d") + + print(result.json) + + def test_matches(self): + # Matches is a regex (or substring) match + # Feed test documents + docs = [ + { # Doc 1: Should match - satisfies (f1="v1" AND f2="v2") and f4!="v4" + "f1": "v1", + "f2": "v2", + "f3": "other", + "f4": "nothing", + }, + { # Doc 2: Should not match - fails f4!="v4" condition + "f1": "v1", + "f2": "v2", + "f3": "v3", + "f4": "v4", + }, + { # Doc 3: Should match - satisfies f3="v3" and f4!="v4" + "f1": "other", + "f2": "other", + "f3": "v3", + "f4": "nothing", + }, + { # Doc 4: Should not match - fails all conditions + "f1": "other", + "f2": "other", + "f3": "other", + "f4": "v4", + }, + ] + + # Ensure fields are properly indexed for matching + docs = [ + { + "fields": doc, + "id": str(data_id), + } + for data_id, doc in enumerate(docs, 1) + ] + + # Feed documents + self.app.feed_iterable(iter=docs, schema=self.schema_name) + + # Build and send query + q = qb.test_matches() + # select * from sd1 where ((f1 matches "v1" and f2 matches "v2") or f3 matches "v3") and !(f4 matches "v4") + print(f"Executing query: {q}") + + with self.app.syncio() as sess: + result = sess.query(yql=q) + + # Check result count + self.assertEqual(len(result.hits), 2) + + # Verify specific matches + ids = sorted([hit["id"] for hit in result.hits]) + expected_ids = sorted( + [ + f"id:{self.schema_name}:{self.schema_name}::1", + f"id:{self.schema_name}:{self.schema_name}::3", + ] + ) + + self.assertEqual(ids, expected_ids) + print(result.json) + + def test_nested_queries(self): + # Contains is an exact match + # q = 'select * from sd1 where f1 contains "1" and (!((f2 contains "2" and f3 contains "3") or (f2 contains "4" and !(f3 contains "5"))))' + # Feed test documents + docs = [ + { # Doc 1: Should not match - satisfies f1 contains "1" but fails inner query + "f1": "1", + "f2": "2", + "f3": "3", + }, + { # Doc 2: Should match + "f1": "1", + "f2": "4", + "f3": "5", + }, + { # Doc 3: Should not match - fails f1 contains "1" + "f1": "other", + "f2": "2", + "f3": "3", + }, + { # Doc 4: Should not match + "f1": "1", + "f2": "4", + "f3": "other", + }, + ] + docs = [ + { + "fields": doc, + "id": str(data_id), + } + for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name) + q = qb.test_nested_queries() + print(f"Executing query: {q}") + with self.app.syncio() as sess: + result = sess.query(yql=q) + print(result.json) + self.assertEqual(len(result.hits), 1) + self.assertEqual( + result.hits[0]["id"], + f"id:{self.schema_name}:{self.schema_name}::2", + ) + + def test_userquery_defaultindex(self): + # 'select * from sd1 where ({"defaultIndex":"text"}userQuery())' + # Feed test documents + docs = [ + { # Doc 1: Should match + "description": "foo", + "text": "foo", + }, + { # Doc 2: Should match + "description": "foo", + "text": "bar", + }, + { # Doc 3: Should not match + "description": "bar", + "text": "baz", + }, + ] + + # Format and feed documents + docs = [ + {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name) + + # Execute query + q = qb.test_userquery() + query = "foo" + print(f"Executing query: {q}") + body = { + "yql": str(q), + "query": query, + } + with self.app.syncio() as sess: + result = sess.query(body=body) + self.assertEqual(len(result.hits), 2) + ids = sorted([hit["id"] for hit in result.hits]) + self.assertIn(f"id:{self.schema_name}:{self.schema_name}::1", ids) + self.assertIn(f"id:{self.schema_name}:{self.schema_name}::2", ids) + + def test_userquery_customindex(self): + # 'select * from sd1 where userQuery())' + # Feed test documents + docs = [ + { # Doc 1: Should match + "description": "foo", + "text": "foo", + }, + { # Doc 2: Should not match + "description": "foo", + "text": "bar", + }, + { # Doc 3: Should not match + "description": "bar", + "text": "baz", + }, + ] + + # Format and feed documents + docs = [ + {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name) + + # Execute query + q = qb.test_userquery() + query = "foo" + print(f"Executing query: {q}") + body = { + "yql": str(q), + "query": query, + "ranking": "bm25", + "model.defaultIndex": "text", # userQuery() needs this to set index, see https://docs.vespa.ai/en/query-api.html#using-a-fieldset + } + with self.app.syncio() as sess: + result = sess.query(body=body) + # Verify only one document matches both conditions + self.assertEqual(len(result.hits), 1) + self.assertEqual( + result.hits[0]["id"], f"id:{self.schema_name}:{self.schema_name}::1" + ) + + # Verify matching document has expected values + hit = result.hits[0] + self.assertEqual(hit["fields"]["description"], "foo") + self.assertEqual(hit["fields"]["text"], "foo") + + def test_userinput(self): + # 'select * from sd1 where userInput(@myvar)' + # Feed test documents + myvar = "panda" + docs = [ + { # Doc 1: Should match + "description": "a panda is a cute", + "text": "foo", + }, + { # Doc 2: Should match + "description": "foo", + "text": "you are a cool panda", + }, + { # Doc 3: Should not match + "description": "bar", + "text": "baz", + }, + ] + # Format and feed documents + docs = [ + {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name) + # Execute query + q = qb.test_userinput() + print(f"Executing query: {q}") + body = { + "yql": str(q), + "ranking": "bm25", + "myvar": myvar, + } + with self.app.syncio() as sess: + result = sess.query(body=body) + # Verify only two documents match + self.assertEqual(len(result.hits), 2) + # Verify matching documents have expected values + ids = sorted([hit["id"] for hit in result.hits]) + self.assertIn(f"id:{self.schema_name}:{self.schema_name}::1", ids) + self.assertIn(f"id:{self.schema_name}:{self.schema_name}::2", ids) + + def test_userinput_with_defaultindex(self): + # 'select * from sd1 where {defaultIndex:"text"}userInput(@myvar)' + # Feed test documents + myvar = "panda" + docs = [ + { # Doc 1: Should not match + "description": "a panda is a cute", + "text": "foo", + }, + { # Doc 2: Should match + "description": "foo", + "text": "you are a cool panda", + }, + { # Doc 3: Should not match + "description": "bar", + "text": "baz", + }, + ] + # Format and feed documents + docs = [ + {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name) + # Execute query + q = qb.test_userinput_with_defaultindex() + print(f"Executing query: {q}") + body = { + "yql": str(q), + "ranking": "bm25", + "myvar": myvar, + } + with self.app.syncio() as sess: + result = sess.query(body=body) + print(result.json) + # Verify only one document matches + self.assertEqual(len(result.hits), 1) + # Verify matching document has expected values + hit = result.hits[0] + self.assertEqual(hit["id"], f"id:{self.schema_name}:{self.schema_name}::2") + + def test_in_operator_intfield(self): + # 'select * from * where integer_field in (10, 20, 30)' + # We use age field for this test + # Feed test documents + docs = [ + { # Doc 1: Should match + "age": 10, + }, + { # Doc 2: Should match + "age": 20, + }, + { # Doc 3: Should not match + "age": 31, + }, + { # Doc 4: Should not match + "age": 40, + }, + ] + # Format and feed documents + docs = [ + {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name) + # Execute query + q = qb.test_in_operator_intfield() + print(f"Executing query: {q}") + with self.app.syncio() as sess: + result = sess.query(yql=q) + print(result.json) + # Verify only two documents match + self.assertEqual(len(result.hits), 2) + # Verify matching documents have expected values + ids = sorted([hit["id"] for hit in result.hits]) + self.assertIn(f"id:{self.schema_name}:{self.schema_name}::1", ids) + self.assertIn(f"id:{self.schema_name}:{self.schema_name}::2", ids) + + def test_in_operator_stringfield(self): + # 'select * from sd1 where status in ("active", "inactive")' + # Feed test documents + docs = [ + { # Doc 1: Should match + "status": "active", + }, + { # Doc 2: Should match + "status": "inactive", + }, + { # Doc 3: Should not match + "status": "foo", + }, + { # Doc 4: Should not match + "status": "bar", + }, + ] + # Format and feed documents + docs = [ + {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name) + # Execute query + q = qb.test_in_operator_stringfield() + print(f"Executing query: {q}") + with self.app.syncio() as sess: + result = sess.query(yql=q) + # Verify only two documents match + self.assertEqual(len(result.hits), 2) + # Verify matching documents have expected values + ids = sorted([hit["id"] for hit in result.hits]) + self.assertIn(f"id:{self.schema_name}:{self.schema_name}::1", ids) + self.assertIn(f"id:{self.schema_name}:{self.schema_name}::2", ids) + + def test_predicate(self): + # 'select * from sd1 where predicate(predicate_field,{"gender":"Female"},{"age":25L})' + # Feed test documents with predicate_field + docs = [ + { # Doc 1: Should match - satisfies both predicates + "predicate_field": 'gender in ["Female"] and age in [20..30]', + }, + { # Doc 2: Should not match - wrong gender + "predicate_field": 'gender in ["Male"] and age in [20..30]', + }, + { # Doc 3: Should not match - too young + "predicate_field": 'gender in ["Female"] and age in [30..40]', + }, + ] + + # Format and feed documents + docs = [ + {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name) + + # Execute query using predicate search + q = qb.test_predicate() + print(f"Executing query: {q}") + + with self.app.syncio() as sess: + result = sess.query(yql=q) + + # Verify only one document matches both predicates + self.assertEqual(len(result.hits), 1) + + # Verify matching document has expected id + hit = result.hits[0] + self.assertEqual(hit["id"], f"id:{self.schema_name}:{self.schema_name}::1") + + def test_fuzzy(self): + # 'select * from sd1 where f1 contains ({prefixLength:1,maxEditDistance:2}fuzzy("parantesis"))' + # Feed test documents + docs = [ + { # Doc 1: Should match + "f1": "parantesis", + }, + { # Doc 2: Should match - edit distance 1 + "f1": "paranthesis", + }, + { # Doc 3: Should not match - edit distance 3 + "f1": "parrenthesis", + }, + ] + # Format and feed documents + docs = [ + {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name) + # Execute query + q = qb.test_fuzzy() + print(f"Executing query: {q}") + with self.app.syncio() as sess: + result = sess.query(yql=q) + # Verify only two documents match + self.assertEqual(len(result.hits), 2) + # Verify matching documents have expected values + ids = sorted([hit["id"] for hit in result.hits]) + self.assertIn(f"id:{self.schema_name}:{self.schema_name}::1", ids) + self.assertIn(f"id:{self.schema_name}:{self.schema_name}::2", ids) + + def test_uri(self): + # 'select * from sd1 where myUrlField contains uri("vespa.ai/foo")' + # Feed test documents + docs = [ + { # Doc 1: Should match + "myUrlField": "https://vespa.ai/foo", + }, + { # Doc 2: Should not match - wrong path + "myUrlField": "https://vespa.ai/bar", + }, + { # Doc 3: Should not match - wrong domain + "myUrlField": "https://google.com/foo", + }, + ] + # Format and feed documents + docs = [ + {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name) + # Execute query + q = qb.test_uri() + print(f"Executing query: {q}") + with self.app.syncio() as sess: + result = sess.query(yql=q) + # Verify only one document matches + self.assertEqual(len(result.hits), 1) + # Verify matching document has expected values + hit = result.hits[0] + self.assertEqual(hit["id"], f"id:{self.schema_name}:{self.schema_name}::1") + + def test_same_element(self): + # 'select * from sd1 where persons contains sameElement(first_name contains "Joe", last_name contains "Smith", year_of_birth < 1940)' + # Feed test documents + docs = [ + { # Doc 1: Should match + "persons": [ + {"first_name": "Joe", "last_name": "Smith", "year_of_birth": 1930} + ], + }, + { # Doc 2: Should not match - wrong last name + "persons": [ + {"first_name": "Joe", "last_name": "Johnson", "year_of_birth": 1930} + ], + }, + { # Doc 3: Should not match - wrong year of birth + "persons": [ + {"first_name": "Joe", "last_name": "Smith", "year_of_birth": 1940} + ], + }, + ] + # Format and feed documents + docs = [ + {"fields": doc, "id": str(data_id)} for data_id, doc in enumerate(docs, 1) + ] + self.app.feed_iterable(iter=docs, schema=self.schema_name) + # Execute query + q = qb.test_same_element() + print(f"Executing query: {q}") + with self.app.syncio() as sess: + result = sess.query(yql=q) + # Verify only one document matches + self.assertEqual(len(result.hits), 1) + # Verify matching document has expected values + hit = result.hits[0] + self.assertEqual(hit["id"], f"id:{self.schema_name}:{self.schema_name}::1") + + def test_grouping_with_condition(self): + # "select * from purchase | all(group(customer) each(output(sum(price))))" + # Feed test documents + self.feed_grouping_data() + # Execute query + q = qb.test_grouping_with_condition() + print(f"Executing query: {q}") + with self.app.syncio() as sess: + result = sess.query(yql=q) + result_children = result.json["root"]["children"][0]["children"] + # also get result from https://api.search.vespa.ai/search/?yql=select%20*%20from%20purchase%20where%20true%20%7C%20all(%20group(customer)%20each(output(sum(price)))%20) + # to compare + api_resp = requests.get( + "https://api.search.vespa.ai/search/?yql=select%20*%20from%20purchase%20where%20true%20%7C%20all(%20group(customer)%20each(output(sum(price)))%20)", + ) + api_resp = api_resp.json() + api_children = api_resp["root"]["children"][0]["children"] + self.assertEqual(result_children, api_children) + # Verify the result + group_results = result_children[0]["children"] + self.assertEqual(group_results[0]["id"], "group:string:Brown") + self.assertEqual(group_results[0]["value"], "Brown") + self.assertEqual(group_results[0]["fields"]["sum(price)"], 20537) + self.assertEqual(group_results[1]["id"], "group:string:Jones") + self.assertEqual(group_results[1]["value"], "Jones") + self.assertEqual(group_results[1]["fields"]["sum(price)"], 39816) + self.assertEqual(group_results[2]["id"], "group:string:Smith") + self.assertEqual(group_results[2]["value"], "Smith") + self.assertEqual(group_results[2]["fields"]["sum(price)"], 19484) + + def test_grouping_with_ordering_and_limiting(self): + # "select * from purchase where true | all(group(customer) max(2) precision(12) order(-count()) each(output(sum(price))))" + # Feed test documents + self.feed_grouping_data() + # Execute query + q = qb.test_grouping_with_ordering_and_limiting() + print(f"Executing query: {q}") + with self.app.syncio() as sess: + result = sess.query(yql=q) + result_children = result.json["root"]["children"][0]["children"][0]["children"] + print(result_children) + # assert 2 groups + self.assertEqual(len(result_children), 2) + # assert the first group is Jones + self.assertEqual(result_children[0]["id"], "group:string:Jones") + self.assertEqual(result_children[0]["value"], "Jones") + self.assertEqual(result_children[0]["fields"]["sum(price)"], 39816) + # assert the second group is Brown + self.assertEqual(result_children[1]["id"], "group:string:Smith") + self.assertEqual(result_children[1]["value"], "Smith") + self.assertEqual(result_children[1]["fields"]["sum(price)"], 19484) diff --git a/tests/unit/test_q.py b/tests/unit/test_q.py new file mode 100644 index 00000000..be475a41 --- /dev/null +++ b/tests/unit/test_q.py @@ -0,0 +1,626 @@ +import unittest +from vespa.querybuilder import Grouping +import vespa.querybuilder as qb + + +class TestQueryBuilder(unittest.TestCase): + def test_dotProduct_with_annotations(self): + condition = qb.dotProduct( + "weightedset_field", + {"feature1": 1, "feature2": 2}, + annotations={"label": "myDotProduct"}, + ) + q = qb.select("*").from_("sd1").where(condition) + expected = 'select * from sd1 where ({label:"myDotProduct"}dotProduct(weightedset_field, {"feature1":1,"feature2":2}))' + self.assertEqual(q, expected) + return q + + def test_geolocation_with_annotations(self): + condition = qb.geoLocation( + "location_field", + 37.7749, + -122.4194, + "10km", + annotations={"targetHits": 100}, + ) + q = qb.select("*").from_("sd1").where(condition) + expected = 'select * from sd1 where ({targetHits:100}geoLocation(location_field, 37.7749, -122.4194, "10km"))' + self.assertEqual(q, expected) + return q + + def test_select_specific_fields(self): + f1 = qb.QueryField("f1") + condition = f1.contains("v1") + q = qb.select(["f1", "f2"]).from_("sd1").where(condition) + self.assertEqual(q, 'select f1, f2 from sd1 where f1 contains "v1"') + return q + + def test_select_from_specific_sources(self): + f1 = qb.QueryField("f1") + condition = f1.contains("v1") + q = qb.select("*").from_("sd1").where(condition) + self.assertEqual(q, 'select * from sd1 where f1 contains "v1"') + return q + + def test_select_from_multiples_sources(self): + f1 = qb.QueryField("f1") + condition = f1.contains("v1") + q = qb.select("*").from_("sd1", "sd2").where(condition) + self.assertEqual(q, 'select * from sd1, sd2 where f1 contains "v1"') + return q + + def test_basic_and_andnot_or_offset_limit_param_order_by_and_contains(self): + f1 = qb.QueryField("f1") + f2 = qb.QueryField("f2") + f3 = qb.QueryField("f3") + f4 = qb.QueryField("f4") + condition = ((f1.contains("v1") & f2.contains("v2")) | f3.contains("v3")) & ( + ~f4.contains("v4") + ) + q = ( + qb.select("*") + .from_("sd1") + .where(condition) + .set_offset(1) + .set_limit(2) + .set_timeout(3000) + .orderByDesc("age") + .orderByAsc("duration") + ) + + expected = 'select * from sd1 where ((f1 contains "v1" and f2 contains "v2") or f3 contains "v3") and !(f4 contains "v4") order by age desc, duration asc limit 2 offset 1 timeout 3000' + self.assertEqual(q, expected) + return q + + def test_timeout(self): + f1 = qb.QueryField("title") + condition = f1.contains("madonna") + q = qb.select("*").from_("sd1").where(condition).set_timeout(70) + expected = 'select * from sd1 where title contains "madonna" timeout 70' + self.assertEqual(q, expected) + return q + + def test_matches(self): + condition = ( + (qb.QueryField("f1").matches("v1") & qb.QueryField("f2").matches("v2")) + | qb.QueryField("f3").matches("v3") + ) & ~qb.QueryField("f4").matches("v4") + q = qb.select("*").from_("sd1").where(condition) + expected = 'select * from sd1 where ((f1 matches "v1" and f2 matches "v2") or f3 matches "v3") and !(f4 matches "v4")' + self.assertEqual(q, expected) + return q + + def test_nested_queries(self): + nested_query = ( + qb.QueryField("f2").contains("2") & qb.QueryField("f3").contains("3") + ) | (qb.QueryField("f2").contains("4") & ~qb.QueryField("f3").contains("5")) + condition = qb.QueryField("f1").contains("1") & ~nested_query + q = qb.select("*").from_("sd1").where(condition) + expected = 'select * from sd1 where f1 contains "1" and (!((f2 contains "2" and f3 contains "3") or (f2 contains "4" and !(f3 contains "5"))))' + self.assertEqual(q, expected) + return q + + def test_userquery(self): + condition = qb.userQuery() + q = qb.select("*").from_("sd1").where(condition) + expected = "select * from sd1 where userQuery()" + self.assertEqual(q, expected) + return q + + def test_fields_duration(self): + f1 = qb.QueryField("subject") + f2 = qb.QueryField("display_date") + f3 = qb.QueryField("duration") + q = qb.select([f1, f2]).from_("calendar").where(f3 > 0) + expected = "select subject, display_date from calendar where duration > 0" + self.assertEqual(q, expected) + return q + + def test_nearest_neighbor(self): + condition_uq = qb.userQuery() + condition_nn = qb.nearestNeighbor( + field="dense_rep", query_vector="q_dense", annotations={"targetHits": 10} + ) + q = qb.select(["id, text"]).from_("m").where(condition_uq | condition_nn) + expected = "select id, text from m where userQuery() or ({targetHits:10}nearestNeighbor(dense_rep, q_dense))" + self.assertEqual(q, expected) + return q + + def test_build_many_nn_operators(self): + self.maxDiff = None + conditions = [ + qb.nearestNeighbor( + field="colbert", + query_vector=f"binary_vector_{i}", + annotations={"targetHits": 100}, + ) + for i in range(32) + ] + # Use Condition.any to combine conditions with OR + q = qb.select("*").from_("doc").where(condition=qb.any(*conditions)) + expected = "select * from doc where " + " or ".join( + [ + f"({{targetHits:100}}nearestNeighbor(colbert, binary_vector_{i}))" + for i in range(32) + ] + ) + self.assertEqual(q, expected) + return q + + def test_field_comparison_operators(self): + f1 = qb.QueryField("age") + condition = (f1 > 30) & (f1 <= 50) + q = qb.select("*").from_("people").where(condition) + expected = "select * from people where age > 30 and age <= 50" + self.assertEqual(q, expected) + return q + + def test_field_in_range(self): + f1 = qb.QueryField("age") + condition = f1.in_range(18, 65) + q = qb.select("*").from_("people").where(condition) + expected = "select * from people where range(age, 18, 65)" + self.assertEqual(q, expected) + return q + + def test_field_annotation(self): + f1 = qb.QueryField("title") + annotations = {"highlight": True} + annotated_field = f1.annotate(annotations) + q = qb.select("*").from_("articles").where(annotated_field) + expected = "select * from articles where ({highlight:true})title" + self.assertEqual(q, expected) + return q + + def test_condition_annotation(self): + f1 = qb.QueryField("title") + condition = f1.contains("Python") + annotated_condition = condition.annotate({"filter": True}) + q = qb.select("*").from_("articles").where(annotated_condition) + expected = 'select * from articles where {filter:true}title contains "Python"' + self.assertEqual(q, expected) + return q + + def test_grouping_with_condition(self): + grouping = Grouping.all( + Grouping.group("customer"), + Grouping.each(Grouping.output(Grouping.sum("price"))), + ) + q = qb.select("*").from_("purchase").where(True).set_limit(0).groupby(grouping) + expected = "select * from purchase where true limit 0 | all(group(customer) each(output(sum(price))))" + self.assertEqual(q, expected) + return q + + def test_grouping_with_ordering_and_limiting(self): + self.maxDiff = None + grouping = Grouping.all( + Grouping.group("customer"), + Grouping.max(2), + Grouping.precision(12), + Grouping.order(-Grouping.count()), + Grouping.each(Grouping.output(Grouping.sum("price"))), + ) + q = qb.select("*").from_("purchase").where(True).groupby(grouping) + expected = "select * from purchase where true | all(group(customer) max(2) precision(12) order(-count()) each(output(sum(price))))" + self.assertEqual(q, expected) + return q + + def test_grouping_with_map_keys(self): + grouping = Grouping.all( + Grouping.group("mymap.key"), + Grouping.each( + Grouping.group("mymap.value"), + Grouping.each(Grouping.output(Grouping.count())), + ), + ) + q = qb.select("*").from_("purchase").where(True).groupby(grouping) + expected = "select * from purchase where true | all(group(mymap.key) each(group(mymap.value) each(output(count()))))" + self.assertEqual(q, expected) + return q + + def test_group_by_year(self): + grouping = Grouping.all( + Grouping.group("time.year(a)"), + Grouping.each(Grouping.output(Grouping.count())), + ) + q = qb.select("*").from_("purchase").where(True).groupby(grouping) + expected = "select * from purchase where true | all(group(time.year(a)) each(output(count())))" + self.assertEqual(q, expected) + return q + + def test_grouping_with_date_agg(self): + grouping = Grouping.all( + Grouping.group("time.year(a)"), + Grouping.each( + Grouping.output(Grouping.count()), + Grouping.all( + Grouping.group("time.monthofyear(a)"), + Grouping.each( + Grouping.output(Grouping.count()), + Grouping.all( + Grouping.group("time.dayofmonth(a)"), + Grouping.each( + Grouping.output(Grouping.count()), + Grouping.all( + Grouping.group("time.hourofday(a)"), + Grouping.each(Grouping.output(Grouping.count())), + ), + ), + ), + ), + ), + ), + ) + q = qb.select("*").from_("purchase").where(True).groupby(grouping) + expected = "select * from purchase where true | all(group(time.year(a)) each(output(count()) all(group(time.monthofyear(a)) each(output(count()) all(group(time.dayofmonth(a)) each(output(count()) all(group(time.hourofday(a)) each(output(count())))))))))" + # q select * from purchase where true | all(group(time.year(a)) each(output(count()) all(group(time.monthofyear(a)) each(output(count()) all(group(time.dayofmonth(a)) each(output(count()) all(group(time.hourofday(a)) each(output(count()))))))))) + # e select * from purchase where true | all(group(time.year(a)) each(output(count() all(group(time.monthofyear(a)) each(output(count()) all(group(time.dayofmonth(a)) each(output(count()) all(group(time.hourofday(a)) each(output(count()))))))))) + print(q) + print(expected) + self.assertEqual(q, expected) + return q + + def test_add_parameter(self): + f1 = qb.QueryField("title") + condition = f1.contains("Python") + q = ( + qb.select("*") + .from_("articles") + .where(condition) + .add_parameter("tracelevel", 1) + ) + expected = 'select * from articles where title contains "Python"&tracelevel=1' + self.assertEqual(q, expected) + return q + + def test_custom_ranking_expression(self): + condition = qb.rank( + qb.userQuery(), qb.dotProduct("embedding", {"feature1": 1, "feature2": 2}) + ) + q = qb.select("*").from_("documents").where(condition) + expected = 'select * from documents where rank(userQuery(), dotProduct(embedding, {"feature1":1,"feature2":2}))' + self.assertEqual(q, expected) + return q + + def test_wand(self): + condition = qb.wand("keywords", {"apple": 10, "banana": 20}) + q = qb.select("*").from_("fruits").where(condition) + expected = ( + 'select * from fruits where wand(keywords, {"apple":10, "banana":20})' + ) + self.assertEqual(q, expected) + return q + + def test_wand_numeric(self): + condition = qb.wand("description", [[11, 1], [37, 2]]) + q = qb.select("*").from_("fruits").where(condition) + expected = "select * from fruits where wand(description, [[11, 1], [37, 2]])" + self.assertEqual(q, expected) + return q + + def test_wand_annotations(self): + self.maxDiff = None + condition = qb.wand( + "description", + weights={"a": 1, "b": 2}, + annotations={"scoreThreshold": 0.13, "targetHits": 7}, + ) + q = qb.select("*").from_("fruits").where(condition) + expected = 'select * from fruits where ({scoreThreshold: 0.13, targetHits: 7}wand(description, {"a":1, "b":2}))' + print(q) + print(expected) + self.assertEqual(q, expected) + return q + + def test_weakand(self): + condition1 = qb.QueryField("title").contains("Python") + condition2 = qb.QueryField("description").contains("Programming") + condition = qb.weakAnd( + condition1, condition2, annotations={"targetNumHits": 100} + ) + q = qb.select("*").from_("articles").where(condition) + expected = 'select * from articles where ({"targetNumHits": 100}weakAnd(title contains "Python", description contains "Programming"))' + self.assertEqual(q, expected) + return q + + def test_geolocation(self): + condition = qb.geoLocation("location_field", 37.7749, -122.4194, "10km") + q = qb.select("*").from_("places").where(condition) + expected = 'select * from places where geoLocation(location_field, 37.7749, -122.4194, "10km")' + self.assertEqual(q, expected) + return q + + def test_condition_all_any(self): + c1 = qb.QueryField("f1").contains("v1") + c2 = qb.QueryField("f2").contains("v2") + c3 = qb.QueryField("f3").contains("v3") + condition = qb.all(c1, c2, qb.any(c3, ~c1)) + q = qb.select("*").from_("sd1").where(condition) + expected = 'select * from sd1 where f1 contains "v1" and f2 contains "v2" and (f3 contains "v3" or !(f1 contains "v1"))' + self.assertEqual(q, expected) + return q + + def test_order_by_with_annotations(self): + f1 = "relevance" + f2 = "price" + annotations = {"strength": 0.5} + q = qb.select("*").from_("products").orderByDesc(f1, annotations).orderByAsc(f2) + expected = ( + 'select * from products order by {"strength":0.5}relevance desc, price asc' + ) + self.assertEqual(q, expected) + return q + + def test_field_comparison_methods_builtins(self): + f1 = qb.QueryField("age") + condition = (f1 >= 18) & (f1 < 30) + q = qb.select("*").from_("users").where(condition) + expected = "select * from users where age >= 18 and age < 30" + self.assertEqual(q, expected) + return q + + def test_field_comparison_methods(self): + f1 = qb.QueryField("age") + condition = (f1.ge(18) & f1.lt(30)) | f1.eq(40) + q = qb.select("*").from_("users").where(condition) + expected = "select * from users where (age >= 18 and age < 30) or age = 40" + self.assertEqual(q, expected) + return q + + def test_filter_annotation(self): + f1 = qb.QueryField("title") + condition = f1.contains("Python").annotate({"filter": True}) + q = qb.select("*").from_("articles").where(condition) + expected = 'select * from articles where {filter:true}title contains "Python"' + self.assertEqual(q, expected) + return q + + def test_non_empty(self): + condition = qb.nonEmpty(qb.QueryField("comments").eq("any_value")) + q = qb.select("*").from_("posts").where(condition) + expected = 'select * from posts where nonEmpty(comments = "any_value")' + self.assertEqual(q, expected) + return q + + def test_dotproduct(self): + condition = qb.dotProduct("vector_field", {"feature1": 1, "feature2": 2}) + q = qb.select("*").from_("vectors").where(condition) + expected = 'select * from vectors where dotProduct(vector_field, {"feature1":1,"feature2":2})' + self.assertEqual(q, expected) + return q + + def test_in_range_string_values(self): + f1 = qb.QueryField("date") + condition = f1.in_range("2021-01-01", "2021-12-31") + q = qb.select("*").from_("events").where(condition) + expected = "select * from events where range(date, 2021-01-01, 2021-12-31)" + self.assertEqual(q, expected) + return q + + def test_condition_inversion(self): + f1 = qb.QueryField("status") + condition = ~f1.eq("inactive") + q = qb.select("*").from_("users").where(condition) + expected = 'select * from users where !(status = "inactive")' + self.assertEqual(q, expected) + return q + + def test_multiple_parameters(self): + f1 = qb.QueryField("title") + condition = f1.contains("Python") + q = ( + qb.select("*") + .from_("articles") + .where(condition) + .add_parameter("tracelevel", 1) + .add_parameter("language", "en") + ) + expected = 'select * from articles where title contains "Python"&tracelevel=1&language=en' + self.assertEqual(q, expected) + return q + + def test_multiple_groupings(self): + grouping = Grouping.all( + Grouping.group("category"), + Grouping.max(10), + Grouping.output(Grouping.count()), + Grouping.each( + Grouping.group("subcategory"), Grouping.output(Grouping.summary()) + ), + ) + q = qb.select("*").from_("products").groupby(grouping) + expected = "select * from products | all(group(category) max(10) output(count()) each(group(subcategory) output(summary())))" + self.assertEqual(q, expected) + return q + + def test_userquery_basic(self): + condition = qb.userQuery("search terms") + q = qb.select("*").from_("documents").where(condition) + expected = 'select * from documents where userQuery("search terms")' + self.assertEqual(q, expected) + return q + + def test_rank_multiple_conditions(self): + condition = qb.rank( + qb.userQuery(), + qb.dotProduct("embedding", {"feature1": 1}), + qb.weightedSet("tags", {"tag1": 2}), + ) + q = qb.select("*").from_("documents").where(condition) + expected = 'select * from documents where rank(userQuery(), dotProduct(embedding, {"feature1":1}), weightedSet(tags, {"tag1":2}))' + self.assertEqual(q, expected) + return q + + def test_non_empty_with_annotations(self): + annotated_field = qb.QueryField("comments").annotate({"filter": True}) + condition = qb.nonEmpty(annotated_field) + q = qb.select("*").from_("posts").where(condition) + expected = "select * from posts where nonEmpty(({filter:true})comments)" + self.assertEqual(q, expected) + return q + + def test_weight_annotation(self): + condition = qb.QueryField("title").contains( + "heads", annotations={"weight": 200} + ) + q = qb.select("*").from_("s1").where(condition) + expected = 'select * from s1 where title contains({weight:200}"heads")' + self.assertEqual(q, expected) + return q + + def test_nearest_neighbor_annotations(self): + condition = qb.nearestNeighbor( + field="dense_rep", query_vector="q_dense", annotations={"targetHits": 10} + ) + q = qb.select(["id, text"]).from_("m").where(condition) + expected = "select id, text from m where ({targetHits:10}nearestNeighbor(dense_rep, q_dense))" + self.assertEqual(q, expected) + return q + + def test_phrase(self): + text = qb.QueryField("text") + condition = text.contains(qb.phrase("st", "louis", "blues")) + query = qb.select("*").where(condition) + expected = 'select * from * where text contains phrase("st", "louis", "blues")' + self.assertEqual(query, expected) + return query + + def test_near(self): + title = qb.QueryField("title") + condition = title.contains(qb.near("madonna", "saint")) + query = qb.select("*").where(condition) + expected = 'select * from * where title contains near("madonna", "saint")' + self.assertEqual(query, expected) + return query + + def test_near_with_distance(self): + title = qb.QueryField("title") + condition = title.contains(qb.near("madonna", "saint", distance=10)) + query = qb.select("*").where(condition) + expected = 'select * from * where title contains ({distance:10}near("madonna", "saint"))' + self.assertEqual(query, expected) + return query + + def test_onear(self): + title = qb.QueryField("title") + condition = title.contains(qb.onear("madonna", "saint")) + query = qb.select("*").where(condition) + expected = 'select * from * where title contains onear("madonna", "saint")' + self.assertEqual(query, expected) + return query + + def test_onear_with_distance(self): + title = qb.QueryField("title") + condition = title.contains(qb.onear("madonna", "saint", distance=5)) + query = qb.select("*").where(condition) + expected = 'select * from * where title contains ({distance:5}onear("madonna", "saint"))' + self.assertEqual(query, expected) + return query + + def test_same_element(self): + persons = qb.QueryField("persons") + first_name = qb.QueryField("first_name") + last_name = qb.QueryField("last_name") + year_of_birth = qb.QueryField("year_of_birth") + condition = persons.contains( + qb.sameElement( + first_name.contains("Joe"), + last_name.contains("Smith"), + year_of_birth < 1940, + ) + ) + query = qb.select("*").from_("sd1").where(condition) + expected = 'select * from sd1 where persons contains sameElement(first_name contains "Joe", last_name contains "Smith", year_of_birth < 1940)' + self.assertEqual(query, expected) + return query + + def test_equiv(self): + fieldName = qb.QueryField("fieldName") + condition = fieldName.contains(qb.equiv("A", "B")) + query = qb.select("*").where(condition) + expected = 'select * from * where fieldName contains equiv("A", "B")' + self.assertEqual(query, expected) + return query + + def test_uri(self): + myUrlField = qb.QueryField("myUrlField") + condition = myUrlField.contains(qb.uri("vespa.ai/foo")) + query = qb.select("*").from_("sd1").where(condition) + expected = 'select * from sd1 where myUrlField contains uri("vespa.ai/foo")' + self.assertEqual(query, expected) + return query + + def test_fuzzy(self): + myStringAttribute = qb.QueryField("f1") + annotations = {"prefixLength": 1, "maxEditDistance": 2} + condition = myStringAttribute.contains( + qb.fuzzy("parantesis", annotations=annotations) + ) + query = qb.select("*").from_("sd1").where(condition) + expected = 'select * from sd1 where f1 contains ({prefixLength:1,maxEditDistance:2}fuzzy("parantesis"))' + self.assertEqual(query, expected) + return query + + def test_userinput(self): + condition = qb.userInput("@myvar") + query = qb.select("*").from_("sd1").where(condition) + expected = "select * from sd1 where userInput(@myvar)" + self.assertEqual(query, expected) + return query + + def test_userinput_param(self): + condition = qb.userInput("@animal") + query = qb.select("*").from_("sd1").where(condition).param("animal", "panda") + expected = "select * from sd1 where userInput(@animal)&animal=panda" + self.assertEqual(query, expected) + return query + + def test_userinput_with_defaultindex(self): + condition = qb.userInput("@myvar").annotate({"defaultIndex": "text"}) + query = qb.select("*").from_("sd1").where(condition) + expected = 'select * from sd1 where {defaultIndex:"text"}userInput(@myvar)' + self.assertEqual(query, expected) + return query + + def test_in_operator_intfield(self): + integer_field = qb.QueryField("age") + condition = integer_field.in_(10, 20, 30) + query = qb.select("*").from_("sd1").where(condition) + expected = "select * from sd1 where age in (10, 20, 30)" + self.assertEqual(query, expected) + return query + + def test_in_operator_stringfield(self): + string_field = qb.QueryField("status") + condition = string_field.in_("active", "inactive") + query = qb.select("*").from_("sd1").where(condition) + expected = 'select * from sd1 where status in ("active", "inactive")' + self.assertEqual(query, expected) + return query + + def test_predicate(self): + condition = qb.predicate( + "predicate_field", + attributes={"gender": "Female"}, + range_attributes={"age": "20L"}, + ) + query = qb.select("*").from_("sd1").where(condition) + expected = 'select * from sd1 where predicate(predicate_field,{"gender":"Female"},{"age":20L})' + self.assertEqual(query, expected) + return query + + def test_true(self): + query = qb.select("*").from_("sd1").where(True) + expected = "select * from sd1 where true" + self.assertEqual(query, expected) + return query + + def test_false(self): + query = qb.select("*").from_("sd1").where(False) + expected = "select * from sd1 where false" + self.assertEqual(query, expected) + return query + + +if __name__ == "__main__": + unittest.main() diff --git a/vespa/querybuilder/__init__.py b/vespa/querybuilder/__init__.py new file mode 100644 index 00000000..ecdad8bb --- /dev/null +++ b/vespa/querybuilder/__init__.py @@ -0,0 +1,34 @@ +from .builder.builder import Q, QueryField +from .grouping.grouping import Grouping +import inspect + +# Import original classes +# ...existing code... + +# Automatically expose all static methods from Q +for cls in [Q]: # do not expose G for now + for name, method in inspect.getmembers(cls, predicate=inspect.isfunction): + if not name.startswith("_"): + # Create function with same name and signature as the static method + globals()[name] = method + + +def get_function_members(cls): + return [ + name + for name, method in inspect.getmembers(cls, predicate=inspect.isfunction) + if not name.startswith("_") + ] + + +# Create __all__ list dynamically +__all__ = [ + # Classes + # "Query", + # "Q", + "QueryField", + "Grouping", + # "Condition", + # Add all exposed functions + *get_function_members(Q), +] diff --git a/vespa/querybuilder/builder/__init__.py b/vespa/querybuilder/builder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vespa/querybuilder/builder/builder.py b/vespa/querybuilder/builder/builder.py new file mode 100644 index 00000000..bad58251 --- /dev/null +++ b/vespa/querybuilder/builder/builder.py @@ -0,0 +1,573 @@ +from dataclasses import dataclass +from typing import Any, List, Union, Optional, Dict + + +@dataclass +class QueryField: + name: str + + def __eq__(self, other: Any) -> "Condition": + return Condition(f"{self.name} = {self._format_value(other)}") + + def __ne__(self, other: Any) -> "Condition": + return Condition(f"{self.name} != {self._format_value(other)}") + + def __lt__(self, other: Any) -> "Condition": + return Condition(f"{self.name} < {self._format_value(other)}") + + def __le__(self, other: Any) -> "Condition": + return Condition(f"{self.name} <= {self._format_value(other)}") + + def __gt__(self, other: Any) -> "Condition": + return Condition(f"{self.name} > {self._format_value(other)}") + + def __ge__(self, other: Any) -> "Condition": + return Condition(f"{self.name} >= {self._format_value(other)}") + + def __and__(self, other: Any) -> "Condition": + return Condition(f"{self.name} and {self._format_value(other)}") + + def __or__(self, other: Any) -> "Condition": + return Condition(f"{self.name} or {self._format_value(other)}") + + # repr as str + def __repr__(self) -> str: + return self.name + + def contains( + self, value: Any, annotations: Optional[Dict[str, Any]] = None + ) -> "Condition": + value_str = self._format_value(value) + if annotations: + annotations_str = ",".join( + f"{k}:{self._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition(f"{self.name} contains({{{annotations_str}}}{value_str})") + else: + return Condition(f"{self.name} contains {value_str}") + + def matches( + self, value: Any, annotations: Optional[Dict[str, Any]] = None + ) -> "Condition": + value_str = self._format_value(value) + if annotations: + annotations_str = ",".join( + f"{k}:{self._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition(f"{self.name} matches({{{annotations_str}}}{value_str})") + else: + return Condition(f"{self.name} matches {value_str}") + + def in_(self, *values) -> "Condition": + values_str = ", ".join( + f'"{v}"' if isinstance(v, str) else str(v) for v in values + ) + return Condition(f"{self.name} in ({values_str})") + + def in_range( + self, start: Any, end: Any, annotations: Optional[Dict[str, Any]] = None + ) -> "Condition": + if annotations: + annotations_str = ",".join( + f"{k}:{self._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition( + f"({{{annotations_str}}}range({self.name}, {start}, {end}))" + ) + else: + return Condition(f"range({self.name}, {start}, {end})") + + def le(self, value: Any) -> "Condition": + return self.__le__(value) + + def lt(self, value: Any) -> "Condition": + return self.__lt__(value) + + def ge(self, value: Any) -> "Condition": + return self.__ge__(value) + + def gt(self, value: Any) -> "Condition": + return self.__gt__(value) + + def eq(self, value: Any) -> "Condition": + return self.__eq__(value) + + def __str__(self) -> str: + return self.name + + @staticmethod + def _format_value(value: Any) -> str: + if isinstance(value, str): + return f'"{value}"' + elif isinstance(value, Condition): + return value.build() + else: + return str(value) + + @staticmethod + def _format_annotation_value(value: Any) -> str: + if isinstance(value, str): + return f'"{value}"' + elif isinstance(value, bool): + return str(value).lower() + elif isinstance(value, dict): + return ( + "{" + + ",".join( + f'"{k}":{QueryField._format_annotation_value(v)}' + for k, v in value.items() + ) + + "}" + ) + elif isinstance(value, list): + return ( + "[" + + ",".join(f"{QueryField._format_annotation_value(v)}" for v in value) + + "]" + ) + else: + return str(value) + + def annotate(self, annotations: Dict[str, Any]) -> "Condition": + annotations_str = ",".join( + f"{k}:{self._format_annotation_value(v)}" for k, v in annotations.items() + ) + return Condition(f"({{{annotations_str}}}){self.name}") + + +@dataclass +class Condition: + expression: str + + def __and__(self, other: "Condition") -> "Condition": + left = self.expression + right = other.expression + + # Adjust parentheses based on operator precedence + left = f"({left})" if " or " in left else left + right = f"({right})" if " or " in right else right + + return Condition(f"{left} and {right}") + + def __or__(self, other: "Condition") -> "Condition": + left = self.expression + right = other.expression + + # Always add parentheses if 'and' or 'or' is in the expressions + left = f"({left})" if " and " in left or " or " in left else left + right = f"({right})" if " and " in right or " or " in right else right + + return Condition(f"{left} or {right}") + + def __invert__(self) -> "Condition": + return Condition(f"!({self.expression})") + + def annotate(self, annotations: Dict[str, Any]) -> "Condition": + annotations_str = ",".join( + f"{k}:{QueryField._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition(f"{{{annotations_str}}}{self.expression}") + + def build(self) -> str: + return self.expression + + @classmethod + def all(cls, *conditions: "Condition") -> "Condition": + """Combine multiple conditions using logical AND.""" + expressions = [] + for cond in conditions: + expr = cond.expression + # Wrap expressions with 'or' in parentheses + if " or " in expr: + expr = f"({expr})" + expressions.append(expr) + combined_expression = " and ".join(expressions) + return Condition(combined_expression) + + @classmethod + def any(cls, *conditions: "Condition") -> "Condition": + """Combine multiple conditions using logical OR.""" + expressions = [] + for cond in conditions: + expr = cond.expression + # Wrap expressions with 'and' or 'or' in parentheses + if " and " in expr or " or " in expr: + expr = f"({expr})" + expressions.append(expr) + combined_expression = " or ".join(expressions) + return Condition(combined_expression) + + +class Query: + def __init__( + self, select_fields: Union[str, List[str], List[QueryField]], prepend_yql=False + ): + self.select_fields = ( + ", ".join(select_fields) + if isinstance(select_fields, List) + and all(isinstance(f, str) for f in select_fields) + else ", ".join(str(f) for f in select_fields) + ) + self.sources = "*" + self.condition = None + self.order_by_clauses = [] + self.limit_value = None + self.offset_value = None + self.timeout_value = None + self.parameters = {} + self.grouping = None + self.prepend_yql = prepend_yql + + def __str__(self) -> str: + return self.build(self.prepend_yql) + + def __eq__(self, other: Any) -> bool: + return self.build() == other + + def __ne__(self, other: Any) -> bool: + return self.build() != other + + def __repr__(self) -> str: + return str(self) + + def from_(self, *sources: str) -> "Query": + self.sources = ", ".join(sources) + return self + + def where(self, condition: Union[Condition, QueryField, bool]) -> "Query": + if isinstance(condition, QueryField): + self.condition = condition + elif isinstance(condition, bool): + self.condition = Condition("true") if condition else Condition("false") + else: + self.condition = condition + return self + + def order_by_field( + self, + field: str, + ascending: bool = True, + annotations: Optional[Dict[str, Any]] = None, + ) -> "Query": + direction = "asc" if ascending else "desc" + if annotations: + annotations_str = ",".join( + f'"{k}":{QueryField._format_annotation_value(v)}' + for k, v in annotations.items() + ) + self.order_by_clauses.append(f"{{{annotations_str}}}{field} {direction}") + else: + self.order_by_clauses.append(f"{field} {direction}") + return self + + def orderByAsc( + self, field: str, annotations: Optional[Dict[str, Any]] = None + ) -> "Query": + return self.order_by_field(field, True, annotations) + + def orderByDesc( + self, field: str, annotations: Optional[Dict[str, Any]] = None + ) -> "Query": + return self.order_by_field(field, False, annotations) + + def set_limit(self, limit: int) -> "Query": + self.limit_value = limit + return self + + def set_offset(self, offset: int) -> "Query": + self.offset_value = offset + return self + + def set_timeout(self, timeout: int) -> "Query": + self.timeout_value = timeout + return self + + def add_parameter(self, key: str, value: Any) -> "Query": + self.parameters[key] = value + return self + + def param(self, key: str, value: Any) -> "Query": + return self.add_parameter(key, value) + + def groupby(self, group_expression: str) -> "Query": + self.grouping = group_expression + return self + + def build(self, prepend_yql=False) -> str: + query = f"select {self.select_fields} from {self.sources}" + if prepend_yql: + query = f"yql={query}" + if self.condition: + query += f" where {self.condition.build()}" + if self.order_by_clauses: + query += " order by " + ", ".join(self.order_by_clauses) + if self.limit_value is not None: + query += f" limit {self.limit_value}" + if self.offset_value is not None: + query += f" offset {self.offset_value}" + if self.timeout_value is not None: + query += f" timeout {self.timeout_value}" + if self.grouping: + query += f" | {self.grouping}" + if self.parameters: + params = "&" + "&".join(f"{k}={v}" for k, v in self.parameters.items()) + query += params + return query + + +class Q: + @staticmethod + def select(fields): + return Query(select_fields=fields) + + @staticmethod + def any(*conditions): + return Condition.any(*conditions) + + @staticmethod + def all(*conditions): + return Condition.all(*conditions) + + @staticmethod + def userQuery(value: str = "") -> Condition: + return Condition(f'userQuery("{value}")') if value else Condition("userQuery()") + + @staticmethod + def dotProduct( + field: str, vector: Dict[str, int], annotations: Optional[Dict[str, Any]] = None + ) -> Condition: + vector_str = "{" + ",".join(f'"{k}":{v}' for k, v in vector.items()) + "}" + expr = f"dotProduct({field}, {vector_str})" + if annotations: + annotations_str = ",".join( + f"{k}:{QueryField._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def weightedSet( + field: str, vector: Dict[str, int], annotations: Optional[Dict[str, Any]] = None + ) -> Condition: + vector_str = "{" + ",".join(f'"{k}":{v}' for k, v in vector.items()) + "}" + expr = f"weightedSet({field}, {vector_str})" + if annotations: + annotations_str = ",".join( + f"{k}:{QueryField._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def nonEmpty(condition: Union[Condition, QueryField]) -> Condition: + if isinstance(condition, QueryField): + expr = str(condition) + else: + expr = condition.build() + return Condition(f"nonEmpty({expr})") + + @staticmethod + def wand( + field: str, weights, annotations: Optional[Dict[str, Any]] = None + ) -> Condition: + if isinstance(weights, list): + weights_str = "[" + ", ".join(str(item) for item in weights) + "]" + elif isinstance(weights, dict): + weights_str = ( + "{" + ", ".join(f'"{k}":{v}' for k, v in weights.items()) + "}" + ) + else: + raise ValueError("Invalid weights for wand") + expr = f"wand({field}, {weights_str})" + if annotations: + annotations_str = ", ".join( + f"{k}: {QueryField._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def weakAnd(*conditions, annotations: Dict[str, Any] = None) -> Condition: + conditions_str = ", ".join(cond.build() for cond in conditions) + expr = f"weakAnd({conditions_str})" + if annotations: + annotations_str = ",".join( + f'"{k}": {QueryField._format_annotation_value(v)}' + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def geoLocation( + field: str, + lat: float, + lng: float, + radius: str, + annotations: Optional[Dict[str, Any]] = None, + ) -> Condition: + expr = f'geoLocation({field}, {lat}, {lng}, "{radius}")' + if annotations: + annotations_str = ",".join( + f"{k}:{QueryField._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def nearestNeighbor( + field: str, query_vector: str, annotations: Dict[str, Any] + ) -> Condition: + if "targetHits" not in annotations: + raise ValueError("targetHits annotation is required") + annotations_str = ",".join( + f"{k}:{QueryField._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition( + f"({{{annotations_str}}}nearestNeighbor({field}, {query_vector}))" + ) + + @staticmethod + def rank(*queries) -> Condition: + queries_str = ", ".join(query.build() for query in queries) + return Condition(f"rank({queries_str})") + + @staticmethod + def phrase(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: + terms_str = ", ".join(f'"{term}"' for term in terms) + expr = f"phrase({terms_str})" + if annotations: + annotations_str = ",".join( + f"{k}:{QueryField._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def near( + *terms, annotations: Optional[Dict[str, Any]] = None, **kwargs + ) -> Condition: + terms_str = ", ".join(f'"{term}"' for term in terms) + expr = f"near({terms_str})" + # if kwargs - add to annotations + if kwargs: + if not annotations: + annotations = {} + annotations.update(kwargs) + if annotations: + annotations_str = ", ".join( + f"{k}:{QueryField._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def onear( + *terms, annotations: Optional[Dict[str, Any]] = None, **kwargs + ) -> Condition: + terms_str = ", ".join(f'"{term}"' for term in terms) + expr = f"onear({terms_str})" + # if kwargs - add to annotations + if kwargs: + if not annotations: + annotations = {} + annotations.update(kwargs) + if annotations: + annotations_str = ",".join( + f"{k}:{QueryField._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def sameElement(*conditions) -> Condition: + conditions_str = ", ".join(cond.build() for cond in conditions) + expr = f"sameElement({conditions_str})" + return Condition(expr) + + @staticmethod + def equiv(*terms) -> Condition: + terms_str = ", ".join(f'"{term}"' for term in terms) + expr = f"equiv({terms_str})" + return Condition(expr) + + @staticmethod + def uri(value: str, annotations: Optional[Dict[str, Any]] = None) -> Condition: + expr = f'uri("{value}")' + if annotations: + annotations_str = ",".join( + f"{k}:{QueryField._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def fuzzy(value: str, annotations: Optional[Dict[str, Any]] = None) -> Condition: + expr = f'fuzzy("{value}")' + if annotations: + annotations_str = ",".join( + f"{k}:{QueryField._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def userInput( + value: Optional[str] = None, annotations: Optional[Dict[str, Any]] = None + ) -> Condition: + if value is None: + expr = "userInput()" + elif value.startswith("@"): + expr = f"userInput({value})" + else: + expr = f'userInput("{value}")' + if annotations: + annotations_str = ",".join( + f"{k}:{QueryField._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def predicate( + field: str, + attributes: Optional[Dict[str, Any]] = None, + range_attributes: Optional[Dict[str, Any]] = None, + ) -> Condition: + if attributes is None: + attributes_str = "0" + else: + attributes_str = ( + "{" + ",".join(f'"{k}":"{v}"' for k, v in attributes.items()) + "}" + ) + if range_attributes is None: + range_attributes_str = "0" + else: + range_attributes_str = ( + "{" + ",".join(f'"{k}":{v}' for k, v in range_attributes.items()) + "}" + ) + expr = f"predicate({field},{attributes_str},{range_attributes_str})" + return Condition(expr) + + @staticmethod + def true() -> Condition: + return Condition("true") + + @staticmethod + def false() -> Condition: + return Condition("false") diff --git a/vespa/querybuilder/grouping/__init__.py b/vespa/querybuilder/grouping/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vespa/querybuilder/grouping/grouping.py b/vespa/querybuilder/grouping/grouping.py new file mode 100644 index 00000000..ea75570f --- /dev/null +++ b/vespa/querybuilder/grouping/grouping.py @@ -0,0 +1,67 @@ +from typing import Union + + +class Grouping: + @staticmethod + def all(*args) -> str: + return "all(" + " ".join(args) + ")" + + @staticmethod + def group(field: str) -> str: + return f"group({field})" + + @staticmethod + def max(value: Union[int, float]) -> str: + return f"max({value})" + + @staticmethod + def precision(value: int) -> str: + return f"precision({value})" + + @staticmethod + def min(value: Union[int, float]) -> str: + return f"min({value})" + + @staticmethod + def sum(value: Union[int, float]) -> str: + return f"sum({value})" + + @staticmethod + def avg(value: Union[int, float]) -> str: + return f"avg({value})" + + @staticmethod + def stddev(value: Union[int, float]) -> str: + return f"stddev({value})" + + @staticmethod + def xor(value: str) -> str: + return f"xor({value})" + + @staticmethod + def each(*args) -> str: + return "each(" + " ".join(args) + ")" + + @staticmethod + def output(output_func: str) -> str: + return f"output({output_func})" + + @staticmethod + def count() -> str: + # Also need to handle negative count + class MaybenegativeCount(str): + def __new__(cls, value): + return super().__new__(cls, value) + + def __neg__(self): + return f"-{self}" + + return MaybenegativeCount("count()") + + @staticmethod + def order(value: str) -> str: + return f"order({value})" + + @staticmethod + def summary() -> str: + return "summary()" diff --git a/vespa/querybuilder/main.py b/vespa/querybuilder/main.py new file mode 100644 index 00000000..de92f382 --- /dev/null +++ b/vespa/querybuilder/main.py @@ -0,0 +1,571 @@ +from dataclasses import dataclass +from typing import Any, List, Union, Optional, Dict + + +@dataclass +class Queryfield: + name: str + + def __eq__(self, other: Any) -> "Condition": + return Condition(f"{self.name} = {self._format_value(other)}") + + def __ne__(self, other: Any) -> "Condition": + return Condition(f"{self.name} != {self._format_value(other)}") + + def __lt__(self, other: Any) -> "Condition": + return Condition(f"{self.name} < {self._format_value(other)}") + + def __le__(self, other: Any) -> "Condition": + return Condition(f"{self.name} <= {self._format_value(other)}") + + def __gt__(self, other: Any) -> "Condition": + return Condition(f"{self.name} > {self._format_value(other)}") + + def __ge__(self, other: Any) -> "Condition": + return Condition(f"{self.name} >= {self._format_value(other)}") + + def __and__(self, other: Any) -> "Condition": + return Condition(f"{self.name} and {self._format_value(other)}") + + def __or__(self, other: Any) -> "Condition": + return Condition(f"{self.name} or {self._format_value(other)}") + + def contains( + self, value: Any, annotations: Optional[Dict[str, Any]] = None + ) -> "Condition": + value_str = self._format_value(value) + if annotations: + annotations_str = ",".join( + f"{k}:{self._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition(f"{self.name} contains({{{annotations_str}}}{value_str})") + else: + return Condition(f"{self.name} contains {value_str}") + + def matches( + self, value: Any, annotations: Optional[Dict[str, Any]] = None + ) -> "Condition": + value_str = self._format_value(value) + if annotations: + annotations_str = ",".join( + f"{k}:{self._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition(f"{self.name} matches({{{annotations_str}}}{value_str})") + else: + return Condition(f"{self.name} matches {value_str}") + + def in_(self, *values) -> "Condition": + values_str = ", ".join( + f'"{v}"' if isinstance(v, str) else str(v) for v in values + ) + return Condition(f"{self.name} in ({values_str})") + + def in_range( + self, start: Any, end: Any, annotations: Optional[Dict[str, Any]] = None + ) -> "Condition": + if annotations: + annotations_str = ",".join( + f"{k}:{self._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition( + f"({{{annotations_str}}}range({self.name}, {start}, {end}))" + ) + else: + return Condition(f"range({self.name}, {start}, {end})") + + def le(self, value: Any) -> "Condition": + return self.__le__(value) + + def lt(self, value: Any) -> "Condition": + return self.__lt__(value) + + def ge(self, value: Any) -> "Condition": + return self.__ge__(value) + + def gt(self, value: Any) -> "Condition": + return self.__gt__(value) + + def eq(self, value: Any) -> "Condition": + return self.__eq__(value) + + def __str__(self) -> str: + return self.name + + @staticmethod + def _format_value(value: Any) -> str: + if isinstance(value, str): + return f'"{value}"' + elif isinstance(value, Condition): + return value.build() + else: + return str(value) + + @staticmethod + def _format_annotation_value(value: Any) -> str: + if isinstance(value, str): + return f'"{value}"' + elif isinstance(value, bool): + return str(value).lower() + elif isinstance(value, dict): + return ( + "{" + + ",".join( + f'"{k}":{Queryfield._format_annotation_value(v)}' + for k, v in value.items() + ) + + "}" + ) + elif isinstance(value, list): + return ( + "[" + + ",".join(f"{Queryfield._format_annotation_value(v)}" for v in value) + + "]" + ) + else: + return str(value) + + def annotate(self, annotations: Dict[str, Any]) -> "Condition": + annotations_str = ",".join( + f"{k}:{self._format_annotation_value(v)}" for k, v in annotations.items() + ) + return Condition(f"({{{annotations_str}}}){self.name}") + + +@dataclass +class Condition: + expression: str + + def __and__(self, other: "Condition") -> "Condition": + left = self.expression + right = other.expression + + # Adjust parentheses based on operator precedence + left = f"({left})" if " or " in left else left + right = f"({right})" if " or " in right else right + + return Condition(f"{left} and {right}") + + def __or__(self, other: "Condition") -> "Condition": + left = self.expression + right = other.expression + + # Always add parentheses if 'and' or 'or' is in the expressions + left = f"({left})" if " and " in left or " or " in left else left + right = f"({right})" if " and " in right or " or " in right else right + + return Condition(f"{left} or {right}") + + def __invert__(self) -> "Condition": + return Condition(f"!({self.expression})") + + def annotate(self, annotations: Dict[str, Any]) -> "Condition": + annotations_str = ",".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition(f"{{{annotations_str}}}{self.expression}") + + def build(self) -> str: + return self.expression + + @classmethod + def all(cls, *conditions: "Condition") -> "Condition": + """Combine multiple conditions using logical AND.""" + expressions = [] + for cond in conditions: + expr = cond.expression + # Wrap expressions with 'or' in parentheses + if " or " in expr: + expr = f"({expr})" + expressions.append(expr) + combined_expression = " and ".join(expressions) + return Condition(combined_expression) + + @classmethod + def any(cls, *conditions: "Condition") -> "Condition": + """Combine multiple conditions using logical OR.""" + expressions = [] + for cond in conditions: + expr = cond.expression + # Wrap expressions with 'and' or 'or' in parentheses + if " and " in expr or " or " in expr: + expr = f"({expr})" + expressions.append(expr) + combined_expression = " or ".join(expressions) + return Condition(combined_expression) + + +class Query: + def __init__( + self, select_fields: Union[str, List[str], List[Queryfield]], prepend_yql=False + ): + self.select_fields = ( + ", ".join(select_fields) + if isinstance(select_fields, List) + and all(isinstance(f, str) for f in select_fields) + else ", ".join(str(f) for f in select_fields) + ) + self.sources = "*" + self.condition = None + self.order_by_clauses = [] + self.limit_value = None + self.offset_value = None + self.timeout_value = None + self.parameters = {} + self.grouping = None + self.prepend_yql = prepend_yql + + def __str__(self) -> str: + return self.build(self.prepend_yql) + + def __eq__(self, other: Any) -> bool: + return self.build() == other + + def __ne__(self, other: Any) -> bool: + return self.build() != other + + def __repr__(self) -> str: + return str(self) + + def from_(self, *sources: str) -> "Query": + self.sources = ", ".join(sources) + return self + + def where(self, condition: Union[Condition, Queryfield, bool]) -> "Query": + if isinstance(condition, Queryfield): + self.condition = condition + elif isinstance(condition, bool): + self.condition = Condition("true") if condition else Condition("false") + else: + self.condition = condition + return self + + def order_by_field( + self, + field: str, + ascending: bool = True, + annotations: Optional[Dict[str, Any]] = None, + ) -> "Query": + direction = "asc" if ascending else "desc" + if annotations: + annotations_str = ",".join( + f'"{k}":{Queryfield._format_annotation_value(v)}' + for k, v in annotations.items() + ) + self.order_by_clauses.append(f"{{{annotations_str}}}{field} {direction}") + else: + self.order_by_clauses.append(f"{field} {direction}") + return self + + def orderByAsc( + self, field: str, annotations: Optional[Dict[str, Any]] = None + ) -> "Query": + return self.order_by_field(field, True, annotations) + + def orderByDesc( + self, field: str, annotations: Optional[Dict[str, Any]] = None + ) -> "Query": + return self.order_by_field(field, False, annotations) + + def set_limit(self, limit: int) -> "Query": + self.limit_value = limit + return self + + def set_offset(self, offset: int) -> "Query": + self.offset_value = offset + return self + + def set_timeout(self, timeout: int) -> "Query": + self.timeout_value = timeout + return self + + def add_parameter(self, key: str, value: Any) -> "Query": + self.parameters[key] = value + return self + + def param(self, key: str, value: Any) -> "Query": + return self.add_parameter(key, value) + + def groupby(self, group_expression: str) -> "Query": + self.grouping = group_expression + return self + + def build(self, prepend_yql=False) -> str: + query = f"select {self.select_fields} from {self.sources}" + if prepend_yql: + query = f"yql={query}" + if self.condition: + query += f" where {self.condition.build()}" + if self.order_by_clauses: + query += " order by " + ", ".join(self.order_by_clauses) + if self.limit_value is not None: + query += f" limit {self.limit_value}" + if self.offset_value is not None: + query += f" offset {self.offset_value}" + if self.timeout_value is not None: + query += f" timeout {self.timeout_value}" + if self.grouping: + query += f" | {self.grouping}" + if self.parameters: + params = "&" + "&".join(f"{k}={v}" for k, v in self.parameters.items()) + query += params + return query + + +class Q: + @staticmethod + def select(*fields): + return Query(select_fields=list(fields)) + + @staticmethod + def p(*args): + if not args: + return Condition("") + else: + condition = args[0] + for arg in args[1:]: + condition = condition & arg + return condition + + @staticmethod + def userQuery(value: str = "") -> Condition: + return Condition(f'userQuery("{value}")') if value else Condition("userQuery()") + + @staticmethod + def dotProduct( + field: str, vector: Dict[str, int], annotations: Optional[Dict[str, Any]] = None + ) -> Condition: + vector_str = "{" + ",".join(f'"{k}":{v}' for k, v in vector.items()) + "}" + expr = f"dotProduct({field}, {vector_str})" + if annotations: + annotations_str = ",".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def weightedSet( + field: str, vector: Dict[str, int], annotations: Optional[Dict[str, Any]] = None + ) -> Condition: + vector_str = "{" + ",".join(f'"{k}":{v}' for k, v in vector.items()) + "}" + expr = f"weightedSet({field}, {vector_str})" + if annotations: + annotations_str = ",".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def nonEmpty(condition: Union[Condition, Queryfield]) -> Condition: + if isinstance(condition, Queryfield): + expr = str(condition) + else: + expr = condition.build() + return Condition(f"nonEmpty({expr})") + + @staticmethod + def wand( + field: str, weights, annotations: Optional[Dict[str, Any]] = None + ) -> Condition: + if isinstance(weights, list): + weights_str = "[" + ", ".join(str(item) for item in weights) + "]" + elif isinstance(weights, dict): + weights_str = ( + "{" + ", ".join(f'"{k}":{v}' for k, v in weights.items()) + "}" + ) + else: + raise ValueError("Invalid weights for wand") + expr = f"wand({field}, {weights_str})" + if annotations: + annotations_str = ", ".join( + f"{k}: {Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def weakAnd(*conditions, annotations: Dict[str, Any] = None) -> Condition: + conditions_str = ", ".join(cond.build() for cond in conditions) + expr = f"weakAnd({conditions_str})" + if annotations: + annotations_str = ",".join( + f'"{k}": {Queryfield._format_annotation_value(v)}' + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def geoLocation( + field: str, + lat: float, + lng: float, + radius: str, + annotations: Optional[Dict[str, Any]] = None, + ) -> Condition: + expr = f'geoLocation({field}, {lat}, {lng}, "{radius}")' + if annotations: + annotations_str = ",".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def nearestNeighbor( + field: str, query_vector: str, annotations: Dict[str, Any] + ) -> Condition: + if "targetHits" not in annotations: + raise ValueError("targetHits annotation is required") + annotations_str = ",".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + return Condition( + f"({{{annotations_str}}}nearestNeighbor({field}, {query_vector}))" + ) + + @staticmethod + def rank(*queries) -> Condition: + queries_str = ", ".join(query.build() for query in queries) + return Condition(f"rank({queries_str})") + + @staticmethod + def phrase(*terms, annotations: Optional[Dict[str, Any]] = None) -> Condition: + terms_str = ", ".join(f'"{term}"' for term in terms) + expr = f"phrase({terms_str})" + if annotations: + annotations_str = ",".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def near( + *terms, annotations: Optional[Dict[str, Any]] = None, **kwargs + ) -> Condition: + terms_str = ", ".join(f'"{term}"' for term in terms) + expr = f"near({terms_str})" + # if kwargs - add to annotations + if kwargs: + if not annotations: + annotations = {} + annotations.update(kwargs) + if annotations: + annotations_str = ", ".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def onear( + *terms, annotations: Optional[Dict[str, Any]] = None, **kwargs + ) -> Condition: + terms_str = ", ".join(f'"{term}"' for term in terms) + expr = f"onear({terms_str})" + # if kwargs - add to annotations + if kwargs: + if not annotations: + annotations = {} + annotations.update(kwargs) + if annotations: + annotations_str = ",".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def sameElement(*conditions) -> Condition: + conditions_str = ", ".join(cond.build() for cond in conditions) + expr = f"sameElement({conditions_str})" + return Condition(expr) + + @staticmethod + def equiv(*terms) -> Condition: + terms_str = ", ".join(f'"{term}"' for term in terms) + expr = f"equiv({terms_str})" + return Condition(expr) + + @staticmethod + def uri(value: str, annotations: Optional[Dict[str, Any]] = None) -> Condition: + expr = f'uri("{value}")' + if annotations: + annotations_str = ",".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def fuzzy(value: str, annotations: Optional[Dict[str, Any]] = None) -> Condition: + expr = f'fuzzy("{value}")' + if annotations: + annotations_str = ",".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def userInput( + value: Optional[str] = None, annotations: Optional[Dict[str, Any]] = None + ) -> Condition: + if value is None: + expr = "userInput()" + elif value.startswith("@"): + expr = f"userInput({value})" + else: + expr = f'userInput("{value}")' + if annotations: + annotations_str = ",".join( + f"{k}:{Queryfield._format_annotation_value(v)}" + for k, v in annotations.items() + ) + expr = f"({{{annotations_str}}}{expr})" + return Condition(expr) + + @staticmethod + def predicate( + field: str, + attributes: Optional[Dict[str, Any]] = None, + range_attributes: Optional[Dict[str, Any]] = None, + ) -> Condition: + if attributes is None: + attributes_str = "0" + else: + attributes_str = ( + "{" + ",".join(f'"{k}":"{v}"' for k, v in attributes.items()) + "}" + ) + if range_attributes is None: + range_attributes_str = "0" + else: + range_attributes_str = ( + "{" + ",".join(f'"{k}":{v}' for k, v in range_attributes.items()) + "}" + ) + expr = f"predicate({field},{attributes_str},{range_attributes_str})" + return Condition(expr) + + @staticmethod + def true() -> Condition: + return Condition("true") + + @staticmethod + def false() -> Condition: + return Condition("false")