From 24cd023922d288e817ac8e7c556a15eb6c076862 Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Fri, 2 Aug 2024 18:10:03 -0700 Subject: [PATCH 1/3] feat: add chunking function to allow sequence tagger training on sentences exceeding the token limit, including tests --- flair/class_utils.py | 9 +- flair/training_utils.py | 217 +++++++++++++++++---- requirements-dev.txt | 2 +- tests/resources/text_sequences/resume1.txt | 85 ++++++++ tests/test_sentence_labeling.py | 191 ++++++++++++++++++ 5 files changed, 461 insertions(+), 43 deletions(-) create mode 100644 tests/resources/text_sequences/resume1.txt create mode 100644 tests/test_sentence_labeling.py diff --git a/flair/class_utils.py b/flair/class_utils.py index 7e01f4ff42..ec6666c99f 100644 --- a/flair/class_utils.py +++ b/flair/class_utils.py @@ -2,12 +2,17 @@ import inspect from collections.abc import Iterable from types import ModuleType -from typing import Any, Optional, TypeVar, Union, overload +from typing import Any, Iterable, List, Optional, Protocol, Type, TypeVar, Union, overload + T = TypeVar("T") -def get_non_abstract_subclasses(cls: type[T]) -> Iterable[type[T]]: +class StringLike(Protocol): + def __str__(self) -> str: ... + + +def get_non_abstract_subclasses(cls: Type[T]) -> Iterable[Type[T]]: for subclass in cls.__subclasses__(): yield from get_non_abstract_subclasses(subclass) if inspect.isabstract(subclass): diff --git a/flair/training_utils.py b/flair/training_utils.py index 9b38ec1ddb..69430a17b3 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -1,22 +1,27 @@ import logging +import pathlib import random from collections import defaultdict from enum import Enum from functools import reduce from math import inf from pathlib import Path -from typing import Literal, Optional, Union +from typing import Dict, List, Literal, NamedTuple, Optional, Union +from numpy import ndarray from scipy.stats import pearsonr, spearmanr +from scipy.stats._stats_py import PearsonRResult, SignificanceResult from sklearn.metrics import mean_absolute_error, mean_squared_error from torch.optim import Optimizer from torch.utils.data import Dataset import flair -from flair.data import DT, Dictionary, Sentence, _iter_dataset +from flair.class_utils import StringLike +from flair.data import DT, Dictionary, Sentence, Token, _iter_dataset EmbeddingStorageMode = Literal["none", "cpu", "gpu"] -log = logging.getLogger("flair") +MinMax = Literal["min", "max"] +logger = logging.getLogger("flair") class Result: @@ -33,7 +38,7 @@ def __init__( self.main_score: float = main_score self.scores = scores self.detailed_results: str = detailed_results - self.classification_report = classification_report + self.classification_report = classification_report if classification_report is not None else {} @property def loss(self): @@ -44,40 +49,36 @@ def __str__(self) -> str: class MetricRegression: - def __init__(self, name) -> None: + def __init__(self, name: str) -> None: self.name = name self.true: list[float] = [] self.pred: list[float] = [] - def mean_squared_error(self): + def mean_squared_error(self) -> Union[float, ndarray]: return mean_squared_error(self.true, self.pred) def mean_absolute_error(self): return mean_absolute_error(self.true, self.pred) - def pearsonr(self): + def pearsonr(self) -> PearsonRResult: return pearsonr(self.true, self.pred)[0] - def spearmanr(self): + def spearmanr(self) -> SignificanceResult: return spearmanr(self.true, self.pred)[0] - # dummy return to fulfill trainer.train() needs - def micro_avg_f_score(self): - return self.mean_squared_error() - - def to_tsv(self): + def to_tsv(self) -> str: return f"{self.mean_squared_error()}\t{self.mean_absolute_error()}\t{self.pearsonr()}\t{self.spearmanr()}" @staticmethod - def tsv_header(prefix=None): + def tsv_header(prefix: StringLike = None) -> str: if prefix: return f"{prefix}_MEAN_SQUARED_ERROR\t{prefix}_MEAN_ABSOLUTE_ERROR\t{prefix}_PEARSON\t{prefix}_SPEARMAN" return "MEAN_SQUARED_ERROR\tMEAN_ABSOLUTE_ERROR\tPEARSON\tSPEARMAN" @staticmethod - def to_empty_tsv(): + def to_empty_tsv() -> str: return "\t_\t_\t_\t_" def __str__(self) -> str: @@ -101,13 +102,13 @@ def __init__(self, directory: Union[str, Path], number_of_weights: int = 10) -> self.weights_dict: dict[str, dict[int, list[float]]] = defaultdict(lambda: defaultdict(list)) self.number_of_weights = number_of_weights - def extract_weights(self, state_dict, iteration): + def extract_weights(self, state_dict: Dict, iteration: int) -> None: for key in state_dict: vec = state_dict[key] - # print(vec) try: weights_to_watch = min(self.number_of_weights, reduce(lambda x, y: x * y, list(vec.size()))) - except Exception: + except Exception as e: + logger.debug(e) continue if key not in self.weights_dict: @@ -195,15 +196,15 @@ class AnnealOnPlateau: def __init__( self, optimizer, - mode="min", - aux_mode="min", - factor=0.1, - patience=10, - initial_extra_patience=0, - verbose=False, - cooldown=0, - min_lr=0, - eps=1e-8, + mode: MinMax = "min", + aux_mode: MinMax = "min", + factor: float = 0.1, + patience: int = 10, + initial_extra_patience: int = 0, + verbose: bool = False, + cooldown: int = 0, + min_lr: float = 0.0, + eps: float = 1e-8, ) -> None: if factor >= 1.0: raise ValueError("Factor should be < 1.0.") @@ -214,6 +215,7 @@ def __init__( raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") self.optimizer = optimizer + self.min_lrs: List[float] if isinstance(min_lr, (list, tuple)): if len(min_lr) != len(optimizer.param_groups): raise ValueError(f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}") @@ -231,7 +233,7 @@ def __init__( self.best = None self.best_aux = None self.num_bad_epochs = None - self.mode_worse = None # the worse value for the chosen mode + self.mode_worse: Optional[float] = None # the worse value for the chosen mode self.eps = eps self.last_epoch = 0 self._init_is_better(mode=mode) @@ -258,7 +260,7 @@ def step(self, metric, auxiliary_metric=None) -> bool: if self.mode == "max" and current > self.best: is_better = True - if current == self.best and auxiliary_metric: + if current == self.best and auxiliary_metric is not None: current_aux = float(auxiliary_metric) if self.aux_mode == "min" and current_aux < self.best_aux: is_better = True @@ -289,20 +291,20 @@ def step(self, metric, auxiliary_metric=None) -> bool: return reduce_learning_rate - def _reduce_lr(self, epoch): + def _reduce_lr(self, epoch: int) -> None: for i, param_group in enumerate(self.optimizer.param_groups): old_lr = float(param_group["lr"]) new_lr = max(old_lr * self.factor, self.min_lrs[i]) if old_lr - new_lr > self.eps: param_group["lr"] = new_lr if self.verbose: - log.info(f" - reducing learning rate of group {epoch} to {new_lr}") + logger.info(f" - reducing learning rate of group {epoch} to {new_lr}") @property def in_cooldown(self): return self.cooldown_counter > 0 - def _init_is_better(self, mode): + def _init_is_better(self, mode: MinMax) -> None: if mode not in {"min", "max"}: raise ValueError("mode " + mode + " is unknown!") @@ -313,10 +315,10 @@ def _init_is_better(self, mode): self.mode = mode - def state_dict(self): + def state_dict(self) -> Dict: return {key: value for key, value in self.__dict__.items() if key != "optimizer"} - def load_state_dict(self, state_dict): + def load_state_dict(self, state_dict: Dict) -> None: self.__dict__.update(state_dict) self._init_is_better(mode=self.mode) @@ -350,11 +352,11 @@ def convert_labels_to_one_hot(label_list: list[list[str]], label_dict: Dictionar return [[1 if label in labels else 0 for label in label_dict.get_items()] for labels in label_list] -def log_line(log): +def log_line(log: logging.Logger) -> None: log.info("-" * 100, stacklevel=3) -def add_file_handler(log, output_file): +def add_file_handler(log: logging.Logger, output_file: pathlib.Path) -> logging.FileHandler: init_output_file(output_file.parents[0], output_file.name) fh = logging.FileHandler(output_file, mode="w", encoding="utf-8") fh.setLevel(logging.INFO) @@ -367,12 +369,20 @@ def add_file_handler(log, output_file): def store_embeddings( data_points: Union[list[DT], Dataset], storage_mode: EmbeddingStorageMode, - dynamic_embeddings: Optional[list[str]] = None, -): + dynamic_embeddings: Optional[List[str]] = None, +) -> None: + """Stores embeddings of data points in memory or on disk. + + Args: + data_points: a DataSet or list of DataPoints for which embeddings should be stored + storage_mode: store in either CPU or GPU memory, or delete them if set to 'none' + dynamic_embeddings: these are always deleted. If not passed, they are identified automatically. + """ + if isinstance(data_points, Dataset): data_points = list(_iter_dataset(data_points)) - # if memory mode option 'none' delete everything + # if storage mode option 'none' delete everything if storage_mode == "none": dynamic_embeddings = None @@ -391,7 +401,7 @@ def store_embeddings( data_point.to("cpu", pin_memory=pin_memory) -def identify_dynamic_embeddings(data_points: list[DT]) -> Optional[list[str]]: +def identify_dynamic_embeddings(data_points: List[DT]) -> Optional[List[str]]: dynamic_embeddings = [] all_embeddings = [] for data_point in data_points: @@ -411,3 +421,130 @@ def identify_dynamic_embeddings(data_points: list[DT]) -> Optional[list[str]]: if not all_embeddings: return None return list(set(dynamic_embeddings)) + + +class TokenEntity(NamedTuple): + """Entity represented by token indices.""" + + start_token_idx: int + end_token_idx: int + label: str + value: str = "" # text value of the entity + score: float = 1.0 + + +class CharEntity(NamedTuple): + """Entity represented by character indices.""" + + start_char_idx: int + end_char_idx: int + label: str + value: str + score: float = 1.0 + + +def create_labeled_sentence_from_tokens( + tokens: Union[List[Token]], token_entities: List[TokenEntity], type_name: str = "ner" +) -> Sentence: + """Creates a new Sentence object from a list of tokens or strings and applies entity labels. + + Tokens are recreated with the same text, but not attached to the previous sentence. + + Args: + tokens: a list of Token objects or strings - only the text is used, not any labels + token_entities: a list of TokenEntity objects representing entity annotations + type_name: the type of entity label to apply + Returns: + A labeled Sentence object + """ + tokens = [Token(token.text) for token in tokens] # create new tokens that do not already belong to a sentence + sentence = Sentence(tokens, use_tokenizer=True) + for entity in token_entities: + sentence[entity.start_token_idx : entity.end_token_idx].add_label(type_name, entity.label, score=entity.score) + return sentence + + +def create_sentence_chunks( + text: str, + entities: List[CharEntity], + token_limit: int = 512, + use_context: bool = True, + overlap: int = 0, # TODO: implement overlap +) -> List[Sentence]: + """Chunks and labels a text from a list of entity annotations. + + The function explicitly tokenizes the text and labels separately, ensuring entity labels are + not partially split across tokens. + + Args: + text (str): The full text to be tokenized and labeled. + entities (list of tuples): Ordered non-overlapping entity annotations with each tuple in the + format (start_char_index, end_char_index, entity_class, entity_text). + token_limit: numerical value that determines the maximum size of a chunk. use inf to not perform chunking + use_context: whether to add context to the sentence + overlap: the size of overlap between chunks, repeating the last n tokens of previous chunk to preserve context + + Returns: + A list of labeled Sentence objects representing the chunks of the original text + """ + chunks = [] + + tokens: List[Token] = [] + current_index = 0 + token_entities: List[TokenEntity] = [] + end_token_idx = 0 + + for entity in entities: + + if entity.start_char_idx > current_index: # add non-entity text + non_entity_tokens = Sentence(text[current_index : entity.start_char_idx]).tokens + while end_token_idx + len(non_entity_tokens) > token_limit: + num_tokens = token_limit - len(tokens) + tokens.extend(non_entity_tokens[:num_tokens]) + non_entity_tokens = non_entity_tokens[num_tokens:] + # skip any fully negative samples, they cause fine_tune to fail with + # `torch.cat(): expected a non-empty list of Tensors` + if len(token_entities) > 0: + chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) + tokens, token_entities = [], [] + end_token_idx = 0 + tokens.extend(non_entity_tokens) + + # add new entity tokens + start_token_idx = len(tokens) + entity_sentence = Sentence(text[entity.start_char_idx : entity.end_char_idx]) + if len(entity_sentence) > token_limit: + logger.warning(f"Entity length is greater than token limit! {len(entity_sentence)} > {token_limit}") + end_token_idx = start_token_idx + len(entity_sentence) + + if end_token_idx >= token_limit: # create chunk from existing and add this entity to next chunk + chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) + + tokens, token_entities = [], [] + start_token_idx, end_token_idx = 0, len(entity_sentence) + + token_entity = TokenEntity(start_token_idx, end_token_idx, entity.label, entity.value, entity.score) + token_entities.append(token_entity) + tokens.extend(entity_sentence) + + current_index = entity.end_char_idx + + # add any remaining tokens to a new chunk + if current_index < len(text): + remaining_sentence = Sentence(text[current_index:]) + if end_token_idx + len(remaining_sentence) > token_limit: + chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) + tokens, token_entities = [], [] + tokens.extend(remaining_sentence) + + if tokens: + chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) + + for chunk in chunks: + if len(chunk) > token_limit: + logger.warning(f"Chunk size is longer than token limit: {len(chunk)} > {token_limit}") + + if use_context: + Sentence.set_context_for_sentences(chunks) + + return chunks diff --git a/requirements-dev.txt b/requirements-dev.txt index 3b8fbde79c..8053d231b8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -12,4 +12,4 @@ types-Deprecated>=1.2.9.2 types-requests>=2.28.11.17 types-tabulate>=0.9.0.2 pyab3p -transformers!=4.40.1,!=4.40.0 \ No newline at end of file +transformers!=4.40.1,!=4.40.0 diff --git a/tests/resources/text_sequences/resume1.txt b/tests/resources/text_sequences/resume1.txt new file mode 100644 index 0000000000..6be7107559 --- /dev/null +++ b/tests/resources/text_sequences/resume1.txt @@ -0,0 +1,85 @@ +Nate Diaz +_____________________________________________________________________ +         5151 Morada Lane Stockton, CA | (555) 445-6487| lkji8765@ziprecruiter.com + +Summary +____________________________________________________________________________ + +Reliable, service-focused nursing professional with excellent patient-care and charting skills gained through five years of experience as a CNA. Compassionate and technically skilled in attending to patients in diverse healthcare settings. BLS and CPR certified (current). + +Highlights +____________________________________________________________________________ + +· Patient Care & Safety +· Diagnostic Testing +· Customer loyalty +· Medical Terminology +· Privacy / HIPAA Regulations +Experience +____________________________________________________________________________ +ABC STAFFING COMPANY +Certified Nursing Assistant 2012- Present +Stockton, CA + +· Provide high-quality patient care as an in-demand per-diem CNA within surgical, acute-care, rehabilitation, home-healthcare and nursing-home settings. +· Preserve patient dignity and minimize discomfort while carrying out duties such as bedpan changes, diapering, emptying drainage bags and bathing. +· Commended for chart accuracy, effective team collaboration, patient relations and consistent delivery of empathetic care. + +DEF REHABILITATION COMPANY 2010-2012 +Certified Nursing Assistant Stockton, CA + +· Displayed strong clinical skills in assessing vital signs, performing lab draws and glucose checks, and providing pre- and post-operative care. +· Adhered to safety guidelines; completed hospital’s three-hour Patient Safety Training Program. +· Ensured the accurate, timely flow of information by maintaining thorough patient records and updating healthcare team on patients’ status. +· Complied with HIPAA standards in all patient documentation and interactions. + +Education + +____________________________________________________________________________ + +Nurse’s Aid Program- Completed 60 hours of clinical training +Arizona State University +Tempe, AZ + +Nate Diaz +_____________________________________________________________________ +         5151 Morada Lane Stockton, CA | (555) 445-6487| lkji8765@ziprecruiter.com + +Summary +____________________________________________________________________________ + +Reliable, service-focused nursing professional with excellent patient-care and charting skills gained through five years of experience as a CNA. Compassionate and technically skilled in attending to patients in diverse healthcare settings. BLS and CPR certified (current). + +Highlights +____________________________________________________________________________ + +· Patient Care & Safety +· Diagnostic Testing +· Customer loyalty +· Medical Terminology +· Privacy / HIPAA Regulations +Experience +____________________________________________________________________________ +ABC STAFFING COMPANY +Certified Nursing Assistant 2012- Present +Stockton, CA + +· Provide high-quality patient care as an in-demand per-diem CNA within surgical, acute-care, rehabilitation, home-healthcare and nursing-home settings. +· Preserve patient dignity and minimize discomfort while carrying out duties such as bedpan changes, diapering, emptying drainage bags and bathing. +· Commended for chart accuracy, effective team collaboration, patient relations and consistent delivery of empathetic care. + +DEF REHABILITATION COMPANY 2010-2012 +Certified Nursing Assistant Stockton, CA + +· Displayed strong clinical skills in assessing vital signs, performing lab draws and glucose checks, and providing pre- and post-operative care. +· Adhered to safety guidelines; completed hospital’s three-hour Patient Safety Training Program. +· Ensured the accurate, timely flow of information by maintaining thorough patient records and updating healthcare team on patients’ status. +· Complied with HIPAA standards in all patient documentation and interactions. + +Education + +____________________________________________________________________________ + +Nurse’s Aid Program- Completed 60 hours of clinical training +Arizona State University +Tempe, AZ \ No newline at end of file diff --git a/tests/test_sentence_labeling.py b/tests/test_sentence_labeling.py new file mode 100644 index 0000000000..810ae21038 --- /dev/null +++ b/tests/test_sentence_labeling.py @@ -0,0 +1,191 @@ +from typing import Dict, List + +import pytest + +from flair.data import Sentence +from flair.training_utils import CharEntity, create_sentence_chunks + + +@pytest.fixture(params=["resume1.txt"]) +def resume(request, resources_path) -> str: + filepath = resources_path / "text_sequences" / request.param + with open(filepath, encoding="utf8") as file: + text_content = file.read() + return text_content + + +@pytest.fixture +def parsed_resume_dict(resume) -> dict: + return { + "raw_text": resume, + "entities": [ + CharEntity(20, 40, "dummy_label1", "Dummy Text 1"), + CharEntity(250, 300, "dummy_label2", "Dummy Text 2"), + CharEntity(700, 810, "dummy_label3", "Dummy Text 3"), + CharEntity(3900, 4000, "dummy_label4", "Dummy Text 4"), + ], + } + + +@pytest.fixture +def small_token_limit_resume() -> Dict: + return { + "raw_text": "Professional Clown June 2020 - August 2021 Entertaining students of all ages. Blah Blah Blah " + "Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah " + "Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah " + "Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah " + "Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Gained " + "proficiency in juggling and scaring children.", + "entities": [ + CharEntity(0, 18, "EXPERIENCE.TITLE", ""), + CharEntity(19, 29, "DATE.START_DATE", ""), + CharEntity(31, 42, "DATE.END_DATE", ""), + CharEntity(450, 510, "EXPERIENCE.DESCRIPTION", ""), + ], + } + + +@pytest.fixture +def small_token_limit_response() -> List[Sentence]: + """Recreates expected response Sentences.""" + chunk0 = Sentence("Professional Clown June 2020 - August 2021 Entertaining students of") + chunk0[0:2].add_label("Professional Clown", "EXPERIENCE.TITLE") + chunk0[2:4].add_label("June 2020", "DATE.START_DATE") + chunk0[5:7].add_label("August 2021", "DATE.END_DATE") + + chunk1 = Sentence("Blah Blah Blah Blah Blah Blah Blah Bl") + + chunk2 = Sentence("ah Blah Gained proficiency in juggling and scaring children .") + chunk2[0:10].add_label("ah Blah Gained proficiency in juggling and scaring children .", "EXPERIENCE.DESCRIPTION") + + return [chunk0, chunk1, chunk2] + + +class TestChunking: + def test_empty_string(self): + sentences = create_sentence_chunks("", []) + assert len(sentences) == 0 + + def check_split_entities(self, entity_labels, chunks, max_token_limit): + """Ensure that no entities are split over chunks (except entities longer than the token limit).""" + chunk_intervals = [] + start_index = 0 + for chunk in chunks: + end_index = start_index + len(chunk.text) + chunk_intervals.append((start_index, end_index)) + start_index = end_index + + for entity in entity_labels: + entity_start, entity_end = entity.start_char_idx, entity.end_char_idx + entity_length = entity_end - entity_start + + # Skip the check if the entity itself is longer than the maximum token limit + if entity_length > max_token_limit: + continue + + assert any( + start <= entity_start and entity_end <= end for start, end in chunk_intervals + ), f"Entity {entity} is not within a single chunk interval" + + @pytest.mark.parametrize( + "test_text, expected_text", + [ + ("test text", "test text"), + ("a", "a"), + ("this ", "this"), + ], + ) + def test_short_text(self, test_text, expected_text): + """Short texts that should fit nicely into a single chunk.""" + chunks = create_sentence_chunks(test_text, []) + assert chunks[0].text == expected_text + + def test_create_flair_sentence(self, parsed_resume_dict): + chunks = create_sentence_chunks(parsed_resume_dict["raw_text"], parsed_resume_dict["entities"]) + assert len(chunks) == 2 + + max_token_limit = 512 # default + assert all(len(c) <= max_token_limit for c in chunks) + + self.check_split_entities(parsed_resume_dict["entities"], chunks, max_token_limit) + + def test_small_token_limit(self, small_token_limit_resume, small_token_limit_response): + max_token_limit = 10 # test a small max token limit + chunks = create_sentence_chunks( + small_token_limit_resume["raw_text"], small_token_limit_resume["entities"], token_limit=max_token_limit + ) + + for response, expected in zip(chunks, small_token_limit_response): + assert response.to_tagged_string() == expected.to_tagged_string() + + assert all(len(c) <= max_token_limit for c in chunks) + + self.check_split_entities(small_token_limit_resume["entities"], chunks, max_token_limit) + + @pytest.mark.parametrize( + "test_text, entities, expected_chunks", + [ + ( + "Led a team of five engineers. It's important to note the project's success. We've implemented state-of-the-art technologies. Co-ordinated efforts with cross-functional teams.", + [ + CharEntity(0, 25, "RESPONSIBILITY", "Led a team of five engineers"), + CharEntity(27, 72, "ACHIEVEMENT", "It's important to note the project's success"), + CharEntity(74, 117, "ACHIEVEMENT", "We've implemented state-of-the-art technologies"), + CharEntity(119, 168, "RESPONSIBILITY", "Co-ordinated efforts with cross-functional teams"), + ], + [ + "Led a team of five engine er s. It 's important to note the project 's succe ss", + ". We 've implemented state-of-the-art techno lo gies . Co-ordinated efforts with cross-functional teams .", + ], + ), + ], + ) + def test_contractions_and_hyphens(self, test_text, entities, expected_chunks): + max_token_limit = 20 + chunks = create_sentence_chunks(test_text, entities, max_token_limit) + for i, chunk in enumerate(expected_chunks): + assert chunks[i].text == chunk + self.check_split_entities(entities, chunks, max_token_limit) + + @pytest.mark.parametrize( + "test_text, entities", + [ + ( + "This is a long text. " * 100, + [CharEntity(0, 1000, "dummy_label1", "Dummy Text 1")], + ) + ], + ) + def test_long_text(self, test_text, entities): + """Test for handling long texts that should be split into multiple chunks.""" + max_token_limit = 512 + chunks = create_sentence_chunks(test_text, entities, max_token_limit) + assert len(chunks) > 1 + assert all(len(c) <= max_token_limit for c in chunks) + self.check_split_entities(entities, chunks, max_token_limit) + + @pytest.mark.parametrize( + "test_text, entities, expected_chunks", + [ + ( + "Hello! Is your company hiring? I am available for employment. Contact me at 5:00 p.m.", + [ + CharEntity(0, 6, "LABEL", "Hello!"), + CharEntity(7, 31, "LABEL", "Is your company hiring?"), + CharEntity(32, 65, "LABEL", "I am available for employment."), + CharEntity(66, 86, "LABEL", "Contact me at 5:00 p.m."), + ], + [ + "Hello ! Is your company hiring ? I", + "am available for employment . Con t", + "act me at 5:00 p.m .", + ], + ) + ], + ) + def test_text_with_punctuation(self, test_text, entities, expected_chunks): + max_token_limit = 10 + chunks = create_sentence_chunks(test_text, entities, max_token_limit) + for i, chunk in enumerate(expected_chunks): + assert chunks[i].text == chunk + self.check_split_entities(entities, chunks, max_token_limit) From a711cb6ac98e8b12b61df99d94b049924d3e72bf Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Fri, 22 Nov 2024 11:53:14 -0800 Subject: [PATCH 2/3] remove chunking logic to have simple sentence labeler. fix tests. --- flair/data.py | 2 +- flair/nn/model.py | 2 +- flair/training_utils.py | 82 +++++----------- tests/test_sentence_labeling.py | 168 +++++++++++++++++--------------- 4 files changed, 117 insertions(+), 137 deletions(-) diff --git a/flair/data.py b/flair/data.py index 56622b249c..92de6a0be2 100644 --- a/flair/data.py +++ b/flair/data.py @@ -540,7 +540,7 @@ def __init__( head_id: Optional[int] = None, whitespace_after: int = 1, start_position: int = 0, - sentence=None, + sentence: Optional["Sentence"] = None, ) -> None: super().__init__(sentence=sentence) diff --git a/flair/nn/model.py b/flair/nn/model.py index f670c969a0..a2a9cec98f 100644 --- a/flair/nn/model.py +++ b/flair/nn/model.py @@ -982,7 +982,7 @@ def _get_state_dict(self): state["locked_dropout"] = self.locked_dropout.dropout_rate state["multi_label"] = self.multi_label state["multi_label_threshold"] = self.multi_label_threshold - state["loss_weights"] = self.loss_weights + state["loss_weights"] = self.weight_dict state["train_on_gold_pairs_only"] = self.train_on_gold_pairs_only state["inverse_model"] = self.inverse_model if self._custom_decoder: diff --git a/flair/training_utils.py b/flair/training_utils.py index 69430a17b3..ea5b576f0c 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -6,7 +6,7 @@ from functools import reduce from math import inf from pathlib import Path -from typing import Dict, List, Literal, NamedTuple, Optional, Union +from typing import Literal, NamedTuple, Optional, Union from numpy import ndarray from scipy.stats import pearsonr, spearmanr @@ -102,7 +102,7 @@ def __init__(self, directory: Union[str, Path], number_of_weights: int = 10) -> self.weights_dict: dict[str, dict[int, list[float]]] = defaultdict(lambda: defaultdict(list)) self.number_of_weights = number_of_weights - def extract_weights(self, state_dict: Dict, iteration: int) -> None: + def extract_weights(self, state_dict: dict, iteration: int) -> None: for key in state_dict: vec = state_dict[key] try: @@ -215,7 +215,7 @@ def __init__( raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") self.optimizer = optimizer - self.min_lrs: List[float] + self.min_lrs: list[float] if isinstance(min_lr, (list, tuple)): if len(min_lr) != len(optimizer.param_groups): raise ValueError(f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}") @@ -315,10 +315,10 @@ def _init_is_better(self, mode: MinMax) -> None: self.mode = mode - def state_dict(self) -> Dict: + def state_dict(self) -> dict: return {key: value for key, value in self.__dict__.items() if key != "optimizer"} - def load_state_dict(self, state_dict: Dict) -> None: + def load_state_dict(self, state_dict: dict) -> None: self.__dict__.update(state_dict) self._init_is_better(mode=self.mode) @@ -369,7 +369,7 @@ def add_file_handler(log: logging.Logger, output_file: pathlib.Path) -> logging. def store_embeddings( data_points: Union[list[DT], Dataset], storage_mode: EmbeddingStorageMode, - dynamic_embeddings: Optional[List[str]] = None, + dynamic_embeddings: Optional[list[str]] = None, ) -> None: """Stores embeddings of data points in memory or on disk. @@ -401,7 +401,7 @@ def store_embeddings( data_point.to("cpu", pin_memory=pin_memory) -def identify_dynamic_embeddings(data_points: List[DT]) -> Optional[List[str]]: +def identify_dynamic_embeddings(data_points: list[DT]) -> Optional[list[str]]: dynamic_embeddings = [] all_embeddings = [] for data_point in data_points: @@ -444,7 +444,7 @@ class CharEntity(NamedTuple): def create_labeled_sentence_from_tokens( - tokens: Union[List[Token]], token_entities: List[TokenEntity], type_name: str = "ner" + tokens: Union[list[Token]], token_entities: list[TokenEntity], type_name: str = "ner" ) -> Sentence: """Creates a new Sentence object from a list of tokens or strings and applies entity labels. @@ -457,20 +457,18 @@ def create_labeled_sentence_from_tokens( Returns: A labeled Sentence object """ - tokens = [Token(token.text) for token in tokens] # create new tokens that do not already belong to a sentence - sentence = Sentence(tokens, use_tokenizer=True) + tokens_ = [token.text for token in tokens] # create new tokens that do not already belong to a sentence + sentence = Sentence(tokens_, use_tokenizer=True) for entity in token_entities: sentence[entity.start_token_idx : entity.end_token_idx].add_label(type_name, entity.label, score=entity.score) return sentence -def create_sentence_chunks( +def create_labeled_sentence( text: str, - entities: List[CharEntity], - token_limit: int = 512, - use_context: bool = True, - overlap: int = 0, # TODO: implement overlap -) -> List[Sentence]: + entities: list[CharEntity], + token_limit: float = inf, +) -> Sentence: """Chunks and labels a text from a list of entity annotations. The function explicitly tokenizes the text and labels separately, ensuring entity labels are @@ -481,48 +479,25 @@ def create_sentence_chunks( entities (list of tuples): Ordered non-overlapping entity annotations with each tuple in the format (start_char_index, end_char_index, entity_class, entity_text). token_limit: numerical value that determines the maximum size of a chunk. use inf to not perform chunking - use_context: whether to add context to the sentence - overlap: the size of overlap between chunks, repeating the last n tokens of previous chunk to preserve context Returns: A list of labeled Sentence objects representing the chunks of the original text """ - chunks = [] - - tokens: List[Token] = [] + tokens: list[Token] = [] current_index = 0 - token_entities: List[TokenEntity] = [] - end_token_idx = 0 + token_entities: list[TokenEntity] = [] for entity in entities: - - if entity.start_char_idx > current_index: # add non-entity text - non_entity_tokens = Sentence(text[current_index : entity.start_char_idx]).tokens - while end_token_idx + len(non_entity_tokens) > token_limit: - num_tokens = token_limit - len(tokens) - tokens.extend(non_entity_tokens[:num_tokens]) - non_entity_tokens = non_entity_tokens[num_tokens:] - # skip any fully negative samples, they cause fine_tune to fail with - # `torch.cat(): expected a non-empty list of Tensors` - if len(token_entities) > 0: - chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) - tokens, token_entities = [], [] - end_token_idx = 0 - tokens.extend(non_entity_tokens) + if current_index < entity.start_char_idx: + # add tokens before the entity + sentence = Sentence(text[current_index : entity.start_char_idx]) + tokens.extend(sentence) # add new entity tokens start_token_idx = len(tokens) entity_sentence = Sentence(text[entity.start_char_idx : entity.end_char_idx]) - if len(entity_sentence) > token_limit: - logger.warning(f"Entity length is greater than token limit! {len(entity_sentence)} > {token_limit}") end_token_idx = start_token_idx + len(entity_sentence) - if end_token_idx >= token_limit: # create chunk from existing and add this entity to next chunk - chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) - - tokens, token_entities = [], [] - start_token_idx, end_token_idx = 0, len(entity_sentence) - token_entity = TokenEntity(start_token_idx, end_token_idx, entity.label, entity.value, entity.score) token_entities.append(token_entity) tokens.extend(entity_sentence) @@ -532,19 +507,10 @@ def create_sentence_chunks( # add any remaining tokens to a new chunk if current_index < len(text): remaining_sentence = Sentence(text[current_index:]) - if end_token_idx + len(remaining_sentence) > token_limit: - chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) - tokens, token_entities = [], [] tokens.extend(remaining_sentence) - if tokens: - chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) - - for chunk in chunks: - if len(chunk) > token_limit: - logger.warning(f"Chunk size is longer than token limit: {len(chunk)} > {token_limit}") - - if use_context: - Sentence.set_context_for_sentences(chunks) + if isinstance(token_limit, int) and token_limit < len(tokens): + tokens = tokens[:token_limit] + token_entities = [entity for entity in token_entities if entity.end_token_idx <= token_limit] - return chunks + return create_labeled_sentence_from_tokens(tokens, token_entities) diff --git a/tests/test_sentence_labeling.py b/tests/test_sentence_labeling.py index 810ae21038..56742da4df 100644 --- a/tests/test_sentence_labeling.py +++ b/tests/test_sentence_labeling.py @@ -1,9 +1,9 @@ -from typing import Dict, List +from typing import cast import pytest from flair.data import Sentence -from flair.training_utils import CharEntity, create_sentence_chunks +from flair.training_utils import CharEntity, TokenEntity, create_labeled_sentence @pytest.fixture(params=["resume1.txt"]) @@ -28,7 +28,7 @@ def parsed_resume_dict(resume) -> dict: @pytest.fixture -def small_token_limit_resume() -> Dict: +def small_token_limit_resume() -> dict: return { "raw_text": "Professional Clown June 2020 - August 2021 Entertaining students of all ages. Blah Blah Blah " "Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah Blah " @@ -46,7 +46,7 @@ def small_token_limit_resume() -> Dict: @pytest.fixture -def small_token_limit_response() -> List[Sentence]: +def small_token_limit_response() -> list[Sentence]: """Recreates expected response Sentences.""" chunk0 = Sentence("Professional Clown June 2020 - August 2021 Entertaining students of") chunk0[0:2].add_label("Professional Clown", "EXPERIENCE.TITLE") @@ -63,28 +63,32 @@ def small_token_limit_response() -> List[Sentence]: class TestChunking: def test_empty_string(self): - sentences = create_sentence_chunks("", []) + sentences = create_labeled_sentence("", []) assert len(sentences) == 0 - def check_split_entities(self, entity_labels, chunks, max_token_limit): - """Ensure that no entities are split over chunks (except entities longer than the token limit).""" - chunk_intervals = [] - start_index = 0 - for chunk in chunks: - end_index = start_index + len(chunk.text) - chunk_intervals.append((start_index, end_index)) - start_index = end_index + def check_tokens(self, sentence: Sentence, expected_tokens: list[str]): + assert len(sentence.tokens) == len(expected_tokens) + assert [token.text for token in sentence.tokens] == expected_tokens + for token, expected_token in zip(sentence.tokens, expected_tokens): + assert token.text == expected_token - for entity in entity_labels: - entity_start, entity_end = entity.start_char_idx, entity.end_char_idx - entity_length = entity_end - entity_start + def check_token_entities(self, sentence: Sentence, expected_labels: list[TokenEntity]): + assert len(sentence.labels) == len(expected_labels) + for label, expected_label in zip(sentence.labels, expected_labels): - # Skip the check if the entity itself is longer than the maximum token limit - if entity_length > max_token_limit: - continue + assert label.value == expected_label.label + span = cast(Sentence, label.data_point) + assert span.tokens[0]._internal_index is not None + assert span.tokens[0]._internal_index - 1 == expected_label.start_token_idx + assert span.tokens[-1]._internal_index is not None + assert span.tokens[-1]._internal_index - 1 == expected_label.end_token_idx - assert any( - start <= entity_start and entity_end <= end for start, end in chunk_intervals + def check_split_entities(self, entity_labels, sentence: Sentence): + """Ensure that no entities are split over chunks (except entities longer than the token limit).""" + for entity in entity_labels: + entity_start, entity_end = entity.start_char_idx, entity.end_char_idx + assert entity_start >= 0 and entity_end <= len( + sentence ), f"Entity {entity} is not within a single chunk interval" @pytest.mark.parametrize( @@ -95,57 +99,71 @@ def check_split_entities(self, entity_labels, chunks, max_token_limit): ("this ", "this"), ], ) - def test_short_text(self, test_text, expected_text): + def test_short_text(self, test_text: str, expected_text: str): """Short texts that should fit nicely into a single chunk.""" - chunks = create_sentence_chunks(test_text, []) - assert chunks[0].text == expected_text - - def test_create_flair_sentence(self, parsed_resume_dict): - chunks = create_sentence_chunks(parsed_resume_dict["raw_text"], parsed_resume_dict["entities"]) - assert len(chunks) == 2 + chunks = create_labeled_sentence(test_text, []) + assert chunks.text == expected_text - max_token_limit = 512 # default - assert all(len(c) <= max_token_limit for c in chunks) - - self.check_split_entities(parsed_resume_dict["entities"], chunks, max_token_limit) - - def test_small_token_limit(self, small_token_limit_resume, small_token_limit_response): - max_token_limit = 10 # test a small max token limit - chunks = create_sentence_chunks( - small_token_limit_resume["raw_text"], small_token_limit_resume["entities"], token_limit=max_token_limit - ) - - for response, expected in zip(chunks, small_token_limit_response): - assert response.to_tagged_string() == expected.to_tagged_string() - - assert all(len(c) <= max_token_limit for c in chunks) - - self.check_split_entities(small_token_limit_resume["entities"], chunks, max_token_limit) + def test_create_labeled_sentence(self, parsed_resume_dict: dict): + create_labeled_sentence(parsed_resume_dict["raw_text"], parsed_resume_dict["entities"]) @pytest.mark.parametrize( - "test_text, entities, expected_chunks", + "test_text, entities, expected_tokens, expected_labels", [ ( "Led a team of five engineers. It's important to note the project's success. We've implemented state-of-the-art technologies. Co-ordinated efforts with cross-functional teams.", [ - CharEntity(0, 25, "RESPONSIBILITY", "Led a team of five engineers"), - CharEntity(27, 72, "ACHIEVEMENT", "It's important to note the project's success"), - CharEntity(74, 117, "ACHIEVEMENT", "We've implemented state-of-the-art technologies"), - CharEntity(119, 168, "RESPONSIBILITY", "Co-ordinated efforts with cross-functional teams"), + CharEntity(0, 28, "RESPONSIBILITY", "Led a team of five engineers"), + CharEntity(30, 74, "ACHIEVEMENT", "It's important to note the project's success"), + CharEntity(76, 123, "ACHIEVEMENT", "We've implemented state-of-the-art technologies"), + CharEntity(125, 173, "RESPONSIBILITY", "Co-ordinated efforts with cross-functional teams"), + ], + [ + "Led", + "a", + "team", + "of", + "five", + "engineers", + ".", + "It", + "'s", + "important", + "to", + "note", + "the", + "project", + "'s", + "success", + ".", + "We", + "'ve", + "implemented", + "state-of-the-art", + "technologies", + ".", + "Co-ordinated", + "efforts", + "with", + "cross-functional", + "teams", + ".", ], [ - "Led a team of five engine er s. It 's important to note the project 's succe ss", - ". We 've implemented state-of-the-art techno lo gies . Co-ordinated efforts with cross-functional teams .", + TokenEntity(0, 5, "RESPONSIBILITY"), + TokenEntity(7, 15, "ACHIEVEMENT"), + TokenEntity(17, 21, "ACHIEVEMENT"), + TokenEntity(23, 27, "RESPONSIBILITY"), ], ), ], ) - def test_contractions_and_hyphens(self, test_text, entities, expected_chunks): - max_token_limit = 20 - chunks = create_sentence_chunks(test_text, entities, max_token_limit) - for i, chunk in enumerate(expected_chunks): - assert chunks[i].text == chunk - self.check_split_entities(entities, chunks, max_token_limit) + def test_contractions_and_hyphens( + self, test_text: str, entities: list[CharEntity], expected_tokens: list[str], expected_labels: list[TokenEntity] + ): + sentence = create_labeled_sentence(test_text, entities) + self.check_tokens(sentence, expected_tokens) + self.check_token_entities(sentence, expected_labels) @pytest.mark.parametrize( "test_text, entities", @@ -156,36 +174,32 @@ def test_contractions_and_hyphens(self, test_text, entities, expected_chunks): ) ], ) - def test_long_text(self, test_text, entities): + def test_long_text(self, test_text: str, entities: list[CharEntity]): """Test for handling long texts that should be split into multiple chunks.""" - max_token_limit = 512 - chunks = create_sentence_chunks(test_text, entities, max_token_limit) - assert len(chunks) > 1 - assert all(len(c) <= max_token_limit for c in chunks) - self.check_split_entities(entities, chunks, max_token_limit) + create_labeled_sentence(test_text, entities) @pytest.mark.parametrize( - "test_text, entities, expected_chunks", + "test_text, entities, expected_labels", [ ( "Hello! Is your company hiring? I am available for employment. Contact me at 5:00 p.m.", [ CharEntity(0, 6, "LABEL", "Hello!"), - CharEntity(7, 31, "LABEL", "Is your company hiring?"), - CharEntity(32, 65, "LABEL", "I am available for employment."), - CharEntity(66, 86, "LABEL", "Contact me at 5:00 p.m."), + CharEntity(7, 30, "LABEL", "Is your company hiring?"), + CharEntity(31, 61, "LABEL", "I am available for employment."), + CharEntity(62, 85, "LABEL", "Contact me at 5:00 p.m."), ], [ - "Hello ! Is your company hiring ? I", - "am available for employment . Con t", - "act me at 5:00 p.m .", + TokenEntity(0, 1, "LABEL"), + TokenEntity(2, 6, "LABEL"), + TokenEntity(7, 12, "LABEL"), + TokenEntity(13, 18, "LABEL"), ], ) ], ) - def test_text_with_punctuation(self, test_text, entities, expected_chunks): - max_token_limit = 10 - chunks = create_sentence_chunks(test_text, entities, max_token_limit) - for i, chunk in enumerate(expected_chunks): - assert chunks[i].text == chunk - self.check_split_entities(entities, chunks, max_token_limit) + def test_text_with_punctuation( + self, test_text: str, entities: list[CharEntity], expected_labels: list[TokenEntity] + ): + sentence = create_labeled_sentence(test_text, entities) + self.check_token_entities(sentence, expected_labels) From 58f19308b89ee615dd6d394de1e4d98afd185b91 Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Fri, 6 Dec 2024 23:39:42 -0800 Subject: [PATCH 3/3] fix: remove type hints from private module --- flair/training_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/flair/training_utils.py b/flair/training_utils.py index ea5b576f0c..ce15bdb6e5 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -10,7 +10,6 @@ from numpy import ndarray from scipy.stats import pearsonr, spearmanr -from scipy.stats._stats_py import PearsonRResult, SignificanceResult from sklearn.metrics import mean_absolute_error, mean_squared_error from torch.optim import Optimizer from torch.utils.data import Dataset @@ -61,10 +60,10 @@ def mean_squared_error(self) -> Union[float, ndarray]: def mean_absolute_error(self): return mean_absolute_error(self.true, self.pred) - def pearsonr(self) -> PearsonRResult: + def pearsonr(self): return pearsonr(self.true, self.pred)[0] - def spearmanr(self) -> SignificanceResult: + def spearmanr(self): return spearmanr(self.true, self.pred)[0] def to_tsv(self) -> str: