Skip to content

Commit

Permalink
Merge pull request #217 from vespa-engine/tgm/default-query-model
Browse files Browse the repository at this point in the history
Allow default query model to be specified and define it for TextSearch
  • Loading branch information
lesters authored Oct 14, 2021
2 parents 90bbeff + adec27c commit 58eaad0
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 5 deletions.
9 changes: 9 additions & 0 deletions vespa/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ def _build_query_body(
**kwargs
) -> Dict:
assert query is not None, "No 'query' specified."
if not query_model:
query_model = self.get_default_query_model()
assert query_model is not None, "No 'query_model' specified."
body = query_model.create_body(query=query)
if recall is not None:
Expand Down Expand Up @@ -852,6 +854,13 @@ def application_package(self):
else:
return self._application_package

def get_default_query_model(self):
try:
app_package = self.application_package
except ValueError:
return None
return app_package.default_query_model

def get_model_from_application_package(self, model_name: str):
"""Get model definition from application package, if available."""
app_package = self.application_package
Expand Down
12 changes: 10 additions & 2 deletions vespa/gallery.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
QueryProfileType,
QueryTypeField,
)
from vespa.query import QueryModel, AND, RankProfile as Ranking


class TextSearch(ApplicationPackage):
Expand Down Expand Up @@ -48,11 +49,18 @@ def __init__(
first_phase=" + ".join(["bm25({})".format(x) for x in text_fields]),
),
RankProfile(
name="native_rank", first_phase="nativeRank({})".format(",".join(text_fields))
name="native_rank",
first_phase="nativeRank({})".format(",".join(text_fields)),
),
],
)
super().__init__(name=name, schema=[schema])
super().__init__(
name=name,
schema=[schema],
default_query_model=QueryModel(
name="and_bm25", match_phase=AND(), rank_profile=Ranking(name="bm25")
),
)


class QuestionAnswering(ApplicationPackage):
Expand Down
4 changes: 4 additions & 0 deletions vespa/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from jinja2 import Environment, PackageLoader, select_autoescape

from vespa.json_serialization import ToJson, FromJson
from vespa.query import QueryModel


class HNSW(ToJson, FromJson["HNSW"]):
Expand Down Expand Up @@ -1153,6 +1154,7 @@ def __init__(
create_schema_by_default: bool = True,
create_query_profile_by_default: bool = True,
tasks: Optional[List[Task]] = None,
default_query_model: Optional[QueryModel] = None
) -> None:
"""
Create a Vespa Application Package.
Expand All @@ -1173,6 +1175,7 @@ def __init__(
:param create_query_profile_by_default: Include a default :class:`QueryProfile` and :class:`QueryProfileType`
in case it is not explicitly defined by the user in the `query_profile` and `query_profile_type` parameters.
:param tasks: List of tasks to be served.
:param default_query_model: Optional QueryModel to be used as default for the application.
The easiest way to get started is to create a default application package:
Expand Down Expand Up @@ -1200,6 +1203,7 @@ def __init__(
self.model_configs = {}
self.stateless_model_evaluation = stateless_model_evaluation
self.models = {} if not tasks else {model.model_id: model for model in tasks}
self.default_query_model = default_query_model

@property
def schemas(self) -> List[Schema]:
Expand Down
18 changes: 15 additions & 3 deletions vespa/test_integration_docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,12 @@ def execute_data_operations(
),
},
)
#
# Query with 'query' without QueryModel
#
with self.assertRaisesRegex(AssertionError, "No 'query_model' specified."):
_ = app.query(query="this should not work")

#
# Update data
#
Expand Down Expand Up @@ -1394,10 +1400,16 @@ def setUp(self) -> None:
#
self.app.feed_df(df)

def test_default_query_model(self):
result = self.app.query(query="what is finance?", debug_request=True)
expected_request_body = {
"yql": 'select * from sources * where (userInput("what is finance?"));',
"ranking": {"profile": "bm25", "listFeatures": "false"},
}
self.assertDictEqual(expected_request_body, result.request_body)

def test_query(self):
result = self.app.query(
query="what is finance?", query_model=QueryModel(match_phase=OR())
)
result = self.app.query(query="what is finance?")
for hit in result.hits:
self.assertIn("fields", hit)

Expand Down

0 comments on commit 58eaad0

Please sign in to comment.