Skip to content
This repository has been archived by the owner on Sep 11, 2024. It is now read-only.

Commit

Permalink
Merge branch 'main' into bugfix/numpy-dep
Browse files Browse the repository at this point in the history
Merge commit.
  • Loading branch information
Mark committed Mar 26, 2024
2 parents 9347a71 + 32beab1 commit a11c84f
Show file tree
Hide file tree
Showing 9 changed files with 398 additions and 225 deletions.
38 changes: 38 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,44 @@ request = SearchParameters(
)
```

### Search within families or documents

A subset of families or documents can be retrieved for search using their ids
```python
request = SearchParameters(
query_string="forest fires",
family_ids=["CCLW.family.10121.0", "CCLW.family.4980.0"],
)
```

```python
request = SearchParameters(
query_string="forest fires",
document_ids=["CCLW.executive.10121.4637", "CCLW.legislative.4980.1745"],
)
```

### Types of query
The default search approach uses a nearest neighbour search ranking.

Its also possible to search for exact matches instead:

```python
request = SearchParameters(
query_string="forest fires",
exact_match=True,
)
```

Or to ignore the query string and search the whole database instead:
```python
request = SearchParameters(
year_range=(2020, 2024),
sort_by="date",
sort_order="descending",
)
```

### Continuing results

The response objects include continuation tokens, which can be used to get more results.
Expand Down
47 changes: 38 additions & 9 deletions src/cpr_data_access/models/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,17 @@
import re
from typing import List, Optional, Sequence

from pydantic import BaseModel, ConfigDict, field_validator
from pydantic import (
BaseModel,
computed_field,
ConfigDict,
field_validator,
model_validator,
)

from cpr_data_access.exceptions import QueryError

sort_orders = ["ascending", "descending"]
sort_orders = {"ascending": "+", "descending": "-"}

sort_fields = {
"date": "family_publication_ts",
Expand Down Expand Up @@ -53,8 +59,9 @@ def sanitise_filter_inputs(cls, field):
class SearchParameters(BaseModel):
"""Parameters for a search request"""

query_string: str
query_string: Optional[str] = None
exact_match: bool = False
all_results: bool = False
limit: int = 100
max_hits_per_family: int = 10

Expand All @@ -69,6 +76,13 @@ class SearchParameters(BaseModel):

continuation_tokens: Optional[Sequence[str]] = None

@model_validator(mode="after")
def validate(self):
"""Validate against mutually exclusive fields"""
if self.exact_match and self.all_results:
raise QueryError("`exact_match` and `all_results` are mutually exclusive")
return self

@field_validator("continuation_tokens")
def continuation_tokens_must_be_upper_strings(cls, continuation_tokens):
"""Validate continuation_tokens match the expected format"""
Expand All @@ -86,12 +100,12 @@ def continuation_tokens_must_be_upper_strings(cls, continuation_tokens):
)
return continuation_tokens

@field_validator("query_string")
def query_string_must_not_be_empty(cls, query_string):
@model_validator(mode="after")
def query_string_must_not_be_empty(self):
"""Validate that the query string is not empty."""
if query_string == "":
raise QueryError("query_string must not be empty")
return query_string
if not self.query_string:
self.all_results = True
return self

@field_validator("family_ids", "document_ids")
def ids_must_fit_pattern(cls, ids):
Expand Down Expand Up @@ -136,13 +150,26 @@ def sort_by_must_be_valid(cls, sort_by):
@field_validator("sort_order")
def sort_order_must_be_valid(cls, sort_order):
"""Validate that the sort order is valid."""
if sort_order not in ["ascending", "descending"]:
if sort_order not in sort_orders:
raise QueryError(
f"Invalid sort order: {sort_order}. sort_order must be one of: "
f"{sort_orders}"
)
return sort_order

@computed_field
def vespa_sort_by(self) -> Optional[str]:
"""Translates sort by into the format acceptable by vespa"""
if self.sort_by:
return sort_fields.get(self.sort_by)
else:
return None

@computed_field
def vespa_sort_order(self) -> Optional[str]:
"""Translates sort order into the format acceptable by vespa"""
return sort_orders.get(self.sort_order)


class Hit(BaseModel):
"""Common model for all search result hits."""
Expand Down Expand Up @@ -277,6 +304,7 @@ class Family(BaseModel):
hits: Sequence[Hit]
total_passage_hits: int = 0
continuation_token: Optional[str] = None
prev_continuation_token: Optional[str] = None


class SearchResponse(BaseModel):
Expand All @@ -289,3 +317,4 @@ class SearchResponse(BaseModel):
families: Sequence[Family]
continuation_token: Optional[str] = None
this_continuation_token: Optional[str] = None
prev_continuation_token: Optional[str] = None
20 changes: 9 additions & 11 deletions src/cpr_data_access/vespa.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
Hit,
SearchParameters,
SearchResponse,
sort_fields,
)
from cpr_data_access.embedding import Embedder
from cpr_data_access.exceptions import FetchError
Expand Down Expand Up @@ -83,7 +82,9 @@ def build_vespa_request_body(
"query_string": parameters.query_string,
}

if parameters.exact_match:
if parameters.all_results:
pass
elif parameters.exact_match:
vespa_request_body["ranking.profile"] = "exact"
elif sensitive:
vespa_request_body["ranking.profile"] = "hybrid_no_closeness"
Expand Down Expand Up @@ -122,6 +123,7 @@ def parse_vespa_response(
total_passage_hits = dig(family, "fields", "count()")
family_hits: List[Hit] = []
passages_continuation = dig(family, "children", 0, "continuation", "next")
prev_passages_continuation = dig(family, "children", 0, "continuation", "prev")
for hit in dig(family, "children", 0, "children", default=[]):
family_hits.append(Hit.from_vespa_response(response_hit=hit))
families.append(
Expand All @@ -130,21 +132,16 @@ def parse_vespa_response(
hits=family_hits,
total_passage_hits=total_passage_hits,
continuation_token=passages_continuation,
prev_continuation_token=prev_passages_continuation,
)
)

# For now, we can't sort our results natively in vespa because sort orders are
# applied _before_ grouping. We're sorting here instead.
if request.sort_by is not None:
sort_field = sort_fields[request.sort_by]
families.sort(
key=lambda f: getattr(f.hits[0], sort_field),
reverse=request.sort_order == "descending",
)

next_family_continuation = dig(
root, "children", 0, "children", 0, "continuation", "next"
)
prev_family_continuation = dig(
root, "children", 0, "children", 0, "continuation", "prev"
)
this_family_continuation = dig(root, "children", 0, "continuation", "this")
total_hits = dig(root, "fields", "totalCount", default=0)
total_family_hits = dig(root, "children", 0, "fields", "count()", default=0)
Expand All @@ -154,6 +151,7 @@ def parse_vespa_response(
families=families,
continuation_token=next_family_continuation,
this_continuation_token=this_family_continuation,
prev_continuation_token=prev_family_continuation,
query_time_ms=None,
total_time_ms=None,
)
Expand Down
13 changes: 13 additions & 0 deletions src/cpr_data_access/yql_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class YQLBuilder:
group(family_import_id)
output(count())
max($LIMIT)
$SORT
each(
output(count())
max($MAX_HITS_PER_FAMILY)
Expand All @@ -37,6 +38,8 @@ def __init__(self, params: SearchParameters, sensitive: bool = False) -> None:

def build_search_term(self) -> str:
"""Create the part of the query that matches a users search text"""
if self.params.all_results:
return "( true )"
if self.params.exact_match:
return """
(
Expand Down Expand Up @@ -141,6 +144,15 @@ def build_limit(self) -> int:
"""Create the part of the query limiting the number of families returned"""
return self.params.limit

def build_sort(self) -> str:
"""Creates the part of the query used for sorting by different fields"""
sort_by = self.params.vespa_sort_by
sort_order = self.params.vespa_sort_order

if not sort_by or not sort_order:
return ""
return f"order({sort_order}max({sort_by}))"

def build_max_hits_per_family(self) -> int:
"""Create the part of the query limiting passages within a family returned"""
return self.params.max_hits_per_family
Expand All @@ -151,6 +163,7 @@ def to_str(self) -> str:
WHERE_CLAUSE=self.build_where_clause(),
CONTINUATION=self.build_continuation(),
LIMIT=self.build_limit(),
SORT=self.build_sort(),
MAX_HITS_PER_FAMILY=self.build_max_hits_per_family(),
)
return " ".join(yql.split())
Expand Down
28 changes: 19 additions & 9 deletions tests/test_data/search_responses/search_response.json
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@
"relevance": 1.0,
"label": "family_import_id",
"continuation": {
"next": "BGAAABECBEBC"
"next": "BGAAABECBEBC",
"prev": "BGAAAAAABEBC"

},
"children": [
{
Expand All @@ -45,7 +47,8 @@
"relevance": 1.0,
"label": "hits",
"continuation": {
"next": "BKAAAAABGCBEBC"
"next": "BKAAAAABGCBEBC",
"prev": "BKAAAAAAAABEBC"
},
"children": [
{
Expand Down Expand Up @@ -680,7 +683,8 @@
"relevance": 1.0,
"label": "hits",
"continuation": {
"next": "BKAAABEABGCBEBC"
"next": "BKAAABEABGCBEBC",
"prev": "BKAAABEAAACBEBC"
},
"children": [
{
Expand Down Expand Up @@ -1499,7 +1503,8 @@
"relevance": 1.0,
"label": "hits",
"continuation": {
"next": "BKAAABIABGCBEBC"
"next": "BKAAABIABGCBEBC",
"prev": "BKAAABIAAAABEBC"
},
"children": [
{
Expand Down Expand Up @@ -1979,7 +1984,8 @@
"relevance": 1.0,
"label": "hits",
"continuation": {
"next": "BKAAABKABGCBEBC"
"next": "BKAAABKABGCBEBC",
"prev": "BKAAABKAAACBEBC"
},
"children": [
{
Expand Down Expand Up @@ -2459,7 +2465,8 @@
"relevance": 1.0,
"label": "hits",
"continuation": {
"next": "BKAAABMABGCBEBC"
"next": "BKAAABMABGCBEBC",
"prev": "BKAAABMABGCAAAA"
},
"children": [
{
Expand Down Expand Up @@ -2939,7 +2946,8 @@
"relevance": 1.0,
"label": "hits",
"continuation": {
"next": "BKAAABOABGCBEBC"
"next": "BKAAABOABGCBEBC",
"prev": "BKAAAAAABGCBEBC"
},
"children": [
{
Expand Down Expand Up @@ -3419,7 +3427,8 @@
"relevance": 1.0,
"label": "hits",
"continuation": {
"next": "BKAAACBAABGCBEBC"
"next": "BKAAACBAABGCBEBC",
"prev": "BKAAACBAABGAAEBC"
},
"children": [
{
Expand Down Expand Up @@ -3899,7 +3908,8 @@
"relevance": 1.0,
"label": "hits",
"continuation": {
"next": "BKAAACBCABGCBEBC"
"next": "BKAAACBCABGCBEBC",
"prev": "BKAAACBCABGCBEAA"
},
"children": [
{
Expand Down
Loading

0 comments on commit a11c84f

Please sign in to comment.