Skip to content
This repository has been archived by the owner on Sep 11, 2024. It is now read-only.

Commit

Permalink
Build vespa request body in dedicated function
Browse files Browse the repository at this point in the history
Seperating responsibility for testability
  • Loading branch information
olaughter committed Mar 6, 2024
1 parent 76219e7 commit f452819
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 29 deletions.
30 changes: 3 additions & 27 deletions src/cpr_data_access/search_adaptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
import time
from abc import ABC
from pathlib import Path
from typing import Any, Optional
from typing import Optional
import logging

from cpr_data_access.embedding import Embedder
from cpr_data_access.exceptions import DocumentNotFoundError, FetchError, QueryError
from cpr_data_access.models.search import Hit, SearchParameters, SearchResponse
from cpr_data_access.utils import is_sensitive_query, load_sensitive_query_terms
from cpr_data_access.yql_builder import YQLBuilder
from cpr_data_access.vespa import (
build_vespa_request_body,
find_vespa_cert_paths,
parse_vespa_response,
split_document_id,
Expand All @@ -21,7 +20,6 @@
from vespa.exceptions import VespaError

LOGGER = logging.getLogger(__name__)
SENSITIVE_QUERY_TERMS = load_sensitive_query_terms()


class SearchAdapter(ABC):
Expand Down Expand Up @@ -82,29 +80,7 @@ def search(self, parameters: SearchParameters) -> SearchResponse:
:return SearchResponse: a list of families, with response metadata
"""
total_time_start = time.time()
sensitive = is_sensitive_query(parameters.query_string, SENSITIVE_QUERY_TERMS)

yql = YQLBuilder(params=parameters, sensitive=sensitive).to_str()
vespa_request_body: dict[str, Any] = {
"yql": yql,
"timeout": "20",
"ranking.softtimeout.factor": "0.7",
}

if parameters.exact_match:
vespa_request_body["ranking.profile"] = "exact"
elif sensitive:
vespa_request_body["ranking.profile"] = "hybrid_no_closeness"
embedding = self.embedder.embed(
parameters.query_string, normalize=False, show_progress_bar=False
)
else:
vespa_request_body["ranking.profile"] = "hybrid"
embedding = self.embedder.embed(
parameters.query_string, normalize=False, show_progress_bar=False
)
vespa_request_body["input.query(query_embedding)"] = embedding

vespa_request_body = build_vespa_request_body(parameters, self.embedder)
query_time_start = time.time()
try:
vespa_response = self.client.query(body=vespa_request_body)
Expand Down
37 changes: 36 additions & 1 deletion src/cpr_data_access/vespa.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import List
from typing import Any, List

import yaml
from vespa.io import VespaResponse
Expand All @@ -12,7 +12,13 @@
SearchResponse,
sort_fields,
)
from cpr_data_access.embedding import Embedder
from cpr_data_access.exceptions import FetchError
from cpr_data_access.utils import is_sensitive_query, load_sensitive_query_terms
from cpr_data_access.yql_builder import YQLBuilder


SENSITIVE_QUERY_TERMS = load_sensitive_query_terms()


def split_document_id(document_id: str) -> tuple[str, str, str]:
Expand Down Expand Up @@ -63,6 +69,35 @@ def find_vespa_cert_paths() -> tuple[Path, Path]:
return cert_path, key_path


def build_vespa_request_body(
parameters: SearchParameters, embedder: Embedder
) -> dict[str, str]:
"""Constructs the payload for a vespa query"""
sensitive = is_sensitive_query(parameters.query_string, SENSITIVE_QUERY_TERMS)

yql = YQLBuilder(params=parameters, sensitive=sensitive).to_str()
vespa_request_body: dict[str, Any] = {
"yql": yql,
"timeout": "20",
"ranking.softtimeout.factor": "0.7",
}

if parameters.exact_match:
vespa_request_body["ranking.profile"] = "exact"
elif sensitive:
vespa_request_body["ranking.profile"] = "hybrid_no_closeness"
embedding = embedder.embed(
parameters.query_string, normalize=False, show_progress_bar=False
)
else:
vespa_request_body["ranking.profile"] = "hybrid"
embedding = embedder.embed(
parameters.query_string, normalize=False, show_progress_bar=False
)
vespa_request_body["input.query(query_embedding)"] = embedding
return vespa_request_body


def parse_vespa_response(
request: SearchParameters,
vespa_response: VespaResponse,
Expand Down
24 changes: 23 additions & 1 deletion tests/test_search_requests.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,33 @@
from unittest.mock import patch

import pytest

from vespa.exceptions import VespaError

from cpr_data_access.models.search import SearchParameters
from cpr_data_access.vespa import VespaErrorDetails
from cpr_data_access.vespa import build_vespa_request_body, VespaErrorDetails
from cpr_data_access.yql_builder import YQLBuilder, sanitize
from cpr_data_access.exceptions import QueryError
from cpr_data_access.embedding import Embedder


@patch("cpr_data_access.vespa.SENSITIVE_QUERY_TERMS", {"sensitive"})
@pytest.mark.parametrize(
"query_type, params",
[
("hybrid", SearchParameters(query_string="test")),
("exact", SearchParameters(query_string="test", exact_match=True)),
("hybrid_no_closeness", SearchParameters(query_string="sensitive")),
],
)
def test_build_vespa_request_body(query_type, params):
embedder = Embedder()
body = build_vespa_request_body(parameters=params, embedder=embedder)
assert body["ranking.profile"] == query_type
for key, value in body.items():
assert (
len(value) > 0
), f"Query type: {query_type} has an empty value for {key}: {value}"


def test_whether_an_empty_query_string_raises_a_queryerror():
Expand Down

0 comments on commit f452819

Please sign in to comment.