Skip to content

Commit

Permalink
Merge pull request #37 from RapidAI/dev_unitable
Browse files Browse the repository at this point in the history
Dev unitable
  • Loading branch information
SWHL authored Jan 1, 2025
2 parents 01869db + 4a70f60 commit d8d7662
Show file tree
Hide file tree
Showing 25 changed files with 142 additions and 1,088 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/publish_whl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ jobs:
pip install -r requirements.txt
pip install rapidocr_onnxruntime
pip install torch
pip install torchvision
pip install tokenizers
pip install pytest
pytest tests/test_table.py
pytest tests/test_table_torch.py
GenerateWHL_PushPyPi:
needs: UnitTesting
Expand Down
59 changes: 0 additions & 59 deletions .github/workflows/publish_whl_torch.yml

This file was deleted.

65 changes: 6 additions & 59 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structu
```bash
pip install rapidocr_onnxruntime
pip install rapid_table
#pip install rapid_table_torch # for unitable inference
#pip install onnxruntime-gpu # for gpu inference
#pip install rapid_table[torch] # for unitable inference
#pip install onnxruntime-gpu # for onnx gpu inference
```

### 使用方式
Expand All @@ -117,11 +117,11 @@ RapidTable类提供model_path参数,可以自行指定上述2个模型,默

```python
table_engine = RapidTable()
# table_engine = RapidTable(use_cuda=True, device="cuda:0", model_type="unitable")
```

完整示例:

#### onnx版本
```python
from pathlib import Path

Expand All @@ -132,6 +132,8 @@ from rapid_table.table_structure.utils import trans_char_ocr_res
table_engine = RapidTable()
# 开启onnx-gpu推理
# table_engine = RapidTable(use_cuda=True)
# 使用torch推理版本的unitable模型
# table_engine = RapidTable(use_cuda=True, device="cuda:0", model_type="unitable")
ocr_engine = RapidOCR()
viser = VisTable()

Expand Down Expand Up @@ -159,41 +161,8 @@ viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_p
print(table_html_str)
```

#### torch版本
```python
from pathlib import Path
from rapidocr_onnxruntime import RapidOCR

from rapid_table_torch import RapidTable, VisTable
from rapid_table_torch.table_structure.utils import trans_char_ocr_res

if __name__ == '__main__':
# Init
ocr_engine = RapidOCR()
table_engine = RapidTable(device="cpu") # 默认使用cpu,若使用cuda,则传入device="cuda:0"
viser = VisTable()
img_path = "tests/test_files/image34.png"
# OCR,本模型检测框比较精准,配合单字匹配效果更好
ocr_result, _ = ocr_engine(img_path, return_word_box=True)
ocr_result = trans_char_ocr_res(ocr_result)
boxes, txts, scores = list(zip(*ocr_result))
# Save
save_dir = Path("outputs")
save_dir.mkdir(parents=True, exist_ok=True)

save_html_path = save_dir / f"{Path(img_path).stem}.html"
save_drawed_path = save_dir / f"{Path(img_path).stem}_table_vis{Path(img_path).suffix}"
# 返回逻辑坐标
table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_result)
save_logic_path = save_dir / f"vis_logic_{Path(img_path).name}"
vis_imged = viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path, logic_points,
save_logic_path)
print(f"elapse:{elapse}")
```

#### 终端运行

##### onnx:
- 用法:

```bash
Expand All @@ -214,29 +183,7 @@ print(f"elapse:{elapse}")
```bash
rapid_table -v -img test_images/table.jpg
```

##### pytorch:
- 用法:

```bash
$ rapid_table_torch -h
usage: rapid_table_torch [-h] [-v] -img IMG_PATH [-d DEVICE]

