diff --git a/README.md b/README.md index 27645a1..34a870a 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,47 @@ slanet_plus是paddlex内置的SLANet升级版模型,准确率有大幅提升 Demo -### 与[TableStructureRec](https://github.com/RapidAI/TableStructureRec) 关系 +### 更新日志 + +
+ +#### 2024.11.24 update +- 支持gpu推理,适配 rapidOCR 单字识别匹配 + +#### 2024.10.13 update +- 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx) + +#### 2023-12-29 v0.1.3 update + +- 优化可视化结果部分 + +#### 2023-12-27 v0.1.2 update + +- 添加返回cell坐标框参数 +- 完善可视化函数 + +#### 2023-07-17 v0.1.0 update + +- 将`rapidocr_onnxruntime`部分从`rapid_table`中解耦合出来,给出选项是否依赖,更加灵活。 + +- 增加接口输入参数`ocr_result`: + - 如果在调用函数时,事先指定了`ocr_result`参数值,则不会再走OCR。其中`ocr_result`格式需要和`rapidocr_onnxruntime`返回值一致。 + - 如果未指定`ocr_result`参数值,但是事先安装了`rapidocr_onnxruntime`库,则会自动调用该库,进行识别。 + - 如果`ocr_result`未指定,且`rapidocr_onnxruntime`未安装,则会报错。必须满足两个条件中一个。 + +#### 2023-07-10 v0.0.13 updata + +- 更改传入表格还原中OCR的实例接口,可以传入其他OCR实例,前提要与`rapidocr_onnxruntime`接口一致 + +#### 2023-07-06 v0.0.12 update + +- 去掉返回表格的html字符串中的``元素,便于后续统一。 +- 采用Black工具优化代码 + +
+ + +### 与[TableStructureRec](https://github.com/RapidAI/TableStructureRec)关系 TableStructureRec库是一个表格识别算法的集合库,当前有`wired_table_rec`有线表格识别算法和`lineless_table_rec`无线表格识别算法的推理包。 @@ -58,6 +98,7 @@ RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structu ```bash pip install rapidocr_onnxruntime pip install rapid_table +#pip install onnxruntime-gpu # for gpu inference ``` ### 使用方式 @@ -76,14 +117,22 @@ table_engine = RapidTable() from pathlib import Path from rapid_table import RapidTable, VisTable +from rapidocr_onnxruntime import RapidOCR +from rapid_table.table_structure.utils import trans_char_ocr_res + table_engine = RapidTable() +# 开启onnx-gpu推理 +# table_engine = RapidTable(use_cuda=True) ocr_engine = RapidOCR() viser = VisTable() img_path = 'test_images/table.jpg' ocr_result, _ = ocr_engine(img_path) +# 单字匹配 +# ocr_result, _ = ocr_engine(img_path, return_word_box=True) +# ocr_result = trans_char_ocr_res(ocr_result) table_html_str, table_cell_bboxes, elapse = table_engine(img_path, ocr_result) save_dir = Path("./inference_results/") @@ -134,39 +183,3 @@ print(table_html_str) <>
MethodsFPS
SegLink [26]70.086d>77.08.9
PixelLink [4]73.283.077.8
TextSnake [18]73.983.278.31.1
TextField [37]75.987.481.35.2
MSR[38]76.787.87.481.7
FTSN [3]77.187.682.0
LSE[30]81.784.282.9
CRAFT [2]78.288.282.98.6
MCN[16]798883
ATRR[35]82.185.283.6
PAN [34]83.884.484.130.2
DB[12]79.291.584.932.0
DRRG[41]82.3088.0585.08
Ours (SynText)80.688582.9712.68
Ours (MLT-17)84.5486.6285.5712.31
- -### 更新日志 - -
- -#### 2023-12-29 v0.1.3 update - -- 优化可视化结果部分 - -#### 2023-12-27 v0.1.2 update - -- 添加返回cell坐标框参数 -- 完善可视化函数 - -#### 2023-07-17 v0.1.0 update - -- 将`rapidocr_onnxruntime`部分从`rapid_table`中解耦合出来,给出选项是否依赖,更加灵活。 - -- 增加接口输入参数`ocr_result`: - - 如果在调用函数时,事先指定了`ocr_result`参数值,则不会再走OCR。其中`ocr_result`格式需要和`rapidocr_onnxruntime`返回值一致。 - - 如果未指定`ocr_result`参数值,但是事先安装了`rapidocr_onnxruntime`库,则会自动调用该库,进行识别。 - - 如果`ocr_result`未指定,且`rapidocr_onnxruntime`未安装,则会报错。必须满足两个条件中一个。 - -#### 2023-07-10 v0.0.13 updata - -- 更改传入表格还原中OCR的实例接口,可以传入其他OCR实例,前提要与`rapidocr_onnxruntime`接口一致 - -#### 2023-07-06 v0.0.12 update - -- 去掉返回表格的html字符串中的``元素,便于后续统一。 -- 采用Black工具优化代码 - -#### 2024.10.13 update -- 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx) - -
diff --git a/rapid_table/main.py b/rapid_table/main.py index 02d6516..b2a52c0 100644 --- a/rapid_table/main.py +++ b/rapid_table/main.py @@ -19,7 +19,7 @@ class RapidTable: - def __init__(self, model_path: Optional[str] = None, model_type: str = None): + def __init__(self, model_path: Optional[str] = None, model_type: str = None, use_cuda: bool = False): if model_path is None: model_path = str( root_dir / "models" / "slanet-plus.onnx" @@ -27,7 +27,11 @@ def __init__(self, model_path: Optional[str] = None, model_type: str = None): model_type = "slanet-plus" self.model_type = model_type self.load_img = LoadImage() - self.table_structure = TableStructurer(model_path) + config = { + "model_path": model_path, + "use_cuda": use_cuda, + } + self.table_structure = TableStructurer(config) self.table_matcher = TableMatch() try: diff --git a/rapid_table/table_structure/logger.py b/rapid_table/table_structure/logger.py new file mode 100644 index 0000000..2950987 --- /dev/null +++ b/rapid_table/table_structure/logger.py @@ -0,0 +1,21 @@ +# -*- encoding: utf-8 -*- +# @Author: Jocker1212 +# @Contact: xinyijianggo@gmail.com +import logging +from functools import lru_cache + + +@lru_cache(maxsize=32) +def get_logger(name: str) -> logging.Logger: + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + + fmt = "%(asctime)s - %(name)s - %(levelname)s: %(message)s" + format_str = logging.Formatter(fmt) + + sh = logging.StreamHandler() + sh.setLevel(logging.DEBUG) + + logger.addHandler(sh) + sh.setFormatter(format_str) + return logger diff --git a/rapid_table/table_structure/table_structure.py b/rapid_table/table_structure/table_structure.py index 884d004..c2509d8 100644 --- a/rapid_table/table_structure/table_structure.py +++ b/rapid_table/table_structure/table_structure.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import time +from typing import Dict, Any import numpy as np @@ -19,10 +20,10 @@ class TableStructurer: - def __init__(self, model_path: str): + def __init__(self, config: Dict[str, Any]): self.preprocess_op = TablePreprocess() - self.session = OrtInferSession(model_path) + self.session = OrtInferSession(config) self.character = self.session.get_metadata() self.postprocess_op = TableLabelDecode(self.character) @@ -37,7 +38,7 @@ def __call__(self, img): img = np.expand_dims(img, axis=0) img = img.copy() - outputs = self.session(img) + outputs = self.session([img]) preds = {"loc_preds": outputs[0], "structure_probs": outputs[1]} diff --git a/rapid_table/table_structure/utils.py b/rapid_table/table_structure/utils.py index 5c12ab1..c88cf3c 100644 --- a/rapid_table/table_structure/utils.py +++ b/rapid_table/table_structure/utils.py @@ -14,60 +14,233 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com -from pathlib import Path +import os +import platform +import traceback import cv2 +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union + import numpy as np -from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions +from onnxruntime import ( + GraphOptimizationLevel, + InferenceSession, + SessionOptions, + get_available_providers, + get_device, +) + +from rapid_table.table_structure.logger import get_logger + + +class EP(Enum): + CPU_EP = "CPUExecutionProvider" + CUDA_EP = "CUDAExecutionProvider" + DIRECTML_EP = "DmlExecutionProvider" class OrtInferSession: - def __init__(self, onnx_path: str): - self._verify_model(onnx_path) + def __init__(self, config: Dict[str, Any]): + self.logger = get_logger("OrtInferSession") + + model_path = config.get("model_path", None) + self._verify_model(model_path) + + self.cfg_use_cuda = config.get("use_cuda", None) + self.cfg_use_dml = config.get("use_dml", None) + + self.had_providers: List[str] = get_available_providers() + EP_list = self._get_ep_list() + + sess_opt = self._init_sess_opts(config) + self.session = InferenceSession( + model_path, + sess_options=sess_opt, + providers=EP_list, + ) + self._verify_providers() + @staticmethod + def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions: sess_opt = SessionOptions() sess_opt.log_severity_level = 4 sess_opt.enable_cpu_mem_arena = False sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL - cpu_ep = "CPUExecutionProvider" - cpu_provider_options = { + cpu_nums = os.cpu_count() + intra_op_num_threads = config.get("intra_op_num_threads", -1) + if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums: + sess_opt.intra_op_num_threads = intra_op_num_threads + + inter_op_num_threads = config.get("inter_op_num_threads", -1) + if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums: + sess_opt.inter_op_num_threads = inter_op_num_threads + + return sess_opt + def get_metadata(self, key: str = "character") -> list: + meta_dict = self.session.get_modelmeta().custom_metadata_map + content_list = meta_dict[key].splitlines() + return content_list + def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]: + cpu_provider_opts = { "arena_extend_strategy": "kSameAsRequested", } - EP_list = [(cpu_ep, cpu_provider_options)] + EP_list = [(EP.CPU_EP.value, cpu_provider_opts)] - self._verify_model(onnx_path) - self.session = InferenceSession( - onnx_path, sess_options=sess_opt, providers=EP_list + cuda_provider_opts = { + "device_id": 0, + "arena_extend_strategy": "kNextPowerOfTwo", + "cudnn_conv_algo_search": "EXHAUSTIVE", + "do_copy_in_default_stream": True, + } + self.use_cuda = self._check_cuda() + if self.use_cuda: + EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts)) + + self.use_directml = self._check_dml() + if self.use_directml: + self.logger.info( + "Windows 10 or above detected, try to use DirectML as primary provider" + ) + directml_options = ( + cuda_provider_opts if self.use_cuda else cpu_provider_opts + ) + EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options)) + return EP_list + + def _check_cuda(self) -> bool: + if not self.cfg_use_cuda: + return False + + cur_device = get_device() + if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers: + return True + + self.logger.warning( + "%s is not in available providers (%s). Use %s inference by default.", + EP.CUDA_EP.value, + self.had_providers, + self.had_providers[0], ) - - def __call__(self, input_content: np.ndarray) -> np.ndarray: - input_dict = dict(zip(self.get_input_names(), [input_content])) + self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.") + self.logger.info( + "(For reference only) If you want to use GPU acceleration, you must do:" + ) + self.logger.info( + "First, uninstall all onnxruntime pakcages in current environment." + ) + self.logger.info( + "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`." + ) + self.logger.info( + "\tNote the onnxruntime-gpu version must match your cuda and cudnn version." + ) + self.logger.info( + "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html" + ) + self.logger.info( + "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']", + EP.CUDA_EP.value, + ) + return False + + def _check_dml(self) -> bool: + if not self.cfg_use_dml: + return False + + cur_os = platform.system() + if cur_os != "Windows": + self.logger.warning( + "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.", + cur_os, + self.had_providers[0], + ) + return False + + cur_window_version = int(platform.release().split(".")[0]) + if cur_window_version < 10: + self.logger.warning( + "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.", + cur_window_version, + self.had_providers[0], + ) + return False + + if EP.DIRECTML_EP.value in self.had_providers: + return True + + self.logger.warning( + "%s is not in available providers (%s). Use %s inference by default.", + EP.DIRECTML_EP.value, + self.had_providers, + self.had_providers[0], + ) + self.logger.info("If you want to use DirectML acceleration, you must do:") + self.logger.info( + "First, uninstall all onnxruntime pakcages in current environment." + ) + self.logger.info( + "Second, install onnxruntime-directml by `pip install onnxruntime-directml`" + ) + self.logger.info( + "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']", + EP.DIRECTML_EP.value, + ) + return False + + def _verify_providers(self): + session_providers = self.session.get_providers() + first_provider = session_providers[0] + + if self.use_cuda and first_provider != EP.CUDA_EP.value: + self.logger.warning( + "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.", + EP.CUDA_EP.value, + first_provider, + ) + + if self.use_directml and first_provider != EP.DIRECTML_EP.value: + self.logger.warning( + "%s is not available for current env, the inference part is automatically shifted to be executed under %s.", + EP.DIRECTML_EP.value, + first_provider, + ) + + def __call__(self, input_content: List[np.ndarray]) -> np.ndarray: + input_dict = dict(zip(self.get_input_names(), input_content)) try: - return self.session.run(self.get_output_names(), input_dict) + return self.session.run(None, input_dict) except Exception as e: - raise ONNXRuntimeError("ONNXRuntime inferece failed.") from e + error_info = traceback.format_exc() + raise ONNXRuntimeError(error_info) from e - def get_input_names( - self, - ): + def get_input_names(self) -> List[str]: return [v.name for v in self.session.get_inputs()] - def get_output_names( - self, - ): + def get_output_names(self) -> List[str]: return [v.name for v in self.session.get_outputs()] - def get_metadata(self, key: str = "character") -> list: + def get_character_list(self, key: str = "character") -> List[str]: meta_dict = self.session.get_modelmeta().custom_metadata_map - content_list = meta_dict[key].splitlines() - return content_list + return meta_dict[key].splitlines() + + def have_key(self, key: str = "character") -> bool: + meta_dict = self.session.get_modelmeta().custom_metadata_map + if key in meta_dict.keys(): + return True + return False @staticmethod - def _verify_model(model_path): + def _verify_model(model_path: Union[str, Path, None]): + if model_path is None: + raise ValueError("model_path is None!") + model_path = Path(model_path) if not model_path.exists(): raise FileNotFoundError(f"{model_path} does not exists.") + if not model_path.is_file(): raise FileExistsError(f"{model_path} is not a file.") @@ -355,3 +528,16 @@ def __call__(self, data): for key in self.keep_keys: data_list.append(data[key]) return data_list + + +def trans_char_ocr_res(ocr_res): + word_result = [] + for res in ocr_res: + score = res[2] + for word_box, word in zip(res[3], res[4]): + word_res = [] + word_res.append(word_box) + word_res.append(word) + word_res.append(score) + word_result.append(word_res) + return word_result