diff --git a/README.md b/README.md
index 77a85d9..be7c567 100644
--- a/README.md
+++ b/README.md
@@ -1,68 +1,20 @@
-
-
-
-
-
+# iaraugen
+Data augmentation/generation utilities for Iara.
-# Title
+Can be used standalone (instructions below) or as part of other programs.
-Project description goes here. This description is usually two to three lines long. It should give an overview of what the project is, eg technology used, philosophy of existence, what problem it is trying to solve, etc. If you need to write more than 3 lines of description, create subsections.
-
-> **NOTICE:** put here a message that is very relevant to users of the project, if any.
-
-## ✨Features
-
-Here you can place screenshots of the project. Also describe your features using a list:
-
-* ✔️ Easy integration;
-* 🥢 Few dependencies;
-* 🎨 Beautiful template with a nice `README`;
-* 🖖 Great documentation and testing?
-
-## 🚀 Getting started
-
-### 1. First step to get started
-
-Usually the first step to get started is to install dependencies to run the project. Run:
-
-```
-apt get install dependency
+## Offline text augmentation
```
+help: ./txt_aug.py -h
-It is recommended to place each command on a different line:
-
+example usage:
+./txt_aug.py corpus_1br_10pt_15sept.tok --aug translate --maxs 10 --lang en --translate_mode local --append --output out.txt
```
-apt get install something else
-```
-
-This way users can copy and paste without reading the documentation (which is what usually happens).
-
-### 2. Other step(s)
-Usually the next steps teach you how to install and configure the project for use / development. Run:
-
-```
-git clone https://github.com/iarahealth/template template
+## Offline text generation
```
+help: ./txt_gen.py -h
-## 🤝 Contribute
-
-Your help is most welcome regardless of form! Check out the [CONTRIBUTING.md](CONTRIBUTING.md) file for all ways you can contribute to the project. For example, [suggest a new feature](https://github.com/iarahealth/template/issues/new?assignees=&labels=&title=), [report a problem/bug](https://github.com/iarahealth/template/issues/new?assignees=&labels=bug&title=), [submit a pull request](https://help.github.com/en/github/collaborating-with-issues-and-pull-requests/about-pull-requests), or simply use the project and comment your experience. You are encourage to participate as much as possible, but stay tuned to the [code of conduct](./CODE_OF_CONDUCT.md) before making any interaction with other community members.
-
-See the [ROADMAP.md](ROADMAP.md) file for an idea of how the project should evolve.
-
-## 🎫 License
-
-This project is proprietary and confidential. Unauthorized copying of any file in this repository, via any medium is strictly prohibited. Contact [legal@iarahealth.com](mailto:legal@iarahealth.com) for inquiries or reports.
-
-## 🧬 Changelog
-
-See all changes to this project in the [CHANGELOG.md](CHANGELOG.md) file.
-
-## 🧪 Similar projects
-
-Below is a list of interesting links and similar projects:
-
-* [Other project](https://github.com/project)
-* [Project inspiration](https://github.com/project)
-* [Similar tool](https://github.com/project)
+example usage:
+./txt_gen.py --input_file palavras.txt --context "radiologia médica" --num 2 --return_type "frases" --api_key "YOUR_OPENAI_API_KEY" --output query.txt
+```
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/requirements.txt b/requirements.txt
new file mode 100755
index 0000000..effe294
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,10 @@
+torch
+deep-translator
+transformers
+sentencepiece
+tenacity
+openai
+nlpaug
+nltk
+num2words
+tqdm
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..83c4342
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,23 @@
+from setuptools import setup, find_packages
+
+authors = ["Pedro Probst", "Bernardo Henz"]
+
+setup(
+ name="iaraugen",
+ version="1.0.0",
+ author=", ".join(authors),
+ description="Data augmentation/generation functions used at Iara Health (speech-to-text).",
+ packages=find_packages(),
+ install_requires=[
+ "torch",
+ "deep-translator",
+ "transformers",
+ "sentencepiece",
+ "tenacity",
+ "openai",
+ "nlpaug",
+ "nltk",
+ "num2words",
+ "tqdm",
+ ],
+)
diff --git a/txt_aug.py b/txt_aug.py
new file mode 100755
index 0000000..252f800
--- /dev/null
+++ b/txt_aug.py
@@ -0,0 +1,236 @@
+#!/usr/bin/env python3
+import argparse
+import random
+import torch
+from typing import List
+from deep_translator import GoogleTranslator
+from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
+from tqdm import tqdm
+from util.text_augmenter import SentenceAugmenter
+from util.text import (
+ post_process_sentences,
+ print_sentences_comparison,
+ remove_equal_sentences,
+)
+from util.files import read_sentences_corpus, append_sentences_to_file
+
+"""
+example usage:
+./txt_aug.py corpus.tok --aug translate random --action delete --maxs 10 --lang en --append
+"""
+
+
+def backtranslate_sentences_api(
+ sentences: List[str], source_lang: str, target_lang: str
+) -> List[str]:
+ """
+ Backtranslates a list of sentences from the source language to the target language
+ using the Google Translator API.
+
+ Args:
+ sentences (List[str]): The list of sentences to be backtranslated.
+ source_lang (str): The source language code (e.g., "pt" for Portuguese).
+ target_lang (str): The target language code (e.g., "en" for English).
+
+ Returns:
+ List[str]: A list of backtranslated sentences.
+ """
+ translator = GoogleTranslator(source=source_lang, target=target_lang)
+ translations = translator.translate_batch(sentences)
+ backtranslator = GoogleTranslator(source=target_lang, target=source_lang)
+ backtranslations = backtranslator.translate_batch(translations)
+
+ return backtranslations
+
+
+def backtranslate_sentences_local(
+ sentences: List[str], source_lang: str, target_lang: str, device: str = "cpu"
+) -> List[str]:
+ """
+ Backtranslates a list of sentences from the source language to the target language,
+ and then back to the source language using a local model.
+
+ Args:
+ sentences (List[str]): The list of sentences to be backtranslated.
+ source_lang (str): The source language code (e.g., "pt" for Portuguese).
+ target_lang (str): The target language code (e.g., "en" for English).
+ device (str): The device to run the model on (e.g., "cpu" or "cuda").
+
+ Returns:
+ List[str]: A list of backtranslated sentences.
+
+ Note:
+ nlpaug has a backtranslation module, but it only officially supports Helsinki-NLP,
+ but we do not have a Helsinki model for Portuguese -> English. So we use the T5 model
+ directly from HuggingFace.
+ """
+ tokenizer = AutoTokenizer.from_pretrained(
+ f"unicamp-dl/translation-{source_lang}-{target_lang}-t5"
+ )
+ model = AutoModelForSeq2SeqLM.from_pretrained(
+ f"unicamp-dl/translation-{source_lang}-{target_lang}-t5"
+ )
+ model.to(torch.device(device))
+ backtokenizer = AutoTokenizer.from_pretrained(
+ f"unicamp-dl/translation-{target_lang}-{source_lang}-t5"
+ )
+ backmodel = AutoModelForSeq2SeqLM.from_pretrained(
+ f"unicamp-dl/translation-{target_lang}-{source_lang}-t5"
+ )
+ backmodel.to(torch.device(device))
+ pten_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
+ enpt_pipeline = pipeline(
+ "text2text-generation", model=backmodel, tokenizer=backtokenizer
+ )
+
+ print(f"Backtranslating {len(sentences)} sentences...")
+ translations: List[str] = []
+ for sentence in tqdm(sentences):
+ transl = pten_pipeline("translate Portuguese to English: " + sentence)[0][
+ "generated_text"
+ ]
+ backtransl = enpt_pipeline("translate English to Portuguese: " + transl)[0][
+ "generated_text"
+ ]
+ translations.append(backtransl)
+
+ return translations
+
+
+def translation_pipeline(
+ sentences: List[str], translate_mode: str, lang: str, device: str
+) -> List[str]:
+ """
+ Runs the translation pipeline to backtranslate a list of sentences.
+
+ Args:
+ sentences (List[str]): The list of sentences to be translated.
+ translate_mode (str): Use local model or API to translate.
+ lang (str): The target language code (e.g., "en" for English).
+ device (str): The device to run the model on (e.g., "cpu" or "cuda").
+
+ Returns:
+ List[str]: A list of translated sentences.
+ """
+ augmented_sentences: List[str] = []
+ print(f"Backtranslating sentences pt->{lang}->pt...")
+ if translate_mode == "local":
+ augmented_sentences = backtranslate_sentences_local(
+ sentences, "pt", lang, device
+ )
+ elif translate_mode == "google":
+ augmented_sentences = backtranslate_sentences_api(sentences, "pt", lang)
+ assert len(augmented_sentences)
+ return augmented_sentences
+
+
+def create_augmentation_sequence(
+ augmentations: List[str], action: str, translate_mode: str, lang: str, device: str
+) -> List[callable]:
+ augmentation_sequence = []
+ for aug in augmentations:
+ if aug == "random" or aug == "synonym":
+ augmenter = SentenceAugmenter(aug, action=action)
+ augmentation_sequence.append(lambda x: augmenter.augment_sentences(x))
+ elif aug == "translate":
+ augmentation_sequence.append(
+ lambda x: translation_pipeline(x, translate_mode, lang, device)
+ )
+ return augmentation_sequence
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Sentence augmentation: back-translate, random delete, random swap, synonym replacement."
+ )
+ parser.add_argument("corpus", type=str, help="Input corpus file")
+ parser.add_argument(
+ "--aug",
+ nargs="+",
+ type=str,
+ required=True,
+ choices=["random", "translate", "synonym"],
+ help="Augmentation type to perform",
+ )
+ parser.add_argument(
+ "--action",
+ type=str,
+ choices=["delete", "swap"],
+ default="delete",
+ help="Action to perform",
+ )
+ parser.add_argument(
+ "--maxs",
+ type=str,
+ default="10",
+ help="Maximum number of sentences to process. Can be a percentage of the total, e.g., 5%% (default: 10)",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=451,
+ help="Random seed (default: 451)",
+ )
+ parser.add_argument(
+ "--lang",
+ type=str,
+ default="en",
+ help="Target language for translation (default: en)",
+ )
+ parser.add_argument(
+ "--translate_mode",
+ type=str,
+ choices=["google", "local", "openai"],
+ default="local",
+ help="Target language for translation (default: local)",
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="cpu",
+ help="Process on CPU or CUDA (default: cpu)",
+ )
+ parser.add_argument(
+ "--output",
+ type=str,
+ default=None,
+ help="Output file to write augmented sentences in addition to the input corpus",
+ )
+ parser.add_argument("--append", action="store_true", help="Append to corpus file")
+ args = parser.parse_args()
+
+ random.seed(args.seed)
+
+ sentences = read_sentences_corpus(args.corpus, max_sentences=args.maxs)
+ print(f"Read {len(sentences)} sentences from {args.corpus}")
+
+ augmentation_sequence = create_augmentation_sequence(
+ args.aug, args.action, args.translate_mode, args.lang, args.device
+ )
+
+ augmented_sentences = sentences
+ for i, aug_fn in enumerate(augmentation_sequence):
+ print(f"Augmentation step {i + 1} of {len(augmentation_sequence)}:")
+ augmented_sentences = aug_fn(augmented_sentences)
+
+ augmented_sentences = post_process_sentences(augmented_sentences)
+ sentences = post_process_sentences(sentences)
+ print_sentences_comparison(sentences, augmented_sentences)
+
+ print("Removing equal sentences...")
+ augmented_sentences = remove_equal_sentences(sentences, augmented_sentences)
+
+ print("\nFinal results:")
+ print("-------------------")
+ for sentence in augmented_sentences:
+ print(sentence)
+ print(f"\nTotal: {len(augmented_sentences)} sentences")
+ print("-------------------\n")
+
+ if args.append:
+ print(f"Appending augmented sentences to {args.corpus}...")
+ append_sentences_to_file(args.corpus, augmented_sentences)
+
+ if args.output:
+ print(f"Appending augmented sentences to {args.output}...")
+ append_sentences_to_file(args.output, augmented_sentences)
diff --git a/txt_gen.py b/txt_gen.py
new file mode 100755
index 0000000..db2546a
--- /dev/null
+++ b/txt_gen.py
@@ -0,0 +1,173 @@
+#!/usr/bin/env python3
+import argparse
+import random
+import re
+import openai
+
+from typing import List
+from tqdm import tqdm
+from tenacity import (
+ retry,
+ stop_after_attempt,
+ wait_random_exponential,
+)
+from util.files import append_sentences_to_file, read_file
+from util.text import post_process_sentences
+
+MAX_TOKENS = {
+ # https://platform.openai.com/docs/models/gpt-4
+ # https://platform.openai.com/docs/models/gpt-3-5
+ "gpt-4": 8192,
+ "gpt-4-32k": 32768,
+ "gpt-3.5-turbo": 4097,
+ "gpt-3.5-turbo-16k": 16385,
+}
+
+
+# https://github.com/openai/openai-cookbook/blob/main/examples/How_to_handle_rate_limits.ipynb
+@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
+def make_chatgpt_query(
+ query: str, api_key: str, return_type: str, model: str = "gpt-3.5-turbo"
+) -> List[str]:
+ """
+ Makes a query to the ChatGPT model and returns the generated response.
+
+ Args:
+ query (str): The user's query.
+ api_key (str): The API key for accessing the ChatGPT model.
+ model (str): The name of the ChatGPT model.
+
+ Returns:
+ List[str]: A list containing the model's response.
+ """
+ max_tokens = MAX_TOKENS[model]
+ openai.api_key = api_key
+ # check max_tokens
+ response = openai.ChatCompletion.create(
+ model=model, messages=[{"role": "user", "content": query}]
+ )
+ is_truncated = response["usage"]["total_tokens"] >= max_tokens
+ response = [
+ line
+ for line in response["choices"][0]["message"]["content"].split("\n")
+ if line.strip() != ""
+ ]
+ if is_truncated and len(response) > 0:
+ response.pop()
+ return response
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Sentence/word generation using ChatGPT"
+ )
+ parser.add_argument(
+ "--input_file", type=str, default=None, help="Input file with words"
+ )
+ parser.add_argument(
+ "--num", type=int, default=None, help="Number of sentences or words to generate"
+ )
+ parser.add_argument(
+ "--context",
+ type=str,
+ default="radiologia médica",
+ help="Context of the generated sentences",
+ )
+ parser.add_argument(
+ "--query",
+ type=str,
+ default=None,
+ help="A query to OpenAI's ChatGPT; the first number detected in the query will be replaced by the number of sentences to generate",
+ )
+ parser.add_argument(
+ "--return_type",
+ type=str,
+ default="frases",
+ help="Type of data to generate (default: frases)",
+ )
+ parser.add_argument(
+ "--api_key",
+ type=str,
+ default=None,
+ help="OpenAI API key",
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="gpt-3.5-turbo-16k",
+ help="ChatGPT model to use",
+ )
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=451,
+ help="Random seed (default: 451)",
+ )
+ parser.add_argument(
+ "--output",
+ type=str,
+ default=None,
+ help="Output file to write generated sentences",
+ )
+ args = parser.parse_args()
+
+ random.seed(args.seed)
+
+ if args.query is None:
+ if args.return_type == "frases":
+ args.query = f"No contexto de {args.context}, gere {args.num} {args.return_type} contendo o termo '[MASK]', separadas por nova linha."
+ else:
+ args.query = f"No contexto de {args.context}, gere {args.num} {args.return_type} separadas por nova linha."
+ else:
+ args.num = (
+ int(re.search(r"\d+", args.query).group())
+ if re.search(r"\d+", args.query)
+ else None
+ )
+
+ if args.input_file:
+ wordlist = read_file(args.input_file)
+ else:
+ if args.return_type == "frases" and "[MASK]" in args.query:
+ wordlist = []
+ while True:
+ word = input("Enter a word (or press Enter to finish): ")
+ if word == "":
+ break
+ wordlist.append(word)
+ else:
+ wordlist = [""]
+
+ response_sentences: List[str] = []
+ original_query = args.query
+ for word in tqdm(wordlist):
+ word = word.strip()
+ query = re.sub(r"\[MASK\]", word, original_query)
+ number_of_sentences_left = args.num
+
+ while number_of_sentences_left > 0:
+ print(f"\nNumber of sentences left: {number_of_sentences_left}")
+ print(f"Querying OpenAI's {args.model} with '{query}'...")
+ query_response = make_chatgpt_query(
+ query, args.api_key, return_type=args.return_type, model=args.model
+ )
+ print(query_response)
+ response_sentences.extend(
+ [s.split(" ", 1)[1] if s[0].isdigit() else s for s in query_response]
+ )
+ number_of_sentences_left -= len(query_response)
+ query = re.sub(r"\d+", str(number_of_sentences_left), query)
+ print()
+
+ generated_sentences = post_process_sentences(response_sentences, modify=True)
+
+ print("\nFinal results:")
+ print("-------------------")
+ for sentence in generated_sentences:
+ print(sentence)
+ print(f"\nTotal: {len(generated_sentences)} sentences")
+ print("-------------------\n")
+
+ if args.output:
+ print(f"Appending generated sentences to {args.output}...")
+ append_sentences_to_file(args.output, generated_sentences)
diff --git a/util/__init__.py b/util/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/util/files.py b/util/files.py
new file mode 100644
index 0000000..715539c
--- /dev/null
+++ b/util/files.py
@@ -0,0 +1,63 @@
+#!/usr/bin/env python3
+import re
+
+from typing import List, Optional
+from .text import replacement_dict, add_period_and_capitalize, pre_process_sentences
+
+
+def read_sentences_corpus(
+ filename: str, max_sentences: Optional[str] = None
+) -> List[str]:
+ """
+ Reads sentences from a corpus file, and returns a list of sentences.
+
+ Args:
+ filename (str): The name of the input file.
+ max_sentences (str): Optional. The maximum number of sentences to read.
+ Can be % or number.
+
+ Returns:
+ List[str]: A list of sentences.
+
+ Note:
+ To make the augmentation useful, replace punctuation words with
+ punctuation marks, capitalize the sentences and add a dot at the end.
+ The sentences will be post-processed to change the punctuation marks
+ back to words again, unless stated otherwise.
+ """
+ sentences = []
+ with open(filename, "r", encoding="utf-8") as f:
+ line_count = sum(1 for _ in f)
+ f.seek(0)
+ if max_sentences not in [None, -1, "100%"]:
+ max_sentences = (
+ int(max_sentences)
+ if max_sentences[-1] != "%"
+ else round(line_count * (float(max_sentences[:-1]) / 100.0))
+ )
+ for line in f:
+ sentences.append(line.strip())
+ if max_sentences not in [None, -1, "100%"]:
+ if len(sentences) == max_sentences:
+ break
+
+ sentences = pre_process_sentences(sentences)
+ return sentences
+
+
+def append_sentences_to_file(filename: str, sentences: List[str]):
+ """
+ Appends sentences to a file.
+
+ Args:
+ filename (str): The name of the output file.
+ sentences (List[str]): The list of sentences to be written to the file.
+ """
+ with open(filename, "a", encoding="utf-8") as outfile:
+ for sentence in sentences:
+ outfile.write("\n" + sentence)
+
+
+def read_file(filename: str) -> List[str]:
+ with open(filename, "r", encoding="utf-8") as f:
+ return f.readlines()
diff --git a/util/text.py b/util/text.py
new file mode 100644
index 0000000..99d7db1
--- /dev/null
+++ b/util/text.py
@@ -0,0 +1,177 @@
+#!/usr/bin/env python3
+import random
+import re
+from typing import List
+from num2words import num2words
+
+
+replacement_dict = {
+ # Warning: order matters!
+ "ponto de exclamação": "!",
+ "exclamação": "!",
+ "ponto de interrogação": "?",
+ "interrogação": "?",
+ "dois pontos": ":",
+ "reticências": "...",
+ "ponto final": ".",
+ "ponto": ".",
+ "vírgula": ",",
+ "ponto e vírgula": ";",
+ "ponto vírgula": ";",
+ "travessão": "--",
+ # From here on will be excluded from the final augmented sentences.
+ "nova linha": "",
+ "novo parágrafo": "",
+ "parágrafo": "",
+ "nova linha": "",
+ "hífen": "",
+ "abre aspas": '"',
+ "abrir aspas": '"',
+ "fecha aspas": '"',
+ "fechar aspas": '"',
+ "aspas": '"',
+ "aspa": '"',
+ "abre parênteses": "(",
+ "abre parêntese": "(",
+ "abrir parênteses": "(",
+ "abrir parêntese": "(",
+ "fechar parênteses": ")",
+ "fechar parêntese": ")",
+ "fecha parênteses": ")",
+ "fecha parêntese": ")",
+}
+
+reverse_replacement_dict = {
+ # Warning: order matters!
+ "...": " reticências",
+ ".": " ponto",
+ ":": " dois pontos",
+ ",": " vírgula",
+ ";": " ponto e vírgula",
+ "\n": " parágrafo",
+ "‒-": " travessão",
+ "!": " ponto de exclamação",
+ "?": " ponto de interrogação",
+ "(": "abre parênteses ",
+ ")": " fecha parênteses",
+ "-": "",
+ "/": " barra ",
+ '"': "",
+}
+
+
+def add_period_and_capitalize(sentence: str) -> str:
+ """
+ Adds a period at the end of the sentence if it doesn't already have one,
+ capitalizes the sentences, and returns the modified sentence.
+
+ Args:
+ sentence (str): The input sentence.
+
+ Returns:
+ str: The modified sentence with added period and capitalized.
+ """
+ if sentence[-1] != ".":
+ sentence += "."
+ sentences = sentence.split(".")
+ return ". ".join(s.strip().capitalize() for s in sentences).strip()
+
+
+def pre_process_sentences(sentences: List[str]) -> List[str]:
+ sentences_processed = []
+ for s in sentences:
+ s = s.strip()
+ for word, punctuation in replacement_dict.items():
+ s = s.replace(word, punctuation)
+ if s.isspace() or s == "":
+ continue
+ s = s.strip()
+ s = add_period_and_capitalize(s)
+ # "Glue" the punctuation marks to the previous character.
+ s = re.sub(r"\s+([.,;!?:)]|(\.{3,}))", r"\1", s)
+ # "Glue" the ( to the next character.
+ s = re.sub(r"\(\s*(\w)", r"(\1", s)
+ sentences_processed.append(s)
+ return sentences_processed
+
+
+def post_process_sentences(sentences: List[str], modify=True) -> List[str]:
+ """
+ Post-processes a list of sentences by changing punctuation marks back to words
+ and applying additional modifications.
+
+ Args:
+ sentences (List[str]): The list of sentences to be post-processed.
+
+ Returns:
+ List[str]: A list of post-processed sentences.
+ """
+ post_processed_sentences = []
+ for sentence in sentences:
+ original_sentence = sentence
+ for punctuation, word in reverse_replacement_dict.items():
+ sentence = sentence.replace(punctuation, word)
+ sentence = re.sub(
+ r"\d+",
+ lambda x: num2words(int(x.group()), lang="pt_BR", to="cardinal"),
+ sentence,
+ ).replace(",", "")
+ sentence = sentence.lower().strip()
+ if modify and len(original_sentence.split()) > 1:
+ if sentence.endswith("ponto") and random.random() < 0.33:
+ sentence = sentence[:-6] # Remove "ponto" from the end
+ if random.random() < 0.25:
+ sentence = random.choice(
+ ["parágrafo " + sentence, "nova linha " + sentence]
+ )
+ elif random.random() < 0.25:
+ sentence = random.choice(
+ ["ponto parágrafo " + sentence, "ponto nova linha " + sentence]
+ )
+ if not sentence.endswith("ponto"):
+ if random.random() < 0.25:
+ sentence += random.choice([" ponto parágrafo", " ponto nova linha"])
+ else:
+ if random.random() < 0.25:
+ sentence += random.choice([" parágrafo", " nova linha"])
+ post_processed_sentences.append(sentence.strip())
+ return post_processed_sentences
+
+
+def print_sentences_comparison(sentences: List[str], augmented_sentences: List[str]):
+ """
+ Prints the original and augmented sentences for comparison.
+
+ Args:
+ sentences (List[str]): The original sentences.
+ augmented_sentences (List[str]): The augmented sentences.
+ """
+ print("\nResults:")
+ print("-------------")
+ for i, (original, augmented) in enumerate(zip(sentences, augmented_sentences)):
+ print(f"src {i + 1}: {original.strip()}")
+ print(f"aug {i + 1}: {augmented.strip()}\n")
+
+
+def remove_equal_sentences(
+ sentences: List[str], final_sentences: List[str]
+) -> List[str]:
+ """
+ Removes duplicate sentences from the final list of sentences.
+
+ Args:
+ sentences (List[str]): The original list of sentences.
+ final_sentences (List[str]): The final list of sentences.
+
+ Returns:
+ List[str]: A list of unique sentences.
+ """
+ sentences_set = set(sentences)
+ modified_sentences = []
+
+ for sentence in final_sentences:
+ if sentence in sentences_set:
+ continue
+ modified_sentences.append(sentence.strip())
+
+ return modified_sentences
diff --git a/util/text_augmenter.py b/util/text_augmenter.py
new file mode 100644
index 0000000..a4aed62
--- /dev/null
+++ b/util/text_augmenter.py
@@ -0,0 +1,41 @@
+#!/usr/bin/env python3
+from typing import List
+from nlpaug.augmenter.word import RandomWordAug, SynonymAug
+
+
+class SentenceAugmenter:
+ def __init__(
+ self,
+ augmenter_type: str,
+ lang: str = "por",
+ action: str = None,
+ aug_min: int = 1,
+ aug_max: int = 10,
+ aug_p: float = 0.3,
+ ):
+ self.lang = lang
+ self.action = action
+ self.aug_min = aug_min
+ self.aug_max = aug_max
+ self.aug_p = aug_p
+ if augmenter_type == "random":
+ self.augmenter = RandomWordAug(
+ action=self.action,
+ aug_min=self.aug_min,
+ aug_max=self.aug_max,
+ aug_p=self.aug_p,
+ )
+ elif augmenter_type == "synonym":
+ self.augmenter = SynonymAug(
+ aug_src="wordnet",
+ aug_min=self.aug_min,
+ aug_max=self.aug_max,
+ aug_p=self.aug_p,
+ lang=self.lang,
+ )
+ else:
+ raise ValueError("Invalid augmenter_type")
+
+ def augment_sentences(self, sentences: List[str]) -> List[str]:
+ print(f"Augmenting {len(sentences)} sentences with {self.augmenter}...")
+ return self.augmenter.augment(sentences)