diff --git a/vespa/ml.py b/vespa/ml.py index 28a908fe..abdb8afa 100644 --- a/vespa/ml.py +++ b/vespa/ml.py @@ -90,7 +90,7 @@ def export_to_onnx(self, output_path: str) -> None: pipeline, opset=11, output=Path(output_path), use_external_format=False ) - def predict(self, text: str): + def predict(self, text: str) -> List: """ Predict using a local instance of the model @@ -102,7 +102,7 @@ def predict(self, text: str): return [x["score"] for x in predictions] @staticmethod - def parse_vespa_prediction(prediction): + def parse_vespa_prediction(prediction) -> List: return [cell["value"] for cell in prediction["cells"]] def create_url_encoded_tokens(self, x):