From dfe80d605ff73d5c29d3b6e1af22e23045f75c70 Mon Sep 17 00:00:00 2001 From: kicha0 Date: Fri, 25 Aug 2023 11:52:52 -0700 Subject: [PATCH] Raise error on prediction failure. --- .../managers/error_analysis_manager.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/responsibleai_text/responsibleai_text/managers/error_analysis_manager.py b/responsibleai_text/responsibleai_text/managers/error_analysis_manager.py index 2c1fdb824b..21b7c27543 100644 --- a/responsibleai_text/responsibleai_text/managers/error_analysis_manager.py +++ b/responsibleai_text/responsibleai_text/managers/error_analysis_manager.py @@ -15,6 +15,7 @@ from erroranalysis._internal.error_report import as_error_report from responsibleai._tools.shared.state_directory_management import \ DirectoryManager +from responsibleai.exceptions import UserErrorException from responsibleai.managers.error_analysis_manager import \ ErrorAnalysisManager as BaseErrorAnalysisManager from responsibleai.managers.error_analysis_manager import as_error_config @@ -71,8 +72,12 @@ def __init__(self, model, dataset, is_multilabel, task_type, classes=None): ModelTask.MULTILABEL_TEXT_CLASSIFICATION] if self.task_type in classif_tasks: dataset = self.dataset.iloc[:, 0].tolist() - self.predictions = self.model.predict(dataset) - self.predict_proba = self.model.predict_proba(dataset) + self.predictions = self._raise_user_error_on_failure( + self.model.predict, dataset + ) + self.predict_proba = self._raise_user_error_on_failure( + self.model.predict_proba, dataset + ) elif self.task_type == ModelTask.QUESTION_ANSWERING: self.predictions = self.model.predict( self.dataset.loc[:, ['context', 'questions']]) @@ -121,6 +126,15 @@ def predict_proba(self, X): pred_proba = self.predict_proba[index] return pred_proba + def _raise_user_error_on_prediction_failure(func, *args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as ex: + raise UserErrorException( + "Unable to use user model to retrieve predictions" + f" from given dataset. Original exception: {ex}" + ) + class ErrorAnalysisManager(BaseErrorAnalysisManager):