diff --git a/src/cpr_data_access/search_adaptors.py b/src/cpr_data_access/search_adaptors.py index 731c497..f298700 100644 --- a/src/cpr_data_access/search_adaptors.py +++ b/src/cpr_data_access/search_adaptors.py @@ -3,18 +3,23 @@ from abc import ABC from pathlib import Path from typing import Any, Optional +import logging from cpr_data_access.embedding import Embedder -from cpr_data_access.exceptions import DocumentNotFoundError, FetchError +from cpr_data_access.exceptions import DocumentNotFoundError, FetchError, QueryError from cpr_data_access.models.search import Hit, SearchParameters, SearchResponse from cpr_data_access.vespa import ( build_yql, find_vespa_cert_paths, parse_vespa_response, split_document_id, + VespaErrorDetails, ) from requests.exceptions import HTTPError from vespa.application import Vespa +from vespa.exceptions import VespaError + +LOGGER = logging.getLogger(__name__) class SearchAdapter(ABC): @@ -91,7 +96,15 @@ def search(self, parameters: SearchParameters) -> SearchResponse: vespa_request_body["input.query(query_embedding)"] = embedding query_time_start = time.time() - vespa_response = self.client.query(body=vespa_request_body) + try: + vespa_response = self.client.query(body=vespa_request_body) + except VespaError as e: + err_details = VespaErrorDetails(e) + if err_details.is_invalid_query_parameter: + LOGGER.error(err_details.message) + raise QueryError(err_details.summary) + else: + raise e query_time_end = time.time() response = parse_vespa_response( diff --git a/src/cpr_data_access/vespa.py b/src/cpr_data_access/vespa.py index b5dec8c..091d7b3 100644 --- a/src/cpr_data_access/vespa.py +++ b/src/cpr_data_access/vespa.py @@ -3,6 +3,7 @@ import yaml from vespa.io import VespaResponse +from vespa.exceptions import VespaError from cpr_data_access.models.search import ( Family, @@ -65,7 +66,14 @@ def find_vespa_cert_paths() -> tuple[Path, Path]: def sanitize(user_input: str) -> str: """ - Sanitize user input strings to limit possible YQL injection attacks + Sanitize user input strings + + This is intended to limit possible YQL injection attacks. The query endpoint is not + as vulnerable as sql as updates/inserts/deletes in vespa are handled by a seperate + endpoint. The main purpose here is to mitigate vespas "INVALID_QUERY_PARAMETER" + errors. See vespa codebase for context on full list of errors: + https://github.com/vespa-engine/vespa/blob/dd94d619668210d09792597cbd218994058e923e + /container-core/src/main/java/com/yahoo/container/protect/Error.java#L15C2-L15C2 :param str user_input: a potentially hazardous user input string :return str: sanitized user input string @@ -218,3 +226,39 @@ def parse_vespa_response( query_time_ms=None, total_time_ms=None, ) + + +class VespaErrorDetails: + """Wrapper for VespaError that parses the arguments""" + + def __init__(self, e: VespaError) -> None: + self.e = e + self.code = None + self.summary = None + self.message = None + self.parse_args(self.e) + + def parse_args(self, e: VespaError) -> None: + """ + Gets the details of the first error + + Args: + e (VespaError): An error from the vespa python sdk + """ + for arg in e.args: + for error in arg: + self.code = error.get("code") + self.summary = error.get("summary") + self.message = error.get("message") + break + + @property + def is_invalid_query_parameter(self) -> bool: + """ + Checks if an error is coming from vespa on query parameters, see: + + https://github.com/vespa-engine/vespa/blob/0c55dc92a3bf889c67fac1ca855e6e33e1994904/ + container-core/src/main/java/com/yahoo/container/protect/Error.java + """ + INVALID_QUERY_PARAMETER = 4 + return self.code == INVALID_QUERY_PARAMETER diff --git a/tests/test_search_requests.py b/tests/test_search_requests.py index ef04a8a..b76629a 100644 --- a/tests/test_search_requests.py +++ b/tests/test_search_requests.py @@ -1,7 +1,9 @@ import pytest +from vespa.exceptions import VespaError + from cpr_data_access.models.search import SearchParameters -from cpr_data_access.vespa import build_yql, sanitize +from cpr_data_access.vespa import build_yql, sanitize, VespaErrorDetails from cpr_data_access.exceptions import QueryError @@ -117,3 +119,28 @@ def test_whether_year_ranges_appear_in_yql( assert include in yql for exclude in expected_exclude: assert exclude not in yql + + +def test_vespa_error_details(): + # With invalid query parameter code + err_object = [ + { + "code": 4, + "summary": "test_summary", + "message": "test_message", + "stackTrace": None, + } + ] + err = VespaError(err_object) + details = VespaErrorDetails(err) + + assert details.code == err_object[0]["code"] + assert details.summary == err_object[0]["summary"] + assert details.message == err_object[0]["message"] + assert details.is_invalid_query_parameter + + # With other code + err_object = [{"code": 1}] + err = VespaError(err_object) + details = VespaErrorDetails(err) + assert not details.is_invalid_query_parameter