From c4b2a36b261a3a425ef59381aaee58be6d25593f Mon Sep 17 00:00:00 2001 From: Matt Buchovecky Date: Fri, 22 Nov 2024 11:53:14 -0800 Subject: [PATCH] remove chunking logic to have simple sentence labeler --- flair/training_utils.py | 48 +++------------------------------ tests/test_sentence_labeling.py | 6 ++--- 2 files changed, 6 insertions(+), 48 deletions(-) diff --git a/flair/training_utils.py b/flair/training_utils.py index 69430a17b3..6a3480b922 100644 --- a/flair/training_utils.py +++ b/flair/training_utils.py @@ -467,10 +467,8 @@ def create_labeled_sentence_from_tokens( def create_sentence_chunks( text: str, entities: List[CharEntity], - token_limit: int = 512, - use_context: bool = True, - overlap: int = 0, # TODO: implement overlap -) -> List[Sentence]: + token_limit: float = inf, +) -> Sentence: """Chunks and labels a text from a list of entity annotations. The function explicitly tokenizes the text and labels separately, ensuring entity labels are @@ -481,48 +479,21 @@ def create_sentence_chunks( entities (list of tuples): Ordered non-overlapping entity annotations with each tuple in the format (start_char_index, end_char_index, entity_class, entity_text). token_limit: numerical value that determines the maximum size of a chunk. use inf to not perform chunking - use_context: whether to add context to the sentence - overlap: the size of overlap between chunks, repeating the last n tokens of previous chunk to preserve context Returns: A list of labeled Sentence objects representing the chunks of the original text """ - chunks = [] - tokens: List[Token] = [] current_index = 0 token_entities: List[TokenEntity] = [] - end_token_idx = 0 for entity in entities: - if entity.start_char_idx > current_index: # add non-entity text - non_entity_tokens = Sentence(text[current_index : entity.start_char_idx]).tokens - while end_token_idx + len(non_entity_tokens) > token_limit: - num_tokens = token_limit - len(tokens) - tokens.extend(non_entity_tokens[:num_tokens]) - non_entity_tokens = non_entity_tokens[num_tokens:] - # skip any fully negative samples, they cause fine_tune to fail with - # `torch.cat(): expected a non-empty list of Tensors` - if len(token_entities) > 0: - chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) - tokens, token_entities = [], [] - end_token_idx = 0 - tokens.extend(non_entity_tokens) - # add new entity tokens start_token_idx = len(tokens) entity_sentence = Sentence(text[entity.start_char_idx : entity.end_char_idx]) - if len(entity_sentence) > token_limit: - logger.warning(f"Entity length is greater than token limit! {len(entity_sentence)} > {token_limit}") end_token_idx = start_token_idx + len(entity_sentence) - if end_token_idx >= token_limit: # create chunk from existing and add this entity to next chunk - chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) - - tokens, token_entities = [], [] - start_token_idx, end_token_idx = 0, len(entity_sentence) - token_entity = TokenEntity(start_token_idx, end_token_idx, entity.label, entity.value, entity.score) token_entities.append(token_entity) tokens.extend(entity_sentence) @@ -532,19 +503,6 @@ def create_sentence_chunks( # add any remaining tokens to a new chunk if current_index < len(text): remaining_sentence = Sentence(text[current_index:]) - if end_token_idx + len(remaining_sentence) > token_limit: - chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) - tokens, token_entities = [], [] tokens.extend(remaining_sentence) - if tokens: - chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities)) - - for chunk in chunks: - if len(chunk) > token_limit: - logger.warning(f"Chunk size is longer than token limit: {len(chunk)} > {token_limit}") - - if use_context: - Sentence.set_context_for_sentences(chunks) - - return chunks + return create_labeled_sentence_from_tokens(tokens, token_entities) diff --git a/tests/test_sentence_labeling.py b/tests/test_sentence_labeling.py index 810ae21038..ea238f42c3 100644 --- a/tests/test_sentence_labeling.py +++ b/tests/test_sentence_labeling.py @@ -171,9 +171,9 @@ def test_long_text(self, test_text, entities): "Hello! Is your company hiring? I am available for employment. Contact me at 5:00 p.m.", [ CharEntity(0, 6, "LABEL", "Hello!"), - CharEntity(7, 31, "LABEL", "Is your company hiring?"), - CharEntity(32, 65, "LABEL", "I am available for employment."), - CharEntity(66, 86, "LABEL", "Contact me at 5:00 p.m."), + CharEntity(7, 30, "LABEL", "Is your company hiring?"), + CharEntity(31, 61, "LABEL", "I am available for employment."), + CharEntity(62, 85, "LABEL", "Contact me at 5:00 p.m."), ], [ "Hello ! Is your company hiring ? I",