Skip to content

Commit

Permalink
feat: add unit testing of cli
Browse files Browse the repository at this point in the history
  • Loading branch information
SWHL committed Jan 10, 2025
1 parent f327707 commit 14d7ce7
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
15 changes: 8 additions & 7 deletions rapid_table/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def get_model_path(
raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.")


def main():
def parse_args(arg_list: Optional[List[str]] = None):
parser = argparse.ArgumentParser()
parser.add_argument(
"-v",
Expand All @@ -189,7 +189,12 @@ def main():
default=ModelType.SLANETPLUS.value,
choices=list(KEY_TO_MODEL_URL),
)
args = parser.parse_args()
args = parser.parse_args(arg_list)
return args


def main(arg_list: Optional[List[str]] = None):
args = parse_args(arg_list)

try:
ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR()
Expand All @@ -205,11 +210,7 @@ def main():

ocr_result, _ = ocr_engine(img)
table_results = table_engine(img, ocr_result)
table_html_str, table_cell_bboxes = (
table_results.pred_html,
table_results.cell_bboxes,
)
print(table_html_str)
print(table_results.pred_html)

viser = VisTable()
if args.vis:
Expand Down
12 changes: 12 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: [email protected]
import shlex
import sys
from pathlib import Path

Expand All @@ -13,6 +14,7 @@
sys.path.append(str(root_dir))

from rapid_table import RapidTable, RapidTableInput
from rapid_table.main import main

ocr_engine = RapidOCR()

Expand All @@ -23,6 +25,16 @@
img_path = str(test_file_dir / "table.jpg")


@pytest.mark.parametrize(
"command, expected_output",
[(f"--img_path {img_path} --model_type slanet_plus", 1274)],
)
def test_main(capsys, command, expected_output):
main(shlex.split(command))
output = capsys.readouterr().out.rstrip()
assert len(output) == expected_output


@pytest.mark.parametrize("model_type", ["slanet_plus", "unitable"])
def test_ocr_input(model_type):
ocr_res, _ = ocr_engine(img_path)
Expand Down

0 comments on commit 14d7ce7

Please sign in to comment.