Skip to content

Commit

Permalink
update: stop using modal.NFS in OCR example
Browse files Browse the repository at this point in the history
  • Loading branch information
thundergolfer committed Jul 28, 2023
1 parent 863a0f8 commit 18bb294
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions 09_job_queues/doc_ocr_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,41 @@
#
# `donut` downloads the weights for pre-trained models to a local directory, if those weights don't already exist.
# To decrease start-up time, we want this download to happen just once, even across separate function invocations.
# To accomplish this, we use a [`NetworkFileSystem`](/docs/guide/shared-volumes), a writable volume that can be attached
# to Modal functions and persisted across function runs.
# To accomplish this, we use the [`Image.run_function`](docs/reference/modal.Image#run_function) method, which allows
# us to run some code at image build time to save the model weights into the image.

volume = modal.NetworkFileSystem.persisted("doc_ocr_model_vol")
CACHE_PATH = "/root/model_cache"
MODEL_NAME = "naver-clova-ix/donut-base-finetuned-cord-v2"


def download_model_weights() -> None:
from huggingface_hub import snapshot_download

snapshot_download(repo_id=MODEL_NAME, cache_dir=CACHE_PATH)


image = (
modal.Image.debian_slim()
.pip_install(
"donut-python==1.0.7",
"huggingface-hub==0.16.4",
"transformers==4.21.3",
"timm==0.5.4",
)
.run_function(download_model_weights)
)

# ## Handler function
#
# Now let's define our handler function. Using the [@stub.function()](https://modal.com/docs/reference/modal.Stub#function)
# decorator, we set up a Modal [Function](/docs/reference/modal.Function) that uses GPUs,
# has a [`NetworkFileSystem`](/docs/guide/shared-volumes) mount, runs on a [custom container image](/docs/guide/custom-container),
# runs on a [custom container image](/docs/guide/custom-container),
# and automatically [retries](/docs/guide/retries#function-retries) failures up to 3 times.


@stub.function(
gpu="any",
image=modal.Image.debian_slim().pip_install(
"donut-python==1.0.7",
"transformers==4.21.3",
"timm==0.5.4",
),
network_file_systems={CACHE_PATH: volume},
image=image,
retries=3,
)
def parse_receipt(image: bytes):
Expand All @@ -67,7 +80,7 @@ def parse_receipt(image: bytes):
# Use donut fine-tuned on an OCR dataset.
task_prompt = "<s_cord-v2>"
pretrained_model = DonutModel.from_pretrained(
"naver-clova-ix/donut-base-finetuned-cord-v2",
MODEL_NAME,
cache_dir=CACHE_PATH,
)

Expand Down

0 comments on commit 18bb294

Please sign in to comment.