-
Notifications
You must be signed in to change notification settings - Fork 110
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from glucauze/safetensors
v1.1.2 Safetensors
- Loading branch information
Showing
12 changed files
with
115 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,9 @@ | ||
cython | ||
dill==0.3.6 | ||
ifnude | ||
insightface==0.7.3 | ||
onnx==1.14.0 | ||
onnxruntime==1.15.0 | ||
opencv-python==4.7.0.72 | ||
pandas | ||
pydantic==1.10.9 | ||
dill==0.3.6 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import glob | ||
import os | ||
from typing import List | ||
from insightface.app.common import Face | ||
from safetensors.torch import save_file, safe_open | ||
import torch | ||
|
||
import modules.scripts as scripts | ||
from modules import scripts | ||
from scripts.faceswaplab_utils.faceswaplab_logging import logger | ||
import dill as pickle # will be removed in future versions | ||
|
||
|
||
def save_face(face: Face, filename: str) -> None: | ||
tensors = { | ||
"embedding": torch.tensor(face["embedding"]), | ||
"gender": torch.tensor(face["gender"]), | ||
"age": torch.tensor(face["age"]), | ||
} | ||
save_file(tensors, filename) | ||
|
||
|
||
def load_face(filename: str) -> Face: | ||
if filename.endswith(".pkl"): | ||
logger.warning( | ||
"Pkl files for faces are deprecated to enhance safety, they will be unsupported in future versions." | ||
) | ||
logger.warning("The file will be converted to .safetensors") | ||
logger.warning( | ||
"You can also use this script https://gist.github.com/glucauze/4a3c458541f2278ad801f6625e5b9d3d" | ||
) | ||
with open(filename, "rb") as file: | ||
logger.info("Load pkl") | ||
face = Face(pickle.load(file)) | ||
logger.warning( | ||
"Convert to safetensors, you can remove the pkl version once you have ensured that the safetensor is working" | ||
) | ||
save_face(face, filename.replace(".pkl", ".safetensors")) | ||
return face | ||
|
||
elif filename.endswith(".safetensors"): | ||
face = {} | ||
with safe_open(filename, framework="pt", device="cpu") as f: | ||
for k in f.keys(): | ||
logger.debug("load key %s", k) | ||
face[k] = f.get_tensor(k).numpy() | ||
return Face(face) | ||
|
||
raise NotImplementedError("Unknown file type, face extraction not implemented") | ||
|
||
|
||
def get_face_checkpoints() -> List[str]: | ||
""" | ||
Retrieve a list of face checkpoint paths. | ||
This function searches for face files with the extension ".safetensors" in the specified directory and returns a list | ||
containing the paths of those files. | ||
Returns: | ||
list: A list of face paths, including the string "None" as the first element. | ||
""" | ||
faces_path = os.path.join( | ||
scripts.basedir(), "models", "faceswaplab", "faces", "*.safetensors" | ||
) | ||
faces = glob.glob(faces_path) | ||
|
||
faces_path = os.path.join( | ||
scripts.basedir(), "models", "faceswaplab", "faces", "*.pkl" | ||
) | ||
faces += glob.glob(faces_path) | ||
|
||
return ["None"] + sorted(faces) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters