Skip to content

Commit

Permalink
chore: optimize code
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Jan 1, 2025
1 parent d8d7662 commit 8904b15
Show file tree
Hide file tree
Showing 11 changed files with 742 additions and 189 deletions.
154 changes: 80 additions & 74 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@ slanet_plus是paddlex内置的SLANet升级版模型,准确率有大幅提升

unitable是来源unitable的transformer模型,精度最高,暂仅支持pytorch推理,支持gpu推理加速,训练权重来源于 [OhMyTable项目](https://github.com/Sanster/OhMyTable)

| 模型类型 | 模型名称 | 推理框架 |模型大小 |推理耗时(单图 60KB)|
|:--------------:|:--------------------------------------:| :------: |:------: |:------: |
| 英文 | `en_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.3M |0.15s |
| 中文 | `ch_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.4M |0.15s |
| slanet_plus 中文 | `slanet-plus.onnx` | onnxruntime |6.8M |0.15s |
| unitable 中文 | `unitable(encoder.pth,decoder.pth)` | pytorch |500M |cpu(6s) gpu-4090(1.5s)|
| `model_type` | 模型名称 | 推理框架 |模型大小 |推理耗时(单图 60KB)|
|:--------------|:--------------------------------------| :------: |:------ |:------ |
| `ppstructure_en` | `en_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.3M |0.15s |
| `ppstructure_zh` | `ch_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.4M |0.15s |
| `slanet_plus` | `slanet-plus.onnx` | onnxruntime |6.8M |0.15s |
| `unitable` | `unitable(encoder.pth,decoder.pth)` | pytorch |500M |cpu(6s) gpu-4090(1.5s)|

模型来源\
[PaddleOCR 表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/table/README_ch.md) \
[PaddleX-SlaNetPlus 表格识别](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/module_usage/tutorials/ocr_modules/table_structure_recognition.md) \
[PaddleOCR 表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/table/README_ch.md)\
[PaddleX-SlaNetPlus 表格识别](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/module_usage/tutorials/ocr_modules/table_structure_recognition.md)\
[Unitable](https://github.com/poloclub/unitable?tab=readme-ov-file)

模型下载地址为:[link](https://github.com/RapidAI/RapidTable/releases/tag/assets)
Expand All @@ -41,51 +41,6 @@ unitable是来源unitable的transformer模型,精度最高,暂仅支持pytor
<img src="https://github.com/RapidAI/RapidTable/releases/download/assets/preview.gif" alt="Demo" width="100%" height="100%">
</div>

### 更新日志

<details>

#### 2024.12.30 update

- 支持Unitable模型的表格识别,使用pytorch框架

#### 2024.11.24 update

- 支持gpu推理,适配 rapidOCR 单字识别匹配,支持逻辑坐标返回及可视化

#### 2024.10.13 update

- 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx)

#### 2023-12-29 v0.1.3 update

- 优化可视化结果部分

#### 2023-12-27 v0.1.2 update

- 添加返回cell坐标框参数
- 完善可视化函数

#### 2023-07-17 v0.1.0 update

-`rapidocr_onnxruntime`部分从`rapid_table`中解耦合出来,给出选项是否依赖,更加灵活。

- 增加接口输入参数`ocr_result`
- 如果在调用函数时,事先指定了`ocr_result`参数值,则不会再走OCR。其中`ocr_result`格式需要和`rapidocr_onnxruntime`返回值一致。
- 如果未指定`ocr_result`参数值,但是事先安装了`rapidocr_onnxruntime`库,则会自动调用该库,进行识别。
- 如果`ocr_result`未指定,且`rapidocr_onnxruntime`未安装,则会报错。必须满足两个条件中一个。

#### 2023-07-10 v0.0.13 updata

- 更改传入表格还原中OCR的实例接口,可以传入其他OCR实例,前提要与`rapidocr_onnxruntime`接口一致

#### 2023-07-06 v0.0.12 update

- 去掉返回表格的html字符串中的`<thead></thead><tbody></tbody>`元素,便于后续统一。
- 采用Black工具优化代码

</details>

### [TableStructureRec](https://github.com/RapidAI/TableStructureRec)关系

TableStructureRec库是一个表格识别算法的集合库,当前有`wired_table_rec`有线表格识别算法和`lineless_table_rec`无线表格识别算法的推理包。
Expand All @@ -105,8 +60,8 @@ RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structu
```bash
pip install rapidocr_onnxruntime
pip install rapid_table
#pip install rapid_table[torch] # for unitable inference
#pip install onnxruntime-gpu # for onnx gpu inference
# pip install rapid_table[torch] # for unitable inference
# pip install onnxruntime-gpu # for onnx gpu inference
```

### 使用方式
Expand All @@ -130,10 +85,13 @@ from rapidocr_onnxruntime import RapidOCR
from rapid_table.table_structure.utils import trans_char_ocr_res

table_engine = RapidTable()

# 开启onnx-gpu推理
# table_engine = RapidTable(use_cuda=True)

# 使用torch推理版本的unitable模型
# table_engine = RapidTable(use_cuda=True, device="cuda:0", model_type="unitable")

ocr_engine = RapidOCR()
viser = VisTable()

Expand All @@ -143,6 +101,7 @@ ocr_result, _ = ocr_engine(img_path)
# 单字匹配
# ocr_result, _ = ocr_engine(img_path, return_word_box=True)
# ocr_result = trans_char_ocr_res(ocr_result)

table_html_str, table_cell_bboxes, elapse = table_engine(img_path, ocr_result)

save_dir = Path("./inference_results/")
Expand All @@ -155,6 +114,7 @@ viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_p

# 返回逻辑坐标
# table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_result, return_logic_points=True)

# save_logic_path = save_dir / f"vis_logic_{Path(img_path).name}"
# viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path,logic_points, save_logic_path)

Expand All @@ -163,33 +123,32 @@ print(table_html_str)

#### 终端运行

- 用法:

```bash
$ rapid_table -h
usage: rapid_table [-h] [-v] -img IMG_PATH [-m MODEL_PATH]
```bash
$ rapid_table -h
usage: rapid_table [-h] [-v] -img IMG_PATH [-m MODEL_PATH]

optional arguments:
-h, --help show this help message and exit
-v, --vis Whether to visualize the layout results.
-img IMG_PATH, --img_path IMG_PATH
Path to image for layout.
-m MODEL_PATH, --model_path MODEL_PATH
The model path used for inference.
```

optional arguments:
-h, --help show this help message and exit
-v, --vis Whether to visualize the layout results.
-img IMG_PATH, --img_path IMG_PATH
Path to image for layout.
-m MODEL_PATH, --model_path MODEL_PATH
The model path used for inference.
```
示例:

- 示例:
```bash
rapid_table -v -img test_images/table.jpg
```

```bash
rapid_table -v -img test_images/table.jpg
```

### 结果

#### 返回结果

```html
<details>

```html
<html>
<body>
<table>
Expand Down Expand Up @@ -316,9 +275,56 @@ print(table_html_str)
</html>
```

</details>

#### 可视化结果

<div align="center">
<table><tr><td>Methods</td><td></td><td></td><td></td><td>FPS</td></tr><tr><td>SegLink [26]</td><td>70.0</td><td>86d><td.0</td><td>77.0</td><td>8.9</td></tr><tr><td>PixelLink [4]</td><td>73.2</td><td>83.0</td><td>77.8</td><td></td></tr><tr><td>TextSnake [18]</td><td>73.9</td><td>83.2</td><td>78.3</td><td>1.1</td></tr><tr><td>TextField [37]</td><td>75.9</td><td>87.4</td><td>81.3</td><td>5.2</td></tr><tr><td>MSR[38]</td><td>76.7</td><td>87.87.4</td><td>81.7</td><td></td></tr><tr><td>FTSN [3]</td><td>77.1</td><td>87.6</td><td>82.0</td><td></td></tr><tr><td>LSE[30]</td><td>81.7</td><td>84.2</td><td>82.9</td><><ttd></td></tr><tr><td>CRAFT [2]</td><td>78.2</td><td>88.2</td><td>82.9</td><td>8.6</td></tr><tr><td>MCN[16]</td><td>79</td><td>88</td><td>83</td><td></td></tr><tr><td>ATRR</>[35]</td><td>82.1</td><td>85.2</td><td>83.6</td><td></td></tr><tr><td>PAN [34]</td><td>83.8</td><td>84.4</td><td>84.1</td><td>30.2</td></tr><tr><td>DB[12]</td><td>79.2</t91/d><td>91.5</td><td>84.9</td><td>32.0</td></tr><tr><td>DRRG[41]</td><td>82.30</td><td>88.05</td><td>85.08</td><td></td></tr><tr><td>Ours (SynText)</td><td>80.68</td><td>85<t..40</td><td>82.97</td><td>12.68</td></tr><tr><td>Ours (MLT-17)</td><td>84.54</td><td>86.62</td><td>85.57</td><td>12.31</td></tr></table>

</div>

### 更新日志

<details>

#### 2024.12.30 update

- 支持Unitable模型的表格识别,使用pytorch框架

#### 2024.11.24 update

- 支持gpu推理,适配 rapidOCR 单字识别匹配,支持逻辑坐标返回及可视化

#### 2024.10.13 update

- 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx)

#### 2023-12-29 v0.1.3 update

- 优化可视化结果部分

#### 2023-12-27 v0.1.2 update

- 添加返回cell坐标框参数
- 完善可视化函数

#### 2023-07-17 v0.1.0 update

-`rapidocr_onnxruntime`部分从`rapid_table`中解耦合出来,给出选项是否依赖,更加灵活。

- 增加接口输入参数`ocr_result`
- 如果在调用函数时,事先指定了`ocr_result`参数值,则不会再走OCR。其中`ocr_result`格式需要和`rapidocr_onnxruntime`返回值一致。
- 如果未指定`ocr_result`参数值,但是事先安装了`rapidocr_onnxruntime`库,则会自动调用该库,进行识别。
- 如果`ocr_result`未指定,且`rapidocr_onnxruntime`未安装,则会报错。必须满足两个条件中一个。

#### 2023-07-10 v0.0.13 updata

- 更改传入表格还原中OCR的实例接口,可以传入其他OCR实例,前提要与`rapidocr_onnxruntime`接口一致

#### 2023-07-06 v0.0.12 update

- 去掉返回表格的html字符串中的`<thead></thead><tbody></tbody>`元素,便于后续统一。
- 采用Black工具优化代码

</details>
9 changes: 5 additions & 4 deletions rapid_table/download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
}
}


class DownloadModel:
cur_dir = PROJECT_DIR

@staticmethod
def get_model_path(model_type: str, sub_file_type: str, path: Union[str, Path, None]) -> str:
def get_model_path(
model_type: str, sub_file_type: str, path: Union[str, Path, None]
) -> str:
if path is not None:
return path

Expand All @@ -32,9 +35,7 @@ def get_model_path(model_type: str, sub_file_type: str, path: Union[str, Path, N
model_path = DownloadModel.download(model_url)
return model_path

logger.info(
"model url is None, using the default download model %s", path
)
logger.info("model url is None, using the default download model %s", path)
return path

@classmethod
Expand Down
23 changes: 18 additions & 5 deletions rapid_table/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np

from .download_model import DownloadModel
from .params import accept_kwargs_as_dataclass, BaseConfig
from .params import BaseConfig, accept_kwargs_as_dataclass
from .table_matcher import TableMatch
from .table_structure import TableStructurer, TableStructureUnitable
from .utils import LoadImage, VisTable
Expand All @@ -26,13 +26,21 @@ class RapidTable:
def __init__(self, config: BaseConfig):
self.model_type = config.model_type
self.load_img = LoadImage()

if self.model_type == "unitable":
config.encoder_path = DownloadModel.get_model_path(self.model_type, "encoder", config.encoder_path)
config.decoder_path = DownloadModel.get_model_path(self.model_type, "decoder", config.decoder_path)
config.vocab_path = DownloadModel.get_model_path(self.model_type, "vocab", config.vocab_path)
config.encoder_path = DownloadModel.get_model_path(
self.model_type, "encoder", config.encoder_path
)
config.decoder_path = DownloadModel.get_model_path(
self.model_type, "decoder", config.decoder_path
)
config.vocab_path = DownloadModel.get_model_path(
self.model_type, "vocab", config.vocab_path
)
self.table_structure = TableStructureUnitable(asdict(config))
else:
self.table_structure = TableStructurer(asdict(config))

self.table_matcher = TableMatch()

try:
Expand All @@ -44,7 +52,7 @@ def __call__(
self,
img_content: Union[str, np.ndarray, bytes, Path],
ocr_result: List[Union[List[List[float]], str, str]] = None,
return_logic_points = False
return_logic_points=False,
) -> Tuple[str, float]:
if self.ocr_engine is None and ocr_result is None:
raise ValueError(
Expand All @@ -65,14 +73,17 @@ def __call__(
if self.model_type == "slanet-plus":
pred_bboxes = self.adapt_slanet_plus(img, pred_bboxes)
pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res)

# 过滤掉占位的bbox
mask = ~np.all(pred_bboxes == 0, axis=1)
pred_bboxes = pred_bboxes[mask]

# 避免低版本升级后出现问题,默认不打开
if return_logic_points:
logic_points = self.table_matcher.decode_logic_points(pred_structures)
elapse = time.time() - s
return pred_html, pred_bboxes, logic_points, elapse

elapse = time.time() - s
return pred_html, pred_bboxes, elapse

Expand All @@ -93,6 +104,7 @@ def get_boxes_recs(
r_boxes.append(box)
dt_boxes = np.array(r_boxes)
return dt_boxes, rec_res

def adapt_slanet_plus(self, img: np.ndarray, pred_bboxes: np.ndarray) -> np.ndarray:
h, w = img.shape[:2]
resized = 488
Expand All @@ -103,6 +115,7 @@ def adapt_slanet_plus(self, img: np.ndarray, pred_bboxes: np.ndarray) -> np.ndar
pred_bboxes[:, 1::2] *= h_ratio
return pred_bboxes


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down
29 changes: 17 additions & 12 deletions rapid_table/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
root_dir = Path(__file__).resolve().parent
logger = get_logger("params")


@dataclass
class BaseConfig:
model_type: str = "slanet-plus"
Expand All @@ -25,16 +26,20 @@ def wrapper(*args, **kwargs):
if len(args) == 2 and isinstance(args[1], cls):
# 如果已经传递了 ModelConfig 实例,直接调用函数
return func(*args, **kwargs)
else:
# 提取 cls 中定义的字段
cls_fields = {field.name for field in fields(cls)}
# 过滤掉未定义的字段
filtered_kwargs = {k: v for k, v in kwargs.items() if k in cls_fields}
# 发出警告对于未定义的字段
for k in (kwargs.keys() - cls_fields):
logger.warning(f"Warning: '{k}' is not a valid field in {cls.__name__} and will be ignored.")
# 创建 ModelConfig 实例并调用函数
config = cls(**filtered_kwargs)
return func(args[0], config=config)

# 提取 cls 中定义的字段
cls_fields = {field.name for field in fields(cls)}
# 过滤掉未定义的字段
filtered_kwargs = {k: v for k, v in kwargs.items() if k in cls_fields}
# 发出警告对于未定义的字段
for k in kwargs.keys() - cls_fields:
logger.warning(
f"Warning: '{k}' is not a valid field in {cls.__name__} and will be ignored."
)
# 创建 ModelConfig 实例并调用函数
config = cls(**filtered_kwargs)
return func(args[0], config=config)

return wrapper
return decorator

return decorator
Loading

0 comments on commit 8904b15

Please sign in to comment.