diff --git a/README.md b/README.md
index 27645a1..34a870a 100644
--- a/README.md
+++ b/README.md
@@ -39,7 +39,47 @@ slanet_plus是paddlex内置的SLANet升级版模型,准确率有大幅提升
-### 与[TableStructureRec](https://github.com/RapidAI/TableStructureRec) 关系
+### 更新日志
+
+
+
+#### 2024.11.24 update
+- 支持gpu推理,适配 rapidOCR 单字识别匹配
+
+#### 2024.10.13 update
+- 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx)
+
+#### 2023-12-29 v0.1.3 update
+
+- 优化可视化结果部分
+
+#### 2023-12-27 v0.1.2 update
+
+- 添加返回cell坐标框参数
+- 完善可视化函数
+
+#### 2023-07-17 v0.1.0 update
+
+- 将`rapidocr_onnxruntime`部分从`rapid_table`中解耦合出来,给出选项是否依赖,更加灵活。
+
+- 增加接口输入参数`ocr_result`:
+ - 如果在调用函数时,事先指定了`ocr_result`参数值,则不会再走OCR。其中`ocr_result`格式需要和`rapidocr_onnxruntime`返回值一致。
+ - 如果未指定`ocr_result`参数值,但是事先安装了`rapidocr_onnxruntime`库,则会自动调用该库,进行识别。
+ - 如果`ocr_result`未指定,且`rapidocr_onnxruntime`未安装,则会报错。必须满足两个条件中一个。
+
+#### 2023-07-10 v0.0.13 updata
+
+- 更改传入表格还原中OCR的实例接口,可以传入其他OCR实例,前提要与`rapidocr_onnxruntime`接口一致
+
+#### 2023-07-06 v0.0.12 update
+
+- 去掉返回表格的html字符串中的``元素,便于后续统一。
+- 采用Black工具优化代码
+
+
+
+
+### 与[TableStructureRec](https://github.com/RapidAI/TableStructureRec)关系
TableStructureRec库是一个表格识别算法的集合库,当前有`wired_table_rec`有线表格识别算法和`lineless_table_rec`无线表格识别算法的推理包。
@@ -58,6 +98,7 @@ RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structu
```bash
pip install rapidocr_onnxruntime
pip install rapid_table
+#pip install onnxruntime-gpu # for gpu inference
```
### 使用方式
@@ -76,14 +117,22 @@ table_engine = RapidTable()
from pathlib import Path
from rapid_table import RapidTable, VisTable
+from rapidocr_onnxruntime import RapidOCR
+from rapid_table.table_structure.utils import trans_char_ocr_res
+
table_engine = RapidTable()
+# 开启onnx-gpu推理
+# table_engine = RapidTable(use_cuda=True)
ocr_engine = RapidOCR()
viser = VisTable()
img_path = 'test_images/table.jpg'
ocr_result, _ = ocr_engine(img_path)
+# 单字匹配
+# ocr_result, _ = ocr_engine(img_path, return_word_box=True)
+# ocr_result = trans_char_ocr_res(ocr_result)
table_html_str, table_cell_bboxes, elapse = table_engine(img_path, ocr_result)
save_dir = Path("./inference_results/")
@@ -134,39 +183,3 @@ print(table_html_str)
Methods | | | | FPS |
SegLink [26] | 70.0 | 86d> | 77.0 | 8.9 |
PixelLink [4] | 73.2 | 83.0 | 77.8 | |
TextSnake [18] | 73.9 | 83.2 | 78.3 | 1.1 |
TextField [37] | 75.9 | 87.4 | 81.3 | 5.2 |
MSR[38] | 76.7 | 87.87.4 | 81.7 | |
FTSN [3] | 77.1 | 87.6 | 82.0 | |
LSE[30] | 81.7 | 84.2 | 82.9 | <>
CRAFT [2] | 78.2 | 88.2 | 82.9 | 8.6 |
MCN[16] | 79 | 88 | 83 | |
ATRR>[35] | 82.1 | 85.2 | 83.6 | |
PAN [34] | 83.8 | 84.4 | 84.1 | 30.2 |
DB[12] | 79.2 | 91.5 | 84.9 | 32.0 |
DRRG[41] | 82.30 | 88.05 | 85.08 | |
Ours (SynText) | 80.68 | 8582.97 | 12.68 | |
Ours (MLT-17) | 84.54 | 86.62 | 85.57 | 12.31 |
-
-### 更新日志
-
-
-
-#### 2023-12-29 v0.1.3 update
-
-- 优化可视化结果部分
-
-#### 2023-12-27 v0.1.2 update
-
-- 添加返回cell坐标框参数
-- 完善可视化函数
-
-#### 2023-07-17 v0.1.0 update
-
-- 将`rapidocr_onnxruntime`部分从`rapid_table`中解耦合出来,给出选项是否依赖,更加灵活。
-
-- 增加接口输入参数`ocr_result`:
- - 如果在调用函数时,事先指定了`ocr_result`参数值,则不会再走OCR。其中`ocr_result`格式需要和`rapidocr_onnxruntime`返回值一致。
- - 如果未指定`ocr_result`参数值,但是事先安装了`rapidocr_onnxruntime`库,则会自动调用该库,进行识别。
- - 如果`ocr_result`未指定,且`rapidocr_onnxruntime`未安装,则会报错。必须满足两个条件中一个。
-
-#### 2023-07-10 v0.0.13 updata
-
-- 更改传入表格还原中OCR的实例接口,可以传入其他OCR实例,前提要与`rapidocr_onnxruntime`接口一致
-
-#### 2023-07-06 v0.0.12 update
-
-- 去掉返回表格的html字符串中的``元素,便于后续统一。
-- 采用Black工具优化代码
-
-#### 2024.10.13 update
-- 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx)
-
-
diff --git a/rapid_table/main.py b/rapid_table/main.py
index 02d6516..b2a52c0 100644
--- a/rapid_table/main.py
+++ b/rapid_table/main.py
@@ -19,7 +19,7 @@
class RapidTable:
- def __init__(self, model_path: Optional[str] = None, model_type: str = None):
+ def __init__(self, model_path: Optional[str] = None, model_type: str = None, use_cuda: bool = False):
if model_path is None:
model_path = str(
root_dir / "models" / "slanet-plus.onnx"
@@ -27,7 +27,11 @@ def __init__(self, model_path: Optional[str] = None, model_type: str = None):
model_type = "slanet-plus"
self.model_type = model_type
self.load_img = LoadImage()
- self.table_structure = TableStructurer(model_path)
+ config = {
+ "model_path": model_path,
+ "use_cuda": use_cuda,
+ }
+ self.table_structure = TableStructurer(config)
self.table_matcher = TableMatch()
try:
diff --git a/rapid_table/table_structure/logger.py b/rapid_table/table_structure/logger.py
new file mode 100644
index 0000000..2950987
--- /dev/null
+++ b/rapid_table/table_structure/logger.py
@@ -0,0 +1,21 @@
+# -*- encoding: utf-8 -*-
+# @Author: Jocker1212
+# @Contact: xinyijianggo@gmail.com
+import logging
+from functools import lru_cache
+
+
+@lru_cache(maxsize=32)
+def get_logger(name: str) -> logging.Logger:
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.DEBUG)
+
+ fmt = "%(asctime)s - %(name)s - %(levelname)s: %(message)s"
+ format_str = logging.Formatter(fmt)
+
+ sh = logging.StreamHandler()
+ sh.setLevel(logging.DEBUG)
+
+ logger.addHandler(sh)
+ sh.setFormatter(format_str)
+ return logger
diff --git a/rapid_table/table_structure/table_structure.py b/rapid_table/table_structure/table_structure.py
index 884d004..c2509d8 100644
--- a/rapid_table/table_structure/table_structure.py
+++ b/rapid_table/table_structure/table_structure.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
+from typing import Dict, Any
import numpy as np
@@ -19,10 +20,10 @@
class TableStructurer:
- def __init__(self, model_path: str):
+ def __init__(self, config: Dict[str, Any]):
self.preprocess_op = TablePreprocess()
- self.session = OrtInferSession(model_path)
+ self.session = OrtInferSession(config)
self.character = self.session.get_metadata()
self.postprocess_op = TableLabelDecode(self.character)
@@ -37,7 +38,7 @@ def __call__(self, img):
img = np.expand_dims(img, axis=0)
img = img.copy()
- outputs = self.session(img)
+ outputs = self.session([img])
preds = {"loc_preds": outputs[0], "structure_probs": outputs[1]}
diff --git a/rapid_table/table_structure/utils.py b/rapid_table/table_structure/utils.py
index 5c12ab1..c88cf3c 100644
--- a/rapid_table/table_structure/utils.py
+++ b/rapid_table/table_structure/utils.py
@@ -14,60 +14,233 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
-from pathlib import Path
+import os
+import platform
+import traceback
import cv2
+from enum import Enum
+from pathlib import Path
+from typing import Any, Dict, List, Tuple, Union
+
import numpy as np
-from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
+from onnxruntime import (
+ GraphOptimizationLevel,
+ InferenceSession,
+ SessionOptions,
+ get_available_providers,
+ get_device,
+)
+
+from rapid_table.table_structure.logger import get_logger
+
+
+class EP(Enum):
+ CPU_EP = "CPUExecutionProvider"
+ CUDA_EP = "CUDAExecutionProvider"
+ DIRECTML_EP = "DmlExecutionProvider"
class OrtInferSession:
- def __init__(self, onnx_path: str):
- self._verify_model(onnx_path)
+ def __init__(self, config: Dict[str, Any]):
+ self.logger = get_logger("OrtInferSession")
+
+ model_path = config.get("model_path", None)
+ self._verify_model(model_path)
+
+ self.cfg_use_cuda = config.get("use_cuda", None)
+ self.cfg_use_dml = config.get("use_dml", None)
+
+ self.had_providers: List[str] = get_available_providers()
+ EP_list = self._get_ep_list()
+
+ sess_opt = self._init_sess_opts(config)
+ self.session = InferenceSession(
+ model_path,
+ sess_options=sess_opt,
+ providers=EP_list,
+ )
+ self._verify_providers()
+ @staticmethod
+ def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
sess_opt = SessionOptions()
sess_opt.log_severity_level = 4
sess_opt.enable_cpu_mem_arena = False
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
- cpu_ep = "CPUExecutionProvider"
- cpu_provider_options = {
+ cpu_nums = os.cpu_count()
+ intra_op_num_threads = config.get("intra_op_num_threads", -1)
+ if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
+ sess_opt.intra_op_num_threads = intra_op_num_threads
+
+ inter_op_num_threads = config.get("inter_op_num_threads", -1)
+ if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
+ sess_opt.inter_op_num_threads = inter_op_num_threads
+
+ return sess_opt
+ def get_metadata(self, key: str = "character") -> list:
+ meta_dict = self.session.get_modelmeta().custom_metadata_map
+ content_list = meta_dict[key].splitlines()
+ return content_list
+ def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
+ cpu_provider_opts = {
"arena_extend_strategy": "kSameAsRequested",
}
- EP_list = [(cpu_ep, cpu_provider_options)]
+ EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
- self._verify_model(onnx_path)
- self.session = InferenceSession(
- onnx_path, sess_options=sess_opt, providers=EP_list
+ cuda_provider_opts = {
+ "device_id": 0,
+ "arena_extend_strategy": "kNextPowerOfTwo",
+ "cudnn_conv_algo_search": "EXHAUSTIVE",
+ "do_copy_in_default_stream": True,
+ }
+ self.use_cuda = self._check_cuda()
+ if self.use_cuda:
+ EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts))
+
+ self.use_directml = self._check_dml()
+ if self.use_directml:
+ self.logger.info(
+ "Windows 10 or above detected, try to use DirectML as primary provider"
+ )
+ directml_options = (
+ cuda_provider_opts if self.use_cuda else cpu_provider_opts
+ )
+ EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options))
+ return EP_list
+
+ def _check_cuda(self) -> bool:
+ if not self.cfg_use_cuda:
+ return False
+
+ cur_device = get_device()
+ if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers:
+ return True
+
+ self.logger.warning(
+ "%s is not in available providers (%s). Use %s inference by default.",
+ EP.CUDA_EP.value,
+ self.had_providers,
+ self.had_providers[0],
)
-
- def __call__(self, input_content: np.ndarray) -> np.ndarray:
- input_dict = dict(zip(self.get_input_names(), [input_content]))
+ self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.")
+ self.logger.info(
+ "(For reference only) If you want to use GPU acceleration, you must do:"
+ )
+ self.logger.info(
+ "First, uninstall all onnxruntime pakcages in current environment."
+ )
+ self.logger.info(
+ "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`."
+ )
+ self.logger.info(
+ "\tNote the onnxruntime-gpu version must match your cuda and cudnn version."
+ )
+ self.logger.info(
+ "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html"
+ )
+ self.logger.info(
+ "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']",
+ EP.CUDA_EP.value,
+ )
+ return False
+
+ def _check_dml(self) -> bool:
+ if not self.cfg_use_dml:
+ return False
+
+ cur_os = platform.system()
+ if cur_os != "Windows":
+ self.logger.warning(
+ "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.",
+ cur_os,
+ self.had_providers[0],
+ )
+ return False
+
+ cur_window_version = int(platform.release().split(".")[0])
+ if cur_window_version < 10:
+ self.logger.warning(
+ "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.",
+ cur_window_version,
+ self.had_providers[0],
+ )
+ return False
+
+ if EP.DIRECTML_EP.value in self.had_providers:
+ return True
+
+ self.logger.warning(
+ "%s is not in available providers (%s). Use %s inference by default.",
+ EP.DIRECTML_EP.value,
+ self.had_providers,
+ self.had_providers[0],
+ )
+ self.logger.info("If you want to use DirectML acceleration, you must do:")
+ self.logger.info(
+ "First, uninstall all onnxruntime pakcages in current environment."
+ )
+ self.logger.info(
+ "Second, install onnxruntime-directml by `pip install onnxruntime-directml`"
+ )
+ self.logger.info(
+ "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']",
+ EP.DIRECTML_EP.value,
+ )
+ return False
+
+ def _verify_providers(self):
+ session_providers = self.session.get_providers()
+ first_provider = session_providers[0]
+
+ if self.use_cuda and first_provider != EP.CUDA_EP.value:
+ self.logger.warning(
+ "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.",
+ EP.CUDA_EP.value,
+ first_provider,
+ )
+
+ if self.use_directml and first_provider != EP.DIRECTML_EP.value:
+ self.logger.warning(
+ "%s is not available for current env, the inference part is automatically shifted to be executed under %s.",
+ EP.DIRECTML_EP.value,
+ first_provider,
+ )
+
+ def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
+ input_dict = dict(zip(self.get_input_names(), input_content))
try:
- return self.session.run(self.get_output_names(), input_dict)
+ return self.session.run(None, input_dict)
except Exception as e:
- raise ONNXRuntimeError("ONNXRuntime inferece failed.") from e
+ error_info = traceback.format_exc()
+ raise ONNXRuntimeError(error_info) from e
- def get_input_names(
- self,
- ):
+ def get_input_names(self) -> List[str]:
return [v.name for v in self.session.get_inputs()]
- def get_output_names(
- self,
- ):
+ def get_output_names(self) -> List[str]:
return [v.name for v in self.session.get_outputs()]
- def get_metadata(self, key: str = "character") -> list:
+ def get_character_list(self, key: str = "character") -> List[str]:
meta_dict = self.session.get_modelmeta().custom_metadata_map
- content_list = meta_dict[key].splitlines()
- return content_list
+ return meta_dict[key].splitlines()
+
+ def have_key(self, key: str = "character") -> bool:
+ meta_dict = self.session.get_modelmeta().custom_metadata_map
+ if key in meta_dict.keys():
+ return True
+ return False
@staticmethod
- def _verify_model(model_path):
+ def _verify_model(model_path: Union[str, Path, None]):
+ if model_path is None:
+ raise ValueError("model_path is None!")
+
model_path = Path(model_path)
if not model_path.exists():
raise FileNotFoundError(f"{model_path} does not exists.")
+
if not model_path.is_file():
raise FileExistsError(f"{model_path} is not a file.")
@@ -355,3 +528,16 @@ def __call__(self, data):
for key in self.keep_keys:
data_list.append(data[key])
return data_list
+
+
+def trans_char_ocr_res(ocr_res):
+ word_result = []
+ for res in ocr_res:
+ score = res[2]
+ for word_box, word in zip(res[3], res[4]):
+ word_res = []
+ word_res.append(word_box)
+ word_res.append(word)
+ word_res.append(score)
+ word_result.append(word_res)
+ return word_result