Skip to content

Commit

Permalink
Batch processing (#13)
Browse files Browse the repository at this point in the history
this PR contains some more behind-the-scenes improvements related to
ease-of-use and/or batch processing for the `Summarizer` class object:

- disable the progress bar for within-loop summarization of a single
long string
- add a 'smart' `__call__` function that hands off to the text and
filepath processing fns
- small improvements/updates to docs

---------

Signed-off-by: peter szemraj <[email protected]>
  • Loading branch information
pszemraj authored Feb 18, 2024
1 parent d51c4cd commit 82bafca
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 97 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file. Dates are d

Generated by [`auto-changelog`](https://github.com/CookPete/auto-changelog).

#### [v0.2.0](https://github.com/pszemraj/textsum/compare/v0.1.5...v0.2.0)

> 8 July 2023
- Draft: support faster inference methods [`#8`](https://github.com/pszemraj/textsum/pull/8)

#### [v0.1.5](https://github.com/pszemraj/textsum/compare/v0.1.3...v0.1.5)

> 31 January 2023
Expand Down
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,6 @@ summarizer = Summarizer(load_in_8bit=True)

If using the python API, it's better to initiate tf32 yourself; see [here](https://huggingface.co/docs/transformers/perf_train_gpu_one#tf32) for how.

Here are some suggestions for additions to the README in order to reflect the latest changes in the `__init__` method of your `Summarizer` class:

### Using Optimum ONNX Runtime

> ⚠️ **Note:** This feature is experimental and might not work as expected. Use at your own risk. ⚠️🧪
Expand Down Expand Up @@ -324,7 +322,8 @@ See the [CONTRIBUTING.md](CONTRIBUTING.md) file for details on how to contribute
- [x] LLM.int8 inference
- [x] optimum inference integration
- [ ] better documentation [in the wiki](https://github.com/pszemraj/textsum/wiki), details on improving performance (speed, quality, memory usage, etc.)
- [ ] improvements to the PDF OCR helper module
- [x] in-progress
- [ ] improvements to the PDF OCR helper module (_TBD - may focus more on being a summarization tool_)

_Other ideas? Open an issue or PR!_

Expand Down
8 changes: 4 additions & 4 deletions src/textsum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
"""
import sys

from . import summarize, utils

if sys.version_info[:2] >= (3, 8):
# Import directly (no need for conditional) when `python_requires = >= 3.8`
from importlib.metadata import PackageNotFoundError, version # pragma: no cover
from importlib.metadata import PackageNotFoundError # pragma: no cover
from importlib.metadata import version
else:
from importlib_metadata import PackageNotFoundError, version # pragma: no cover
from importlib_metadata import PackageNotFoundError # pragma: no cover
from importlib_metadata import version

try:
# Change here if project is renamed and does not equal the package name
Expand Down
26 changes: 10 additions & 16 deletions src/textsum/app.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,31 @@
"""
app.py - a module to run the text summarization app (gradio interface)
"""

import contextlib
import logging
import os
import random
import re
import time
from pathlib import Path

os.environ["USE_TORCH"] = "1"
os.environ["DEMO_MAX_INPUT_WORDS"] = "2048" # number of words to truncate input to
os.environ["DEMO_MAX_INPUT_PAGES"] = "20" # number of pages to truncate PDFs to
os.environ["TOKENIZERS_PARALLELISM"] = "false" # parallelism is buggy with gradio

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

import gradio as gr
import nltk
from cleantext import clean
from doctr.models import ocr_predictor

from textsum.pdf2text import convert_PDF_to_Text
from textsum.summarize import Summarizer
from textsum.utils import truncate_word_count, get_timestamp
from textsum.utils import get_timestamp, truncate_word_count

os.environ["USE_TORCH"] = "1"
os.environ["DEMO_MAX_INPUT_WORDS"] = "2048" # number of words to truncate input to
os.environ["DEMO_MAX_INPUT_PAGES"] = "20" # number of pages to truncate PDFs to
os.environ["TOKENIZERS_PARALLELISM"] = "false" # parallelism is buggy with gradio

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
_here = Path.cwd()

nltk.download("stopwords") # TODO=find where this requirement originates from
Expand Down Expand Up @@ -214,7 +213,6 @@ def main():

demo = gr.Blocks()
with demo:

gr.Markdown("# Summarization UI with `textsum`")
gr.Markdown(
f"""
Expand All @@ -224,21 +222,18 @@ def main():
"""
)
with gr.Column():

gr.Markdown("## Load Inputs & Select Parameters")
gr.Markdown(
"Enter text below in the text area. The text will be summarized [using the selected parameters](https://huggingface.co/blog/how-to-generate). Optionally load an example below or upload a file. (`.txt` or `.pdf` - _[link to guide](https://i.imgur.com/c6Cs9ly.png)_)"
)
with gr.Row(variant="compact"):
with gr.Column(scale=0.5, variant="compact"):

num_beams = gr.Radio(
choices=[2, 3, 4],
label="Beam Search: # of Beams",
value=2,
)
with gr.Column(variant="compact"):

uploaded_file = gr.File(
label="File Upload",
file_count="single",
Expand All @@ -251,7 +246,6 @@ def main():
placeholder="Enter text to summarize, the text will be cleaned and truncated on Spaces. Narrative, academic (both papers and lecture transcription), and article text work well. May take a bit to generate depending on the input text :)",
)
with gr.Column(min_width=100, scale=0.5):

load_file_button = gr.Button("Upload File")

with gr.Column():
Expand Down
10 changes: 7 additions & 3 deletions src/textsum/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Usage:
textsum-dir --help
"""

import logging
import pprint as pp
import random
Expand Down Expand Up @@ -31,7 +32,7 @@ def main(
batch_length: int = 4096,
batch_stride: int = 16,
num_beams: int = 4,
length_penalty: float = 0.8,
length_penalty: float = 1.0,
repetition_penalty: float = 2.5,
max_length_ratio: float = 0.25,
min_length: int = 8,
Expand All @@ -44,6 +45,7 @@ def main(
logfile: Optional[str] = None,
file_extension: str = "txt",
skip_completed: bool = False,
disable_progress_bar: bool = False,
):
"""
Main function to summarize text files in a directory.
Expand All @@ -61,7 +63,7 @@ def main(
batch_length (int, optional): The length of each batch. Default: 4096.
batch_stride (int, optional): The stride of each batch. Default: 16.
num_beams (int, optional): The number of beams to use for beam search. Default: 4.
length_penalty (float, optional): The length penalty to use for decoding. Default: 0.8.
length_penalty (float, optional): The length penalty to use for decoding. Default: 1.0.
repetition_penalty (float, optional): The repetition penalty to use for beam search. Default: 2.5.
max_length_ratio (float, optional): The maximum length of the summary as a ratio of the batch length. Default: 0.25.
min_length (int, optional): The minimum length of the summary. Default: 8.
Expand All @@ -74,6 +76,7 @@ def main(
logfile (str, optional): Path to the log file. This will set loglevel to INFO (if not set) and write to the file.
file_extension (str, optional): The file extension to use when searching for input files., defaults to "txt"
skip_completed (bool, optional): Skip files that have already been summarized. Default: False.
disable_progress_bar (bool, optional): Disable the progress bar for intra-file summarization batches. Default: False.
Returns:
None
Expand Down Expand Up @@ -107,6 +110,7 @@ def main(
compile_model=compile,
optimum_onnx=optimum_onnx,
force_cache=force_cache,
disable_progress_bar=disable_progress_bar,
**params,
)
summarizer.print_config()
Expand Down Expand Up @@ -142,7 +146,7 @@ def main(
failed_files.append(f)
if isinstance(e, RuntimeError):
# if a runtime error occurs, exit immediately
logging.error("Not continuing summarization due to runtime error")
logging.error("Stopping summarization: runtime error")
failed_files.extend(input_files[input_files.index(f) + 1 :])
break

Expand Down
18 changes: 6 additions & 12 deletions src/textsum/pdf2text.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,6 @@
"""

import logging
from pathlib import Path

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%m/%d/%Y %I:%M:%S",
)


import os
import re
import shutil
Expand All @@ -29,6 +20,12 @@
from doctr.models import ocr_predictor
from spellchecker import SpellChecker

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
datefmt="%m/%d/%Y %I:%M:%S",
)


def simple_rename(filepath, target_ext=".txt"):
_fp = Path(filepath)
Expand Down Expand Up @@ -136,7 +133,6 @@ def clean_OCR(ugly_text: str):


def move2completed(from_dir, filename, new_folder="completed", verbose=False):

# this is the better version
old_filepath = join(from_dir, filename)

Expand Down Expand Up @@ -275,7 +271,6 @@ def cleantxt_ocr(ugly_text, lower=False, lang: str = "en") -> str:


def format_ocr_out(OCR_data):

if isinstance(OCR_data, list):
text = " ".join(OCR_data)
else:
Expand Down Expand Up @@ -322,7 +317,6 @@ def convert_PDF_to_Text(
ocr_model=None,
max_pages: int = 20,
):

st = time.perf_counter()
PDF_file = Path(PDF_file)
ocr_model = ocr_predictor(pretrained=True) if ocr_model is None else ocr_model
Expand Down
Loading

0 comments on commit 82bafca

Please sign in to comment.