Skip to content

Commit

Permalink
fix infer task from model_name if model from sentence transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Jan 8, 2025
1 parent 72498dd commit 163e4a8
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1770,6 +1770,7 @@ def _infer_task_from_model_name_or_path(
revision: Optional[str] = None,
cache_dir: str = HUGGINGFACE_HUB_CACHE,
token: Optional[Union[bool, str]] = None,
library_name: Optional[str] = None,
) -> str:
inferred_task_name = None

Expand All @@ -1791,13 +1792,14 @@ def _infer_task_from_model_name_or_path(
raise RuntimeError(
f"Hugging Face Hub is not reachable and we cannot infer the task from a cached model. Make sure you are not offline, or otherwise please specify the `task` (or `--task` in command-line) argument ({', '.join(TasksManager.get_all_tasks())})."
)
library_name = cls.infer_library_from_model(
model_name_or_path,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
token=token,
)
if library_name is None:
library_name = cls.infer_library_from_model(
model_name_or_path,
subfolder=subfolder,
revision=revision,
cache_dir=cache_dir,
token=token,
)

if library_name == "timm":
inferred_task_name = "image-classification"
Expand All @@ -1816,6 +1818,8 @@ def _infer_task_from_model_name_or_path(
break
if inferred_task_name is not None:
break
elif library_name == "sentence_transformers":
inferred_task_name = "feature-extraction"
elif library_name == "transformers":
pipeline_tag = model_info.pipeline_tag
transformers_info = model_info.transformersInfo
Expand Down

0 comments on commit 163e4a8

Please sign in to comment.