Skip to content

Commit

Permalink
Add: 指定された音声合成モデルをロード・アンロードする API を追加
Browse files Browse the repository at this point in the history
事前にロードしたり、メモリ節約のためにアンロードしたりができるようになる (ロードするだけなら /initialize_speaker でもできるにはできたが、分かりづらかった)
  • Loading branch information
tsukumijima committed Dec 24, 2024
1 parent b2ef9ea commit 35659a7
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 3 deletions.
2 changes: 1 addition & 1 deletion voicevox_engine/app/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _get_core_characters(version: str | None) -> list[CoreCharacter]:
generate_library_router(library_manager, verify_mutability_allowed)
)
# generate_aivm_models_router() は AivisSpeech Engine 独自追加ルーター
app.include_router(generate_aivm_models_router(aivm_manager, verify_mutability_allowed)) # noqa # fmt: skip
app.include_router(generate_aivm_models_router(aivm_manager, tts_engines, verify_mutability_allowed)) # noqa # fmt: skip
app.include_router(
generate_preset_router(preset_manager, verify_mutability_allowed)
)
Expand Down
50 changes: 49 additions & 1 deletion voicevox_engine/app/routers/aivm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@

from voicevox_engine.aivm_manager import AivmManager
from voicevox_engine.model import AivmInfo
from voicevox_engine.tts_pipeline.style_bert_vits2_tts_engine import (
StyleBertVITS2TTSEngine,
)
from voicevox_engine.tts_pipeline.tts_engine import LATEST_VERSION, TTSEngineManager

from ..dependencies import VerifyMutabilityAllowed


def generate_aivm_models_router(
aivm_manager: AivmManager,
tts_engines: TTSEngineManager,
verify_mutability: VerifyMutabilityAllowed,
) -> APIRouter:
"""音声合成モデル管理 API Router を生成する"""
Expand Down Expand Up @@ -79,11 +84,54 @@ def get_aivm_info(

return aivm_manager.get_aivm_info(aivm_uuid)

@router.post(
"/{aivm_uuid}/load",
status_code=204,
summary="指定された音声合成モデルをロードする",
)
def load_aivm(
aivm_uuid: Annotated[str, Path(description="音声合成モデルの UUID")],
) -> None:
"""
指定された音声合成モデルをロードします。すでにロード済みの場合は何も行われません。
実行しなくても他の API は使用できますが、初回実行時に時間がかかることがあります。
"""

# まず対応する音声合成モデルがインストールされているかを確認
# 存在しない場合は内部で HTTPException が送出される
aivm_info = aivm_manager.get_aivm_info(aivm_uuid)

# StyleBertVITS2TTSEngine を取得し、音声合成モデルをロード
engine = tts_engines.get_engine(LATEST_VERSION)
assert isinstance(engine, StyleBertVITS2TTSEngine)
engine.load_model(str(aivm_info.manifest.uuid))

@router.post(
"/{aivm_uuid}/unload",
status_code=204,
summary="指定された音声合成モデルをアンロードする",
)
def unload_aivm(
aivm_uuid: Annotated[str, Path(description="音声合成モデルの UUID")],
) -> None:
"""
指定された音声合成モデルをアンロードします。
"""

# まず対応する音声合成モデルがインストールされているかを確認
# 存在しない場合は内部で HTTPException が送出される
aivm_info = aivm_manager.get_aivm_info(aivm_uuid)

# StyleBertVITS2TTSEngine を取得し、音声合成モデルをアンロード
engine = tts_engines.get_engine(LATEST_VERSION)
assert isinstance(engine, StyleBertVITS2TTSEngine)
engine.unload_model(str(aivm_info.manifest.uuid))

@router.delete(
"/{aivm_uuid}/uninstall",
status_code=204,
dependencies=[Depends(verify_mutability)],
summary="音声合成モデルをアンインストールする",
summary="指定された音声合成モデルをアンインストールする",
)
def uninstall_aivm(
aivm_uuid: Annotated[str, Path(description="音声合成モデルの UUID")]
Expand Down
26 changes: 25 additions & 1 deletion voicevox_engine/tts_pipeline/style_bert_vits2_tts_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def load_model(self, aivm_uuid: str) -> TTSModel:
onnx_providers=self.onnx_providers,
) # fmt: skip
start_time = time.time()
logger.info(f"Loading {aivm_info.manifest.name} ({aivm_uuid})...")
logger.info(f"Loading {aivm_info.manifest.name} ({aivm_uuid}) ...")
tts_model.load()
logger.info(
f"{aivm_info.manifest.name} ({aivm_uuid}) loaded. ({time.time() - start_time:.2f}s)"
Expand All @@ -260,6 +260,30 @@ def load_model(self, aivm_uuid: str) -> TTSModel:
self.tts_models[aivm_uuid] = tts_model
return tts_model

def unload_model(self, aivm_uuid: str) -> None:
"""
指定された AIVM の UUID に対応する音声合成モデルをアンロードする
継承元の TTSEngine には存在しない、StyleBertVITS2TTSEngine 固有のメソッド
Parameters
----------
aivm_uuid : str
AIVM の UUID
"""

# モデルがロードされていない場合は何もしない
if not self.is_model_loaded(aivm_uuid):
return

# モデルをアンロード
aivm_info = self.aivm_manager.get_aivm_info(aivm_uuid)
start_time = time.time()
logger.info(f"Unloading {aivm_info.manifest.name} ({aivm_uuid}) ...")
del self.tts_models[aivm_uuid]
logger.info(
f"{aivm_info.manifest.name} ({aivm_uuid}) unloaded. ({time.time() - start_time:.2f}s)"
)

def is_model_loaded(self, aivm_uuid: str) -> bool:
"""
指定された AIVM の UUID に対応する音声合成モデルがロード済みかどうかを返す
Expand Down

0 comments on commit 35659a7

Please sign in to comment.