diff --git a/README.md b/README.md index d0f6594..3f2397a 100644 --- a/README.md +++ b/README.md @@ -21,16 +21,16 @@ slanet_plus是paddlex内置的SLANet升级版模型,准确率有大幅提升 unitable是来源unitable的transformer模型,精度最高,暂仅支持pytorch推理,支持gpu推理加速,训练权重来源于 [OhMyTable项目](https://github.com/Sanster/OhMyTable) -| 模型类型 | 模型名称 | 推理框架 |模型大小 |推理耗时(单图 60KB)| - |:--------------:|:--------------------------------------:| :------: |:------: |:------: | -| 英文 | `en_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.3M |0.15s | -| 中文 | `ch_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.4M |0.15s | -| slanet_plus 中文 | `slanet-plus.onnx` | onnxruntime |6.8M |0.15s | -| unitable 中文 | `unitable(encoder.pth,decoder.pth)` | pytorch |500M |cpu(6s) gpu-4090(1.5s)| +| `model_type` | 模型名称 | 推理框架 |模型大小 |推理耗时(单图 60KB)| + |:--------------|:--------------------------------------| :------: |:------ |:------ | +| `ppstructure_en` | `en_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.3M |0.15s | +| `ppstructure_zh` | `ch_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.4M |0.15s | +| `slanet_plus` | `slanet-plus.onnx` | onnxruntime |6.8M |0.15s | +| `unitable` | `unitable(encoder.pth,decoder.pth)` | pytorch |500M |cpu(6s) gpu-4090(1.5s)| 模型来源\ -[PaddleOCR 表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/table/README_ch.md) \ -[PaddleX-SlaNetPlus 表格识别](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/module_usage/tutorials/ocr_modules/table_structure_recognition.md) \ +[PaddleOCR 表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/table/README_ch.md)\ +[PaddleX-SlaNetPlus 表格识别](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/module_usage/tutorials/ocr_modules/table_structure_recognition.md)\ [Unitable](https://github.com/poloclub/unitable?tab=readme-ov-file) 模型下载地址为:[link](https://github.com/RapidAI/RapidTable/releases/tag/assets) @@ -41,51 +41,6 @@ unitable是来源unitable的transformer模型,精度最高,暂仅支持pytor Demo -### 更新日志 - -
- -#### 2024.12.30 update - -- 支持Unitable模型的表格识别,使用pytorch框架 - -#### 2024.11.24 update - -- 支持gpu推理,适配 rapidOCR 单字识别匹配,支持逻辑坐标返回及可视化 - -#### 2024.10.13 update - -- 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx) - -#### 2023-12-29 v0.1.3 update - -- 优化可视化结果部分 - -#### 2023-12-27 v0.1.2 update - -- 添加返回cell坐标框参数 -- 完善可视化函数 - -#### 2023-07-17 v0.1.0 update - -- 将`rapidocr_onnxruntime`部分从`rapid_table`中解耦合出来,给出选项是否依赖,更加灵活。 - -- 增加接口输入参数`ocr_result`: - - 如果在调用函数时,事先指定了`ocr_result`参数值,则不会再走OCR。其中`ocr_result`格式需要和`rapidocr_onnxruntime`返回值一致。 - - 如果未指定`ocr_result`参数值,但是事先安装了`rapidocr_onnxruntime`库,则会自动调用该库,进行识别。 - - 如果`ocr_result`未指定,且`rapidocr_onnxruntime`未安装,则会报错。必须满足两个条件中一个。 - -#### 2023-07-10 v0.0.13 updata - -- 更改传入表格还原中OCR的实例接口,可以传入其他OCR实例,前提要与`rapidocr_onnxruntime`接口一致 - -#### 2023-07-06 v0.0.12 update - -- 去掉返回表格的html字符串中的``元素,便于后续统一。 -- 采用Black工具优化代码 - -
- ### 与[TableStructureRec](https://github.com/RapidAI/TableStructureRec)关系 TableStructureRec库是一个表格识别算法的集合库,当前有`wired_table_rec`有线表格识别算法和`lineless_table_rec`无线表格识别算法的推理包。 @@ -105,8 +60,8 @@ RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structu ```bash pip install rapidocr_onnxruntime pip install rapid_table -#pip install rapid_table[torch] # for unitable inference -#pip install onnxruntime-gpu # for onnx gpu inference +# pip install rapid_table[torch] # for unitable inference +# pip install onnxruntime-gpu # for onnx gpu inference ``` ### 使用方式 @@ -130,10 +85,13 @@ from rapidocr_onnxruntime import RapidOCR from rapid_table.table_structure.utils import trans_char_ocr_res table_engine = RapidTable() + # 开启onnx-gpu推理 # table_engine = RapidTable(use_cuda=True) + # 使用torch推理版本的unitable模型 # table_engine = RapidTable(use_cuda=True, device="cuda:0", model_type="unitable") + ocr_engine = RapidOCR() viser = VisTable() @@ -143,6 +101,7 @@ ocr_result, _ = ocr_engine(img_path) # 单字匹配 # ocr_result, _ = ocr_engine(img_path, return_word_box=True) # ocr_result = trans_char_ocr_res(ocr_result) + table_html_str, table_cell_bboxes, elapse = table_engine(img_path, ocr_result) save_dir = Path("./inference_results/") @@ -155,6 +114,7 @@ viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_p # 返回逻辑坐标 # table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_result, return_logic_points=True) + # save_logic_path = save_dir / f"vis_logic_{Path(img_path).name}" # viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path,logic_points, save_logic_path) @@ -163,33 +123,32 @@ print(table_html_str) #### 终端运行 -- 用法: - - ```bash - $ rapid_table -h - usage: rapid_table [-h] [-v] -img IMG_PATH [-m MODEL_PATH] +```bash +$ rapid_table -h +usage: rapid_table [-h] [-v] -img IMG_PATH [-m MODEL_PATH] + +optional arguments: +-h, --help show this help message and exit +-v, --vis Whether to visualize the layout results. +-img IMG_PATH, --img_path IMG_PATH + Path to image for layout. +-m MODEL_PATH, --model_path MODEL_PATH + The model path used for inference. +``` - optional arguments: - -h, --help show this help message and exit - -v, --vis Whether to visualize the layout results. - -img IMG_PATH, --img_path IMG_PATH - Path to image for layout. - -m MODEL_PATH, --model_path MODEL_PATH - The model path used for inference. - ``` +示例: -- 示例: +```bash +rapid_table -v -img test_images/table.jpg +``` - ```bash - rapid_table -v -img test_images/table.jpg - ``` - ### 结果 #### 返回结果 -```html +
+```html @@ -316,9 +275,56 @@ print(table_html_str) ``` + + #### 可视化结果
<>
MethodsFPS
SegLink [26]70.086d>77.08.9
PixelLink [4]73.283.077.8
TextSnake [18]73.983.278.31.1
TextField [37]75.987.481.35.2
MSR[38]76.787.87.481.7
FTSN [3]77.187.682.0
LSE[30]81.784.282.9
CRAFT [2]78.288.282.98.6
MCN[16]798883
ATRR[35]82.185.283.6
PAN [34]83.884.484.130.2
DB[12]79.291.584.932.0
DRRG[41]82.3088.0585.08
Ours (SynText)80.688582.9712.68
Ours (MLT-17)84.5486.6285.5712.31
+ +### 更新日志 + +
+ +#### 2024.12.30 update + +- 支持Unitable模型的表格识别,使用pytorch框架 + +#### 2024.11.24 update + +- 支持gpu推理,适配 rapidOCR 单字识别匹配,支持逻辑坐标返回及可视化 + +#### 2024.10.13 update + +- 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx) + +#### 2023-12-29 v0.1.3 update + +- 优化可视化结果部分 + +#### 2023-12-27 v0.1.2 update + +- 添加返回cell坐标框参数 +- 完善可视化函数 + +#### 2023-07-17 v0.1.0 update + +- 将`rapidocr_onnxruntime`部分从`rapid_table`中解耦合出来,给出选项是否依赖,更加灵活。 + +- 增加接口输入参数`ocr_result`: + - 如果在调用函数时,事先指定了`ocr_result`参数值,则不会再走OCR。其中`ocr_result`格式需要和`rapidocr_onnxruntime`返回值一致。 + - 如果未指定`ocr_result`参数值,但是事先安装了`rapidocr_onnxruntime`库,则会自动调用该库,进行识别。 + - 如果`ocr_result`未指定,且`rapidocr_onnxruntime`未安装,则会报错。必须满足两个条件中一个。 + +#### 2023-07-10 v0.0.13 updata + +- 更改传入表格还原中OCR的实例接口,可以传入其他OCR实例,前提要与`rapidocr_onnxruntime`接口一致 + +#### 2023-07-06 v0.0.12 update + +- 去掉返回表格的html字符串中的``元素,便于后续统一。 +- 采用Black工具优化代码 + +
diff --git a/rapid_table/download_model.py b/rapid_table/download_model.py index 9e9c496..ab65fe3 100644 --- a/rapid_table/download_model.py +++ b/rapid_table/download_model.py @@ -19,11 +19,14 @@ } } + class DownloadModel: cur_dir = PROJECT_DIR @staticmethod - def get_model_path(model_type: str, sub_file_type: str, path: Union[str, Path, None]) -> str: + def get_model_path( + model_type: str, sub_file_type: str, path: Union[str, Path, None] + ) -> str: if path is not None: return path @@ -32,9 +35,7 @@ def get_model_path(model_type: str, sub_file_type: str, path: Union[str, Path, N model_path = DownloadModel.download(model_url) return model_path - logger.info( - "model url is None, using the default download model %s", path - ) + logger.info("model url is None, using the default download model %s", path) return path @classmethod diff --git a/rapid_table/main.py b/rapid_table/main.py index b412790..9567e6c 100644 --- a/rapid_table/main.py +++ b/rapid_table/main.py @@ -13,7 +13,7 @@ import numpy as np from .download_model import DownloadModel -from .params import accept_kwargs_as_dataclass, BaseConfig +from .params import BaseConfig, accept_kwargs_as_dataclass from .table_matcher import TableMatch from .table_structure import TableStructurer, TableStructureUnitable from .utils import LoadImage, VisTable @@ -26,13 +26,21 @@ class RapidTable: def __init__(self, config: BaseConfig): self.model_type = config.model_type self.load_img = LoadImage() + if self.model_type == "unitable": - config.encoder_path = DownloadModel.get_model_path(self.model_type, "encoder", config.encoder_path) - config.decoder_path = DownloadModel.get_model_path(self.model_type, "decoder", config.decoder_path) - config.vocab_path = DownloadModel.get_model_path(self.model_type, "vocab", config.vocab_path) + config.encoder_path = DownloadModel.get_model_path( + self.model_type, "encoder", config.encoder_path + ) + config.decoder_path = DownloadModel.get_model_path( + self.model_type, "decoder", config.decoder_path + ) + config.vocab_path = DownloadModel.get_model_path( + self.model_type, "vocab", config.vocab_path + ) self.table_structure = TableStructureUnitable(asdict(config)) else: self.table_structure = TableStructurer(asdict(config)) + self.table_matcher = TableMatch() try: @@ -44,7 +52,7 @@ def __call__( self, img_content: Union[str, np.ndarray, bytes, Path], ocr_result: List[Union[List[List[float]], str, str]] = None, - return_logic_points = False + return_logic_points=False, ) -> Tuple[str, float]: if self.ocr_engine is None and ocr_result is None: raise ValueError( @@ -65,14 +73,17 @@ def __call__( if self.model_type == "slanet-plus": pred_bboxes = self.adapt_slanet_plus(img, pred_bboxes) pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res) + # 过滤掉占位的bbox mask = ~np.all(pred_bboxes == 0, axis=1) pred_bboxes = pred_bboxes[mask] + # 避免低版本升级后出现问题,默认不打开 if return_logic_points: logic_points = self.table_matcher.decode_logic_points(pred_structures) elapse = time.time() - s return pred_html, pred_bboxes, logic_points, elapse + elapse = time.time() - s return pred_html, pred_bboxes, elapse @@ -93,6 +104,7 @@ def get_boxes_recs( r_boxes.append(box) dt_boxes = np.array(r_boxes) return dt_boxes, rec_res + def adapt_slanet_plus(self, img: np.ndarray, pred_bboxes: np.ndarray) -> np.ndarray: h, w = img.shape[:2] resized = 488 @@ -103,6 +115,7 @@ def adapt_slanet_plus(self, img: np.ndarray, pred_bboxes: np.ndarray) -> np.ndar pred_bboxes[:, 1::2] *= h_ratio return pred_bboxes + def main(): parser = argparse.ArgumentParser() parser.add_argument( diff --git a/rapid_table/params.py b/rapid_table/params.py index c5756aa..cbbaaee 100644 --- a/rapid_table/params.py +++ b/rapid_table/params.py @@ -7,6 +7,7 @@ root_dir = Path(__file__).resolve().parent logger = get_logger("params") + @dataclass class BaseConfig: model_type: str = "slanet-plus" @@ -25,16 +26,20 @@ def wrapper(*args, **kwargs): if len(args) == 2 and isinstance(args[1], cls): # 如果已经传递了 ModelConfig 实例,直接调用函数 return func(*args, **kwargs) - else: - # 提取 cls 中定义的字段 - cls_fields = {field.name for field in fields(cls)} - # 过滤掉未定义的字段 - filtered_kwargs = {k: v for k, v in kwargs.items() if k in cls_fields} - # 发出警告对于未定义的字段 - for k in (kwargs.keys() - cls_fields): - logger.warning(f"Warning: '{k}' is not a valid field in {cls.__name__} and will be ignored.") - # 创建 ModelConfig 实例并调用函数 - config = cls(**filtered_kwargs) - return func(args[0], config=config) + + # 提取 cls 中定义的字段 + cls_fields = {field.name for field in fields(cls)} + # 过滤掉未定义的字段 + filtered_kwargs = {k: v for k, v in kwargs.items() if k in cls_fields} + # 发出警告对于未定义的字段 + for k in kwargs.keys() - cls_fields: + logger.warning( + f"Warning: '{k}' is not a valid field in {cls.__name__} and will be ignored." + ) + # 创建 ModelConfig 实例并调用函数 + config = cls(**filtered_kwargs) + return func(args[0], config=config) + return wrapper - return decorator \ No newline at end of file + + return decorator diff --git a/rapid_table/table_matcher/matcher.py b/rapid_table/table_matcher/matcher.py index bc976ed..00bea6a 100644 --- a/rapid_table/table_matcher/matcher.py +++ b/rapid_table/table_matcher/matcher.py @@ -29,7 +29,7 @@ def __call__(self, pred_structures, pred_bboxes, dt_boxes, rec_res): pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res) return pred_html - def match_result(self, dt_boxes, pred_bboxes, min_iou=0.1 ** 8): + def match_result(self, dt_boxes, pred_bboxes, min_iou=0.1**8): matched = {} for i, gt_box in enumerate(dt_boxes): distances = [] @@ -52,7 +52,8 @@ def match_result(self, dt_boxes, pred_bboxes, min_iou=0.1 ** 8): # must > min_iou if sorted_distances[0][1] >= 1 - min_iou: continue - if distances.index(sorted_distances[0]) not in matched.keys(): + + if distances.index(sorted_distances[0]) not in matched: matched[distances.index(sorted_distances[0])] = [i] else: matched[distances.index(sorted_distances[0])].append(i) @@ -114,6 +115,7 @@ def get_pred_html(self, pred_structures, matched_index, ocr_contents): filter_elements = ["", "", "", ""] end_html = [v for v in end_html if v not in filter_elements] return "".join(end_html), end_html + def decode_logic_points(self, pred_structures): logic_points = [] current_row = 0 @@ -134,22 +136,24 @@ def mark_occupied(row, col, rowspan, colspan): while i < len(pred_structures): token = pred_structures[i] - if token == '': + if token == "": current_col = 0 # 每次遇到 时,重置当前列号 - elif token == '': + elif token == "": current_row += 1 # 行结束,行号增加 - elif token .startswith(''): - if 'colspan=' in pred_structures[j]: - colspan = int(pred_structures[j].split('=')[1].strip('"\'')) - elif 'rowspan=' in pred_structures[j]: - rowspan = int(pred_structures[j].split('=')[1].strip('"\'')) + while j < len(pred_structures) and not pred_structures[ + j + ].startswith(">"): + if "colspan=" in pred_structures[j]: + colspan = int(pred_structures[j].split("=")[1].strip("\"'")) + elif "rowspan=" in pred_structures[j]: + rowspan = int(pred_structures[j].split("=")[1].strip("\"'")) j += 1 # 跳过已经处理过的属性 token diff --git a/rapid_table/table_matcher/utils.py b/rapid_table/table_matcher/utils.py index 3ec8fcc..57a613c 100644 --- a/rapid_table/table_matcher/utils.py +++ b/rapid_table/table_matcher/utils.py @@ -14,6 +14,7 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com +import copy import re @@ -48,7 +49,7 @@ def deal_isolate_span(thead_part): spanStr_in_isolateItem = span_part.group() # 3. merge the span number into the span token format string. if spanStr_in_isolateItem is not None: - corrected_item = "".format(spanStr_in_isolateItem) + corrected_item = f"" corrected_list.append(corrected_item) else: corrected_list.append(None) @@ -243,6 +244,6 @@ def compute_iou(rec1, rec2): # judge if there is an intersect if left_line >= right_line or top_line >= bottom_line: return 0.0 - else: - intersect = (right_line - left_line) * (bottom_line - top_line) - return (intersect / (sum_area - intersect)) * 1.0 + + intersect = (right_line - left_line) * (bottom_line - top_line) + return (intersect / (sum_area - intersect)) * 1.0 diff --git a/rapid_table/table_structure/table_structure.py b/rapid_table/table_structure/table_structure.py index c2509d8..9152603 100644 --- a/rapid_table/table_structure/table_structure.py +++ b/rapid_table/table_structure/table_structure.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import time -from typing import Dict, Any +from typing import Any, Dict import numpy as np diff --git a/rapid_table/table_structure/unitable_modules.py b/rapid_table/table_structure/unitable_modules.py index 367bd57..5b8dac3 100644 --- a/rapid_table/table_structure/unitable_modules.py +++ b/rapid_table/table_structure/unitable_modules.py @@ -1,21 +1,522 @@ -from functools import partial from dataclasses import dataclass +from functools import partial from typing import Optional + import torch import torch.nn as nn from torch import Tensor from torch.nn import functional as F from torch.nn.modules.transformer import _get_activation_fn -TOKEN_WHITE_LIST = [1, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509] +TOKEN_WHITE_LIST = [ + 1, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, + 123, + 124, + 125, + 126, + 127, + 128, + 129, + 130, + 131, + 132, + 133, + 134, + 135, + 136, + 137, + 138, + 139, + 140, + 141, + 142, + 143, + 144, + 145, + 146, + 147, + 148, + 149, + 150, + 151, + 152, + 153, + 154, + 155, + 156, + 157, + 158, + 159, + 160, + 161, + 162, + 163, + 164, + 165, + 166, + 167, + 168, + 169, + 170, + 171, + 172, + 173, + 174, + 175, + 176, + 177, + 178, + 179, + 180, + 181, + 182, + 183, + 184, + 185, + 186, + 187, + 188, + 189, + 190, + 191, + 192, + 193, + 194, + 195, + 196, + 197, + 198, + 199, + 200, + 201, + 202, + 203, + 204, + 205, + 206, + 207, + 208, + 209, + 210, + 211, + 212, + 213, + 214, + 215, + 216, + 217, + 218, + 219, + 220, + 221, + 222, + 223, + 224, + 225, + 226, + 227, + 228, + 229, + 230, + 231, + 232, + 233, + 234, + 235, + 236, + 237, + 238, + 239, + 240, + 241, + 242, + 243, + 244, + 245, + 246, + 247, + 248, + 249, + 250, + 251, + 252, + 253, + 254, + 255, + 256, + 257, + 258, + 259, + 260, + 261, + 262, + 263, + 264, + 265, + 266, + 267, + 268, + 269, + 270, + 271, + 272, + 273, + 274, + 275, + 276, + 277, + 278, + 279, + 280, + 281, + 282, + 283, + 284, + 285, + 286, + 287, + 288, + 289, + 290, + 291, + 292, + 293, + 294, + 295, + 296, + 297, + 298, + 299, + 300, + 301, + 302, + 303, + 304, + 305, + 306, + 307, + 308, + 309, + 310, + 311, + 312, + 313, + 314, + 315, + 316, + 317, + 318, + 319, + 320, + 321, + 322, + 323, + 324, + 325, + 326, + 327, + 328, + 329, + 330, + 331, + 332, + 333, + 334, + 335, + 336, + 337, + 338, + 339, + 340, + 341, + 342, + 343, + 344, + 345, + 346, + 347, + 348, + 349, + 350, + 351, + 352, + 353, + 354, + 355, + 356, + 357, + 358, + 359, + 360, + 361, + 362, + 363, + 364, + 365, + 366, + 367, + 368, + 369, + 370, + 371, + 372, + 373, + 374, + 375, + 376, + 377, + 378, + 379, + 380, + 381, + 382, + 383, + 384, + 385, + 386, + 387, + 388, + 389, + 390, + 391, + 392, + 393, + 394, + 395, + 396, + 397, + 398, + 399, + 400, + 401, + 402, + 403, + 404, + 405, + 406, + 407, + 408, + 409, + 410, + 411, + 412, + 413, + 414, + 415, + 416, + 417, + 418, + 419, + 420, + 421, + 422, + 423, + 424, + 425, + 426, + 427, + 428, + 429, + 430, + 431, + 432, + 433, + 434, + 435, + 436, + 437, + 438, + 439, + 440, + 441, + 442, + 443, + 444, + 445, + 446, + 447, + 448, + 449, + 450, + 451, + 452, + 453, + 454, + 455, + 456, + 457, + 458, + 459, + 460, + 461, + 462, + 463, + 464, + 465, + 466, + 467, + 468, + 469, + 470, + 471, + 472, + 473, + 474, + 475, + 476, + 477, + 478, + 479, + 480, + 481, + 482, + 483, + 484, + 485, + 486, + 487, + 488, + 489, + 490, + 491, + 492, + 493, + 494, + 495, + 496, + 497, + 498, + 499, + 500, + 501, + 502, + 503, + 504, + 505, + 506, + 507, + 508, + 509, +] class ImgLinearBackbone(nn.Module): def __init__( - self, - d_model: int, - patch_size: int, - in_chan: int = 3, + self, + d_model: int, + patch_size: int, + in_chan: int = 3, ) -> None: super().__init__() @@ -34,9 +535,7 @@ def forward(self, x: Tensor) -> Tensor: class Encoder(nn.Module): - def __init__( - self - ) -> None: + def __init__(self) -> None: super().__init__() self.patch_size = 16 @@ -57,11 +556,17 @@ def __init__( batch_first=True, norm_first=self.norm_first, ) - norm_layer=partial(nn.LayerNorm, eps=1e-6) + norm_layer = partial(nn.LayerNorm, eps=1e-6) self.norm = norm_layer(self.d_model) - self.backbone = ImgLinearBackbone(d_model=self.d_model, patch_size=self.patch_size) - self.pos_embed = PositionEmbedding(max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout) - self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=self.n_encoder_layer, enable_nested_tensor=False) + self.backbone = ImgLinearBackbone( + d_model=self.d_model, patch_size=self.patch_size + ) + self.pos_embed = PositionEmbedding( + max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout + ) + self.encoder = nn.TransformerEncoder( + encoder_layer, num_layers=self.n_encoder_layer, enable_nested_tensor=False + ) def forward(self, x: Tensor) -> Tensor: src_feature = self.backbone(x) @@ -89,10 +594,10 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: class TokenEmbedding(nn.Module): def __init__( - self, - vocab_size: int, - d_model: int, - padding_idx: int, + self, + vocab_size: int, + d_model: int, + padding_idx: int, ) -> None: super().__init__() assert vocab_size > 0 @@ -125,20 +630,29 @@ def __post_init__(self): self.intermediate_size = find_multiple(n_hidden, 256) self.head_dim = self.dim // self.n_head + class KVCache(nn.Module): def __init__( - self, - max_batch_size, - max_seq_length, - n_heads, - head_dim, - dtype=torch.bfloat16, - device="cpu", + self, + max_batch_size, + max_seq_length, + n_heads, + head_dim, + dtype=torch.bfloat16, + device="cpu", ): super().__init__() cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) - self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype, device=device), persistent=False) - self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype, device=device), persistent=False) + self.register_buffer( + "k_cache", + torch.zeros(cache_shape, dtype=dtype, device=device), + persistent=False, + ) + self.register_buffer( + "v_cache", + torch.zeros(cache_shape, dtype=dtype, device=device), + persistent=False, + ) def update(self, input_pos, k_val, v_val): # input_pos: [S], k_val: [B, H, S, D] @@ -152,13 +666,11 @@ def update(self, input_pos, k_val, v_val): return k_out[:bs], v_out[:bs] + class GPTFastDecoder(nn.Module): - def __init__( - self - ) -> None: + def __init__(self) -> None: super().__init__() - self.vocab_size = 960 self.padding_idx = 2 self.prefix_token_id = 11 @@ -179,9 +691,17 @@ def __init__( norm_first=self.norm_first, ) self.config = config - self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) - self.token_embed = TokenEmbedding(vocab_size=self.vocab_size, d_model=self.d_model, padding_idx=self.padding_idx) - self.pos_embed = PositionEmbedding(max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout) + self.layers = nn.ModuleList( + TransformerBlock(config) for _ in range(config.n_layer) + ) + self.token_embed = TokenEmbedding( + vocab_size=self.vocab_size, + d_model=self.d_model, + padding_idx=self.padding_idx, + ) + self.pos_embed = PositionEmbedding( + max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout + ) self.generator = nn.Linear(self.d_model, self.vocab_size) self.token_white_list = TOKEN_WHITE_LIST self.mask_cache: Optional[Tensor] = None @@ -193,7 +713,10 @@ def setup_caches(self, max_batch_size, max_seq_length, dtype, device): b.multihead_attn.k_cache = None b.multihead_attn.v_cache = None - if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: + if ( + self.max_seq_length >= max_seq_length + and self.max_batch_size >= max_batch_size + ): return head_dim = self.config.dim // self.config.n_head max_seq_length = find_multiple(max_seq_length, 8) @@ -207,24 +730,23 @@ def setup_caches(self, max_batch_size, max_seq_length, dtype, device): self.config.n_head, head_dim, dtype, - device + device, ) b.multihead_attn.k_cache = None b.multihead_attn.v_cache = None - self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)).to(device) - - def forward( - self, - memory: Tensor, - tgt: Tensor - ) -> Tensor: + self.causal_mask = torch.tril( + torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool) + ).to(device) + def forward(self, memory: Tensor, tgt: Tensor) -> Tensor: input_pos = torch.tensor([tgt.shape[1] - 1], device=tgt.device, dtype=torch.int) tgt = tgt[:, -1:] tgt_feature = self.pos_embed(self.token_embed(tgt), input_pos=input_pos) # tgt = self.decoder(tgt_feature, memory, input_pos) - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): + with torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=False, enable_math=True + ): logits = tgt_feature tgt_mask = self.causal_mask[None, None, input_pos] for i, layer in enumerate(self.layers): @@ -238,6 +760,7 @@ def forward( _, next_tokens = probs.topk(1) return next_tokens + class TransformerBlock(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() @@ -260,11 +783,11 @@ def __init__(self, config: ModelArgs) -> None: self.activation = _get_activation_fn(config.activation) def forward( - self, - tgt: Tensor, - memory: Tensor, - tgt_mask: Tensor, - input_pos: Tensor, + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Tensor, + input_pos: Tensor, ) -> Tensor: if self.norm_first: x = tgt @@ -299,10 +822,10 @@ def __init__(self, config: ModelArgs): self.dim = config.dim def forward( - self, - x: Tensor, - mask: Tensor, - input_pos: Optional[Tensor] = None, + self, + x: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, ) -> Tensor: bsz, seqlen, _ = x.shape @@ -362,9 +885,9 @@ def get_kv(self, xa: torch.Tensor): return k, v def forward( - self, - x: Tensor, - xa: Tensor, + self, + x: Tensor, + xa: Tensor, ): q = self.query(x) batch_size, target_seq_len, _ = q.shape @@ -383,6 +906,6 @@ def forward( batch_size, target_seq_len, self.n_head * self.head_dim, - ) + ) return self.out(wv) diff --git a/rapid_table/table_structure/utils.py b/rapid_table/table_structure/utils.py index c88cf3c..e06f610 100644 --- a/rapid_table/table_structure/utils.py +++ b/rapid_table/table_structure/utils.py @@ -17,12 +17,11 @@ import os import platform import traceback - -import cv2 from enum import Enum from pathlib import Path from typing import Any, Dict, List, Tuple, Union +import cv2 import numpy as np from onnxruntime import ( GraphOptimizationLevel, @@ -79,10 +78,12 @@ def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions: sess_opt.inter_op_num_threads = inter_op_num_threads return sess_opt + def get_metadata(self, key: str = "character") -> list: meta_dict = self.session.get_modelmeta().custom_metadata_map content_list = meta_dict[key].splitlines() return content_list + def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]: cpu_provider_opts = { "arena_extend_strategy": "kSameAsRequested", diff --git a/rapid_table/utils.py b/rapid_table/utils.py index 77224ce..82d7cef 100644 --- a/rapid_table/utils.py +++ b/rapid_table/utils.py @@ -4,7 +4,7 @@ import os from io import BytesIO from pathlib import Path -from typing import Optional, Union, List +from typing import List, Optional, Union import cv2 import numpy as np @@ -14,9 +14,7 @@ class LoadImage: - def __init__( - self, - ): + def __init__(self): pass def __call__(self, img: InputType) -> np.ndarray: @@ -79,20 +77,18 @@ class LoadImageError(Exception): class VisTable: - def __init__( - self, - ): + def __init__(self): self.load_img = LoadImage() def __call__( - self, - img_path: Union[str, Path], - table_html_str: str, - save_html_path: Optional[str] = None, - table_cell_bboxes: Optional[np.ndarray] = None, - save_drawed_path: Optional[str] = None, - logic_points: List[List[float]] = None, - save_logic_path: Optional[str] = None, + self, + img_path: Union[str, Path], + table_html_str: str, + save_html_path: Optional[str] = None, + table_cell_bboxes: Optional[np.ndarray] = None, + save_drawed_path: Optional[str] = None, + logic_points: List[List[float]] = None, + save_logic_path: Optional[str] = None, ) -> None: if save_html_path: html_with_border = self.insert_border_style(table_html_str) @@ -114,8 +110,10 @@ def __call__( if save_drawed_path: self.save_img(save_drawed_path, drawed_img) if save_logic_path and logic_points: - polygons = [[box[0],box[1], box[4], box[5]] for box in table_cell_bboxes] - self.plot_rec_box_with_logic_info(img_path, save_logic_path, logic_points, polygons) + polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes] + self.plot_rec_box_with_logic_info( + img_path, save_logic_path, logic_points, polygons + ) return drawed_img def insert_border_style(self, table_html_str: str): @@ -137,7 +135,9 @@ def insert_border_style(self, table_html_str: str): html_with_border = f"{prefix_table}{style_res}{suffix_table}" return html_with_border - def plot_rec_box_with_logic_info(self, img_path, output_path, logic_points, sorted_polygons): + def plot_rec_box_with_logic_info( + self, img_path, output_path, logic_points, sorted_polygons + ): """ :param img_path :param output_path @@ -208,4 +208,3 @@ def save_img(save_path: Union[str, Path], img: np.ndarray): def save_html(save_path: Union[str, Path], html: str): with open(save_path, "w", encoding="utf-8") as f: f.write(html) - diff --git a/setup.py b/setup.py index 0489760..8dc2e60 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ def get_readme(): MODULE_NAME = "rapid_table" obtainer = GetPyPiLatestVersion() latest_version = obtainer(MODULE_NAME) -VERSION_NUM = obtainer.version_add_one(latest_version) +VERSION_NUM = obtainer.version_add_one(latest_version, add_patch=True) if len(sys.argv) > 2: match_str = " ".join(sys.argv[2:])