Skip to content

Commit

Permalink
Merge pull request #342 from vespa-engine/tgm/remove-forced-ml-install
Browse files Browse the repository at this point in the history
Remove ml module dependency from application module
  • Loading branch information
thigm85 authored Jun 2, 2022
2 parents 3caa7b9 + 015feaf commit 1939be7
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
3 changes: 1 addition & 2 deletions vespa/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from time import sleep

from vespa.io import VespaQueryResponse, VespaResponse
from vespa.ml import TextTask
from vespa.query import QueryModel
from vespa.evaluation import EvalMetric
from vespa.package import ApplicationPackage
Expand Down Expand Up @@ -1166,7 +1165,7 @@ def predict(self, x, model_id, function_name="output_0"):
model = self.get_model_from_application_package(model_id)
encoded_tokens = model.create_url_encoded_tokens(x=x)
with VespaSync(self) as sync_app:
return TextTask.parse_vespa_prediction(
return model.parse_vespa_prediction(
sync_app.predict(
model_id=model_id,
function_name=function_name,
Expand Down
4 changes: 2 additions & 2 deletions vespa/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def export_to_onnx(self, output_path: str) -> None:
pipeline, opset=11, output=Path(output_path), use_external_format=False
)

def predict(self, text: str):
def predict(self, text: str) -> List:
"""
Predict using a local instance of the model
Expand All @@ -102,7 +102,7 @@ def predict(self, text: str):
return [x["score"] for x in predictions]

@staticmethod
def parse_vespa_prediction(prediction):
def parse_vespa_prediction(prediction) -> List:
return [cell["value"] for cell in prediction["cells"]]

def create_url_encoded_tokens(self, x):
Expand Down

0 comments on commit 1939be7

Please sign in to comment.