Skip to content

Commit

Permalink
Raise error on prediction failure.
Browse files Browse the repository at this point in the history
  • Loading branch information
kicha0 authored Aug 25, 2023
1 parent bc4744e commit dfe80d6
Showing 1 changed file with 16 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from erroranalysis._internal.error_report import as_error_report
from responsibleai._tools.shared.state_directory_management import \
DirectoryManager
from responsibleai.exceptions import UserErrorException
from responsibleai.managers.error_analysis_manager import \
ErrorAnalysisManager as BaseErrorAnalysisManager
from responsibleai.managers.error_analysis_manager import as_error_config
Expand Down Expand Up @@ -71,8 +72,12 @@ def __init__(self, model, dataset, is_multilabel, task_type, classes=None):
ModelTask.MULTILABEL_TEXT_CLASSIFICATION]
if self.task_type in classif_tasks:
dataset = self.dataset.iloc[:, 0].tolist()
self.predictions = self.model.predict(dataset)
self.predict_proba = self.model.predict_proba(dataset)
self.predictions = self._raise_user_error_on_failure(
self.model.predict, dataset
)
self.predict_proba = self._raise_user_error_on_failure(
self.model.predict_proba, dataset
)
elif self.task_type == ModelTask.QUESTION_ANSWERING:
self.predictions = self.model.predict(
self.dataset.loc[:, ['context', 'questions']])
Expand Down Expand Up @@ -121,6 +126,15 @@ def predict_proba(self, X):
pred_proba = self.predict_proba[index]
return pred_proba

def _raise_user_error_on_prediction_failure(func, *args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as ex:
raise UserErrorException(
"Unable to use user model to retrieve predictions"
f" from given dataset. Original exception: {ex}"
)


class ErrorAnalysisManager(BaseErrorAnalysisManager):

Expand Down

0 comments on commit dfe80d6

Please sign in to comment.