diff --git a/vespa/application.py b/vespa/application.py index 5ec5fa5f..04378ca5 100644 --- a/vespa/application.py +++ b/vespa/application.py @@ -19,7 +19,6 @@ from time import sleep from vespa.io import VespaQueryResponse, VespaResponse -from vespa.ml import TextTask from vespa.query import QueryModel from vespa.evaluation import EvalMetric from vespa.package import ApplicationPackage @@ -1166,7 +1165,7 @@ def predict(self, x, model_id, function_name="output_0"): model = self.get_model_from_application_package(model_id) encoded_tokens = model.create_url_encoded_tokens(x=x) with VespaSync(self) as sync_app: - return TextTask.parse_vespa_prediction( + return model.parse_vespa_prediction( sync_app.predict( model_id=model_id, function_name=function_name, diff --git a/vespa/ml.py b/vespa/ml.py index 28a908fe..abdb8afa 100644 --- a/vespa/ml.py +++ b/vespa/ml.py @@ -90,7 +90,7 @@ def export_to_onnx(self, output_path: str) -> None: pipeline, opset=11, output=Path(output_path), use_external_format=False ) - def predict(self, text: str): + def predict(self, text: str) -> List: """ Predict using a local instance of the model @@ -102,7 +102,7 @@ def predict(self, text: str): return [x["score"] for x in predictions] @staticmethod - def parse_vespa_prediction(prediction): + def parse_vespa_prediction(prediction) -> List: return [cell["value"] for cell in prediction["cells"]] def create_url_encoded_tokens(self, x):