Skip to content

Commit

Permalink
feat: Improve qa datamodule (#292)
Browse files Browse the repository at this point in the history
* feat(qa_datamodule): Improve qa datamodule by adding cache and multithreading processing

* fix(qa): Fix logging config args

* styling: Fix styling

* fix(tests): Fix tests

* fix(qa_datamodule): Fix qa_datamodule

* refactor: Fix mypy issues

* feat(qa_preprocessing): Add multianswers support

* feat(qa): Improve QA data split transformation

* fix(transformations): Revert previous QA data split transformation

* feat(qa_evaluator): Improve QA evaluator

* feat(QA): Add fixes and improvements to QA pipelines
  • Loading branch information
ktagowski authored Dec 3, 2023
1 parent f8e6c35 commit ed47c0c
Show file tree
Hide file tree
Showing 14 changed files with 298 additions and 86 deletions.
84 changes: 65 additions & 19 deletions embeddings/data/qa_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import typing
import os
import pickle
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple

import datasets
Expand All @@ -11,6 +13,10 @@

from embeddings.data.datamodule import HuggingFaceDataModule, HuggingFaceDataset
from embeddings.data.io import T_path
from embeddings.utils.loggers import get_logger
from embeddings.utils.utils import standardize_name

_logger = get_logger(__name__)


class CharToTokenMapper:
Expand Down Expand Up @@ -95,6 +101,8 @@ def get_token_positions_train(


class QuestionAnsweringDataModule(HuggingFaceDataModule):
CACHE_DEFAULT_DIR = Path(os.path.expanduser("~/.cache/embeddings/"))

def __init__(
self,
dataset_name_or_path: T_path,
Expand All @@ -111,7 +119,8 @@ def __init__(
load_dataset_kwargs: Optional[Dict[str, Any]] = None,
dataloader_kwargs: Optional[Dict[str, Any]] = None,
seed: int = 441,
**kwargs: Any
use_cache: bool = False,
**kwargs: Any,
) -> None:
self.question_field = question_field
self.context_field = context_field
Expand Down Expand Up @@ -147,15 +156,55 @@ def __init__(
self.splits: List[str] = list(self.dataset_raw.keys())
else:
self.splits = ["train", "validation"]
self.process_data(stage="fit")
self.processed_data_cache_path = None
if use_cache:
datasets.disable_caching()
self.processed_data_cache_path = (
QuestionAnsweringDataModule.CACHE_DEFAULT_DIR
/ f"{standardize_name(str(dataset_name_or_path))}__{standardize_name(str(tokenizer_name_or_path))}"
)
self.processed_data_cache_path.mkdir(parents=True, exist_ok=True)
_logger.warning(
f"Using embeddingsdatamodule caching. Cache path={self.processed_data_cache_path}"
)

@typing.overload
def process_data(self) -> None:
pass
self.process_data_with_cache(stage="fit")

def cache_datamodule(self, path: Path, stage: str) -> None:
self.dataset.save_to_disk(str(path))
if stage != "fit":
with open(path / "overflow_to_sample_mapping", "wb") as f:
pickle.dump(obj=self.overflow_to_sample_mapping, file=f)
with open(path / "offset_mapping", "wb") as f:
pickle.dump(obj=self.offset_mapping, file=f)

def load_cached_datamodule(self, path: Path, stage: str) -> None:
self.dataset = datasets.load_from_disk(dataset_path=str(path))
if stage != "fit":
with open(path / "overflow_to_sample_mapping", "rb") as f:
self.overflow_to_sample_mapping = pickle.load(f)
with open(path / "offset_mapping", "rb") as f:
self.offset_mapping = pickle.load(f)

def process_data_with_cache(self, stage: Optional[str] = None) -> None:
if stage is None:
return

@typing.overload
def process_data(self, stage: Optional[str] = None) -> None:
pass
if self.processed_data_cache_path:
data_cache_path = self.processed_data_cache_path / stage
if data_cache_path.exists():
_logger.warning(f"Loading cached datamodule from path {data_cache_path}")
self.load_cached_datamodule(data_cache_path, stage=stage)
_logger.warning("Load completed!")
else:
_logger.warning(
f"Cached datamodule not found. Processing datamodule {data_cache_path}"
)
self.process_data(stage=stage)
_logger.warning(f"Saving cached datamodule at path {data_cache_path}")
self.cache_datamodule(data_cache_path, stage=stage)
else:
self.process_data(stage=stage)

def process_data(self, stage: Optional[str] = None) -> None:
assert isinstance(self.dataset_raw, datasets.DatasetDict)
Expand All @@ -165,9 +214,6 @@ def process_data(self, stage: Optional[str] = None) -> None:
{k: v for k, v in self.dataset.items() if k in {"train", "validation"}}
)

if stage is None:
return

columns = [c for c in self.dataset["train"].column_names if c not in self.LOADER_COLUMNS]

for split in self.dataset.keys():
Expand All @@ -185,11 +231,11 @@ def process_data(self, stage: Optional[str] = None) -> None:
batch_size=self.processing_batch_size,
remove_columns=columns,
)

self.overflow_to_sample_mapping[split] = self.dataset[split][
"overflow_to_sample_mapping"
]
self.offset_mapping[split] = self.dataset[split]["offset_mapping"]
if stage != "fit":
self.overflow_to_sample_mapping[split] = self.dataset[split][
"overflow_to_sample_mapping"
]
self.offset_mapping[split] = self.dataset[split]["offset_mapping"]
self.dataset[split] = self.dataset[split].remove_columns(
["offset_mapping", "overflow_to_sample_mapping"]
)
Expand Down Expand Up @@ -233,7 +279,7 @@ def setup(self, stage: Optional[str] = None) -> None:
self.dataset_raw["train"].info.version = Version("0.0.1")
self.data_loader_has_setup = True
if self.processed_data_stage and (self.processed_data_stage != stage):
self.process_data(stage=stage)
self.process_data_with_cache(stage=stage)
self.processed_data_stage = stage

@property
Expand All @@ -248,5 +294,5 @@ def _class_encode_column(self, column_name: str) -> None:

def test_dataloader(self) -> DataLoader[HuggingFaceDataset]:
if "test" in self.splits and not "test" in self.dataset.keys():
self.process_data(stage="test")
self.process_data_with_cache(stage="test")
return super().test_dataloader()
10 changes: 10 additions & 0 deletions embeddings/evaluator/evaluation_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,13 @@ class QuestionAnsweringEvaluationResults(EvaluationResults):
NoAns_f1: Optional[float] = None
NoAns_total: Optional[float] = None
data: Optional[Data] = None
golds_text: Optional[Union[List[List[str]], List[str]]] = None
predictions_text: Optional[List[str]] = None

@property
def metrics(self) -> Dict[str, Any]:
result = asdict(self)
result.pop("data")
result.pop("golds_text")
result.pop("predictions_text")
return result
21 changes: 19 additions & 2 deletions embeddings/evaluator/question_answering_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,22 @@ def __init__(self, no_answer_threshold: float = 1.0):

def metrics(
self,
) -> Dict[str, Metric[Union[List[Any], nptyping.NDArray[Any], torch.Tensor], Dict[Any, Any]]]:
) -> Dict[str, Metric[Union[List[Any], nptyping.NDArray[Any], torch.Tensor], Dict[Any, Any]],]:
return {}

@staticmethod
def get_golds_text(references: List[QA_GOLD_ANSWER_TYPE]) -> Union[List[List[str]], List[str]]:
golds_text = []
for ref in references:
answers = ref["answers"]
assert isinstance(answers, dict)
golds_text.append(answers["text"])
return golds_text

@staticmethod
def get_predictions_text(predictions: List[QA_PREDICTED_ANSWER_TYPE]) -> List[str]:
return [str(it["prediction_text"]) for it in predictions]

def evaluate(
self, data: Union[Dict[str, nptyping.NDArray[Any]], Predictions, Dict[str, Any]]
) -> QuestionAnsweringEvaluationResults:
Expand All @@ -51,5 +64,9 @@ def evaluate(
{"id": it_id, **it["predicted_answer"]} for it_id, it in enumerate(outputs)
]
metrics = SQUADv2Metric().calculate(predictions=predictions, references=references)
gold_texts = QuestionAnsweringEvaluator.get_golds_text(references)
predictions_text = QuestionAnsweringEvaluator.get_predictions_text(predictions)

return QuestionAnsweringEvaluationResults(data=outputs, **metrics)
return QuestionAnsweringEvaluationResults(
data=outputs, golds_text=gold_texts, predictions_text=predictions_text, **metrics
)
12 changes: 10 additions & 2 deletions embeddings/model/lightning_module/question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,17 @@


class QuestionAnsweringInferenceModule(pl.LightningModule):
def __init__(self, model_name: str, devices: str = "auto", accelerator: str = "auto") -> None:
def __init__(
self,
model_name: str,
devices: str = "auto",
accelerator: str = "auto",
use_auth_token: bool = False,
) -> None:
super().__init__()
self.model = AutoModelForQuestionAnswering.from_pretrained(model_name)
self.model = AutoModelForQuestionAnswering.from_pretrained(
model_name, use_auth_token=use_auth_token
)
self.trainer = pl.Trainer(devices=devices, accelerator=accelerator)

def predict_step(
Expand Down
21 changes: 15 additions & 6 deletions embeddings/pipeline/lightning_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from embeddings.evaluator.evaluator import Evaluator
from embeddings.model.model import Model
from embeddings.pipeline.pipeline import Pipeline
from embeddings.utils.loggers import LightningLoggingConfig, WandbWrapper
from embeddings.utils.loggers import LightningLoggingConfig, LightningWandbWrapper
from embeddings.utils.utils import get_installed_packages, standardize_name

EvaluationResult = TypeVar("EvaluationResult")
Expand Down Expand Up @@ -46,25 +46,34 @@ def __init__(
self.pipeline_kwargs = pipeline_kwargs
self.pipeline_kwargs.pop("self")
self.pipeline_kwargs.pop("pipeline_kwargs")
self.result: Optional[EvaluationResult] = None

def run(self, run_name: Optional[str] = None) -> EvaluationResult:
if run_name:
run_name = standardize_name(run_name)
self._save_artifacts()
model_result = self.model.execute(data=self.datamodule, run_name=run_name)
result = self.evaluator.evaluate(model_result)
self.result = self.evaluator.evaluate(model_result)
self._save_metrics()
self._finish_logging()
return result
return self.result

def _save_artifacts(self) -> None:
srsly.write_json(self.output_path / "packages.json", get_installed_packages())
with open(self.output_path / "pipeline_config.yaml", "w") as f:
yaml.dump(self.pipeline_kwargs, stream=f)

def _save_metrics(self) -> None:
metrics = getattr(self.result, "metrics")
with open(self.output_path / "metrics.yaml", "w") as f:
yaml.dump(metrics, stream=f)

def _finish_logging(self) -> None:
if self.logging_config.use_wandb():
logger = WandbWrapper()
logger.log_output(
wrapper = LightningWandbWrapper(self.logging_config)
wrapper.log_output(
self.output_path, ignore={"wandb", "csv", "tensorboard", "checkpoints"}
)
logger.finish_logging()
metrics = getattr(self.result, "metrics")
wrapper.log_metrics(metrics)
wrapper.finish_logging()
24 changes: 23 additions & 1 deletion embeddings/pipeline/lightning_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import Any, Dict, List, Optional, Union

import datasets
import pandas as pd
import wandb
from pytorch_lightning.accelerators import Accelerator

from embeddings.config.lightning_config import LightningQABasicConfig, LightningQAConfig
Expand All @@ -13,7 +15,8 @@
from embeddings.model.lightning_model import LightningModel
from embeddings.pipeline.lightning_pipeline import LightningPipeline
from embeddings.task.lightning_task import question_answering as qa
from embeddings.utils.loggers import LightningLoggingConfig
from embeddings.utils.loggers import LightningLoggingConfig, LightningWandbWrapper
from embeddings.utils.utils import convert_qa_df_to_bootstrap_html


class LightningQuestionAnsweringPipeline(
Expand Down Expand Up @@ -72,6 +75,7 @@ def __init__(
early_stopping_kwargs=config.early_stopping_kwargs,
model_checkpoint_kwargs=model_checkpoint_kwargs,
compile_model_kwargs=compile_model_kwargs,
logging_config=logging_config,
)
model: LightningModel[QuestionAnsweringDataModule, Dict[str, Any]] = LightningModel(
task, predict_subset
Expand All @@ -85,3 +89,21 @@ def __init__(
logging_config,
pipeline_kwargs=pipeline_kwargs,
)

def _finish_logging(self) -> None:
if self.logging_config.use_wandb():
wrapper = LightningWandbWrapper(self.logging_config)
wrapper.log_output(
self.output_path, ignore={"wandb", "csv", "tensorboard", "checkpoints"}
)
metrics = getattr(self.result, "metrics")
wrapper.log_metrics(metrics)

predictions_text = getattr(self.result, "predictions_text")
golds_text = getattr(self.result, "golds_text")
preditions_html = convert_qa_df_to_bootstrap_html(
pd.DataFrame({"predictions": predictions_text, "golds": golds_text})
)

wrapper.wandb_logger.experiment.log({"predictions": wandb.Html(preditions_html)})
wrapper.finish_logging()
Loading

0 comments on commit ed47c0c

Please sign in to comment.