Skip to content

Commit

Permalink
Fix mypy errors (#277)
Browse files Browse the repository at this point in the history
* Fix mypy errors

* Fix run_locally

* Fix noqa
  • Loading branch information
neubig authored Aug 24, 2023
1 parent 3f86001 commit 31e1e8b
Show file tree
Hide file tree
Showing 11 changed files with 32 additions and 22 deletions.
4 changes: 3 additions & 1 deletion prompt2model/dataset_processor/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""A base class for dataset processor."""

from __future__ import annotations # noqa FI58

from abc import ABC, abstractmethod
from functools import partial

Expand All @@ -9,7 +11,7 @@
class BaseProcessor(ABC):
"""A base class for post-processing datasets."""

def __init__(self, has_encoder: bool, eos_token: str) -> None:
def __init__(self, has_encoder: bool, eos_token: str | None = None) -> None:
"""Initialize the `BaseProcessor`.
Args:
Expand Down
10 changes: 9 additions & 1 deletion prompt2model/dataset_processor/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,15 @@ def process_dataset_dict(
_ = instruction
return dataset_dicts

def post_process_example(example: dict, instruction: str, task_id: int) -> dict:
@staticmethod
def post_process_example(
example: dict,
instruction: str,
task_id: int,
has_encoder: bool,
dataset_split: str,
eos_token: str,
) -> dict:
"""A mock function that modifies a given example dictionary.
Args:
Expand Down
6 changes: 3 additions & 3 deletions prompt2model/dataset_retriever/hf_dataset_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,11 +284,11 @@ def retrieve_dataset_dict(
self.dataset_infos[dataset_idx].score = dataset_score
top_dataset_infos.append(self.dataset_infos[dataset_idx])

ranked_list = sorted(top_dataset_infos, key=lambda x: x.score, reverse=True)[
sorted_list = sorted(top_dataset_infos, key=lambda x: x.score, reverse=True)[
: self.max_search_depth
]
assert len(ranked_list) > 0, "No datasets retrieved from search index."
top_dataset_name = self.choose_dataset(ranked_list)
assert len(sorted_list) > 0, "No datasets retrieved from search index."
top_dataset_name = self.choose_dataset(sorted_list)
if top_dataset_name is None:
return None
return self.canonicalize_dataset(top_dataset_name)
4 changes: 1 addition & 3 deletions prompt2model/dataset_retriever/run_dataset_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,4 @@
prompt_spec._instruction = prompt

retriever = DescriptionDatasetRetriever()
retriever.retrieve_dataset_dict(
prompt_spec, blocklist=["squad", "stanford question answering"]
)
retriever.retrieve_dataset_dict(prompt_spec)
7 changes: 4 additions & 3 deletions prompt2model/model_evaluator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from typing import Any

import datasets
import evaluate

from prompt2model.model_executor import ModelOutput
from prompt2model.prompt_parser import PromptSpec


class ModelEvaluator(ABC):
Expand All @@ -21,8 +21,9 @@ def evaluate_model(
dataset: datasets.Dataset,
gt_column: str,
predictions: list[ModelOutput],
metrics: list[datasets.Metric] | None = None,
prompt_spec: PromptSpec | None = None,
model_input_column: str | None = None,
metrics: list[evaluate.Metric] | None = None,
encoder_model_name: str = "xlm-roberta-base",
) -> dict[str, Any]:
"""Evaluate a model on a test set..
Expand Down
7 changes: 4 additions & 3 deletions prompt2model/model_evaluator/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from typing import Any

import datasets
import evaluate

from prompt2model.model_evaluator.base import ModelEvaluator
from prompt2model.model_executor import ModelOutput
from prompt2model.prompt_parser import PromptSpec


class MockEvaluator(ModelEvaluator):
Expand All @@ -21,8 +21,9 @@ def evaluate_model(
dataset: datasets.Dataset,
gt_column: str,
predictions: list[ModelOutput],
metrics: list[datasets.Metric] | None = None,
prompt_spec: PromptSpec | None = None,
model_input_column: str | None = None,
metrics: list[evaluate.Metric] | None = None,
encoder_model_name: str = "xlm-roberta-base",
) -> dict[str, Any]:
"""Return empty metrics dictionary.
Expand Down
4 changes: 2 additions & 2 deletions prompt2model/model_retriever/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ class ModelRetriever(ABC):
def retrieve(
self,
prompt: PromptSpec,
) -> str:
) -> list[str]:
"""Retrieve relevant models from HuggingFace.
Args:
prompt: A prompt to use to select relevant models.
Return:
A relevant model's HuggingFace name.
A list of relevant models' HuggingFace names.
"""
4 changes: 2 additions & 2 deletions prompt2model/model_retriever/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, fixed_model_name: str):
def retrieve(
self,
prompt: PromptSpec,
) -> str:
) -> list[str]:
"""Select an arbitrary, fixed model from HuggingFace.
Args:
Expand All @@ -23,4 +23,4 @@ def retrieve(
Return:
A relevant model's HuggingFace name.
"""
return self.fixed_model_name
return [self.fixed_model_name]
2 changes: 1 addition & 1 deletion prompt2model/model_trainer/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(

# self.validation_callback is used for evaluate the model on
# the validation dataset after each epoch.
self.validation_callback = None
self.validation_callback: ValidationCallback | None = None
self.training_seed = seed_generator.get_seed()

def get_left_padding_length(cls, input_list, padding_token_id):
Expand Down
2 changes: 1 addition & 1 deletion prompt2model/param_selector/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def select_from_hyperparameters(
A model and tokenizer (trained using default hyperparameters).
"""
single_model = self.trainer.train_model(
training_sets, self._example_hyperparameter_choices()
self._example_hyperparameter_choices(), training_sets
)
return single_model

Expand Down
4 changes: 2 additions & 2 deletions prompt2model/run_locally.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def run_skeleton(prompt_tokens: list[str], metrics_output_path: str) -> None:
model_retriever = MockModelRetriever("cardiffnlp/twitter-roberta-base-sentiment")
retrieved_model_name = model_retriever.retrieve(prompt_spec)

trainer = MockTrainer(retrieved_model_name)
trainer = MockTrainer(retrieved_model_name[0])
selector = MockParamSelector(trainer)
model, tokenizer = selector.select_from_hyperparameters(
all_training, validation, {}
Expand All @@ -105,7 +105,7 @@ def run_skeleton(prompt_tokens: list[str], metrics_output_path: str) -> None:

evaluator = MockEvaluator()
metrics_dict = evaluator.evaluate_model(
testing, "output_col", predictions, [], prompt_spec
testing, "output_col", predictions, "input_col", []
)
evaluator.write_metrics(metrics_dict, metrics_output_path)
mock_gradio_create(model_executor, prompt_spec)
Expand Down

0 comments on commit 31e1e8b

Please sign in to comment.