optional arguments:
-h, --help show this help message and exit
-v, --vis Whether to visualize the layout results.
-img IMG_PATH, --img_path IMG_PATH
Path to image for layout.
-d DEVICE, --device device
The model device used for inference.
```

- 示例:

```bash
rapid_table_torch -v -img test_images/table.jpg
```


### 结果

#### 返回结果
Expand Down
50 changes: 24 additions & 26 deletions demo_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,29 @@
from pathlib import Path
from rapidocr_onnxruntime import RapidOCR

from rapid_table_torch import RapidTable, VisTable
from rapid_table_torch.table_structure.utils import trans_char_ocr_res
from rapid_table import RapidTable, VisTable
from rapid_table.table_structure.utils import trans_char_ocr_res

if __name__ == '__main__':
# Init
ocr_engine = RapidOCR()
table_engine = RapidTable(encoder_path="rapid_table_torch/models/encoder.pth",
decoder_path="rapid_table_torch/models/decoder.pth",
vocab_path="rapid_table_torch/models/vocab.json",
device="cpu")
viser = VisTable()
img_path = "tests/test_files/image34.png"
# OCR
ocr_result, _ = ocr_engine(img_path, return_word_box=True)
ocr_result = trans_char_ocr_res(ocr_result)
boxes, txts, scores = list(zip(*ocr_result))
# Save
save_dir = Path("outputs")
save_dir.mkdir(parents=True, exist_ok=True)
ocr_engine = RapidOCR()
table_engine = RapidTable(use_cuda=True, device="cuda:0", model_type="unitable")
viser = VisTable()
img_path = "tests/test_files/table.jpg"
# OCR
ocr_result, _ = ocr_engine(img_path, return_word_box=True)
ocr_result = trans_char_ocr_res(ocr_result)
boxes, txts, scores = list(zip(*ocr_result))
# Save
save_dir = Path("outputs")
save_dir.mkdir(parents=True, exist_ok=True)

save_html_path = save_dir / f"{Path(img_path).stem}.html"
save_drawed_path = save_dir / f"{Path(img_path).stem}_table_vis{Path(img_path).suffix}"
# 返回逻辑坐标
table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_result)
save_logic_path = save_dir / f"vis_logic_{Path(img_path).name}"
vis_imged = viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path, logic_points,
save_logic_path)
print(f"elapse:{elapse}")
save_html_path = save_dir / f"{Path(img_path).stem}.html"
save_drawed_path = save_dir / f"{Path(img_path).stem}_table_vis{Path(img_path).suffix}"

table_html_str, table_cell_bboxes, elapse = table_engine(img_path, ocr_result)
viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path)
# 返回逻辑坐标
# table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_result, return_logic_points=True)
# save_logic_path = save_dir / f"vis_logic_{Path(img_path).name}"
# vis_imged = viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path, logic_points,
# save_logic_path)
print(f"elapse:{elapse}")
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,33 @@
logger = get_logger("DownloadModel")
CUR_DIR = Path(__file__).resolve()
PROJECT_DIR = CUR_DIR.parent

ROOT_URL = "https://www.modelscope.cn/studio/jockerK/TableRec/resolve/master/models/table_rec/unitable/"
KEY_TO_MODEL_URL = {
"unitable": {
"encoder": f"{ROOT_URL}/encoder.pth",
"decoder": f"{ROOT_URL}/decoder.pth",
"vocab": f"{ROOT_URL}/vocab.json",
}
}

class DownloadModel:
cur_dir = PROJECT_DIR

@staticmethod
def get_model_path(model_type: str, sub_file_type: str, path: Union[str, Path, None]) -> str:
if path is not None:
return path

model_url = KEY_TO_MODEL_URL.get(model_type, {}).get(sub_file_type, None)
if model_url:
model_path = DownloadModel.download(model_url)
return model_path

logger.info(
"model url is None, using the default download model %s", path
)
return path

@classmethod
def download(cls, model_full_url: Union[str, Path]) -> str:
save_dir = cls.cur_dir / "models"
Expand Down
File renamed without changes.
32 changes: 18 additions & 14 deletions rapid_table/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,34 @@
import copy
import importlib
import time
from dataclasses import asdict
from pathlib import Path
from typing import List, Optional, Tuple, Union
from typing import List, Tuple, Union

import cv2
import numpy as np

from .download_model import DownloadModel
from .params import accept_kwargs_as_dataclass, BaseConfig
from .table_matcher import TableMatch
from .table_structure import TableStructurer
from .table_structure import TableStructurer, TableStructureUnitable
from .utils import LoadImage, VisTable

root_dir = Path(__file__).resolve().parent


class RapidTable:
def __init__(self, model_path: Optional[str] = None, model_type: str = None, use_cuda: bool = False):
if model_path is None:
model_path = str(
root_dir / "models" / "slanet-plus.onnx"
)
model_type = "slanet-plus"
self.model_type = model_type
@accept_kwargs_as_dataclass(BaseConfig)
def __init__(self, config: BaseConfig):
self.model_type = config.model_type
self.load_img = LoadImage()
config = {
"model_path": model_path,
"use_cuda": use_cuda,
}
self.table_structure = TableStructurer(config)
if self.model_type == "unitable":
config.encoder_path = DownloadModel.get_model_path(self.model_type, "encoder", config.encoder_path)
config.decoder_path = DownloadModel.get_model_path(self.model_type, "decoder", config.decoder_path)
config.vocab_path = DownloadModel.get_model_path(self.model_type, "vocab", config.vocab_path)
self.table_structure = TableStructureUnitable(asdict(config))
else:
self.table_structure = TableStructurer(asdict(config))
self.table_matcher = TableMatch()

try:
Expand Down Expand Up @@ -64,6 +65,9 @@ def __call__(
if self.model_type == "slanet-plus":
pred_bboxes = self.adapt_slanet_plus(img, pred_bboxes)
pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res)
# 过滤掉占位的bbox
mask = ~np.all(pred_bboxes == 0, axis=1)
pred_bboxes = pred_bboxes[mask]
# 避免低版本升级后出现问题,默认不打开
if return_logic_points:
logic_points = self.table_matcher.decode_logic_points(pred_structures)
Expand Down
40 changes: 40 additions & 0 deletions rapid_table/params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from dataclasses import dataclass, fields
from functools import wraps
from pathlib import Path

from rapid_table.logger import get_logger

root_dir = Path(__file__).resolve().parent
logger = get_logger("params")

@dataclass
class BaseConfig:
model_type: str = "slanet-plus"
model_path: str = str(root_dir / "models" / "slanet-plus.onnx")
use_cuda: bool = False
device: str = "cpu"
encoder_path: str = None
decoder_path: str = None
vocab_path: str = None


def accept_kwargs_as_dataclass(cls):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if len(args) == 2 and isinstance(args[1], cls):
# 如果已经传递了 ModelConfig 实例,直接调用函数
return func(*args, **kwargs)
else:
# 提取 cls 中定义的字段
cls_fields = {field.name for field in fields(cls)}
# 过滤掉未定义的字段
filtered_kwargs = {k: v for k, v in kwargs.items() if k in cls_fields}
# 发出警告对于未定义的字段
for k in (kwargs.keys() - cls_fields):
logger.warning(f"Warning: '{k}' is not a valid field in {cls.__name__} and will be ignored.")
# 创建 ModelConfig 实例并调用函数
config = cls(**filtered_kwargs)
return func(args[0], config=config)
return wrapper
return decorator
5 changes: 4 additions & 1 deletion rapid_table/table_matcher/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __call__(self, pred_structures, pred_bboxes, dt_boxes, rec_res):
pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
return pred_html

def match_result(self, dt_boxes, pred_bboxes):
def match_result(self, dt_boxes, pred_bboxes, min_iou=0.1 ** 8):
matched = {}
for i, gt_box in enumerate(dt_boxes):
distances = []
Expand All @@ -49,6 +49,9 @@ def match_result(self, dt_boxes, pred_bboxes):
sorted_distances = sorted(
sorted_distances, key=lambda item: (item[1], item[0])
)
# must > min_iou
if sorted_distances[0][1] >= 1 - min_iou:
continue
if distances.index(sorted_distances[0]) not in matched.keys():
matched[distances.index(sorted_distances[0])] = [i]
else:
Expand Down
1 change: 1 addition & 0 deletions rapid_table/table_structure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .table_structure import TableStructurer
from .table_structure_unitable import TableStructureUnitable
Loading

0 comments on commit d8d7662

Please sign in to comment.