Skip to content

Commit

Permalink
remove chunking logic to have simple sentence labeler
Browse files Browse the repository at this point in the history
  • Loading branch information
MattGPT-ai committed Nov 22, 2024
1 parent 24cd023 commit c4b2a36
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 48 deletions.
48 changes: 3 additions & 45 deletions flair/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,10 +467,8 @@ def create_labeled_sentence_from_tokens(
def create_sentence_chunks(
text: str,
entities: List[CharEntity],
token_limit: int = 512,
use_context: bool = True,
overlap: int = 0, # TODO: implement overlap
) -> List[Sentence]:
token_limit: float = inf,
) -> Sentence:
"""Chunks and labels a text from a list of entity annotations.
The function explicitly tokenizes the text and labels separately, ensuring entity labels are
Expand All @@ -481,48 +479,21 @@ def create_sentence_chunks(
entities (list of tuples): Ordered non-overlapping entity annotations with each tuple in the
format (start_char_index, end_char_index, entity_class, entity_text).
token_limit: numerical value that determines the maximum size of a chunk. use inf to not perform chunking
use_context: whether to add context to the sentence
overlap: the size of overlap between chunks, repeating the last n tokens of previous chunk to preserve context
Returns:
A list of labeled Sentence objects representing the chunks of the original text
"""
chunks = []

tokens: List[Token] = []
current_index = 0
token_entities: List[TokenEntity] = []
end_token_idx = 0

for entity in entities:

if entity.start_char_idx > current_index: # add non-entity text
non_entity_tokens = Sentence(text[current_index : entity.start_char_idx]).tokens
while end_token_idx + len(non_entity_tokens) > token_limit:
num_tokens = token_limit - len(tokens)
tokens.extend(non_entity_tokens[:num_tokens])
non_entity_tokens = non_entity_tokens[num_tokens:]
# skip any fully negative samples, they cause fine_tune to fail with
# `torch.cat(): expected a non-empty list of Tensors`
if len(token_entities) > 0:
chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities))
tokens, token_entities = [], []
end_token_idx = 0
tokens.extend(non_entity_tokens)

# add new entity tokens
start_token_idx = len(tokens)
entity_sentence = Sentence(text[entity.start_char_idx : entity.end_char_idx])
if len(entity_sentence) > token_limit:
logger.warning(f"Entity length is greater than token limit! {len(entity_sentence)} > {token_limit}")
end_token_idx = start_token_idx + len(entity_sentence)

if end_token_idx >= token_limit: # create chunk from existing and add this entity to next chunk
chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities))

tokens, token_entities = [], []
start_token_idx, end_token_idx = 0, len(entity_sentence)

token_entity = TokenEntity(start_token_idx, end_token_idx, entity.label, entity.value, entity.score)
token_entities.append(token_entity)
tokens.extend(entity_sentence)
Expand All @@ -532,19 +503,6 @@ def create_sentence_chunks(
# add any remaining tokens to a new chunk
if current_index < len(text):
remaining_sentence = Sentence(text[current_index:])
if end_token_idx + len(remaining_sentence) > token_limit:
chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities))
tokens, token_entities = [], []
tokens.extend(remaining_sentence)

if tokens:
chunks.append(create_labeled_sentence_from_tokens(tokens, token_entities))

for chunk in chunks:
if len(chunk) > token_limit:
logger.warning(f"Chunk size is longer than token limit: {len(chunk)} > {token_limit}")

if use_context:
Sentence.set_context_for_sentences(chunks)

return chunks
return create_labeled_sentence_from_tokens(tokens, token_entities)
6 changes: 3 additions & 3 deletions tests/test_sentence_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ def test_long_text(self, test_text, entities):
"Hello! Is your company hiring? I am available for employment. Contact me at 5:00 p.m.",
[
CharEntity(0, 6, "LABEL", "Hello!"),
CharEntity(7, 31, "LABEL", "Is your company hiring?"),
CharEntity(32, 65, "LABEL", "I am available for employment."),
CharEntity(66, 86, "LABEL", "Contact me at 5:00 p.m."),
CharEntity(7, 30, "LABEL", "Is your company hiring?"),
CharEntity(31, 61, "LABEL", "I am available for employment."),
CharEntity(62, 85, "LABEL", "Contact me at 5:00 p.m."),
],
[
"Hello ! Is your company hiring ? I",
Expand Down

0 comments on commit c4b2a36

Please sign in to comment.