diff --git a/CHANGELOG.md b/CHANGELOG.md index 825c32f..e048265 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1 +1,20 @@ -# Changelog +### Changelog + +All notable changes to this project will be documented in this file. Dates are displayed in UTC. + +Generated by [`auto-changelog`](https://github.com/CookPete/auto-changelog). + +#### [v0.0.5](https://github.com/pszemraj/textsum/compare/v0.0.1...v0.0.5) + +> 16 January 2023 + +- Summarization Pipeline CLI [`#2`](https://github.com/pszemraj/textsum/pull/2) + +#### v0.0.1 + +> 20 December 2022 + +- min working example [`#1`](https://github.com/pszemraj/textsum/pull/1) +- 🚚 migrate docsum space files [`a33b00c`](https://github.com/pszemraj/textsum/commit/a33b00c676add7db63a163b37f6ca6dba61d646b) +- 🎉 add pyscaffold skeleton [`cacaea3`](https://github.com/pszemraj/textsum/commit/cacaea3840ac620dedfcbdce8f92ae023fbf161b) +- Initial commit [`ec48913`](https://github.com/pszemraj/textsum/commit/ec48913456d314908838db7574183e21e698a066) diff --git a/README.md b/README.md index da904a9..694253c 100644 --- a/README.md +++ b/README.md @@ -14,24 +14,24 @@ > utility for using transformers summarization models on text docs -An extension/generalization of the [document summarization]() space on huggingface. The purpose of this package is to provide a simple interface for using summarization models on text documents of arbitrary length. +The purpose of this package is to provide a simple interface (python API, CLI, gradio web UI) for using summarization models on text documents of arbitrary length. ⚠️ **WARNING**: _This package is a WIP and is not ready for production use. Some things may not work yet._ ⚠️ ## Installation -Install the package using pip: +Install using pip: ```bash # create a virtual environment (optional) pip install git+https://github.com/pszemraj/textsum.git ``` -The textsum package is now installed in your virtual environment. You can now use the CLI or UI demo (see [Usage](#usage)). +The `textsum` package is now installed in your virtual environment. You can now use the CLI or python API to summarize text docs see the [Usage](#usage) section for more details. -### Full Installation _(PDF OCR, gradio UI demo)_ +### Full Installation -To install all the dependencies _(includes PDF OCR, gradio UI demo)_, run: +To install all the dependencies _(includes PDF OCR, gradio UI demo, optimum, etc)_, run: ```bash git clone https://github.com/pszemraj/textsum.git @@ -42,6 +42,31 @@ pip install -e .[all] ## Usage +There are three ways to use this package: + +1. [python API](#python-api) +2. [CLI](#cli) +3. [Demo App](#demo-app) + +### Python API + +```python +from textsum.summarize import Summarizer + +summarizer = Summarizer() # loads default model and parameters + +# summarize a long string +out_str = summarizer.summarize_string('This is a long string of text that will be summarized.') +print(f'summary: {out_str}') +``` + +you can also directly summarize a file: + +```python +out_path = summarizer.summarize_file('/path/to/file.txt') +print(f'summary saved to {out_path}') +``` + ### CLI To summarize a directory of text files, run the following command: @@ -66,27 +91,36 @@ For more information, run: textsum-dir --help ``` -### UI Demo +### Demo App + +For convenience, a UI demo[^1] is provided using [gradio](https://gradio.app/). To ensure you have the dependencies installed, clone the repo and run the following command: + +```bash +pip install -e .[app] +``` -For convenience, a UI demo is provided using [gradio](https://gradio.app/). To run the demo, run the following command: +To run the demo, run the following command: ```bash textsum-ui ``` -This is currently a minimal demo, but it will be expanded in the future to accept other arguments and options. +This will start a local server that you can access in your browser & a shareable link will be printed to the console. + +[^1]: The demo is currently minimal, but will be expanded in the future to accept other arguments and options. --- ## Roadmap -- [ ] add argparse CLI for UI demo - [x] add CLI for summarization of all text files in a directory -- [ ] python API for summarization of text docs -- [ ] optimum inference integration -- [ ] better documentation, details on improving performance (speed, quality, memory usage, etc.) +- [x] python API for summarization of text docs +- [ ] add argparse CLI for UI demo +- [ ] put on pypi +- [ ] optimum inference integration, LLM.int8 inference +- [ ] better documentation [in the wiki](https://github.com/pszemraj/textsum/wiki), details on improving performance (speed, quality, memory usage, etc.) -and other things I haven't thought of yet +_Other ideas? Open an issue or PR!_ --- diff --git a/setup.cfg b/setup.cfg index f72c610..6513343 100644 --- a/setup.cfg +++ b/setup.cfg @@ -70,9 +70,10 @@ optimum = optimum PDF = python-doctr[torch] pyspellchecker -app = gradio -all = +app = + gradio %(PDF)s +all = %(app)s %(optimum)s %(8bit)s diff --git a/src/textsum/__init__.py b/src/textsum/__init__.py index 243e7cf..45361f0 100644 --- a/src/textsum/__init__.py +++ b/src/textsum/__init__.py @@ -4,7 +4,7 @@ """ import sys -from . import cli, utils +from . import summarize, utils if sys.version_info[:2] >= (3, 8): # TODO: Import directly (no need for conditional) when `python_requires = >= 3.8` diff --git a/src/textsum/app.py b/src/textsum/app.py index faab016..b3a566f 100644 --- a/src/textsum/app.py +++ b/src/textsum/app.py @@ -10,6 +10,10 @@ from pathlib import Path os.environ["USE_TORCH"] = "1" +os.environ["DEMO_MAX_INPUT_WORDS"] = "2048" # number of words to truncate input to +os.environ["DEMO_MAX_INPUT_PAGES"] = "20" # number of pages to truncate PDFs to +os.environ["TOKENIZERS_PARALLELISM"] = "false" # parallelism is buggy with gradio + logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" ) @@ -17,34 +21,31 @@ import gradio as gr import nltk from cleantext import clean -from doctr.io import DocumentFile from doctr.models import ocr_predictor from textsum.pdf2text import convert_PDF_to_Text -from textsum.summarize import load_model_and_tokenizer, summarize_via_tokenbatches -from textsum.utils import save_summary, truncate_word_count +from textsum.summarize import Summarizer +from textsum.utils import truncate_word_count, get_timestamp -_here = Path(__file__).parent +_here = Path.cwd() nltk.download("stopwords") # TODO=find where this requirement originates from def proc_submission( input_text: str, - model_size: str, - num_beams, - token_batch_length, - length_penalty, - repetition_penalty, - no_repeat_ngram_size, - max_input_length: int = 1024, + num_beams: int, + token_batch_length: int, + length_penalty: float, + repetition_penalty: float, + no_repeat_ngram_size: int, + max_input_words: int = 1024, ): """ proc_submission - a helper function for the gradio module to process submissions Args: input_text (str): the input text to summarize - model_size (str): the size of the model to use num_beams (int): the number of beams to use token_batch_length (int): the length of the token batches to use length_penalty (float): the length penalty to use @@ -55,17 +56,13 @@ def proc_submission( Returns: str in HTML format, string of the summary, str of score """ - global model, tokenizer, model_sm, tokenizer_sm - # assert that the model is loaded and accessible - if "model" not in globals(): - model, tokenizer = load_model_and_tokenizer( - "pszemraj/pegasus-x-large-book-summary" - ) - if "model_sm" not in globals(): - model_sm, tokenizer_sm = load_model_and_tokenizer( - "pszemraj/long-t5-tglobal-base-16384-book-summary" - ) + global summarizer + max_input_words = ( + int(os.environ["DEMO_MAX_INPUT_WORDS"]) + if int(os.environ["DEMO_MAX_INPUT_WORDS"]) > 0 + else max_input_words + ) settings = { "length_penalty": float(length_penalty), "repetition_penalty": float(repetition_penalty), @@ -77,22 +74,33 @@ def proc_submission( "early_stopping": True, "do_sample": False, } + + if "summarizer" not in globals(): + logging.info("model not loaded, reloading now") + summarizer = Summarizer( + use_cuda=True, + token_batch_length=token_batch_length, + **settings, + ) + st = time.perf_counter() history = {} clean_text = clean(input_text, lower=False) - max_input_length = 2048 if "base" in model_size.lower() else max_input_length - processed = truncate_word_count(clean_text, max_input_length) + processed = truncate_word_count( + clean_text, + max_words=max_input_words, + ) if processed["was_truncated"]: tr_in = processed["truncated_text"] - # create elaborate HTML warning input_wc = re.split(r"\s+", input_text) + msg = f"""

Warning

-

Input text was truncated to {max_input_length} words. That's about {100*max_input_length/len(input_wc):.2f}% of the submission.

+

Input text was truncated to {max_input_words} words. That's about {100*max_input_words/len(input_wc):.2f}% of the submission.

- """ + """ # create elaborate HTML warning message logging.warning(msg) history["WARNING"] = msg else: @@ -100,38 +108,38 @@ def proc_submission( msg = None if len(input_text) < 50: - # this is essentially a different case from the above msg = f"""

Warning

Input text is too short to summarize. Detected {len(input_text)} characters. Please load text by selecting an example from the dropdown menu or by pasting text into the text box.

- """ + """ # no-input warning logging.warning(msg) logging.warning("RETURNING EMPTY STRING") history["WARNING"] = msg return msg, "", [] - _summaries = summarize_via_tokenbatches( - tr_in, - model_sm if model_size == "LongT5-base" else model, - tokenizer_sm if model_size == "LongT5-base" else tokenizer, + processed_outputs = summarizer.summarize_via_tokenbatches( + input_text=tr_in, batch_length=token_batch_length, - **settings, - ) - sum_text = [f"Section {i}: " + s["summary"][0] for i, s in enumerate(_summaries)] + ) # get the summaries + + # reformat output + history["Summary Scores"] = "

" + sum_text = [ + f"\tSection {i}: " + s["summary"][0] for i, s in enumerate(processed_outputs) + ] sum_scores = [ - f" - Section {i}: {round(s['summary_score'],4)}" - for i, s in enumerate(_summaries) + f"\tSection {i}: {round(s['summary_score'],4)}" + for i, s in enumerate(processed_outputs) ] sum_text_out = "\n".join(sum_text) - history["Summary Scores"] = "

" scores_out = "\n".join(sum_scores) rt = round((time.perf_counter() - st) / 60, 2) - print(f"Runtime: {rt} minutes") + logging.info(f"Runtime: {rt} minutes") html = "" html += f"

Runtime: {rt} minutes on CPU

