diff --git a/learntorank/test_ml.py b/learntorank/test_ml.py index 61557f0..977362c 100644 --- a/learntorank/test_ml.py +++ b/learntorank/test_ml.py @@ -111,7 +111,7 @@ def _predict_with_onnx(onnx_file_path, model_inputs): os.environ[ "KMP_DUPLICATE_LIB_OK" ] = "True" # required to run on mac https://stackoverflow.com/a/53014308 - m = InferenceSession(onnx_file_path) + m = InferenceSession(onnx_file_path, providers=['CPUExecutionProvider']) (out,) = m.run(input_feed=model_inputs, output_names=["output_0"]) return out