diff --git a/docling/cli/main.py b/docling/cli/main.py index a83aecbf..3598d541 100644 --- a/docling/cli/main.py +++ b/docling/cli/main.py @@ -29,6 +29,7 @@ AcceleratorDevice, AcceleratorOptions, EasyOcrOptions, + GoogleOcrOptions, OcrEngine, OcrMacOptions, OcrOptions, @@ -335,6 +336,8 @@ def convert( ocr_options = OcrMacOptions(force_full_page_ocr=force_ocr) elif ocr_engine == OcrEngine.RAPIDOCR: ocr_options = RapidOcrOptions(force_full_page_ocr=force_ocr) + elif ocr_engine == OcrEngine.GOOGLEOCR: + ocr_options = GoogleOcrOptions(force_full_page_ocr=force_ocr) else: raise RuntimeError(f"Unexpected OCR engine type {ocr_engine}") diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index 0f750dbc..fc4bddd5 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -149,6 +149,19 @@ class EasyOcrOptions(OcrOptions): ) +class GoogleOcrOptions(OcrOptions): + """Options for the dense GoogleOcr engine.""" + + kind: Literal["googleocr"] = "googleocr" + lang: List[str] = ["en", "de"] + google_ocr_config_file_path: Optional[str] = os.getenv("GOOGLE_CONFIG_FILE_PATH") + google_ocr_region: Optional[str] = "eu-vision.googleapis.com" + + model_config = ConfigDict( + extra="forbid", + ) + + class TesseractCliOcrOptions(OcrOptions): """Options for the TesseractCli engine.""" @@ -205,6 +218,7 @@ class OcrEngine(str, Enum): TESSERACT = "tesseract" OCRMAC = "ocrmac" RAPIDOCR = "rapidocr" + GOOGLEOCR = "googleocr" class PipelineOptions(BaseModel): @@ -231,6 +245,7 @@ class PdfPipelineOptions(PipelineOptions): TesseractOcrOptions, OcrMacOptions, RapidOcrOptions, + GoogleOcrOptions, ] = Field(EasyOcrOptions(), discriminator="kind") images_scale: float = 1.0 diff --git a/docling/models/google_ocr_model.py b/docling/models/google_ocr_model.py new file mode 100644 index 00000000..a2c6b6cd --- /dev/null +++ b/docling/models/google_ocr_model.py @@ -0,0 +1,180 @@ +import io +import logging +from typing import Iterable + +from docling_core.types.doc import BoundingBox, CoordOrigin + +from docling.datamodel.base_models import Cell, OcrCell, Page +from docling.datamodel.document import ConversionResult +from docling.datamodel.pipeline_options import GoogleOcrOptions +from docling.datamodel.settings import settings +from docling.models.base_ocr_model import BaseOcrModel +from docling.utils.profiling import TimeRecorder + +_log = logging.getLogger(__name__) + + +class GoogleOcrModel(BaseOcrModel): + def __init__(self, enabled: bool, options: GoogleOcrOptions): + super().__init__(enabled=enabled, options=options) + self.options: GoogleOcrOptions + + self.scale = 3 # multiplier for 72 dpi == 216 dpi. + self.reader = None + + if self.enabled: + try: + from google.cloud import vision + from google.oauth2 import service_account + + # Initialize the tesseractAPI + _log.debug("Initializing Google OCR ") + self.image_context = {"language_hints": self.options.lang} + client_options = {"api_endpoint": self.options.google_ocr_region} + if self.options.google_ocr_config_file_path is None: + raise FileNotFoundError( + "Google OCR Config File is missing. Please provide a valid file path " + "via the GOOGLE_CONFIG_FILE_PATH environment variable." + ) + google_creds = service_account.Credentials.from_service_account_file( + self.options.google_ocr_config_file_path + ) + self.reader = vision.ImageAnnotatorClient( + credentials=google_creds, client_options=client_options + ) + + except ImportError: + raise ImportError( + "Failed to import required libraries for Google OCR. Ensure that the " + "'google-cloud-vision' and 'google-auth' packages are installed. " + "You can install them using 'pip install google-cloud-vision google-auth'." + ) + + def __del__(self): + if self.reader is not None: + pass + + def __call__( + self, conv_res: ConversionResult, page_batch: Iterable[Page] + ) -> Iterable[Page]: + + if not self.enabled: + yield from page_batch + return + + for page in page_batch: + + assert page._backend is not None + if not page._backend.is_valid(): + yield page + else: + with TimeRecorder(conv_res, "ocr"): + + assert self.reader is not None + + ocr_rects = self.get_ocr_rects(page) + try: + all_ocr_cells = [] + for ocr_rect in ocr_rects: + # Skip zero area boxes + if ocr_rect.area() == 0: + continue + high_res_image = page._backend.get_page_image( + scale=self.scale, cropbox=ocr_rect + ) + # Convert Pillow image to content, represented as a stream of bytes, using IO buffer. + buffer = io.BytesIO() + try: + from google.cloud import vision + from google.oauth2 import service_account + except: + raise Exception + + high_res_image.save(buffer, "PNG") + content = buffer.getvalue() + + new_image = vision.Image(content=content) + google_response = self.reader.text_detection( + image=new_image, image_context=self.image_context + ) + + cells = [] + ix = 0 + for file_page in google_response.full_text_annotation.pages: + for block in file_page.blocks: + for paragraph in block.paragraphs: + for word in paragraph.words: + box = word.bounding_box.vertices + text = "" + for symbol in word.symbols: + text += symbol.text + + # Extract text within the bounding box + confidence = word.confidence * 100 + left = ( + min( + box[0].x, + box[1].x, + box[2].x, + box[3].x, + ) + / self.scale + ) + ocr_rect.l + bottom = ( + max( + box[0].y, + box[1].y, + box[2].y, + box[3].y, + ) + / self.scale + ) + ocr_rect.t + top = ( + min( + box[0].y, + box[1].y, + box[2].y, + box[3].y, + ) + / self.scale + ) + ocr_rect.t + right = ( + max( + box[0].x, + box[1].x, + box[2].x, + box[3].x, + ) + / self.scale + ) + ocr_rect.l + + cells.append( + OcrCell( + id=ix, + text=text, + confidence=confidence, + bbox=BoundingBox.from_tuple( + coord=( + left, + top, + right, + bottom, + ), + origin=CoordOrigin.TOPLEFT, + ), + ) + ) + ix += 1 + + del high_res_image, buffer, content + all_ocr_cells.extend(cells) + except Exception as e: + raise e + # Post-process the cells + page.cells = self.post_process_cells(all_ocr_cells, page.cells) + + # DEBUG code: + if settings.debug.visualize_ocr: + self.draw_ocr_rects_and_cells(conv_res, page, ocr_rects, show=True) + + yield page diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index 2f8c1421..83f87e7b 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -11,6 +11,7 @@ from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options import ( EasyOcrOptions, + GoogleOcrOptions, OcrMacOptions, PdfPipelineOptions, RapidOcrOptions, @@ -20,6 +21,7 @@ from docling.models.base_ocr_model import BaseOcrModel from docling.models.ds_glm_model import GlmModel, GlmOptions from docling.models.easyocr_model import EasyOcrModel +from docling.models.google_ocr_model import GoogleOcrModel from docling.models.layout_model import LayoutModel from docling.models.ocr_mac_model import OcrMacModel from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions @@ -143,6 +145,11 @@ def get_ocr_model(self) -> Optional[BaseOcrModel]: enabled=self.pipeline_options.do_ocr, options=self.pipeline_options.ocr_options, ) + elif isinstance(self.pipeline_options.ocr_options, GoogleOcrOptions): + return GoogleOcrModel( + enabled=self.pipeline_options.do_ocr, + options=self.pipeline_options.ocr_options, + ) return None def initialize_page(self, conv_res: ConversionResult, page: Page) -> Page: diff --git a/poetry.lock b/poetry.lock index 4d7c9a88..2befe889 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -376,6 +376,17 @@ webencodings = "*" [package.extras] css = ["tinycss2 (>=1.1.0,<1.5)"] +[[package]] +name = "cachetools" +version = "5.5.0" +description = "Extensible memoizing collections and decorators" +optional = false +python-versions = ">=3.7" +files = [ + {file = "cachetools-5.5.0-py3-none-any.whl", hash = "sha256:02134e8439cdc2ffb62023ce1debca2944c3f289d66bb17ead3ab3dede74b292"}, + {file = "cachetools-5.5.0.tar.gz", hash = "sha256:2cc24fb4cbe39633fb7badd9db9ca6295d766d9c2995f245725a46715d050f2a"}, +] + [[package]] name = "certifi" version = "2024.8.30" @@ -1354,6 +1365,102 @@ gitdb = ">=4.0.1,<5" doc = ["sphinx (==4.3.2)", "sphinx-autodoc-typehints", "sphinx-rtd-theme", "sphinxcontrib-applehelp (>=1.0.2,<=1.0.4)", "sphinxcontrib-devhelp (==1.0.2)", "sphinxcontrib-htmlhelp (>=2.0.0,<=2.0.1)", "sphinxcontrib-qthelp (==1.0.3)", "sphinxcontrib-serializinghtml (==1.1.5)"] test = ["coverage[toml]", "ddt (>=1.1.1,!=1.4.3)", "mock", "mypy", "pre-commit", "pytest (>=7.3.1)", "pytest-cov", "pytest-instafail", "pytest-mock", "pytest-sugar", "typing-extensions"] +[[package]] +name = "google-api-core" +version = "2.24.0" +description = "Google API client core library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google_api_core-2.24.0-py3-none-any.whl", hash = "sha256:10d82ac0fca69c82a25b3efdeefccf6f28e02ebb97925a8cce8edbfe379929d9"}, + {file = "google_api_core-2.24.0.tar.gz", hash = "sha256:e255640547a597a4da010876d333208ddac417d60add22b6851a0c66a831fcaf"}, +] + +[package.dependencies] +google-auth = ">=2.14.1,<3.0.dev0" +googleapis-common-protos = ">=1.56.2,<2.0.dev0" +grpcio = [ + {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, +] +grpcio-status = [ + {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, +] +proto-plus = [ + {version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""}, + {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, +] +protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" +requests = ">=2.18.0,<3.0.0.dev0" + +[package.extras] +async-rest = ["google-auth[aiohttp] (>=2.35.0,<3.0.dev0)"] +grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.33.2,<2.0.dev0)", "grpcio-status (>=1.49.1,<2.0.dev0)"] +grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] +grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0.dev0)"] + +[[package]] +name = "google-auth" +version = "2.37.0" +description = "Google Authentication Library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google_auth-2.37.0-py2.py3-none-any.whl", hash = "sha256:42664f18290a6be591be5329a96fe30184be1a1badb7292a7f686a9659de9ca0"}, + {file = "google_auth-2.37.0.tar.gz", hash = "sha256:0054623abf1f9c83492c63d3f47e77f0a544caa3d40b2d98e099a611c2dd5d00"}, +] + +[package.dependencies] +cachetools = ">=2.0.0,<6.0" +pyasn1-modules = ">=0.2.1" +rsa = ">=3.1.4,<5" + +[package.extras] +aiohttp = ["aiohttp (>=3.6.2,<4.0.0.dev0)", "requests (>=2.20.0,<3.0.0.dev0)"] +enterprise-cert = ["cryptography", "pyopenssl"] +pyjwt = ["cryptography (>=38.0.3)", "pyjwt (>=2.0)"] +pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] +reauth = ["pyu2f (>=0.1.5)"] +requests = ["requests (>=2.20.0,<3.0.0.dev0)"] + +[[package]] +name = "google-cloud-vision" +version = "3.9.0" +description = "Google Cloud Vision API client library" +optional = false +python-versions = ">=3.7" +files = [ + {file = "google_cloud_vision-3.9.0-py2.py3-none-any.whl", hash = "sha256:9acec27ee05bd197f0d89c97e9719712ef245e0c37fd428e6af09a15a082fbef"}, + {file = "google_cloud_vision-3.9.0.tar.gz", hash = "sha256:21226aac9cb4ba45bf89cc2e107aea19e4f78f9736eb1de56837e0c2989fecff"}, +] + +[package.dependencies] +google-api-core = {version = ">=1.34.1,<2.0.dev0 || >=2.11.dev0,<3.0.0dev", extras = ["grpc"]} +google-auth = ">=2.14.1,<2.24.0 || >2.24.0,<2.25.0 || >2.25.0,<3.0.0dev" +proto-plus = [ + {version = ">=1.22.3,<2.0.0dev", markers = "python_version < \"3.13\""}, + {version = ">=1.25.0,<2.0.0dev", markers = "python_version >= \"3.13\""}, +] +protobuf = ">=3.20.2,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0dev" + +[[package]] +name = "googleapis-common-protos" +version = "1.66.0" +description = "Common protobufs used in Google APIs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "googleapis_common_protos-1.66.0-py2.py3-none-any.whl", hash = "sha256:d7abcd75fabb2e0ec9f74466401f6c119a0b498e27370e9be4c94cb7e382b8ed"}, + {file = "googleapis_common_protos-1.66.0.tar.gz", hash = "sha256:c3e7b33d15fdca5374cc0a7346dd92ffa847425cc4ea941d970f13680052ec8c"}, +] + +[package.dependencies] +protobuf = ">=3.20.2,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<6.0.0.dev0" + +[package.extras] +grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] + [[package]] name = "griffe" version = "1.5.1" @@ -1449,6 +1556,22 @@ files = [ [package.extras] protobuf = ["grpcio-tools (>=1.67.1)"] +[[package]] +name = "grpcio-status" +version = "1.67.1" +description = "Status proto mapping for gRPC" +optional = false +python-versions = ">=3.8" +files = [ + {file = "grpcio_status-1.67.1-py3-none-any.whl", hash = "sha256:16e6c085950bdacac97c779e6a502ea671232385e6e37f258884d6883392c2bd"}, + {file = "grpcio_status-1.67.1.tar.gz", hash = "sha256:2bf38395e028ceeecfd8866b081f61628114b384da7d51ae064ddc8d766a5d11"}, +] + +[package.dependencies] +googleapis-common-protos = ">=1.5.5" +grpcio = ">=1.67.1" +protobuf = ">=5.26.1,<6.0dev" + [[package]] name = "h11" version = "0.14.0" @@ -4362,6 +4485,23 @@ files = [ {file = "propcache-0.2.1.tar.gz", hash = "sha256:3f77ce728b19cb537714499928fe800c3dda29e8d9428778fc7c186da4c09a64"}, ] +[[package]] +name = "proto-plus" +version = "1.25.0" +description = "Beautiful, Pythonic protocol buffers." +optional = false +python-versions = ">=3.7" +files = [ + {file = "proto_plus-1.25.0-py3-none-any.whl", hash = "sha256:c91fc4a65074ade8e458e95ef8bac34d4008daa7cce4a12d6707066fca648961"}, + {file = "proto_plus-1.25.0.tar.gz", hash = "sha256:fbb17f57f7bd05a68b7707e745e26528b0b3c34e378db91eef93912c54982d91"}, +] + +[package.dependencies] +protobuf = ">=3.19.0,<6.0.0dev" + +[package.extras] +testing = ["google-api-core (>=1.31.5)"] + [[package]] name = "protobuf" version = "5.29.1" @@ -4491,6 +4631,31 @@ files = [ [package.extras] test = ["cffi", "hypothesis", "pandas", "pytest", "pytz"] +[[package]] +name = "pyasn1" +version = "0.6.1" +description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyasn1-0.6.1-py3-none-any.whl", hash = "sha256:0d632f46f2ba09143da3a8afe9e33fb6f92fa2320ab7e886e2d0f7672af84629"}, + {file = "pyasn1-0.6.1.tar.gz", hash = "sha256:6f580d2bdd84365380830acf45550f2511469f673cb4a5ae3857a3170128b034"}, +] + +[[package]] +name = "pyasn1-modules" +version = "0.4.1" +description = "A collection of ASN.1-based protocols modules" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pyasn1_modules-0.4.1-py3-none-any.whl", hash = "sha256:49bfa96b45a292b711e986f222502c1c9a5e1f4e568fc30e2574a6c7d07838fd"}, + {file = "pyasn1_modules-0.4.1.tar.gz", hash = "sha256:c28e2dbf9c06ad61c71a075c7e0f9fd0f1b0bb2d2ad4377f240d33ac2ab60a7c"}, +] + +[package.dependencies] +pyasn1 = ">=0.4.6,<0.7.0" + [[package]] name = "pyclipper" version = "1.3.0.post6" @@ -5842,6 +6007,20 @@ files = [ {file = "rpds_py-0.22.3.tar.gz", hash = "sha256:e32fee8ab45d3c2db6da19a5323bc3362237c8b653c70194414b892fd06a080d"}, ] +[[package]] +name = "rsa" +version = "4.9" +description = "Pure-Python RSA implementation" +optional = false +python-versions = ">=3.6,<4" +files = [ + {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"}, + {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"}, +] + +[package.dependencies] +pyasn1 = ">=0.1.3" + [[package]] name = "rtree" version = "1.3.0" @@ -7608,4 +7787,4 @@ tesserocr = ["tesserocr"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "5271637a86ae221be362a288546c9fee3e3e25e5b323c997464c032c284716bd" +content-hash = "abc7b155a26345198ea2163465eaf67a50d577c26683ad424cfeb4221a3943fa" diff --git a/pyproject.toml b/pyproject.toml index 99a32a1c..52ac4ccf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,10 @@ pydantic-settings = "^2.3.0" huggingface_hub = ">=0.23,<1" requests = "^2.32.3" easyocr = "^1.7" +google-api-core="^2.13.0" +google-auth="^2.23.4" +google-cloud-vision="^3.4.5" +googleapis-common-protos="^1.61.0" tesserocr = { version = "^2.7.1", optional = true } certifi = ">=2024.7.4" rtree = "^1.3.0" diff --git a/tests/test_e2e_ocr_conversion.py b/tests/test_e2e_ocr_conversion.py index 73a943af..6eb6ed98 100644 --- a/tests/test_e2e_ocr_conversion.py +++ b/tests/test_e2e_ocr_conversion.py @@ -7,6 +7,7 @@ from docling.datamodel.document import ConversionResult from docling.datamodel.pipeline_options import ( EasyOcrOptions, + GoogleOcrOptions, OcrMacOptions, OcrOptions, PdfPipelineOptions, @@ -62,6 +63,7 @@ def test_e2e_conversions(): TesseractOcrOptions(force_full_page_ocr=True), TesseractCliOcrOptions(force_full_page_ocr=True), RapidOcrOptions(force_full_page_ocr=True), + GoogleOcrOptions(force_full_page_ocr=True), ] # only works on mac