" if msg is not None: @@ -139,27 +147,31 @@ def proc_submission( html += "" - # save to file - saved_file = save_summary(_summaries) + summary_file = _here / f"summarized_{get_timestamp()}.txt" + summarizer.save_summary( + summary_data=processed_outputs, + target_file=summary_file, + ) - return html, sum_text_out, scores_out, saved_file + return html, sum_text_out, scores_out, summary_file -def load_uploaded_file(file_obj, max_pages=20): +def load_uploaded_file(file_obj, max_pages=20) -> str: """ - load_uploaded_file - process an uploaded file + load_uploaded_file - loads a file added by the user - Args: - file_obj (POTENTIALLY list): Gradio file object inside a list - - Returns: - str, the uploaded file contents + :param file_obj: a file object from gr.File() + :param int max_pages: the maximum number of pages to convert from a PDF + :return str: the text from the file """ - # file_path = Path(file_obj[0].name) - - # check if mysterious file object is a list global ocr_model + max_pages = ( + int(os.environ["DEMO_MAX_INPUT_PAGES"]) + if int(os.environ["DEMO_MAX_INPUT_PAGES"]) > 0 + else max_pages + ) + logging.info(f"Loading file, truncating to {max_pages} pages for PDFs") if isinstance(file_obj, list): file_obj = file_obj[0] file_path = Path(file_obj.name) @@ -187,19 +199,9 @@ def load_uploaded_file(file_obj, max_pages=20): def main(): - logging.info("Starting app instance") - os.environ[ - "TOKENIZERS_PARALLELISM" - ] = "false" # parallelism on tokenizers is buggy with gradio - logging.info("Loading summ models") - with contextlib.redirect_stdout(None): - model, tokenizer = load_model_and_tokenizer( - "pszemraj/pegasus-x-large-book-summary" - ) - model_sm, tokenizer_sm = load_model_and_tokenizer( - "pszemraj/long-t5-tglobal-base-16384-book-summary" - ) - # ensure that the models are global variables + logging.info(f"Starting app instance. Files will be saved to {str(_here)}") + + summarizer = Summarizer() logging.info("Loading OCR model") with contextlib.redirect_stdout(None): @@ -208,13 +210,18 @@ def main(): "crnn_mobilenet_v3_large", pretrained=True, assume_straight_pages=True, - ) + ) # mostly to pre-download the model + demo = gr.Blocks() with demo: - gr.Markdown("# Document Summarization with Long-Document Transformers") + gr.Markdown("# Summarization UI with `textsum`") gr.Markdown( - "This is an example use case for fine-tuned long document transformers. The model is trained on book summaries (via the BookSum dataset). The models in this demo are [LongT5-base](https://huggingface.co/pszemraj/long-t5-tglobal-base-16384-book-summary) and [Pegasus-X-Large](https://huggingface.co/pszemraj/pegasus-x-large-book-summary)." + f""" + This is an example use case for fine-tuned long document transformers. + - Model: `{summarizer.model_name_or_path}` + - this demo created with the [textsum](https://github.com/pszemraj/textsum) library + gradio. + """ ) with gr.Column(): @@ -225,11 +232,6 @@ def main(): with gr.Row(variant="compact"): with gr.Column(scale=0.5, variant="compact"): - model_size = gr.Radio( - choices=["LongT5-base", "Pegasus-X-large"], - label="Model Variant", - value="LongT5-base", - ) num_beams = gr.Radio( choices=[2, 3, 4], label="Beam Search: # of Beams", @@ -314,7 +316,7 @@ def main(): with gr.Column(): gr.Markdown("### About the Model") gr.Markdown( - "These models are fine-tuned on the [BookSum dataset](https://arxiv.org/abs/2105.08209).The goal was to create a model that can generalize well and is useful in summarizing lots of text in academic and daily usage." + "Model(s) are fine-tuned on the [BookSum dataset](https://arxiv.org/abs/2105.08209).The goal was to create a model that can generalize well and is useful in summarizing lots of text in academic and daily usage." ) gr.Markdown("---") @@ -326,7 +328,6 @@ def main(): fn=proc_submission, inputs=[ input_text, - model_size, num_beams, token_batch_length, length_penalty, @@ -336,10 +337,13 @@ def main(): outputs=[output_text, summary_text, summary_scores, text_file], ) - demo.launch(enable_queue=True) + demo.launch(enable_queue=True, share=True) def run(): + """ + run - main entry point for the app + """ main() diff --git a/src/textsum/cli.py b/src/textsum/cli.py index b181536..a4ac7c5 100644 --- a/src/textsum/cli.py +++ b/src/textsum/cli.py @@ -1,6 +1,5 @@ """ cli.py - a module containing functions for the command line interface (to run the summarization on a directory of files) - #TODO: add a function to summarize a single file usage: textsum-dir [-h] [-o OUTPUT_DIR] [-m MODEL_NAME] [-batch BATCH_LENGTH] [-stride BATCH_STRIDE] [-nb NUM_BEAMS] [-l2 LENGTH_PENALTY] [-r2 REPETITION_PENALTY] [--no_cuda] [-length_ratio MAX_LENGTH_RATIO] [-ml MIN_LENGTH] @@ -19,113 +18,12 @@ import pprint as pp import random import sys -import warnings from pathlib import Path -import torch -from cleantext import clean from tqdm.auto import tqdm -from textsum.summarize import ( - load_model_and_tokenizer, - save_params, - summarize_via_tokenbatches, -) -from textsum.utils import get_mem_footprint, postprocess_booksummary, setup_logging - - -def summarize_text_file( - file_path: str or Path, - model, - tokenizer, - batch_length: int = 4096, - batch_stride: int = 16, - lowercase: bool = False, - **kwargs, -) -> dict: - """ - summarize_text_file - given a file path, summarize the text in the file - - :param str or Path file_path: the path to the file to summarize - :param model: the model to use for summarization - :param tokenizer: the tokenizer to use for summarization - :param int batch_length: length of each batch in tokens to summarize, defaults to 4096 - :param int batch_stride: stride between batches in tokens, defaults to 16 - :param bool lowercase: whether to lowercase the text before summarizing, defaults to False - :return: a dictionary containing the summary and other information - """ - file_path = Path(file_path) - ALLOWED_EXTENSIONS = [".txt", ".md", ".rst", ".py", ".ipynb"] - assert ( - file_path.exists() and file_path.suffix in ALLOWED_EXTENSIONS - ), f"File {file_path} does not exist or is not a text file" - - logging.info(f"Summarizing {file_path}") - with open(file_path, "r", encoding="utf-8", errors="ignore") as f: - text = clean(f.read(), lower=lowercase, no_line_breaks=True) - logging.debug( - f"Text length: {len(text)}. batch length: {batch_length} batch stride: {batch_stride}" - ) - summary_data = summarize_via_tokenbatches( - input_text=text, - model=model, - tokenizer=tokenizer, - batch_length=batch_length, - batch_stride=batch_stride, - **kwargs, - ) - logging.info(f"Finished summarizing {file_path}") - return summary_data - - -def process_summarization( - summary_data: dict, - target_file: str or Path, - custom_phrases: list = None, - save_scores: bool = True, -) -> None: - """ - process_summarization - given a dictionary of summary data, save the summary to a file - - :param dict summary_data: a dictionary containing the summary and other information (output from summarize_text_file) - :param str or Path target_file: the path to the file to save the summary to - :param list custom_phrases: a list of custom phrases to remove from each summary (relevant for dataset specific repeated phrases) - :param bool save_scores: whether to write the scores to a file - """ - target_file = Path(target_file).resolve() - if target_file.exists(): - warnings.warn(f"File {target_file} exists, overwriting") - - sum_text = [ - postprocess_booksummary( - s["summary"][0], - custom_phrases=custom_phrases, - ) - for s in summary_data - ] - sum_scores = [f"\n - {round(s['summary_score'],4)}" for s in summary_data] - scores_text = "\n".join(sum_scores) - full_summary = "\n\t".join(sum_text) - - with open( - target_file, - "w", - ) as fo: - - fo.writelines(full_summary) - - if save_scores: - with open( - target_file, - "a", - ) as fo: - - fo.write("\n" * 3) - fo.write(f"\n\nSection Scores for {target_file.stem}:\n") - fo.writelines(scores_text) - fo.write("\n\n---\n") - - logging.info(f"Saved summary to {target_file.resolve()}") +from textsum.summarize import Summarizer +from textsum.utils import setup_logging def get_parser(): @@ -290,22 +188,8 @@ def main(args): logging.info("starting summarization") logging.info(f"args: {pp.pformat(args)}") - device = torch.device( - "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" - ) - logging.info(f"using device: {device}") - # load the model and tokenizer - model, tokenizer = load_model_and_tokenizer( - args.model_name, use_cuda=not args.no_cuda - ) - - logging.info(f"model size: {get_mem_footprint(model)}") - # move the model to the device - model.to(device) - params = { "min_length": args.min_length, - "max_length": int(args.max_length_ratio * args.batch_length), "encoder_no_repeat_ngram_size": args.encoder_no_repeat_ngram_size, "no_repeat_ngram_size": args.no_repeat_ngram_size, "repetition_penalty": args.repetition_penalty, @@ -315,8 +199,19 @@ def main(args): "early_stopping": args.early_stopping, "do_sample": False, } + + summarizer = Summarizer( + model_name_or_path=args.model_name, + use_cuda=not args.no_cuda, + token_batch_length=args.batch_length, + batch_stride=args.batch_stride, + max_length_ratio=args.max_length_ratio, + **params, + ) + # get the input files input_files = list(Path(args.input_dir).glob("*.txt")) + logging.info(f"found {len(input_files)} input files") if args.shuffle: logging.info("shuffling input files") @@ -333,23 +228,12 @@ def main(args): # get the batches for f in tqdm(input_files): - outpath = output_dir / f"{f.stem}.summary.txt" - summary_data = summarize_text_file( - file_path=f, - model=model, - tokenizer=tokenizer, - batch_length=args.batch_length, - batch_stride=args.batch_stride, - lowercase=args.lowercase, - **params, - ) - process_summarization( - summary_data=summary_data, target_file=outpath, save_scores=True + _ = summarizer.summarize_file( + file_path=f, output_dir=output_dir, lowercase=args.lowercase ) logging.info(f"finished summarization loop - output dir: {output_dir.resolve()}") - save_params(params=params, output_dir=output_dir, hf_tag=args.model_name) - + summarizer.save_params(output_dir=output_dir, hf_tag=args.model_name) logging.info("finished summarizing files") diff --git a/src/textsum/pdf2text.py b/src/textsum/pdf2text.py index cbdf31e..a967a03 100644 --- a/src/textsum/pdf2text.py +++ b/src/textsum/pdf2text.py @@ -1,6 +1,9 @@ # -*- coding: utf-8 -*- """ pdf2text.py - convert pdf files to text files (OCR). helper functions for textsum + + + #TODO: rewrite this to a class with methods """ import logging diff --git a/src/textsum/summarize.py b/src/textsum/summarize.py index 3036d91..e015f27 100644 --- a/src/textsum/summarize.py +++ b/src/textsum/summarize.py @@ -3,190 +3,436 @@ """ import json import logging +import warnings from pathlib import Path import torch +from cleantext import clean from tqdm.auto import tqdm from transformers import AutoModelForSeq2SeqLM, AutoTokenizer -from textsum.utils import get_timestamp +from textsum.utils import get_timestamp, postprocess_booksummary -def load_model_and_tokenizer(model_name: str, use_cuda: bool = True): +class Summarizer: """ - load_model_and_tokenizer - a function that loads a model and tokenizer from huggingface - - Args: - model_name (str): the name of the model to load from huggingface - use_cuda (bool, optional): whether to use cuda. Defaults to True. - Returns: - AutoModelForSeq2SeqLM: the model - AutoTokenizer: the tokenizer + Summarizer - a class that contains functions for summarizing text with a transformers model """ - logger = logging.getLogger(__name__) - device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu" - logger.debug(f"loading model {model_name} to {device}") - model = AutoModelForSeq2SeqLM.from_pretrained( - model_name, - ).to(device) - tokenizer = AutoTokenizer.from_pretrained(model_name) - logger.info(f"Loaded model {model_name} to {device}") - return model, tokenizer - - -def summarize_and_score( - ids, mask, model, tokenizer, is_general_attention_model=True, **kwargs -): - """ - summarize_and_score - given a batch of ids and a mask, return a summary and a score for the summary - - Args: - ids (): the batch of ids - mask (): the attention mask for the batch - model (): the model to use for summarization - tokenizer (): the tokenizer to use for summarization - is_general_attention_model (bool, optional): whether the model is a general attention model. Defaults to True. - - Returns: - str: the summary of the batch - """ - - ids = ids[None, :] - mask = mask[None, :] - - input_ids = ids.to("cuda") if torch.cuda.is_available() else ids - attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask + def __init__( + self, + model_name_or_path: str = "pszemraj/long-t5-tglobal-base-16384-book-summary", + use_cuda: bool = True, + is_general_attention_model: bool = True, + token_batch_length: int = 2048, + batch_stride: int = 16, + max_length_ratio: float = 0.25, + **kwargs, + ): + """ + __init__ - initialize the Summarizer class + + :param str model_name_or_path: the name or path of the model to load, defaults to "pszemraj/long-t5-tglobal-base-16384-book-summary" + :param bool use_cuda: whether to use cuda, defaults to True + :param bool is_general_attention_model: whether the model is a general attention model, defaults to True + :param int token_batch_length: the amount of tokens to process in a batch, defaults to 2048 + :param int batch_stride: the amount of tokens to stride the batch by, defaults to 16 + :param float max_length_ratio: the ratio of the token_batch_length to use as the max_length for the model, defaults to 0.25 + :param kwargs: additional keyword arguments to pass to the model as inference parameters + """ + self.logger = logging.getLogger(__name__) + + self.model_name_or_path = model_name_or_path + self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu" + self.logger.debug(f"loading model {model_name_or_path} to {self.device}") + self.model = AutoModelForSeq2SeqLM.from_pretrained( + self.model_name_or_path, + ).to(self.device) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) + self.is_general_attention_model = ( + is_general_attention_model # TODO: add a check later + ) - global_attention_mask = torch.zeros_like(attention_mask) - # put global attention on token - global_attention_mask[:, 0] = 1 + self.logger.info(f"Loaded model {model_name_or_path} to {self.device}") + + # set batch processing parameters + self.token_batch_length = token_batch_length + self.batch_stride = batch_stride + self.max_len_ratio = max_length_ratio + + self.settable_inference_params = [ + "min_length", + "max_length", + "no_repeat_ngram_size", + "encoder_no_repeat_ngram_size", + "repetition_penalty", + "num_beams", + "num_beam_groups", + "length_penalty", + "early_stopping", + "do_sample", + ] # list of inference parameters that can be set + self.inference_params = { + "min_length": 8, + "max_length": int(token_batch_length * max_length_ratio), + "no_repeat_ngram_size": 3, + "encoder_no_repeat_ngram_size": 4, + "repetition_penalty": 2.5, + "num_beams": 4, + "num_beam_groups": 1, + "length_penalty": 0.8, + "early_stopping": True, + "do_sample": False, + } # default inference parameters + + for key, value in kwargs.items(): + if key in self.settable_inference_params: + self.inference_params[key] = value + else: + self.logger.warning( + f"{key} is not a supported inference parameter, ignoring" + ) + + def set_inference_params( + self, + new_params: dict = None, + config_file: str or Path = None, + config_metadata_id: str = "META_", + ): + """ + set_inference_params - update the inference parameters to use when summarizing text + + :param dict new_params: a dictionary of new inference parameters to use, defaults to None + :param str or Path config_file: a path to a json file containing inference parameters, defaults to None + + NOTE: if both new_params and config_file are provided, entries in the config_file will overwrite entries in new_params if they have the same key + """ + + assert ( + new_params or config_file + ), "must provide new_params or config_file to set inference parameters" + + new_params = new_params or {} + # load from config file if provided + if config_file: + with open(config_file, "r") as f: + config_params = json.load(f) + config_params = { + k: v + for k, v in config_params.items() + if k in self.settable_inference_params + } # remove key:value pairs that start with config_metadata_id + new_params.update(config_params) + self.logger.info(f"loaded inference parameters from {config_file}") + self.logger.debug(f"inference parameters: {new_params}") + + for key, value in new_params.items(): + if key in self.settable_inference_params: + self.inference_params[key] = value + else: + self.logger.warning( + f"{key} is not a valid inference parameter, ignoring" + ) + + def get_inference_params(self): + """get the inference parameters currently being used""" + return self.inference_params + + def summarize_and_score(self, ids, mask, **kwargs): + """ + summarize_and_score - summarize a batch of text and return the summary and output scores + + :param ids: the token ids of the tokenized batch to summarize + :param mask: the attention mask of the tokenized batch to summarize + :return tuple: a tuple containing the summary and output scores + """ + + ids = ids[None, :] + mask = mask[None, :] + + input_ids = ids.to("cuda") if torch.cuda.is_available() else ids + attention_mask = mask.to("cuda") if torch.cuda.is_available() else mask + + global_attention_mask = torch.zeros_like(attention_mask) + # put global attention on token + global_attention_mask[:, 0] = 1 + + if self.is_general_attention_model: + summary_pred_ids = self.model.generate( + input_ids, + attention_mask=attention_mask, + output_scores=True, + return_dict_in_generate=True, + **kwargs, + ) + else: + # this is for LED etc. + summary_pred_ids = self.model.generate( + input_ids, + attention_mask=attention_mask, + global_attention_mask=global_attention_mask, + output_scores=True, + return_dict_in_generate=True, + **kwargs, + ) + summary = self.tokenizer.batch_decode( + summary_pred_ids.sequences, + skip_special_tokens=True, + remove_invalid_values=True, + ) + score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4) + + return summary, score + + def summarize_via_tokenbatches( + self, + input_text: str, + batch_length: int = None, + batch_stride: int = None, + **kwargs, + ): + """ + summarize_via_tokenbatches - given a string of text, split it into batches of tokens and summarize each batch + + :param str input_text: the text to summarize + :param int batch_length: number of tokens to include in each input batch, default None (self.token_batch_length) + :param int batch_stride: number of tokens to stride between batches, default None (self.token_batch_stride) + :return: a list of summaries, a list of scores, and a list of the input text for each batch + """ + + logger = logging.getLogger(__name__) + # log all input parameters + if batch_length and batch_length < 512: + logger.warning( + "WARNING: entered batch_length was too low at {batch_length}, resetting to 512" + ) + batch_length = 512 + + logger.debug( + f"batch_length: {batch_length} batch_stride: {batch_stride}, kwargs: {kwargs}" + ) + if kwargs: + # if received kwargs, update inference params + self.set_inference_params(**kwargs) + + params = self.get_inference_params() + + encoded_input = self.tokenizer( + input_text, + padding="max_length", + truncation=True, + max_length=batch_length or self.token_batch_length, + stride=batch_stride or self.batch_stride, + return_overflowing_tokens=True, + add_special_tokens=False, + return_tensors="pt", + ) + in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask + + gen_summaries = [] + pbar = tqdm(total=len(in_id_arr), desc="Generating Summaries") + + for _id, _mask in zip(in_id_arr, att_arr): + + result, score = self.summarize_and_score( + ids=_id, + mask=_mask, + **params, + ) + score = round(float(score), 4) + _sum = { + "input_tokens": _id, + "summary": result, + "summary_score": score, + } + gen_summaries.append(_sum) + logger.debug(f"\n\t{result[0]}\nScore:\t{score}") + pbar.update() + + pbar.close() + + return gen_summaries + + def save_summary( + self, + summary_data: dict, + target_file: str or Path = None, + postprocess: bool = True, + custom_phrases: list = None, + save_scores: bool = True, + return_string: bool = False, + ): + """ + save_summary - a function that takes the output of summarize_via_tokenbatches and saves it to a file after postprocessing + + :param dict summary_data: output of summarize_via_tokenbatches containing the summary and score for each batch + :param str or Path target_file: the file to save the summary to, defaults to None + :param bool postprocess: whether to postprocess the summary, defaults to True + :param list custom_phrases: a list of custom phrases to use in postprocessing, defaults to None + :param bool save_scores: whether to save the scores for each batch, defaults to True + :param bool return_string: whether to return the summary as a string, defaults to False + + :return: None or str if return_string is True + """ + assert ( + target_file or return_string + ), "Must specify a target file or return_string=True" + + if postprocess: + sum_text = [ + postprocess_booksummary( + s["summary"][0], + custom_phrases=custom_phrases, + ) + for s in summary_data + ] + else: + sum_text = [s["summary"][0] for s in summary_data] + + sum_scores = [f"\n - {round(s['summary_score'],4)}" for s in summary_data] + scores_text = "\n".join(sum_scores) + full_summary = "\n\t".join(sum_text) + + if return_string: + return full_summary + + target_file = Path(target_file) + if not target_file.parent.exists(): + logging.info(f"Creating directory {target_file.parent}") + target_file.parent.mkdir(parents=True) + if target_file.exists(): + warnings.warn(f"File {target_file} exists, overwriting") + + with open( + target_file, + "w", + encoding="utf-8", + errors="ignore", + ) as fo: + + fo.writelines(full_summary) + + if save_scores: + with open( + target_file, + "a", + encoding="utf-8", + errors="ignore", + ) as fo: + + fo.write("\n" * 3) + fo.write(f"\n\nSection Scores for {target_file.stem}:\n") + fo.writelines(scores_text) + fo.write("\n\n---\n") + + self.logger.info(f"Saved summary to {target_file.resolve()}") + + def summarize_string( + self, + input_text: str, + batch_length: int = None, + batch_stride: int = None, + **kwargs, + ) -> str: + """ + summarize_string - generate a summary for a string of text + + :param str input_text: the text to summarize + :param int batch_length: number of tokens to use in each batch, defaults to None (self.token_batch_length) + :param int batch_stride: number of tokens to stride between batches, defaults to None (self.batch_stride) + :return str: the summary + """ + + logger = logging.getLogger(__name__) + # log all input parameters + if batch_length and batch_length < 512: + logger.warning( + "WARNING: entered batch_length was too low at {batch_length}, resetting to 512" + ) + batch_length = 512 + + logger.debug( + f"batch_length: {batch_length} batch_stride: {batch_stride}, kwargs: {kwargs}" + ) - if is_general_attention_model: - summary_pred_ids = model.generate( - input_ids, - attention_mask=attention_mask, - output_scores=True, - return_dict_in_generate=True, + gen_summaries = self.summarize_via_tokenbatches( + input_text, + batch_length=batch_length, + batch_stride=batch_stride, **kwargs, ) - else: - # this is for LED etc. - summary_pred_ids = model.generate( - input_ids, - attention_mask=attention_mask, - global_attention_mask=global_attention_mask, - output_scores=True, - return_dict_in_generate=True, + + return self.save_summary(summary_data=gen_summaries, return_string=True) + + def summarize_file( + self, + file_path: str or Path, + output_dir: str or Path = None, + batch_length=None, + batch_stride=None, + lowercase: bool = False, + **kwargs, + ) -> Path: + """ + summarize_file - a function that takes a text file and returns a summary + + :param str or Path file_path: the path to the text file + :param str or Path output_dir: the directory to save the summary to, defaults to None (current working directory) + :param bool lowercase: whether to lowercase the text prior to summarization, defaults to False + + :return Path: the path to the summary file + """ + + file_path = Path(file_path) + output_dir = Path(output_dir) if output_dir is not None else Path.cwd() + output_file = output_dir / f"{file_path.stem}_summary.txt" + + with open(file_path, "r") as f: + text = clean(f.read(), lower=lowercase) + + gen_summaries = self.summarize_via_tokenbatches( + text, + batch_length=batch_length, + batch_stride=batch_stride, **kwargs, ) - summary = tokenizer.batch_decode( - summary_pred_ids.sequences, - skip_special_tokens=True, - remove_invalid_values=True, - ) - score = round(summary_pred_ids.sequences_scores.cpu().numpy()[0], 4) - - return summary, score - - -def summarize_via_tokenbatches( - input_text: str, - model, - tokenizer, - batch_length=4096, - batch_stride=16, - **kwargs, -): - """ - summarize_via_tokenbatches - a function that takes a string and returns a summary - - Args: - input_text (str): the text to summarize - model (): the model to use for summarizationz - tokenizer (): the tokenizer to use for summarization - batch_length (int, optional): the length of each batch. Defaults to 4096. - batch_stride (int, optional): the stride of each batch. Defaults to 16. The stride is the number of tokens that overlap between batches. - Returns: - str: the summary - """ + self.save_summary( + gen_summaries, + output_file, + ) - logger = logging.getLogger(__name__) - # log all input parameters - if batch_length < 512: - batch_length = 512 - logger.warning("WARNING: batch_length was set to 512") - logger.debug( - f"batch_length: {batch_length} batch_stride: {batch_stride}, kwargs: {kwargs}" - ) - encoded_input = tokenizer( - input_text, - padding="max_length", - truncation=True, - max_length=batch_length, - stride=batch_stride, - return_overflowing_tokens=True, - add_special_tokens=False, - return_tensors="pt", - ) - - in_id_arr, att_arr = encoded_input.input_ids, encoded_input.attention_mask - gen_summaries = [] - - pbar = tqdm(total=len(in_id_arr)) - - for _id, _mask in zip(in_id_arr, att_arr): - - result, score = summarize_and_score( - ids=_id, - mask=_mask, - model=model, - tokenizer=tokenizer, - **kwargs, + return output_file + + def save_params( + self, + output_dir: str or Path = None, + hf_tag: str = None, + verbose: bool = False, + ) -> None: + """ + save_params - save the parameters of the run to a json file + + :param dict params: parameters to save + :param str or Path output_dir: directory to save the parameters to + :param str hf_tag: the model tag on huggingface (will be used instead of self.model_name_or_path) + :param bool verbose: whether to log the parameters + + :return: None + """ + output_dir = Path(output_dir) if output_dir is not None else Path.cwd() + metadata_path = output_dir / "summarization_parameters.json" + + exported_params = self.get_inference_params().copy() + exported_params["META_huggingface_model"] = ( + self.model_name_or_path if hf_tag is None else hf_tag ) - score = round(float(score), 4) - _sum = { - "input_tokens": _id, - "summary": result, - "summary_score": score, - } - gen_summaries.append(_sum) - logger.debug(f"\n\t{result[0]}\nScore:\t{score}") - pbar.update() - - pbar.close() - - return gen_summaries - - -def save_params( - params: dict, - output_dir: str or Path, - hf_tag: str = None, - verbose: bool = False, -) -> None: - """ - save_params - save the parameters of the run to a json file + exported_params["META_date"] = get_timestamp() - :param dict params: parameters to save - :param str or Path output_dir: directory to save the parameters to - :param str hf_tag: the model tag on huggingface - :param bool verbose: whether to log the parameters + self.logger.info(f"Saving parameters to {metadata_path}") + with open(metadata_path, "w") as write_file: + json.dump(exported_params, write_file, indent=4) - :return: None - """ - output_dir = Path(output_dir) if output_dir is not None else Path.cwd() - session_settings = params - session_settings["huggingface-model-tag"] = "" if hf_tag is None else hf_tag - session_settings["date-run"] = get_timestamp() - - metadata_path = output_dir / "summarization-parameters.json" - logging.info(f"Saving parameters to {metadata_path}") - with open(metadata_path, "w") as write_file: - json.dump(session_settings, write_file) - - logging.debug(f"Saved parameters to {metadata_path}") - if verbose: - # log the parameters - logging.info(f"parameters: {session_settings}") + logging.debug(f"Saved parameters to {metadata_path}") + if verbose: + self.logger.info(f"parameters: {exported_params}") diff --git a/src/textsum/utils.py b/src/textsum/utils.py index 73b5ad6..6538e0f 100644 --- a/src/textsum/utils.py +++ b/src/textsum/utils.py @@ -15,15 +15,6 @@ datefmt="%m/%d/%Y %I:%M:%S", ) -# ------------------------- # - -TEXT_EXAMPLE_URLS = { - "whisper_lecture": "https://pastebin.com/raw/X9PEgS2w", - "hf_blog_clip": "https://pastebin.com/raw/1RMg1Naz", -} - -# ------------------------- # - def get_timestamp() -> str: """ @@ -131,95 +122,6 @@ def truncate_word_count(text, max_words=512): return processed -def load_text_examples( - urls: dict = TEXT_EXAMPLE_URLS, target_dir: str or Path = None -) -> Path: - """ - load_text_examples - load the text examples from the web to a directory - - :param dict urls: the urls to the text examples, defaults to TEXT_EXAMPLE_URLS - :param str or Path target_dir: the path to the target directory, defaults to the current working directory - :return Path: the path to the directory containing the text examples - """ - target_dir = Path.cwd() if target_dir is None else Path(target_dir) - target_dir.mkdir(exist_ok=True) - - for name, url in urls.items(): # download the examples - subprocess.run(["wget", url, "-O", target_dir / f"{name}.txt"]) - - return target_dir - - -TEXT_EX_EXTENSIONS = [".txt", ".md"] - - -def load_example_filenames( - example_path: str or Path, ext: list = TEXT_EX_EXTENSIONS -) -> dict: - """ - load_example_filenames - load the example filenames from a directory - - :param strorPath example_path: the path to the examples directory - :param list ext: the file extensions to load (default: [".txt", ".md"]) - :return dict: the example filenames - """ - example_path = Path(example_path) - if not example_path.exists(): - # download the examples - logging.info("Downloading the examples...") - example_path = load_text_examples(target_dir=example_path) - - # load the examples into a list - examples = {f.name: f.resolve() for f in example_path.glob("*") if f.suffix in ext} - logging.info(f"Loaded {len(examples)} examples from {example_path}") - return examples - - -def save_summary( - summarize_output, outpath: str or Path = None, write_scores=True -) -> Path: - """ - - save_summary - save the summary generated from summarize_via_tokenbatches() to a text file - - :param list summarize_output: the output from summarize_via_tokenbatches() - :param strorPath outpath: the path to the output file, defaults to the current working directory - :param bool write_scores: whether to write the scores to the output file, defaults to True - :return Path: the path to the output file - - Example in use: - _summaries = summarize_via_tokenbatches( - text, - batch_length=token_batch_length, - batch_stride=batch_stride, - **settings, - ) - save_summary(_summaries, outpath=outpath, write_scores=True) - """ - - outpath = ( - Path.cwd() / f"document_summary_{get_timestamp()}.txt" - if outpath is None - else Path(outpath) - ) - sum_text = [s["summary"][0] for s in summarize_output] - sum_scores = [f"\n - {round(s['summary_score'],4)}" for s in summarize_output] - scores_text = "\n".join(sum_scores) - full_summary = "\n\t".join(sum_text) - - with open(outpath, "w", encoding="utf-8", errors="ignore") as fo: - fo.writelines(full_summary) - if write_scores: - with open(outpath, "a", encoding="utf-8", errors="ignore") as fo: - - fo.write("\n" * 3) - fo.write(f"\n\nSection Scores:\n") - fo.writelines(scores_text) - fo.write("\n\n---\n") - - return outpath - - def setup_logging(loglevel, logfile=None) -> None: """Setup basic logging you will need something like this in your main script: