diff --git a/README.md b/README.md
index d0f6594..3f2397a 100644
--- a/README.md
+++ b/README.md
@@ -21,16 +21,16 @@ slanet_plus是paddlex内置的SLANet升级版模型,准确率有大幅提升
unitable是来源unitable的transformer模型,精度最高,暂仅支持pytorch推理,支持gpu推理加速,训练权重来源于 [OhMyTable项目](https://github.com/Sanster/OhMyTable)
-| 模型类型 | 模型名称 | 推理框架 |模型大小 |推理耗时(单图 60KB)|
- |:--------------:|:--------------------------------------:| :------: |:------: |:------: |
-| 英文 | `en_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.3M |0.15s |
-| 中文 | `ch_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.4M |0.15s |
-| slanet_plus 中文 | `slanet-plus.onnx` | onnxruntime |6.8M |0.15s |
-| unitable 中文 | `unitable(encoder.pth,decoder.pth)` | pytorch |500M |cpu(6s) gpu-4090(1.5s)|
+| `model_type` | 模型名称 | 推理框架 |模型大小 |推理耗时(单图 60KB)|
+ |:--------------|:--------------------------------------| :------: |:------ |:------ |
+| `ppstructure_en` | `en_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.3M |0.15s |
+| `ppstructure_zh` | `ch_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.4M |0.15s |
+| `slanet_plus` | `slanet-plus.onnx` | onnxruntime |6.8M |0.15s |
+| `unitable` | `unitable(encoder.pth,decoder.pth)` | pytorch |500M |cpu(6s) gpu-4090(1.5s)|
模型来源\
-[PaddleOCR 表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/table/README_ch.md) \
-[PaddleX-SlaNetPlus 表格识别](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/module_usage/tutorials/ocr_modules/table_structure_recognition.md) \
+[PaddleOCR 表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/table/README_ch.md)\
+[PaddleX-SlaNetPlus 表格识别](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/module_usage/tutorials/ocr_modules/table_structure_recognition.md)\
[Unitable](https://github.com/poloclub/unitable?tab=readme-ov-file)
模型下载地址为:[link](https://github.com/RapidAI/RapidTable/releases/tag/assets)
@@ -41,51 +41,6 @@ unitable是来源unitable的transformer模型,精度最高,暂仅支持pytor
-### 更新日志
-
-
-
-#### 2024.12.30 update
-
-- 支持Unitable模型的表格识别,使用pytorch框架
-
-#### 2024.11.24 update
-
-- 支持gpu推理,适配 rapidOCR 单字识别匹配,支持逻辑坐标返回及可视化
-
-#### 2024.10.13 update
-
-- 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx)
-
-#### 2023-12-29 v0.1.3 update
-
-- 优化可视化结果部分
-
-#### 2023-12-27 v0.1.2 update
-
-- 添加返回cell坐标框参数
-- 完善可视化函数
-
-#### 2023-07-17 v0.1.0 update
-
-- 将`rapidocr_onnxruntime`部分从`rapid_table`中解耦合出来,给出选项是否依赖,更加灵活。
-
-- 增加接口输入参数`ocr_result`:
- - 如果在调用函数时,事先指定了`ocr_result`参数值,则不会再走OCR。其中`ocr_result`格式需要和`rapidocr_onnxruntime`返回值一致。
- - 如果未指定`ocr_result`参数值,但是事先安装了`rapidocr_onnxruntime`库,则会自动调用该库,进行识别。
- - 如果`ocr_result`未指定,且`rapidocr_onnxruntime`未安装,则会报错。必须满足两个条件中一个。
-
-#### 2023-07-10 v0.0.13 updata
-
-- 更改传入表格还原中OCR的实例接口,可以传入其他OCR实例,前提要与`rapidocr_onnxruntime`接口一致
-
-#### 2023-07-06 v0.0.12 update
-
-- 去掉返回表格的html字符串中的``元素,便于后续统一。
-- 采用Black工具优化代码
-
-
-
### 与[TableStructureRec](https://github.com/RapidAI/TableStructureRec)关系
TableStructureRec库是一个表格识别算法的集合库,当前有`wired_table_rec`有线表格识别算法和`lineless_table_rec`无线表格识别算法的推理包。
@@ -105,8 +60,8 @@ RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structu
```bash
pip install rapidocr_onnxruntime
pip install rapid_table
-#pip install rapid_table[torch] # for unitable inference
-#pip install onnxruntime-gpu # for onnx gpu inference
+# pip install rapid_table[torch] # for unitable inference
+# pip install onnxruntime-gpu # for onnx gpu inference
```
### 使用方式
@@ -130,10 +85,13 @@ from rapidocr_onnxruntime import RapidOCR
from rapid_table.table_structure.utils import trans_char_ocr_res
table_engine = RapidTable()
+
# 开启onnx-gpu推理
# table_engine = RapidTable(use_cuda=True)
+
# 使用torch推理版本的unitable模型
# table_engine = RapidTable(use_cuda=True, device="cuda:0", model_type="unitable")
+
ocr_engine = RapidOCR()
viser = VisTable()
@@ -143,6 +101,7 @@ ocr_result, _ = ocr_engine(img_path)
# 单字匹配
# ocr_result, _ = ocr_engine(img_path, return_word_box=True)
# ocr_result = trans_char_ocr_res(ocr_result)
+
table_html_str, table_cell_bboxes, elapse = table_engine(img_path, ocr_result)
save_dir = Path("./inference_results/")
@@ -155,6 +114,7 @@ viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_p
# 返回逻辑坐标
# table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_result, return_logic_points=True)
+
# save_logic_path = save_dir / f"vis_logic_{Path(img_path).name}"
# viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path,logic_points, save_logic_path)
@@ -163,33 +123,32 @@ print(table_html_str)
#### 终端运行
-- 用法:
-
- ```bash
- $ rapid_table -h
- usage: rapid_table [-h] [-v] -img IMG_PATH [-m MODEL_PATH]
+```bash
+$ rapid_table -h
+usage: rapid_table [-h] [-v] -img IMG_PATH [-m MODEL_PATH]
+
+optional arguments:
+-h, --help show this help message and exit
+-v, --vis Whether to visualize the layout results.
+-img IMG_PATH, --img_path IMG_PATH
+ Path to image for layout.
+-m MODEL_PATH, --model_path MODEL_PATH
+ The model path used for inference.
+```
- optional arguments:
- -h, --help show this help message and exit
- -v, --vis Whether to visualize the layout results.
- -img IMG_PATH, --img_path IMG_PATH
- Path to image for layout.
- -m MODEL_PATH, --model_path MODEL_PATH
- The model path used for inference.
- ```
+示例:
-- 示例:
+```bash
+rapid_table -v -img test_images/table.jpg
+```
- ```bash
- rapid_table -v -img test_images/table.jpg
- ```
-
### 结果
#### 返回结果
-```html
+
+```html
@@ -316,9 +275,56 @@ print(table_html_str)
```
+
+
#### 可视化结果
Methods | | | | FPS |
SegLink [26] | 70.0 | 86d> | 77.0 | 8.9 |
PixelLink [4] | 73.2 | 83.0 | 77.8 | |
TextSnake [18] | 73.9 | 83.2 | 78.3 | 1.1 |
TextField [37] | 75.9 | 87.4 | 81.3 | 5.2 |
MSR[38] | 76.7 | 87.87.4 | 81.7 | |
FTSN [3] | 77.1 | 87.6 | 82.0 | |
LSE[30] | 81.7 | 84.2 | 82.9 | <>
CRAFT [2] | 78.2 | 88.2 | 82.9 | 8.6 |
MCN[16] | 79 | 88 | 83 | |
ATRR>[35] | 82.1 | 85.2 | 83.6 | |
PAN [34] | 83.8 | 84.4 | 84.1 | 30.2 |
DB[12] | 79.2 | 91.5 | 84.9 | 32.0 |
DRRG[41] | 82.30 | 88.05 | 85.08 | |
Ours (SynText) | 80.68 | 8582.97 | 12.68 | |
Ours (MLT-17) | 84.54 | 86.62 | 85.57 | 12.31 |
+
+### 更新日志
+
+
+
+#### 2024.12.30 update
+
+- 支持Unitable模型的表格识别,使用pytorch框架
+
+#### 2024.11.24 update
+
+- 支持gpu推理,适配 rapidOCR 单字识别匹配,支持逻辑坐标返回及可视化
+
+#### 2024.10.13 update
+
+- 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx)
+
+#### 2023-12-29 v0.1.3 update
+
+- 优化可视化结果部分
+
+#### 2023-12-27 v0.1.2 update
+
+- 添加返回cell坐标框参数
+- 完善可视化函数
+
+#### 2023-07-17 v0.1.0 update
+
+- 将`rapidocr_onnxruntime`部分从`rapid_table`中解耦合出来,给出选项是否依赖,更加灵活。
+
+- 增加接口输入参数`ocr_result`:
+ - 如果在调用函数时,事先指定了`ocr_result`参数值,则不会再走OCR。其中`ocr_result`格式需要和`rapidocr_onnxruntime`返回值一致。
+ - 如果未指定`ocr_result`参数值,但是事先安装了`rapidocr_onnxruntime`库,则会自动调用该库,进行识别。
+ - 如果`ocr_result`未指定,且`rapidocr_onnxruntime`未安装,则会报错。必须满足两个条件中一个。
+
+#### 2023-07-10 v0.0.13 updata
+
+- 更改传入表格还原中OCR的实例接口,可以传入其他OCR实例,前提要与`rapidocr_onnxruntime`接口一致
+
+#### 2023-07-06 v0.0.12 update
+
+- 去掉返回表格的html字符串中的``元素,便于后续统一。
+- 采用Black工具优化代码
+
+
diff --git a/rapid_table/download_model.py b/rapid_table/download_model.py
index 9e9c496..ab65fe3 100644
--- a/rapid_table/download_model.py
+++ b/rapid_table/download_model.py
@@ -19,11 +19,14 @@
}
}
+
class DownloadModel:
cur_dir = PROJECT_DIR
@staticmethod
- def get_model_path(model_type: str, sub_file_type: str, path: Union[str, Path, None]) -> str:
+ def get_model_path(
+ model_type: str, sub_file_type: str, path: Union[str, Path, None]
+ ) -> str:
if path is not None:
return path
@@ -32,9 +35,7 @@ def get_model_path(model_type: str, sub_file_type: str, path: Union[str, Path, N
model_path = DownloadModel.download(model_url)
return model_path
- logger.info(
- "model url is None, using the default download model %s", path
- )
+ logger.info("model url is None, using the default download model %s", path)
return path
@classmethod
diff --git a/rapid_table/main.py b/rapid_table/main.py
index b412790..9567e6c 100644
--- a/rapid_table/main.py
+++ b/rapid_table/main.py
@@ -13,7 +13,7 @@
import numpy as np
from .download_model import DownloadModel
-from .params import accept_kwargs_as_dataclass, BaseConfig
+from .params import BaseConfig, accept_kwargs_as_dataclass
from .table_matcher import TableMatch
from .table_structure import TableStructurer, TableStructureUnitable
from .utils import LoadImage, VisTable
@@ -26,13 +26,21 @@ class RapidTable:
def __init__(self, config: BaseConfig):
self.model_type = config.model_type
self.load_img = LoadImage()
+
if self.model_type == "unitable":
- config.encoder_path = DownloadModel.get_model_path(self.model_type, "encoder", config.encoder_path)
- config.decoder_path = DownloadModel.get_model_path(self.model_type, "decoder", config.decoder_path)
- config.vocab_path = DownloadModel.get_model_path(self.model_type, "vocab", config.vocab_path)
+ config.encoder_path = DownloadModel.get_model_path(
+ self.model_type, "encoder", config.encoder_path
+ )
+ config.decoder_path = DownloadModel.get_model_path(
+ self.model_type, "decoder", config.decoder_path
+ )
+ config.vocab_path = DownloadModel.get_model_path(
+ self.model_type, "vocab", config.vocab_path
+ )
self.table_structure = TableStructureUnitable(asdict(config))
else:
self.table_structure = TableStructurer(asdict(config))
+
self.table_matcher = TableMatch()
try:
@@ -44,7 +52,7 @@ def __call__(
self,
img_content: Union[str, np.ndarray, bytes, Path],
ocr_result: List[Union[List[List[float]], str, str]] = None,
- return_logic_points = False
+ return_logic_points=False,
) -> Tuple[str, float]:
if self.ocr_engine is None and ocr_result is None:
raise ValueError(
@@ -65,14 +73,17 @@ def __call__(
if self.model_type == "slanet-plus":
pred_bboxes = self.adapt_slanet_plus(img, pred_bboxes)
pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res)
+
# 过滤掉占位的bbox
mask = ~np.all(pred_bboxes == 0, axis=1)
pred_bboxes = pred_bboxes[mask]
+
# 避免低版本升级后出现问题,默认不打开
if return_logic_points:
logic_points = self.table_matcher.decode_logic_points(pred_structures)
elapse = time.time() - s
return pred_html, pred_bboxes, logic_points, elapse
+
elapse = time.time() - s
return pred_html, pred_bboxes, elapse
@@ -93,6 +104,7 @@ def get_boxes_recs(
r_boxes.append(box)
dt_boxes = np.array(r_boxes)
return dt_boxes, rec_res
+
def adapt_slanet_plus(self, img: np.ndarray, pred_bboxes: np.ndarray) -> np.ndarray:
h, w = img.shape[:2]
resized = 488
@@ -103,6 +115,7 @@ def adapt_slanet_plus(self, img: np.ndarray, pred_bboxes: np.ndarray) -> np.ndar
pred_bboxes[:, 1::2] *= h_ratio
return pred_bboxes
+
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
diff --git a/rapid_table/params.py b/rapid_table/params.py
index c5756aa..cbbaaee 100644
--- a/rapid_table/params.py
+++ b/rapid_table/params.py
@@ -7,6 +7,7 @@
root_dir = Path(__file__).resolve().parent
logger = get_logger("params")
+
@dataclass
class BaseConfig:
model_type: str = "slanet-plus"
@@ -25,16 +26,20 @@ def wrapper(*args, **kwargs):
if len(args) == 2 and isinstance(args[1], cls):
# 如果已经传递了 ModelConfig 实例,直接调用函数
return func(*args, **kwargs)
- else:
- # 提取 cls 中定义的字段
- cls_fields = {field.name for field in fields(cls)}
- # 过滤掉未定义的字段
- filtered_kwargs = {k: v for k, v in kwargs.items() if k in cls_fields}
- # 发出警告对于未定义的字段
- for k in (kwargs.keys() - cls_fields):
- logger.warning(f"Warning: '{k}' is not a valid field in {cls.__name__} and will be ignored.")
- # 创建 ModelConfig 实例并调用函数
- config = cls(**filtered_kwargs)
- return func(args[0], config=config)
+
+ # 提取 cls 中定义的字段
+ cls_fields = {field.name for field in fields(cls)}
+ # 过滤掉未定义的字段
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in cls_fields}
+ # 发出警告对于未定义的字段
+ for k in kwargs.keys() - cls_fields:
+ logger.warning(
+ f"Warning: '{k}' is not a valid field in {cls.__name__} and will be ignored."
+ )
+ # 创建 ModelConfig 实例并调用函数
+ config = cls(**filtered_kwargs)
+ return func(args[0], config=config)
+
return wrapper
- return decorator
\ No newline at end of file
+
+ return decorator
diff --git a/rapid_table/table_matcher/matcher.py b/rapid_table/table_matcher/matcher.py
index bc976ed..00bea6a 100644
--- a/rapid_table/table_matcher/matcher.py
+++ b/rapid_table/table_matcher/matcher.py
@@ -29,7 +29,7 @@ def __call__(self, pred_structures, pred_bboxes, dt_boxes, rec_res):
pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
return pred_html
- def match_result(self, dt_boxes, pred_bboxes, min_iou=0.1 ** 8):
+ def match_result(self, dt_boxes, pred_bboxes, min_iou=0.1**8):
matched = {}
for i, gt_box in enumerate(dt_boxes):
distances = []
@@ -52,7 +52,8 @@ def match_result(self, dt_boxes, pred_bboxes, min_iou=0.1 ** 8):
# must > min_iou
if sorted_distances[0][1] >= 1 - min_iou:
continue
- if distances.index(sorted_distances[0]) not in matched.keys():
+
+ if distances.index(sorted_distances[0]) not in matched:
matched[distances.index(sorted_distances[0])] = [i]
else:
matched[distances.index(sorted_distances[0])].append(i)
@@ -114,6 +115,7 @@ def get_pred_html(self, pred_structures, matched_index, ocr_contents):
filter_elements = ["", "", "", ""]
end_html = [v for v in end_html if v not in filter_elements]
return "".join(end_html), end_html
+
def decode_logic_points(self, pred_structures):
logic_points = []
current_row = 0
@@ -134,22 +136,24 @@ def mark_occupied(row, col, rowspan, colspan):
while i < len(pred_structures):
token = pred_structures[i]
- if token == '':
+ if token == "
":
current_col = 0 # 每次遇到
时,重置当前列号
- elif token == '
':
+ elif token == "":
current_row += 1 # 行结束,行号增加
- elif token .startswith(' | ':
+ if token != " | ":
j += 1
# 提取 colspan 和 rowspan 属性
- while j < len(pred_structures) and not pred_structures[j].startswith('>'):
- if 'colspan=' in pred_structures[j]:
- colspan = int(pred_structures[j].split('=')[1].strip('"\''))
- elif 'rowspan=' in pred_structures[j]:
- rowspan = int(pred_structures[j].split('=')[1].strip('"\''))
+ while j < len(pred_structures) and not pred_structures[
+ j
+ ].startswith(">"):
+ if "colspan=" in pred_structures[j]:
+ colspan = int(pred_structures[j].split("=")[1].strip("\"'"))
+ elif "rowspan=" in pred_structures[j]:
+ rowspan = int(pred_structures[j].split("=")[1].strip("\"'"))
j += 1
# 跳过已经处理过的属性 token
diff --git a/rapid_table/table_matcher/utils.py b/rapid_table/table_matcher/utils.py
index 3ec8fcc..57a613c 100644
--- a/rapid_table/table_matcher/utils.py
+++ b/rapid_table/table_matcher/utils.py
@@ -14,6 +14,7 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
+import copy
import re
@@ -48,7 +49,7 @@ def deal_isolate_span(thead_part):
spanStr_in_isolateItem = span_part.group()
# 3. merge the span number into the span token format string.
if spanStr_in_isolateItem is not None:
- corrected_item = " | ".format(spanStr_in_isolateItem)
+ corrected_item = f" | "
corrected_list.append(corrected_item)
else:
corrected_list.append(None)
@@ -243,6 +244,6 @@ def compute_iou(rec1, rec2):
# judge if there is an intersect
if left_line >= right_line or top_line >= bottom_line:
return 0.0
- else:
- intersect = (right_line - left_line) * (bottom_line - top_line)
- return (intersect / (sum_area - intersect)) * 1.0
+
+ intersect = (right_line - left_line) * (bottom_line - top_line)
+ return (intersect / (sum_area - intersect)) * 1.0
diff --git a/rapid_table/table_structure/table_structure.py b/rapid_table/table_structure/table_structure.py
index c2509d8..9152603 100644
--- a/rapid_table/table_structure/table_structure.py
+++ b/rapid_table/table_structure/table_structure.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
-from typing import Dict, Any
+from typing import Any, Dict
import numpy as np
diff --git a/rapid_table/table_structure/unitable_modules.py b/rapid_table/table_structure/unitable_modules.py
index 367bd57..5b8dac3 100644
--- a/rapid_table/table_structure/unitable_modules.py
+++ b/rapid_table/table_structure/unitable_modules.py
@@ -1,21 +1,522 @@
-from functools import partial
from dataclasses import dataclass
+from functools import partial
from typing import Optional
+
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from torch.nn.modules.transformer import _get_activation_fn
-TOKEN_WHITE_LIST = [1, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, 506, 507, 508, 509]
+TOKEN_WHITE_LIST = [
+ 1,
+ 12,
+ 13,
+ 14,
+ 15,
+ 16,
+ 17,
+ 18,
+ 19,
+ 20,
+ 21,
+ 22,
+ 23,
+ 24,
+ 25,
+ 26,
+ 27,
+ 28,
+ 29,
+ 30,
+ 31,
+ 32,
+ 33,
+ 34,
+ 35,
+ 36,
+ 37,
+ 38,
+ 39,
+ 40,
+ 41,
+ 42,
+ 43,
+ 44,
+ 45,
+ 46,
+ 47,
+ 48,
+ 49,
+ 50,
+ 51,
+ 52,
+ 53,
+ 54,
+ 55,
+ 56,
+ 57,
+ 58,
+ 59,
+ 60,
+ 61,
+ 62,
+ 63,
+ 64,
+ 65,
+ 66,
+ 67,
+ 68,
+ 69,
+ 70,
+ 71,
+ 72,
+ 73,
+ 74,
+ 75,
+ 76,
+ 77,
+ 78,
+ 79,
+ 80,
+ 81,
+ 82,
+ 83,
+ 84,
+ 85,
+ 86,
+ 87,
+ 88,
+ 89,
+ 90,
+ 91,
+ 92,
+ 93,
+ 94,
+ 95,
+ 96,
+ 97,
+ 98,
+ 99,
+ 100,
+ 101,
+ 102,
+ 103,
+ 104,
+ 105,
+ 106,
+ 107,
+ 108,
+ 109,
+ 110,
+ 111,
+ 112,
+ 113,
+ 114,
+ 115,
+ 116,
+ 117,
+ 118,
+ 119,
+ 120,
+ 121,
+ 122,
+ 123,
+ 124,
+ 125,
+ 126,
+ 127,
+ 128,
+ 129,
+ 130,
+ 131,
+ 132,
+ 133,
+ 134,
+ 135,
+ 136,
+ 137,
+ 138,
+ 139,
+ 140,
+ 141,
+ 142,
+ 143,
+ 144,
+ 145,
+ 146,
+ 147,
+ 148,
+ 149,
+ 150,
+ 151,
+ 152,
+ 153,
+ 154,
+ 155,
+ 156,
+ 157,
+ 158,
+ 159,
+ 160,
+ 161,
+ 162,
+ 163,
+ 164,
+ 165,
+ 166,
+ 167,
+ 168,
+ 169,
+ 170,
+ 171,
+ 172,
+ 173,
+ 174,
+ 175,
+ 176,
+ 177,
+ 178,
+ 179,
+ 180,
+ 181,
+ 182,
+ 183,
+ 184,
+ 185,
+ 186,
+ 187,
+ 188,
+ 189,
+ 190,
+ 191,
+ 192,
+ 193,
+ 194,
+ 195,
+ 196,
+ 197,
+ 198,
+ 199,
+ 200,
+ 201,
+ 202,
+ 203,
+ 204,
+ 205,
+ 206,
+ 207,
+ 208,
+ 209,
+ 210,
+ 211,
+ 212,
+ 213,
+ 214,
+ 215,
+ 216,
+ 217,
+ 218,
+ 219,
+ 220,
+ 221,
+ 222,
+ 223,
+ 224,
+ 225,
+ 226,
+ 227,
+ 228,
+ 229,
+ 230,
+ 231,
+ 232,
+ 233,
+ 234,
+ 235,
+ 236,
+ 237,
+ 238,
+ 239,
+ 240,
+ 241,
+ 242,
+ 243,
+ 244,
+ 245,
+ 246,
+ 247,
+ 248,
+ 249,
+ 250,
+ 251,
+ 252,
+ 253,
+ 254,
+ 255,
+ 256,
+ 257,
+ 258,
+ 259,
+ 260,
+ 261,
+ 262,
+ 263,
+ 264,
+ 265,
+ 266,
+ 267,
+ 268,
+ 269,
+ 270,
+ 271,
+ 272,
+ 273,
+ 274,
+ 275,
+ 276,
+ 277,
+ 278,
+ 279,
+ 280,
+ 281,
+ 282,
+ 283,
+ 284,
+ 285,
+ 286,
+ 287,
+ 288,
+ 289,
+ 290,
+ 291,
+ 292,
+ 293,
+ 294,
+ 295,
+ 296,
+ 297,
+ 298,
+ 299,
+ 300,
+ 301,
+ 302,
+ 303,
+ 304,
+ 305,
+ 306,
+ 307,
+ 308,
+ 309,
+ 310,
+ 311,
+ 312,
+ 313,
+ 314,
+ 315,
+ 316,
+ 317,
+ 318,
+ 319,
+ 320,
+ 321,
+ 322,
+ 323,
+ 324,
+ 325,
+ 326,
+ 327,
+ 328,
+ 329,
+ 330,
+ 331,
+ 332,
+ 333,
+ 334,
+ 335,
+ 336,
+ 337,
+ 338,
+ 339,
+ 340,
+ 341,
+ 342,
+ 343,
+ 344,
+ 345,
+ 346,
+ 347,
+ 348,
+ 349,
+ 350,
+ 351,
+ 352,
+ 353,
+ 354,
+ 355,
+ 356,
+ 357,
+ 358,
+ 359,
+ 360,
+ 361,
+ 362,
+ 363,
+ 364,
+ 365,
+ 366,
+ 367,
+ 368,
+ 369,
+ 370,
+ 371,
+ 372,
+ 373,
+ 374,
+ 375,
+ 376,
+ 377,
+ 378,
+ 379,
+ 380,
+ 381,
+ 382,
+ 383,
+ 384,
+ 385,
+ 386,
+ 387,
+ 388,
+ 389,
+ 390,
+ 391,
+ 392,
+ 393,
+ 394,
+ 395,
+ 396,
+ 397,
+ 398,
+ 399,
+ 400,
+ 401,
+ 402,
+ 403,
+ 404,
+ 405,
+ 406,
+ 407,
+ 408,
+ 409,
+ 410,
+ 411,
+ 412,
+ 413,
+ 414,
+ 415,
+ 416,
+ 417,
+ 418,
+ 419,
+ 420,
+ 421,
+ 422,
+ 423,
+ 424,
+ 425,
+ 426,
+ 427,
+ 428,
+ 429,
+ 430,
+ 431,
+ 432,
+ 433,
+ 434,
+ 435,
+ 436,
+ 437,
+ 438,
+ 439,
+ 440,
+ 441,
+ 442,
+ 443,
+ 444,
+ 445,
+ 446,
+ 447,
+ 448,
+ 449,
+ 450,
+ 451,
+ 452,
+ 453,
+ 454,
+ 455,
+ 456,
+ 457,
+ 458,
+ 459,
+ 460,
+ 461,
+ 462,
+ 463,
+ 464,
+ 465,
+ 466,
+ 467,
+ 468,
+ 469,
+ 470,
+ 471,
+ 472,
+ 473,
+ 474,
+ 475,
+ 476,
+ 477,
+ 478,
+ 479,
+ 480,
+ 481,
+ 482,
+ 483,
+ 484,
+ 485,
+ 486,
+ 487,
+ 488,
+ 489,
+ 490,
+ 491,
+ 492,
+ 493,
+ 494,
+ 495,
+ 496,
+ 497,
+ 498,
+ 499,
+ 500,
+ 501,
+ 502,
+ 503,
+ 504,
+ 505,
+ 506,
+ 507,
+ 508,
+ 509,
+]
class ImgLinearBackbone(nn.Module):
def __init__(
- self,
- d_model: int,
- patch_size: int,
- in_chan: int = 3,
+ self,
+ d_model: int,
+ patch_size: int,
+ in_chan: int = 3,
) -> None:
super().__init__()
@@ -34,9 +535,7 @@ def forward(self, x: Tensor) -> Tensor:
class Encoder(nn.Module):
- def __init__(
- self
- ) -> None:
+ def __init__(self) -> None:
super().__init__()
self.patch_size = 16
@@ -57,11 +556,17 @@ def __init__(
batch_first=True,
norm_first=self.norm_first,
)
- norm_layer=partial(nn.LayerNorm, eps=1e-6)
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.norm = norm_layer(self.d_model)
- self.backbone = ImgLinearBackbone(d_model=self.d_model, patch_size=self.patch_size)
- self.pos_embed = PositionEmbedding(max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout)
- self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=self.n_encoder_layer, enable_nested_tensor=False)
+ self.backbone = ImgLinearBackbone(
+ d_model=self.d_model, patch_size=self.patch_size
+ )
+ self.pos_embed = PositionEmbedding(
+ max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout
+ )
+ self.encoder = nn.TransformerEncoder(
+ encoder_layer, num_layers=self.n_encoder_layer, enable_nested_tensor=False
+ )
def forward(self, x: Tensor) -> Tensor:
src_feature = self.backbone(x)
@@ -89,10 +594,10 @@ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
class TokenEmbedding(nn.Module):
def __init__(
- self,
- vocab_size: int,
- d_model: int,
- padding_idx: int,
+ self,
+ vocab_size: int,
+ d_model: int,
+ padding_idx: int,
) -> None:
super().__init__()
assert vocab_size > 0
@@ -125,20 +630,29 @@ def __post_init__(self):
self.intermediate_size = find_multiple(n_hidden, 256)
self.head_dim = self.dim // self.n_head
+
class KVCache(nn.Module):
def __init__(
- self,
- max_batch_size,
- max_seq_length,
- n_heads,
- head_dim,
- dtype=torch.bfloat16,
- device="cpu",
+ self,
+ max_batch_size,
+ max_seq_length,
+ n_heads,
+ head_dim,
+ dtype=torch.bfloat16,
+ device="cpu",
):
super().__init__()
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
- self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype, device=device), persistent=False)
- self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype, device=device), persistent=False)
+ self.register_buffer(
+ "k_cache",
+ torch.zeros(cache_shape, dtype=dtype, device=device),
+ persistent=False,
+ )
+ self.register_buffer(
+ "v_cache",
+ torch.zeros(cache_shape, dtype=dtype, device=device),
+ persistent=False,
+ )
def update(self, input_pos, k_val, v_val):
# input_pos: [S], k_val: [B, H, S, D]
@@ -152,13 +666,11 @@ def update(self, input_pos, k_val, v_val):
return k_out[:bs], v_out[:bs]
+
class GPTFastDecoder(nn.Module):
- def __init__(
- self
- ) -> None:
+ def __init__(self) -> None:
super().__init__()
-
self.vocab_size = 960
self.padding_idx = 2
self.prefix_token_id = 11
@@ -179,9 +691,17 @@ def __init__(
norm_first=self.norm_first,
)
self.config = config
- self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
- self.token_embed = TokenEmbedding(vocab_size=self.vocab_size, d_model=self.d_model, padding_idx=self.padding_idx)
- self.pos_embed = PositionEmbedding(max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout)
+ self.layers = nn.ModuleList(
+ TransformerBlock(config) for _ in range(config.n_layer)
+ )
+ self.token_embed = TokenEmbedding(
+ vocab_size=self.vocab_size,
+ d_model=self.d_model,
+ padding_idx=self.padding_idx,
+ )
+ self.pos_embed = PositionEmbedding(
+ max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout
+ )
self.generator = nn.Linear(self.d_model, self.vocab_size)
self.token_white_list = TOKEN_WHITE_LIST
self.mask_cache: Optional[Tensor] = None
@@ -193,7 +713,10 @@ def setup_caches(self, max_batch_size, max_seq_length, dtype, device):
b.multihead_attn.k_cache = None
b.multihead_attn.v_cache = None
- if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
+ if (
+ self.max_seq_length >= max_seq_length
+ and self.max_batch_size >= max_batch_size
+ ):
return
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
@@ -207,24 +730,23 @@ def setup_caches(self, max_batch_size, max_seq_length, dtype, device):
self.config.n_head,
head_dim,
dtype,
- device
+ device,
)
b.multihead_attn.k_cache = None
b.multihead_attn.v_cache = None
- self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)).to(device)
-
- def forward(
- self,
- memory: Tensor,
- tgt: Tensor
- ) -> Tensor:
+ self.causal_mask = torch.tril(
+ torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
+ ).to(device)
+ def forward(self, memory: Tensor, tgt: Tensor) -> Tensor:
input_pos = torch.tensor([tgt.shape[1] - 1], device=tgt.device, dtype=torch.int)
tgt = tgt[:, -1:]
tgt_feature = self.pos_embed(self.token_embed(tgt), input_pos=input_pos)
# tgt = self.decoder(tgt_feature, memory, input_pos)
- with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
+ with torch.backends.cuda.sdp_kernel(
+ enable_flash=False, enable_mem_efficient=False, enable_math=True
+ ):
logits = tgt_feature
tgt_mask = self.causal_mask[None, None, input_pos]
for i, layer in enumerate(self.layers):
@@ -238,6 +760,7 @@ def forward(
_, next_tokens = probs.topk(1)
return next_tokens
+
class TransformerBlock(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
@@ -260,11 +783,11 @@ def __init__(self, config: ModelArgs) -> None:
self.activation = _get_activation_fn(config.activation)
def forward(
- self,
- tgt: Tensor,
- memory: Tensor,
- tgt_mask: Tensor,
- input_pos: Tensor,
+ self,
+ tgt: Tensor,
+ memory: Tensor,
+ tgt_mask: Tensor,
+ input_pos: Tensor,
) -> Tensor:
if self.norm_first:
x = tgt
@@ -299,10 +822,10 @@ def __init__(self, config: ModelArgs):
self.dim = config.dim
def forward(
- self,
- x: Tensor,
- mask: Tensor,
- input_pos: Optional[Tensor] = None,
+ self,
+ x: Tensor,
+ mask: Tensor,
+ input_pos: Optional[Tensor] = None,
) -> Tensor:
bsz, seqlen, _ = x.shape
@@ -362,9 +885,9 @@ def get_kv(self, xa: torch.Tensor):
return k, v
def forward(
- self,
- x: Tensor,
- xa: Tensor,
+ self,
+ x: Tensor,
+ xa: Tensor,
):
q = self.query(x)
batch_size, target_seq_len, _ = q.shape
@@ -383,6 +906,6 @@ def forward(
batch_size,
target_seq_len,
self.n_head * self.head_dim,
- )
+ )
return self.out(wv)
diff --git a/rapid_table/table_structure/utils.py b/rapid_table/table_structure/utils.py
index c88cf3c..e06f610 100644
--- a/rapid_table/table_structure/utils.py
+++ b/rapid_table/table_structure/utils.py
@@ -17,12 +17,11 @@
import os
import platform
import traceback
-
-import cv2
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
+import cv2
import numpy as np
from onnxruntime import (
GraphOptimizationLevel,
@@ -79,10 +78,12 @@ def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
sess_opt.inter_op_num_threads = inter_op_num_threads
return sess_opt
+
def get_metadata(self, key: str = "character") -> list:
meta_dict = self.session.get_modelmeta().custom_metadata_map
content_list = meta_dict[key].splitlines()
return content_list
+
def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
cpu_provider_opts = {
"arena_extend_strategy": "kSameAsRequested",
diff --git a/rapid_table/utils.py b/rapid_table/utils.py
index 77224ce..82d7cef 100644
--- a/rapid_table/utils.py
+++ b/rapid_table/utils.py
@@ -4,7 +4,7 @@
import os
from io import BytesIO
from pathlib import Path
-from typing import Optional, Union, List
+from typing import List, Optional, Union
import cv2
import numpy as np
@@ -14,9 +14,7 @@
class LoadImage:
- def __init__(
- self,
- ):
+ def __init__(self):
pass
def __call__(self, img: InputType) -> np.ndarray:
@@ -79,20 +77,18 @@ class LoadImageError(Exception):
class VisTable:
- def __init__(
- self,
- ):
+ def __init__(self):
self.load_img = LoadImage()
def __call__(
- self,
- img_path: Union[str, Path],
- table_html_str: str,
- save_html_path: Optional[str] = None,
- table_cell_bboxes: Optional[np.ndarray] = None,
- save_drawed_path: Optional[str] = None,
- logic_points: List[List[float]] = None,
- save_logic_path: Optional[str] = None,
+ self,
+ img_path: Union[str, Path],
+ table_html_str: str,
+ save_html_path: Optional[str] = None,
+ table_cell_bboxes: Optional[np.ndarray] = None,
+ save_drawed_path: Optional[str] = None,
+ logic_points: List[List[float]] = None,
+ save_logic_path: Optional[str] = None,
) -> None:
if save_html_path:
html_with_border = self.insert_border_style(table_html_str)
@@ -114,8 +110,10 @@ def __call__(
if save_drawed_path:
self.save_img(save_drawed_path, drawed_img)
if save_logic_path and logic_points:
- polygons = [[box[0],box[1], box[4], box[5]] for box in table_cell_bboxes]
- self.plot_rec_box_with_logic_info(img_path, save_logic_path, logic_points, polygons)
+ polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes]
+ self.plot_rec_box_with_logic_info(
+ img_path, save_logic_path, logic_points, polygons
+ )
return drawed_img
def insert_border_style(self, table_html_str: str):
@@ -137,7 +135,9 @@ def insert_border_style(self, table_html_str: str):
html_with_border = f"{prefix_table}{style_res}{suffix_table}"
return html_with_border
- def plot_rec_box_with_logic_info(self, img_path, output_path, logic_points, sorted_polygons):
+ def plot_rec_box_with_logic_info(
+ self, img_path, output_path, logic_points, sorted_polygons
+ ):
"""
:param img_path
:param output_path
@@ -208,4 +208,3 @@ def save_img(save_path: Union[str, Path], img: np.ndarray):
def save_html(save_path: Union[str, Path], html: str):
with open(save_path, "w", encoding="utf-8") as f:
f.write(html)
-
diff --git a/setup.py b/setup.py
index 0489760..8dc2e60 100644
--- a/setup.py
+++ b/setup.py
@@ -19,7 +19,7 @@ def get_readme():
MODULE_NAME = "rapid_table"
obtainer = GetPyPiLatestVersion()
latest_version = obtainer(MODULE_NAME)
-VERSION_NUM = obtainer.version_add_one(latest_version)
+VERSION_NUM = obtainer.version_add_one(latest_version, add_patch=True)
if len(sys.argv) > 2:
match_str = " ".join(sys.argv[2:])