Skip to content

Commit

Permalink
update rapidocr_api (#242)
Browse files Browse the repository at this point in the history
Co-authored-by: 何熙 <[email protected]>
  • Loading branch information
xmxoxo and 何熙 authored Oct 28, 2024
1 parent 4ab68d6 commit 3aa4463
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 21 deletions.
18 changes: 18 additions & 0 deletions api/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
FROM python:3.10.11-slim-buster

ENV DEBIAN_FRONTEND=noninteractive

# 设置工作目录
WORKDIR /app

RUN pip install --no-cache-dir rapidocr_api -i https://mirrors.aliyun.com/pypi/simple

RUN pip uninstall -y opencv-python && \
pip install --no-cache-dir opencv-python-headless -i https://mirrors.aliyun.com/pypi/simple

EXPOSE 9003

# 升级后可用
# CMD ["bash", "-c", "rapidocr_api -ip 0.0.0.0 -p 9003 -workers 2"]
CMD ["rapidocr_api"]

53 changes: 53 additions & 0 deletions api/README.md
Original file line number Diff line number Diff line change
@@ -1 +1,54 @@
### See [Documentation](https://rapidai.github.io/RapidOCRDocs/install_usage/rapidocr_api/usage/)


### API 修改说明

* uvicorn启动时,reload参数设置为False,避免反复加载;
* 增加了启动参数: workers,可启动多个实例,以满足多并发需求。
* 可通过环境变量传递模型参数:det_model_path, cls_model_path, rec_model_path;
* 接口中可传入参数,控制是否使用检测、方向分类和识别这三部分的模型;客户端调用见`demo.py`
* 增加了Dockerfile,可自行构建镜像。

启动服务端:

Windows下启动:

```for win shell
set det_model_path=I:\models\图像相关\OCR\RapidOCR\PP-OCRv4\ch_PP-OCRv4_det_server_infer.onnx
set det_model_path=
set rec_model_path=I:\models\图像相关\OCR\RapidOCR\PP-OCRv4\ch_PP-OCRv4_rec_server_infer.onnx
rapidocr_api
```

Linux下启动:
```shell
# 默认参数启动
rapidocr_api

# 指定参数:端口与进程数量;
rapidocr_api -ip 0.0.0.0 -p 9005 -workers 2

# 指定模型
expert det_model_path=/mnt/sda1/models/PP-OCRv4/ch_PP-OCRv4_det_server_infer.onnx
expert rec_model_path=/mnt/sda1/models/PP-OCRv4/ch_PP-OCRv4_rec_server_infer.onnx
rapidocr_api -ip 0.0.0.0 -p 9005 -workers 2
```


客户端调用说明:
```
cd api
python demo.py
```

构建镜像:
```
cd api
sudo docker build -t="rapidocr_api:0.1.1" .
```
启动镜像:

```
docker run -p 9003:9003 --name rapidocr_api1 --restart always -d rapidocr_api:0.1.1
```
36 changes: 34 additions & 2 deletions api/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,49 @@

# print(response.text)

import time
import base64

import requests

url = "http://localhost:9003/ocr"
img_path = "../python/tests/test_files/ch_en_num.jpg"

# 方式一:使用base64编码传
stime = time.time()
with open(img_path, "rb") as fa:
img_str = base64.b64encode(fa.read())

payload = {"image_data": img_str}
response = requests.post(url, data=payload, timeout=60)
response = requests.post(url, data=payload) #, timeout=60

print(response.json())
etime = time.time() - stime
print(f'用时:{etime:.3f}秒')

print('-'*40)

# 方式二:使用文件上传方式
stime = time.time()
with open(img_path, 'rb') as f:
file_dict = {'image_file': (img_path, f, 'image/png')}
response = requests.post(url, files=file_dict) #, timeout=60
print(response.json())

etime = time.time() - stime
print(f'用时:{etime:.3f}秒')
print('-'*40)

# 方式三:控制是否使用检测、方向分类和识别这三部分的模型; 不使用检测模型:use_det=False
stime = time.time()
img_path = "../python/tests/test_files/test_without_det.jpg"

with open(img_path, 'rb') as f:
file_dict = {'image_file': (img_path, f, 'image/png')}
# 添加控制参数
data = {"use_det":False, "use_cls":True, "use_rec":True}
response = requests.post(url, files=file_dict, data=data) #, timeout=60
print(response.json())

etime = time.time() - stime
print(f'用时:{etime:.3f}秒')
print('-'*40)
48 changes: 29 additions & 19 deletions api/rapidocr_api/main.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]

import argparse
import base64
import io
import os
import sys
from pathlib import Path
from typing import Dict
Expand All @@ -19,37 +21,47 @@

class OCRAPIUtils:
def __init__(self) -> None:
self.ocr = RapidOCR()
# 从环境变量中读取参数
det_model_path = os.getenv("det_model_path", None)
cls_model_path = os.getenv("cls_model_path", None)
rec_model_path = os.getenv("rec_model_path", None)

self.ocr = RapidOCR(det_model_path=det_model_path, cls_model_path=cls_model_path, rec_model_path=rec_model_path)

def __call__(self, img: Image.Image) -> Dict:
def __call__(self, img: Image.Image, use_det=None, use_cls=None, use_rec=None) -> Dict:
img = np.array(img)
ocr_res, _ = self.ocr(img)
ocr_res, _ = self.ocr(img, use_det=use_det, use_cls=use_cls, use_rec=use_rec)

if not ocr_res:
return {}

out_dict = {
str(i): {
"rec_txt": rec,
"dt_boxes": dt_box,
"score": f"{score:.4f}",
}
for i, (dt_box, rec, score) in enumerate(ocr_res)
}
# 转换为字典格式: 兼容所有参数情况
out_dict = {}
for i, dats in enumerate(ocr_res):
values = {}
for dat in dats:
if type(dat) == str:
values["rec_txt"] = dat
if type(dat) == np.float64:
values["score"] = f"{dat:.4f}"
if type(dat) == list:
values["dt_boxes"] = dat
out_dict[str(i)] = values

return out_dict


app = FastAPI()
processor = OCRAPIUtils()


@app.get("/")
async def root():
return {"message": "Welcome to RapidOCR API Server!"}


@app.post("/ocr")
async def ocr(image_file: UploadFile = None, image_data: str = Form(None)):
async def ocr(image_file: UploadFile = None, image_data: str = Form(None),
use_det: bool = Form(None), use_cls: bool = Form(None), use_rec: bool = Form(None)):

if image_file:
img = Image.open(image_file.file)
elif image_data:
Expand All @@ -60,19 +72,17 @@ async def ocr(image_file: UploadFile = None, image_data: str = Form(None)):
raise ValueError(
"When sending a post request, data or files must have a value."
)

ocr_res = processor(img)
ocr_res = processor(img, use_det=use_det, use_cls=use_cls, use_rec=use_rec)
return ocr_res


def main():
parser = argparse.ArgumentParser("rapidocr_api")
parser.add_argument("-ip", "--ip", type=str, default="0.0.0.0", help="IP Address")
parser.add_argument("-p", "--port", type=int, default=9003, help="IP port")
parser.add_argument('-workers', "--workers", type=int, default=1, help='number of worker process')
args = parser.parse_args()

uvicorn.run("rapidocr_api.main:app", host=args.ip, port=args.port, reload=True)

uvicorn.run("rapidocr_api.main:app", host=args.ip, port=args.port, reload=0, workers=args.workers)

if __name__ == "__main__":
main()

0 comments on commit 3aa4463

Please sign in to comment.