Skip to content

Commit

Permalink
import encryption for aistudio & fix sync bn
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyubo0722 authored Jan 3, 2025
1 parent 4f7476d commit e314510
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 10 deletions.
19 changes: 19 additions & 0 deletions ppocr/utils/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,12 @@ def export_single_model(
model = dynamic_to_static(model, arch_config, logger, input_shape)

if quanter is None:
try:
import encryption # Attempt to import the encryption module for AIStudio's encryption model
except (
ModuleNotFoundError
): # Encryption is not needed if the module cannot be imported
print("Skipping import of the encryption module")
if config["Global"].get("export_with_pir", False):
paddle_version = version.parse(paddle.__version__)
assert (
Expand All @@ -349,6 +355,18 @@ def export_single_model(
return


def convert_bn(model):
for n, m in model.named_children():
if isinstance(m, nn.SyncBatchNorm):
bn = nn.BatchNorm2D(
m._num_features, m._momentum, m._epsilon, m._weight_attr, m._bias_attr
)
bn.set_dict(m.state_dict())
setattr(model, n, bn)
else:
convert_bn(m)


def export(config, base_model=None, save_path=None):
if paddle.distributed.get_rank() != 0:
return
Expand Down Expand Up @@ -424,6 +442,7 @@ def export(config, base_model=None, save_path=None):
else:
model = build_model(config["Architecture"])
load_model(config, model, model_type=config["Architecture"]["model_type"])
convert_bn(model)
model.eval()

if not save_path:
Expand Down
48 changes: 38 additions & 10 deletions ppocr/utils/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@
from ppocr.utils.logging import get_logger
from ppocr.utils.network import maybe_download_params

try:
import encryption # Attempt to import the encryption module for AIStudio's encryption model

encrypted = encryption.is_encryption_needed()
except ImportError:
get_logger().warning("Skipping import of the encryption module.")
encrypted = False # Encryption is not needed if the module cannot be imported

__all__ = ["load_model"]


Expand Down Expand Up @@ -278,13 +286,11 @@ def update_train_results(config, prefix, metric_info, done_flag=False, last_num=
else:
train_results = {}
train_results["model_name"] = config["Global"]["pdx_model_name"]
label_dict_path = os.path.abspath(
config["Global"].get("character_dict_path", "")
)
label_dict_path = config["Global"].get("character_dict_path", "")
if label_dict_path != "":
label_dict_path = os.path.abspath(label_dict_path)
if not os.path.exists(label_dict_path):
label_dict_path = ""
label_dict_path = label_dict_path
train_results["label_dict"] = label_dict_path
train_results["train_log"] = "train.log"
train_results["visualdl_log"] = ""
Expand All @@ -305,9 +311,20 @@ def update_train_results(config, prefix, metric_info, done_flag=False, last_num=
raise ValueError("No metric score found.")
train_results["models"]["best"]["score"] = metric_score
for tag in save_model_tag:
train_results["models"]["best"][tag] = os.path.join(
prefix, f"{prefix}.{tag}" if tag != "pdstates" else f"{prefix}.states"
)
if tag == "pdparams" and encrypted:
train_results["models"]["best"][tag] = os.path.join(
prefix,
(
f"{prefix}.encrypted.{tag}"
if tag != "pdstates"
else f"{prefix}.states"
),
)
else:
train_results["models"]["best"][tag] = os.path.join(
prefix,
f"{prefix}.{tag}" if tag != "pdstates" else f"{prefix}.states",
)
for tag in save_inference_tag:
train_results["models"]["best"][tag] = os.path.join(
prefix,
Expand All @@ -329,9 +346,20 @@ def update_train_results(config, prefix, metric_info, done_flag=False, last_num=
metric_score = 0
train_results["models"][f"last_{1}"]["score"] = metric_score
for tag in save_model_tag:
train_results["models"][f"last_{1}"][tag] = os.path.join(
prefix, f"{prefix}.{tag}" if tag != "pdstates" else f"{prefix}.states"
)
if tag == "pdparams" and encrypted:
train_results["models"][f"last_{1}"][tag] = os.path.join(
prefix,
(
f"{prefix}.encrypted.{tag}"
if tag != "pdstates"
else f"{prefix}.states"
),
)
else:
train_results["models"][f"last_{1}"][tag] = os.path.join(
prefix,
f"{prefix}.{tag}" if tag != "pdstates" else f"{prefix}.states",
)
for tag in save_inference_tag:
train_results["models"][f"last_{1}"][tag] = os.path.join(
prefix,
Expand Down

0 comments on commit e314510

Please sign in to comment.