Skip to content

Commit

Permalink
feat: model unloading
Browse files Browse the repository at this point in the history
  • Loading branch information
Fedir Zadniprovskyi authored and fedirz committed Oct 1, 2024
1 parent 1a02399 commit caba05a
Show file tree
Hide file tree
Showing 8 changed files with 319 additions and 104 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ dependencies = [
client = [
"keyboard>=0.13.5",
]
# NOTE: when installing `dev` group, all other groups should also be installed
dev = [
"anyio>=4.4.0",
"basedpyright>=1.18.0",
"pytest-antilru>=2.0.0",
"pytest-asyncio>=0.24.0",
"pytest-xdist>=3.6.1",
"pytest>=8.3.3",
Expand Down
25 changes: 9 additions & 16 deletions src/faster_whisper_server/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import enum
from typing import Self

from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings, SettingsConfigDict

SAMPLES_PER_SECOND = 16000
Expand Down Expand Up @@ -163,6 +162,12 @@ class WhisperConfig(BaseModel):
compute_type: Quantization = Field(default=Quantization.DEFAULT)
cpu_threads: int = 0
num_workers: int = 1
ttl: int = Field(default=300, ge=-1)
"""
Time in seconds until the model is unloaded if it is not being used.
-1: Never unload the model.
0: Unload the model immediately after usage.
"""


class Config(BaseSettings):
Expand Down Expand Up @@ -198,10 +203,6 @@ class Config(BaseSettings):
"""
default_response_format: ResponseFormat = ResponseFormat.JSON
whisper: WhisperConfig = WhisperConfig()
max_models: int = 1
"""
Maximum number of models that can be loaded at a time.
"""
preload_models: list[str] = Field(
default_factory=list,
examples=[
Expand All @@ -210,8 +211,8 @@ class Config(BaseSettings):
],
)
"""
List of models to preload on startup. Shouldn't be greater than `max_models`. By default, the model is first loaded on first request.
""" # noqa: E501
List of models to preload on startup. By default, the model is first loaded on first request.
"""
max_no_data_seconds: float = 1.0
"""
Max duration to wait for the next audio chunk before transcription is finilized and connection is closed.
Expand All @@ -230,11 +231,3 @@ class Config(BaseSettings):
Controls how many latest seconds of audio are being passed through VAD.
Should be greater than `max_inactivity_seconds`
"""

@model_validator(mode="after")
def ensure_preloaded_models_is_lte_max_models(self) -> Self:
if len(self.preload_models) > self.max_models:
raise ValueError(
f"Number of preloaded models ({len(self.preload_models)}) is greater than max_models ({self.max_models})" # noqa: E501
)
return self
2 changes: 1 addition & 1 deletion src/faster_whisper_server/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def get_config() -> Config:
@lru_cache
def get_model_manager() -> ModelManager:
config = get_config() # HACK
return ModelManager(config)
return ModelManager(config.whisper)


ModelManagerDependency = Annotated[ModelManager, Depends(get_model_manager)]
144 changes: 114 additions & 30 deletions src/faster_whisper_server/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,132 @@
from collections import OrderedDict
import gc
import logging
import threading
import time
from typing import TYPE_CHECKING

from faster_whisper import WhisperModel

if TYPE_CHECKING:
from collections.abc import Callable

from faster_whisper_server.config import (
Config,
WhisperConfig,
)

logger = logging.getLogger(__name__)

# TODO: enable concurrent model downloads


class SelfDisposingWhisperModel:
def __init__(
self,
model_id: str,
whisper_config: WhisperConfig,
*,
on_unload: Callable[[str], None] | None = None,
) -> None:
self.model_id = model_id
self.whisper_config = whisper_config
self.on_unload = on_unload

self.ref_count: int = 0
self.rlock = threading.RLock()
self.expire_timer: threading.Timer | None = None
self.whisper: WhisperModel | None = None

def unload(self) -> None:
with self.rlock:
if self.whisper is None:
raise ValueError(f"Model {self.model_id} is not loaded. {self.ref_count=}")
if self.ref_count > 0:
raise ValueError(f"Model {self.model_id} is still in use. {self.ref_count=}")
if self.expire_timer:
self.expire_timer.cancel()
self.whisper = None
# WARN: ~300 MB of memory will still be held by the model. See https://github.com/SYSTRAN/faster-whisper/issues/992
gc.collect()
logger.info(f"Model {self.model_id} unloaded")
if self.on_unload is not None:
self.on_unload(self.model_id)

def _load(self) -> None:
with self.rlock:
assert self.whisper is None
logger.debug(f"Loading model {self.model_id}")
start = time.perf_counter()
self.whisper = WhisperModel(
self.model_id,
device=self.whisper_config.inference_device,
device_index=self.whisper_config.device_index,
compute_type=self.whisper_config.compute_type,
cpu_threads=self.whisper_config.cpu_threads,
num_workers=self.whisper_config.num_workers,
)
logger.info(f"Model {self.model_id} loaded in {time.perf_counter() - start:.2f}s")

def _increment_ref(self) -> None:
with self.rlock:
self.ref_count += 1
if self.expire_timer:
logger.debug(f"Model was set to expire in {self.expire_timer.interval}s, cancelling")
self.expire_timer.cancel()
logger.debug(f"Incremented ref count for {self.model_id}, {self.ref_count=}")

def _decrement_ref(self) -> None:
with self.rlock:
self.ref_count -= 1
logger.debug(f"Decremented ref count for {self.model_id}, {self.ref_count=}")
if self.ref_count <= 0:
if self.whisper_config.ttl > 0:
logger.info(f"Model {self.model_id} is idle, scheduling offload in {self.whisper_config.ttl}s")
self.expire_timer = threading.Timer(self.whisper_config.ttl, self.unload)
self.expire_timer.start()
elif self.whisper_config.ttl == 0:
logger.info(f"Model {self.model_id} is idle, unloading immediately")
self.unload()
else:
logger.info(f"Model {self.model_id} is idle, not unloading")

def __enter__(self) -> WhisperModel:
with self.rlock:
if self.whisper is None:
self._load()
self._increment_ref()
assert self.whisper is not None
return self.whisper

def __exit__(self, *_args) -> None: # noqa: ANN002
self._decrement_ref()


class ModelManager:
def __init__(self, config: Config) -> None:
self.config = config
self.loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
def __init__(self, whisper_config: WhisperConfig) -> None:
self.whisper_config = whisper_config
self.loaded_models: OrderedDict[str, SelfDisposingWhisperModel] = OrderedDict()
self._lock = threading.Lock()

def load_model(self, model_name: str) -> WhisperModel:
if model_name in self.loaded_models:
logger.debug(f"{model_name} model already loaded")
return self.loaded_models[model_name]
if len(self.loaded_models) >= self.config.max_models:
oldest_model_name = next(iter(self.loaded_models))
logger.info(
f"Max models ({self.config.max_models}) reached. Unloading the oldest model: {oldest_model_name}"
def _handle_model_unload(self, model_name: str) -> None:
with self._lock:
if model_name in self.loaded_models:
del self.loaded_models[model_name]

def unload_model(self, model_name: str) -> None:
with self._lock:
model = self.loaded_models.get(model_name)
if model is None:
raise KeyError(f"Model {model_name} not found")
self.loaded_models[model_name].unload()

def load_model(self, model_name: str) -> SelfDisposingWhisperModel:
with self._lock:
if model_name in self.loaded_models:
logger.debug(f"{model_name} model already loaded")
return self.loaded_models[model_name]
self.loaded_models[model_name] = SelfDisposingWhisperModel(
model_name,
self.whisper_config,
on_unload=self._handle_model_unload,
)
del self.loaded_models[oldest_model_name]
gc.collect()
logger.debug(f"Loading {model_name}...")
start = time.perf_counter()
# NOTE: will raise an exception if the model name isn't valid. Should I do an explicit check?
whisper = WhisperModel(
model_name,
device=self.config.whisper.inference_device,
device_index=self.config.whisper.device_index,
compute_type=self.config.whisper.compute_type,
cpu_threads=self.config.whisper.cpu_threads,
num_workers=self.config.whisper.num_workers,
)
logger.info(
f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {self.config.whisper.inference_device}({self.config.whisper.compute_type}) will be used for inference." # noqa: E501
)
self.loaded_models[model_name] = whisper
return whisper
return self.loaded_models[model_name]
18 changes: 10 additions & 8 deletions src/faster_whisper_server/routers/misc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import gc

from fastapi import (
APIRouter,
Response,
Expand Down Expand Up @@ -42,15 +40,19 @@ def get_running_models(
def load_model_route(model_manager: ModelManagerDependency, model_name: str) -> Response:
if model_name in model_manager.loaded_models:
return Response(status_code=409, content="Model already loaded")
model_manager.load_model(model_name)
with model_manager.load_model(model_name):
pass
return Response(status_code=201)


@router.delete("/api/ps/{model_name:path}", tags=["experimental"], summary="Unload a model from memory.")
def stop_running_model(model_manager: ModelManagerDependency, model_name: str) -> Response:
model = model_manager.loaded_models.get(model_name)
if model is not None:
del model_manager.loaded_models[model_name]
gc.collect()
try:
model_manager.unload_model(model_name)
return Response(status_code=204)
return Response(status_code=404)
except (KeyError, ValueError) as e:
match e:
case KeyError():
return Response(status_code=404, content="Model not found")
case ValueError():
return Response(status_code=409, content=str(e))
96 changes: 48 additions & 48 deletions src/faster_whisper_server/routers/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,20 +142,20 @@ def translate_file(
model = config.whisper.model
if response_format is None:
response_format = config.default_response_format
whisper = model_manager.load_model(model)
segments, transcription_info = whisper.transcribe(
file.file,
task=Task.TRANSLATE,
initial_prompt=prompt,
temperature=temperature,
vad_filter=vad_filter,
)
segments = TranscriptionSegment.from_faster_whisper_segments(segments)

if stream:
return segments_to_streaming_response(segments, transcription_info, response_format)
else:
return segments_to_response(segments, transcription_info, response_format)
with model_manager.load_model(model) as whisper:
segments, transcription_info = whisper.transcribe(
file.file,
task=Task.TRANSLATE,
initial_prompt=prompt,
temperature=temperature,
vad_filter=vad_filter,
)
segments = TranscriptionSegment.from_faster_whisper_segments(segments)

if stream:
return segments_to_streaming_response(segments, transcription_info, response_format)
else:
return segments_to_response(segments, transcription_info, response_format)


# HACK: Since Form() doesn't support `alias`, we need to use a workaround.
Expand Down Expand Up @@ -206,23 +206,23 @@ def transcribe_file(
logger.warning(
"It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities." # noqa: E501
)
whisper = model_manager.load_model(model)
segments, transcription_info = whisper.transcribe(
file.file,
task=Task.TRANSCRIBE,
language=language,
initial_prompt=prompt,
word_timestamps="word" in timestamp_granularities,
temperature=temperature,
vad_filter=vad_filter,
hotwords=hotwords,
)
segments = TranscriptionSegment.from_faster_whisper_segments(segments)

if stream:
return segments_to_streaming_response(segments, transcription_info, response_format)
else:
return segments_to_response(segments, transcription_info, response_format)
with model_manager.load_model(model) as whisper:
segments, transcription_info = whisper.transcribe(
file.file,
task=Task.TRANSCRIBE,
language=language,
initial_prompt=prompt,
word_timestamps="word" in timestamp_granularities,
temperature=temperature,
vad_filter=vad_filter,
hotwords=hotwords,
)
segments = TranscriptionSegment.from_faster_whisper_segments(segments)

if stream:
return segments_to_streaming_response(segments, transcription_info, response_format)
else:
return segments_to_response(segments, transcription_info, response_format)


async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
Expand Down Expand Up @@ -280,24 +280,24 @@ async def transcribe_stream(
"vad_filter": vad_filter,
"condition_on_previous_text": False,
}
whisper = model_manager.load_model(model)
asr = FasterWhisperASR(whisper, **transcribe_opts)
audio_stream = AudioStream()
async with asyncio.TaskGroup() as tg:
tg.create_task(audio_receiver(ws, audio_stream))
async for transcription in audio_transcriber(asr, audio_stream, min_duration=config.min_duration):
logger.debug(f"Sending transcription: {transcription.text}")
if ws.client_state == WebSocketState.DISCONNECTED:
break
with model_manager.load_model(model) as whisper:
asr = FasterWhisperASR(whisper, **transcribe_opts)
audio_stream = AudioStream()
async with asyncio.TaskGroup() as tg:
tg.create_task(audio_receiver(ws, audio_stream))
async for transcription in audio_transcriber(asr, audio_stream, min_duration=config.min_duration):
logger.debug(f"Sending transcription: {transcription.text}")
if ws.client_state == WebSocketState.DISCONNECTED:
break

if response_format == ResponseFormat.TEXT:
await ws.send_text(transcription.text)
elif response_format == ResponseFormat.JSON:
await ws.send_json(CreateTranscriptionResponseJson.from_transcription(transcription).model_dump())
elif response_format == ResponseFormat.VERBOSE_JSON:
await ws.send_json(
CreateTranscriptionResponseVerboseJson.from_transcription(transcription).model_dump()
)
if response_format == ResponseFormat.TEXT:
await ws.send_text(transcription.text)
elif response_format == ResponseFormat.JSON:
await ws.send_json(CreateTranscriptionResponseJson.from_transcription(transcription).model_dump())
elif response_format == ResponseFormat.VERBOSE_JSON:
await ws.send_json(
CreateTranscriptionResponseVerboseJson.from_transcription(transcription).model_dump()
)

if ws.client_state != WebSocketState.DISCONNECTED:
logger.info("Closing the connection.")
Expand Down
Loading

0 comments on commit caba05a

Please sign in to comment.