From ed47c0ceba6b5157a7f8cf1a05087ba380252942 Mon Sep 17 00:00:00 2001 From: Kamil Tagowski Date: Sun, 3 Dec 2023 18:33:20 +0100 Subject: [PATCH] feat: Improve qa datamodule (#292) * feat(qa_datamodule): Improve qa datamodule by adding cache and multithreading processing * fix(qa): Fix logging config args * styling: Fix styling * fix(tests): Fix tests * fix(qa_datamodule): Fix qa_datamodule * refactor: Fix mypy issues * feat(qa_preprocessing): Add multianswers support * feat(qa): Improve QA data split transformation * fix(transformations): Revert previous QA data split transformation * feat(qa_evaluator): Improve QA evaluator * feat(QA): Add fixes and improvements to QA pipelines --- embeddings/data/qa_datamodule.py | 84 ++++++++++++++----- embeddings/evaluator/evaluation_results.py | 10 +++ .../evaluator/question_answering_evaluator.py | 21 ++++- .../lightning_module/question_answering.py | 12 ++- embeddings/pipeline/lightning_pipeline.py | 21 +++-- .../pipeline/lightning_question_answering.py | 24 +++++- .../task/lightning_task/lightning_task.py | 78 +++++++++++------ .../task/lightning_task/question_answering.py | 14 +++- ...uestion_answering_output_transformation.py | 4 +- embeddings/utils/loggers.py | 76 +++++++++++------ embeddings/utils/model_exporter.py | 4 +- embeddings/utils/utils.py | 30 +++++++ ...t_lightning_question_answering_pipeline.py | 1 + tests/test_question_answering_evaluator.py | 5 +- 14 files changed, 298 insertions(+), 86 deletions(-) diff --git a/embeddings/data/qa_datamodule.py b/embeddings/data/qa_datamodule.py index 4c25369b..eea9dfaf 100644 --- a/embeddings/data/qa_datamodule.py +++ b/embeddings/data/qa_datamodule.py @@ -1,5 +1,7 @@ -import typing +import os +import pickle from copy import deepcopy +from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple import datasets @@ -11,6 +13,10 @@ from embeddings.data.datamodule import HuggingFaceDataModule, HuggingFaceDataset from embeddings.data.io import T_path +from embeddings.utils.loggers import get_logger +from embeddings.utils.utils import standardize_name + +_logger = get_logger(__name__) class CharToTokenMapper: @@ -95,6 +101,8 @@ def get_token_positions_train( class QuestionAnsweringDataModule(HuggingFaceDataModule): + CACHE_DEFAULT_DIR = Path(os.path.expanduser("~/.cache/embeddings/")) + def __init__( self, dataset_name_or_path: T_path, @@ -111,7 +119,8 @@ def __init__( load_dataset_kwargs: Optional[Dict[str, Any]] = None, dataloader_kwargs: Optional[Dict[str, Any]] = None, seed: int = 441, - **kwargs: Any + use_cache: bool = False, + **kwargs: Any, ) -> None: self.question_field = question_field self.context_field = context_field @@ -147,15 +156,55 @@ def __init__( self.splits: List[str] = list(self.dataset_raw.keys()) else: self.splits = ["train", "validation"] - self.process_data(stage="fit") + self.processed_data_cache_path = None + if use_cache: + datasets.disable_caching() + self.processed_data_cache_path = ( + QuestionAnsweringDataModule.CACHE_DEFAULT_DIR + / f"{standardize_name(str(dataset_name_or_path))}__{standardize_name(str(tokenizer_name_or_path))}" + ) + self.processed_data_cache_path.mkdir(parents=True, exist_ok=True) + _logger.warning( + f"Using embeddingsdatamodule caching. Cache path={self.processed_data_cache_path}" + ) - @typing.overload - def process_data(self) -> None: - pass + self.process_data_with_cache(stage="fit") + + def cache_datamodule(self, path: Path, stage: str) -> None: + self.dataset.save_to_disk(str(path)) + if stage != "fit": + with open(path / "overflow_to_sample_mapping", "wb") as f: + pickle.dump(obj=self.overflow_to_sample_mapping, file=f) + with open(path / "offset_mapping", "wb") as f: + pickle.dump(obj=self.offset_mapping, file=f) + + def load_cached_datamodule(self, path: Path, stage: str) -> None: + self.dataset = datasets.load_from_disk(dataset_path=str(path)) + if stage != "fit": + with open(path / "overflow_to_sample_mapping", "rb") as f: + self.overflow_to_sample_mapping = pickle.load(f) + with open(path / "offset_mapping", "rb") as f: + self.offset_mapping = pickle.load(f) + + def process_data_with_cache(self, stage: Optional[str] = None) -> None: + if stage is None: + return - @typing.overload - def process_data(self, stage: Optional[str] = None) -> None: - pass + if self.processed_data_cache_path: + data_cache_path = self.processed_data_cache_path / stage + if data_cache_path.exists(): + _logger.warning(f"Loading cached datamodule from path {data_cache_path}") + self.load_cached_datamodule(data_cache_path, stage=stage) + _logger.warning("Load completed!") + else: + _logger.warning( + f"Cached datamodule not found. Processing datamodule {data_cache_path}" + ) + self.process_data(stage=stage) + _logger.warning(f"Saving cached datamodule at path {data_cache_path}") + self.cache_datamodule(data_cache_path, stage=stage) + else: + self.process_data(stage=stage) def process_data(self, stage: Optional[str] = None) -> None: assert isinstance(self.dataset_raw, datasets.DatasetDict) @@ -165,9 +214,6 @@ def process_data(self, stage: Optional[str] = None) -> None: {k: v for k, v in self.dataset.items() if k in {"train", "validation"}} ) - if stage is None: - return - columns = [c for c in self.dataset["train"].column_names if c not in self.LOADER_COLUMNS] for split in self.dataset.keys(): @@ -185,11 +231,11 @@ def process_data(self, stage: Optional[str] = None) -> None: batch_size=self.processing_batch_size, remove_columns=columns, ) - - self.overflow_to_sample_mapping[split] = self.dataset[split][ - "overflow_to_sample_mapping" - ] - self.offset_mapping[split] = self.dataset[split]["offset_mapping"] + if stage != "fit": + self.overflow_to_sample_mapping[split] = self.dataset[split][ + "overflow_to_sample_mapping" + ] + self.offset_mapping[split] = self.dataset[split]["offset_mapping"] self.dataset[split] = self.dataset[split].remove_columns( ["offset_mapping", "overflow_to_sample_mapping"] ) @@ -233,7 +279,7 @@ def setup(self, stage: Optional[str] = None) -> None: self.dataset_raw["train"].info.version = Version("0.0.1") self.data_loader_has_setup = True if self.processed_data_stage and (self.processed_data_stage != stage): - self.process_data(stage=stage) + self.process_data_with_cache(stage=stage) self.processed_data_stage = stage @property @@ -248,5 +294,5 @@ def _class_encode_column(self, column_name: str) -> None: def test_dataloader(self) -> DataLoader[HuggingFaceDataset]: if "test" in self.splits and not "test" in self.dataset.keys(): - self.process_data(stage="test") + self.process_data_with_cache(stage="test") return super().test_dataloader() diff --git a/embeddings/evaluator/evaluation_results.py b/embeddings/evaluator/evaluation_results.py index ef160b18..c5260fbc 100644 --- a/embeddings/evaluator/evaluation_results.py +++ b/embeddings/evaluator/evaluation_results.py @@ -110,3 +110,13 @@ class QuestionAnsweringEvaluationResults(EvaluationResults): NoAns_f1: Optional[float] = None NoAns_total: Optional[float] = None data: Optional[Data] = None + golds_text: Optional[Union[List[List[str]], List[str]]] = None + predictions_text: Optional[List[str]] = None + + @property + def metrics(self) -> Dict[str, Any]: + result = asdict(self) + result.pop("data") + result.pop("golds_text") + result.pop("predictions_text") + return result diff --git a/embeddings/evaluator/question_answering_evaluator.py b/embeddings/evaluator/question_answering_evaluator.py index c37b90d1..69614be1 100644 --- a/embeddings/evaluator/question_answering_evaluator.py +++ b/embeddings/evaluator/question_answering_evaluator.py @@ -26,9 +26,22 @@ def __init__(self, no_answer_threshold: float = 1.0): def metrics( self, - ) -> Dict[str, Metric[Union[List[Any], nptyping.NDArray[Any], torch.Tensor], Dict[Any, Any]]]: + ) -> Dict[str, Metric[Union[List[Any], nptyping.NDArray[Any], torch.Tensor], Dict[Any, Any]],]: return {} + @staticmethod + def get_golds_text(references: List[QA_GOLD_ANSWER_TYPE]) -> Union[List[List[str]], List[str]]: + golds_text = [] + for ref in references: + answers = ref["answers"] + assert isinstance(answers, dict) + golds_text.append(answers["text"]) + return golds_text + + @staticmethod + def get_predictions_text(predictions: List[QA_PREDICTED_ANSWER_TYPE]) -> List[str]: + return [str(it["prediction_text"]) for it in predictions] + def evaluate( self, data: Union[Dict[str, nptyping.NDArray[Any]], Predictions, Dict[str, Any]] ) -> QuestionAnsweringEvaluationResults: @@ -51,5 +64,9 @@ def evaluate( {"id": it_id, **it["predicted_answer"]} for it_id, it in enumerate(outputs) ] metrics = SQUADv2Metric().calculate(predictions=predictions, references=references) + gold_texts = QuestionAnsweringEvaluator.get_golds_text(references) + predictions_text = QuestionAnsweringEvaluator.get_predictions_text(predictions) - return QuestionAnsweringEvaluationResults(data=outputs, **metrics) + return QuestionAnsweringEvaluationResults( + data=outputs, golds_text=gold_texts, predictions_text=predictions_text, **metrics + ) diff --git a/embeddings/model/lightning_module/question_answering.py b/embeddings/model/lightning_module/question_answering.py index 4ba662dd..52c7bd85 100644 --- a/embeddings/model/lightning_module/question_answering.py +++ b/embeddings/model/lightning_module/question_answering.py @@ -18,9 +18,17 @@ class QuestionAnsweringInferenceModule(pl.LightningModule): - def __init__(self, model_name: str, devices: str = "auto", accelerator: str = "auto") -> None: + def __init__( + self, + model_name: str, + devices: str = "auto", + accelerator: str = "auto", + use_auth_token: bool = False, + ) -> None: super().__init__() - self.model = AutoModelForQuestionAnswering.from_pretrained(model_name) + self.model = AutoModelForQuestionAnswering.from_pretrained( + model_name, use_auth_token=use_auth_token + ) self.trainer = pl.Trainer(devices=devices, accelerator=accelerator) def predict_step( diff --git a/embeddings/pipeline/lightning_pipeline.py b/embeddings/pipeline/lightning_pipeline.py index ddd5014c..6d854175 100644 --- a/embeddings/pipeline/lightning_pipeline.py +++ b/embeddings/pipeline/lightning_pipeline.py @@ -8,7 +8,7 @@ from embeddings.evaluator.evaluator import Evaluator from embeddings.model.model import Model from embeddings.pipeline.pipeline import Pipeline -from embeddings.utils.loggers import LightningLoggingConfig, WandbWrapper +from embeddings.utils.loggers import LightningLoggingConfig, LightningWandbWrapper from embeddings.utils.utils import get_installed_packages, standardize_name EvaluationResult = TypeVar("EvaluationResult") @@ -46,25 +46,34 @@ def __init__( self.pipeline_kwargs = pipeline_kwargs self.pipeline_kwargs.pop("self") self.pipeline_kwargs.pop("pipeline_kwargs") + self.result: Optional[EvaluationResult] = None def run(self, run_name: Optional[str] = None) -> EvaluationResult: if run_name: run_name = standardize_name(run_name) self._save_artifacts() model_result = self.model.execute(data=self.datamodule, run_name=run_name) - result = self.evaluator.evaluate(model_result) + self.result = self.evaluator.evaluate(model_result) + self._save_metrics() self._finish_logging() - return result + return self.result def _save_artifacts(self) -> None: srsly.write_json(self.output_path / "packages.json", get_installed_packages()) with open(self.output_path / "pipeline_config.yaml", "w") as f: yaml.dump(self.pipeline_kwargs, stream=f) + def _save_metrics(self) -> None: + metrics = getattr(self.result, "metrics") + with open(self.output_path / "metrics.yaml", "w") as f: + yaml.dump(metrics, stream=f) + def _finish_logging(self) -> None: if self.logging_config.use_wandb(): - logger = WandbWrapper() - logger.log_output( + wrapper = LightningWandbWrapper(self.logging_config) + wrapper.log_output( self.output_path, ignore={"wandb", "csv", "tensorboard", "checkpoints"} ) - logger.finish_logging() + metrics = getattr(self.result, "metrics") + wrapper.log_metrics(metrics) + wrapper.finish_logging() diff --git a/embeddings/pipeline/lightning_question_answering.py b/embeddings/pipeline/lightning_question_answering.py index b87a20c6..3718a518 100644 --- a/embeddings/pipeline/lightning_question_answering.py +++ b/embeddings/pipeline/lightning_question_answering.py @@ -2,6 +2,8 @@ from typing import Any, Dict, List, Optional, Union import datasets +import pandas as pd +import wandb from pytorch_lightning.accelerators import Accelerator from embeddings.config.lightning_config import LightningQABasicConfig, LightningQAConfig @@ -13,7 +15,8 @@ from embeddings.model.lightning_model import LightningModel from embeddings.pipeline.lightning_pipeline import LightningPipeline from embeddings.task.lightning_task import question_answering as qa -from embeddings.utils.loggers import LightningLoggingConfig +from embeddings.utils.loggers import LightningLoggingConfig, LightningWandbWrapper +from embeddings.utils.utils import convert_qa_df_to_bootstrap_html class LightningQuestionAnsweringPipeline( @@ -72,6 +75,7 @@ def __init__( early_stopping_kwargs=config.early_stopping_kwargs, model_checkpoint_kwargs=model_checkpoint_kwargs, compile_model_kwargs=compile_model_kwargs, + logging_config=logging_config, ) model: LightningModel[QuestionAnsweringDataModule, Dict[str, Any]] = LightningModel( task, predict_subset @@ -85,3 +89,21 @@ def __init__( logging_config, pipeline_kwargs=pipeline_kwargs, ) + + def _finish_logging(self) -> None: + if self.logging_config.use_wandb(): + wrapper = LightningWandbWrapper(self.logging_config) + wrapper.log_output( + self.output_path, ignore={"wandb", "csv", "tensorboard", "checkpoints"} + ) + metrics = getattr(self.result, "metrics") + wrapper.log_metrics(metrics) + + predictions_text = getattr(self.result, "predictions_text") + golds_text = getattr(self.result, "golds_text") + preditions_html = convert_qa_df_to_bootstrap_html( + pd.DataFrame({"predictions": predictions_text, "golds": golds_text}) + ) + + wrapper.wandb_logger.experiment.log({"predictions": wandb.Html(preditions_html)}) + wrapper.finish_logging() diff --git a/embeddings/task/lightning_task/lightning_task.py b/embeddings/task/lightning_task/lightning_task.py index 3d3c1939..528801cc 100644 --- a/embeddings/task/lightning_task/lightning_task.py +++ b/embeddings/task/lightning_task/lightning_task.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar, Union import pytorch_lightning as pl +from pytorch_lightning.accelerators import Accelerator from pytorch_lightning.callbacks import Callback, ModelCheckpoint from pytorch_lightning.callbacks.early_stopping import EarlyStopping from torch.utils.data import DataLoader @@ -52,6 +53,22 @@ def __init__( self.trainer: Optional[pl.Trainer] = None self.logging_config = logging_config self.tokenizer: Optional[AutoTokenizer] = None + self.callbacks: List[Callback] = [] + + self.inference_mode = ( + self.task_train_kwargs.pop("inference_mode") + if "inference_mode" in self.task_train_kwargs.keys() + else None + ) + if isinstance(self.compile_model_kwargs, dict): + _logger.warning( + "PyTorch 2.0 compile mode is turned on! Pass None to compile_model_kwargs if the behavior is unintended." + ) + if self.inference_mode or self.inference_mode is None: + _logger.warning( + "PyTorch 2.0 compile mode does not support inference_mode! Setting Lightning Trainer inference_mode to False!" + ) + self.inference_mode = False @property def best_epoch(self) -> Optional[float]: @@ -87,6 +104,32 @@ def _get_callbacks(self, dataset_subsets: Sequence[str]) -> List[Callback]: callbacks.append(EarlyStopping(**self.early_stopping_kwargs)) return callbacks + def setup_trainer( + self, + run_name: str, + accelerator: Optional[Union[str, Accelerator]] = None, + devices: Optional[Union[List[int], str, int]] = None, + ) -> None: + if self.trainer: + del self.trainer + cleanup_torch_model_artifacts() + + accelerator = accelerator if accelerator else self.task_train_kwargs["accelerator"] + devices = devices if devices else self.task_train_kwargs["devices"] + task_train_kwargs = { + k: v for k, v in self.task_train_kwargs.items() if k not in ("accelerator", "devices") + } + + self.trainer = pl.Trainer( + default_root_dir=str(self.output_path), + callbacks=self.callbacks, + logger=self.logging_config.get_lightning_loggers(run_name=run_name), + inference_mode=self.inference_mode, + accelerator=accelerator, + devices=devices, + **task_train_kwargs, + ) + def fit( self, data: LightningDataModule, @@ -95,31 +138,9 @@ def fit( if not self.model: raise self.MODEL_UNDEFINED_EXCEPTION self.tokenizer = data.tokenizer - - callbacks = self._get_callbacks(dataset_subsets=list(data.load_dataset().keys())) - - inference_mode = ( - self.task_train_kwargs.pop("inference_mode") - if "inference_mode" in self.task_train_kwargs.keys() - else None - ) - if isinstance(self.compile_model_kwargs, dict): - _logger.warning( - "PyTorch 2.0 compile mode is turned on! Pass None to compile_model_kwargs if the behavior is unintended." - ) - if inference_mode or inference_mode is None: - _logger.warning( - "PyTorch 2.0 compile mode does not support inference_mode! Setting Lightning Trainer inference_mode to False!" - ) - inference_mode = False - - self.trainer = pl.Trainer( - default_root_dir=str(self.output_path), - callbacks=callbacks, - logger=self.logging_config.get_lightning_loggers(self.output_path, run_name), - inference_mode=inference_mode, - **self.task_train_kwargs, - ) + self.callbacks = self._get_callbacks(dataset_subsets=list(data.load_dataset().keys())) + self.setup_trainer(run_name=run_name if run_name else "") + assert isinstance(self.trainer, pl.Trainer) try: self.trainer.fit(self.model, data) except Exception as e: @@ -200,6 +221,13 @@ def fit_predict( self.fit(data, run_name=run_name) dataloader = data.get_subset(subset=predict_subset) assert isinstance(dataloader, DataLoader) + assert isinstance(self.trainer, pl.Trainer) + if isinstance(self.trainer.strategy, pl.strategies.ddp.DDPStrategy): + self.setup_trainer( + run_name=run_name if run_name else "", + accelerator="gpu", + devices=[0], # made predict only on single gpu, + ) result = self.predict(dataloader=dataloader) return result diff --git a/embeddings/task/lightning_task/question_answering.py b/embeddings/task/lightning_task/question_answering.py index 2950911b..3eaeb23d 100644 --- a/embeddings/task/lightning_task/question_answering.py +++ b/embeddings/task/lightning_task/question_answering.py @@ -27,6 +27,7 @@ def __init__( task_train_kwargs: Dict[str, Any], early_stopping_kwargs: Dict[str, Any], model_checkpoint_kwargs: Dict[str, Any], + logging_config: LightningLoggingConfig, finetune_last_n_layers: int = -1, compile_model_kwargs: Optional[Dict[str, Any]] = None, ) -> None: @@ -42,7 +43,7 @@ def __init__( early_stopping_kwargs=early_stopping_kwargs, model_checkpoint_kwargs=model_checkpoint_kwargs, compile_model_kwargs=compile_model_kwargs, - logging_config=LightningLoggingConfig.from_flags(), + logging_config=logging_config, hf_task_name=HuggingFaceTaskName.question_answering, ) self.model_name_or_path = model_name_or_path @@ -63,7 +64,9 @@ def build_task_model(self) -> None: def predict(self, dataloader: Any, return_names: bool = True) -> Any: assert self.model is not None assert self.trainer is not None - return self.trainer.predict(model=self.model, dataloaders=dataloader) + return self.trainer.predict( + model=self.model, dataloaders=dataloader, return_predictions=True + ) @staticmethod def postprocess_outputs( @@ -91,6 +94,13 @@ def fit_predict( dataloader = data.get_subset(subset=predict_subset) assert isinstance(dataloader, DataLoader) + assert isinstance(self.trainer, pl.Trainer) + if isinstance(self.trainer.strategy, pl.strategies.ddp.DDPStrategy): + self.setup_trainer( + run_name=run_name if run_name else "", + accelerator="gpu", + devices=[0], # made predict only on single gpu, + ) model_outputs = self.predict(dataloader=dataloader) result = self.postprocess_outputs( model_outputs=model_outputs, data=data, predict_subset=predict_subset diff --git a/embeddings/transformation/lightning_transformation/question_answering_output_transformation.py b/embeddings/transformation/lightning_transformation/question_answering_output_transformation.py index 3493b412..46be5952 100644 --- a/embeddings/transformation/lightning_transformation/question_answering_output_transformation.py +++ b/embeddings/transformation/lightning_transformation/question_answering_output_transformation.py @@ -127,7 +127,7 @@ def _get_predicted_text_from_context( def _get_softmax_scores_with_sort(predictions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: scores = torch.from_numpy(np.array([pred.pop("score") for pred in predictions])) # Module torch.functional does not explicitly export attritube "F" - softmax_scores = torch.functional.F.softmax(scores) # type: ignore[attr-defined] + softmax_scores = torch.functional.F.softmax(scores, dim=0) # type: ignore[attr-defined] for prob, pred in zip(softmax_scores, predictions): pred["softmax_score"] = prob # mypy thinks the function only returns Any @@ -162,7 +162,7 @@ def _postprocess_example( end_logits=end_logits[1:], offset_mapping=offset_mappings[output_index], ) - # Argument 1 to "append" of "list" has incompatible type "Optional[Dict[str, object]]"; expected "Dict[str, Any]" + predictions.append(min_no_answer_score) # type: ignore[arg-type] # mypy thinks the function only returns Any predictions = sorted(predictions, key=lambda x: x["score"], reverse=True)[ # type: ignore[no-any-return] diff --git a/embeddings/utils/loggers.py b/embeddings/utils/loggers.py index 73a5e78a..48d672a0 100644 --- a/embeddings/utils/loggers.py +++ b/embeddings/utils/loggers.py @@ -7,6 +7,7 @@ import wandb from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.loggers.wandb import WandbLogger from typing_extensions import Literal from embeddings.data.io import T_path @@ -30,10 +31,12 @@ def get_logger(name: str, log_level: Union[str, int] = DEFAULT_LOG_LEVEL) -> log @dataclass class LightningLoggingConfig: + output_path: Union[Path, str] = "." loggers_names: List[Literal["wandb", "csv", "tensorboard"]] = field(default_factory=list) tracking_project_name: Optional[str] = None wandb_entity: Optional[str] = None wandb_logger_kwargs: Dict[str, Any] = field(default_factory=dict) + loggers: Optional[Dict[str, pl_loggers.Logger]] = field(init=False, default=None) def __post_init__(self) -> None: if "wandb" not in self.loggers_names and ( @@ -80,48 +83,41 @@ def use_tensorboard(self) -> bool: def get_lightning_loggers( self, - output_path: T_path, run_name: Optional[str] = None, ) -> List[pl_loggers.Logger]: """Based on configuration, provides pytorch-lightning loggers' callbacks.""" - output_path = Path(output_path) - loggers: List[pl_loggers.Logger] = [] + if not self.loggers: + self.output_path = Path(self.output_path) + self.loggers = {} - if self.use_tensorboard(): - loggers.append( - pl_loggers.TensorBoardLogger( + if self.use_tensorboard(): + self.loggers["tensorboard"] = pl_loggers.TensorBoardLogger( name=run_name, - save_dir=str(output_path.joinpath("tensorboard")), + save_dir=str(self.output_path / "tensorboard"), ) - ) - if self.use_wandb(): - if not self.tracking_project_name: - raise ValueError( - "Tracking project name is not passed. Pass tracking_project_name argument!" - ) - save_dir = output_path.joinpath("wandb") - save_dir.mkdir(exist_ok=True) - loggers.append( - pl_loggers.wandb.WandbLogger( + if self.use_wandb(): + if not self.tracking_project_name: + raise ValueError( + "Tracking project name is not passed. Pass tracking_project_name argument!" + ) + save_dir = self.output_path / "wandb" + save_dir.mkdir(exist_ok=True, parents=True) + self.loggers["wandb"] = pl_loggers.wandb.WandbLogger( name=run_name, save_dir=str(save_dir), project=self.tracking_project_name, entity=self.wandb_entity, - reinit=True, **self.wandb_logger_kwargs ) - ) - if self.use_csv(): - loggers.append( - pl_loggers.CSVLogger( + if self.use_csv(): + self.loggers["csv"] = pl_loggers.CSVLogger( name=run_name if run_name else "", - save_dir=str(output_path.joinpath("csv")), + save_dir=self.output_path / "csv", ) - ) - return loggers + return list(self.loggers.values()) class ExperimentLogger(abc.ABC): @@ -170,3 +166,33 @@ def log_artifact(self, paths: Iterable[T_path], artifact_name: str, artifact_typ for path in paths: artifact.add_file(path) wandb.log_artifact(artifact) + + +class LightningWandbWrapper: + def __init__(self, logging_config: LightningLoggingConfig) -> None: + assert logging_config.use_wandb() + assert isinstance(logging_config.loggers, dict) + assert "wandb" in logging_config.loggers + assert isinstance(logging_config.loggers["wandb"], WandbLogger) + self.wandb_logger: WandbLogger = logging_config.loggers["wandb"] + + def log_output( + self, + output_path: T_path, + ignore: Optional[Iterable[str]] = None, + ) -> None: + for entry in os.scandir(output_path): + if not ignore or entry.name not in ignore: + self.wandb_logger.experiment.save(entry.path, output_path) + + def log_metrics(self, metrics: Dict[str, Any]) -> None: + self.wandb_logger.log_metrics(metrics) + + def finish_logging(self) -> None: + self.wandb_logger.experiment.finish() + + def log_artifact(self, paths: Iterable[T_path], artifact_name: str, artifact_type: str) -> None: + artifact = wandb.Artifact(name=artifact_name, type=artifact_type) + for path in paths: + artifact.add_file(path) + self.wandb_logger.experiment.log_artifact(artifact) diff --git a/embeddings/utils/model_exporter.py b/embeddings/utils/model_exporter.py index bc1acf4b..7e377054 100644 --- a/embeddings/utils/model_exporter.py +++ b/embeddings/utils/model_exporter.py @@ -38,7 +38,9 @@ class BaseModelExporter(abc.ABC): def __post_init__(self, path: T_path) -> None: self._export_path = pathlib.Path(path) if not isinstance(path, pathlib.Path) else path - self._export_path.mkdir(parents=True, exist_ok=False) + if self._export_path.exists(): + raise ValueError(f"Path {str(self._export_path)} already exists!") + self._export_path.mkdir(parents=True, exist_ok=True) @staticmethod def _check_tokenizer(tokenizer: Optional[AutoTokenizer]) -> None: diff --git a/embeddings/utils/utils.py b/embeddings/utils/utils.py index 923cd487..0bb8774f 100644 --- a/embeddings/utils/utils.py +++ b/embeddings/utils/utils.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np +import pandas as pd import pkg_resources import requests import yaml @@ -152,3 +153,32 @@ def compress_and_remove(filepath: T_path) -> None: ) as arc: arc.write(filepath, arcname=filepath.name) filepath.unlink() + + +def convert_qa_df_to_bootstrap_html(df: pd.DataFrame) -> str: + boostrap_cdn = ( + '' + ) + + output = ( + "" + + "\n" + + "" + + "\n" + + "" + + "\n" + + boostrap_cdn + + "\n" + + '' + + "\n" + + "" + + "\n" + + "" + + "\n" + + df.to_html(classes=["table table-bordered table-striped table-hover"]) + + "\n" + + "" + ) + assert isinstance(output, str) + return output diff --git a/tests/test_lightning_question_answering_pipeline.py b/tests/test_lightning_question_answering_pipeline.py index 8dafe403..e0ea48f0 100644 --- a/tests/test_lightning_question_answering_pipeline.py +++ b/tests/test_lightning_question_answering_pipeline.py @@ -145,6 +145,7 @@ def hf_datamodule( load_dataset_kwargs={}, dataloader_kwargs=config.dataloader_kwargs, doc_stride=config.task_model_kwargs["doc_stride"], + use_cache=False, **config.datamodule_kwargs, ) dm.setup("test") diff --git a/tests/test_question_answering_evaluator.py b/tests/test_question_answering_evaluator.py index 9ae187bc..678c3050 100644 --- a/tests/test_question_answering_evaluator.py +++ b/tests/test_question_answering_evaluator.py @@ -11,6 +11,7 @@ from embeddings.data.qa_datamodule import QuestionAnsweringDataModule from embeddings.evaluator.question_answering_evaluator import QuestionAnsweringEvaluator from embeddings.task.lightning_task.question_answering import QuestionAnsweringTask +from embeddings.utils.loggers import LightningLoggingConfig from tests.fixtures.sample_qa_dataset import sample_question_answering_dataset @@ -39,6 +40,7 @@ def question_answering_data_module( "return_offsets_mapping": True, "return_overflowing_tokens": True, }, + use_cache=False, ) @@ -67,6 +69,7 @@ def question_answering_task() -> QuestionAnsweringTask: }, early_stopping_kwargs={"monitor": "val/Loss", "patience": 1, "mode": "min"}, model_checkpoint_kwargs={}, + logging_config=LightningLoggingConfig.from_flags(), ) @@ -81,7 +84,7 @@ def scores( datamodule.setup(stage="fit") task.build_task_model() task.fit(datamodule) - datamodule.process_data(stage="test") + datamodule.process_data_with_cache(stage="test") predict_dataloader = datamodule.predict_dataloader() dataloader = predict_dataloader[0]