Skip to content

Commit

Permalink
feat: support BatchedInferencePipeline (#169)
Browse files Browse the repository at this point in the history
  • Loading branch information
Fedir Zadniprovskyi committed Dec 16, 2024
1 parent 65a92dc commit 3ed59e8
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ requires-python = ">=3.12,<3.13"
dependencies = [
"ctranslate2>=4.5.0",
"fastapi>=0.115.0",
"faster-whisper>=1.0.3",
"faster-whisper>=1.1.0",
"huggingface-hub>=0.25.1",
"numpy>=2.1.1",
"piper-phonemize ; platform_machine == 'x86_64'",
Expand Down
1 change: 1 addition & 0 deletions src/faster_whisper_server/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _transcribe(
prompt: str | None = None,
) -> tuple[Transcription, transcribe.TranscriptionInfo]:
start = time.perf_counter()
# NOTE: should `BatchedInferencePipeline` be used here?
segments, transcription_info = self.whisper.transcribe(
audio.data,
initial_prompt=prompt,
Expand Down
4 changes: 4 additions & 0 deletions src/faster_whisper_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ class WhisperConfig(BaseModel):
-1: Never unload the model.
0: Unload the model immediately after usage.
"""
use_batched_mode: bool = False
"""
Whether to use batch mode(introduced in 1.1.0 `faster-whisper` release) for inference. This will likely become the default in the future and the configuration option will be removed.
""" # noqa: E501


class Config(BaseSettings):
Expand Down
7 changes: 5 additions & 2 deletions src/faster_whisper_server/routers/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from fastapi.responses import StreamingResponse
from fastapi.websockets import WebSocketState
from faster_whisper.audio import decode_audio
from faster_whisper.transcribe import BatchedInferencePipeline
from faster_whisper.vad import VadOptions, get_speech_timestamps
from numpy import float32
from numpy.typing import NDArray
Expand Down Expand Up @@ -188,7 +189,8 @@ def translate_file(
if response_format is None:
response_format = config.default_response_format
with model_manager.load_model(model) as whisper:
segments, transcription_info = whisper.transcribe(
whisper_model = BatchedInferencePipeline(model=whisper) if config.whisper.use_batched_mode else whisper
segments, transcription_info = whisper_model.transcribe(
audio,
task=Task.TRANSLATE,
initial_prompt=prompt,
Expand Down Expand Up @@ -252,7 +254,8 @@ def transcribe_file(
"It only makes sense to provide `timestamp_granularities[]` when `response_format` is set to `verbose_json`. See https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities." # noqa: E501
)
with model_manager.load_model(model) as whisper:
segments, transcription_info = whisper.transcribe(
whisper_model = BatchedInferencePipeline(model=whisper) if config.whisper.use_batched_mode else whisper
segments, transcription_info = whisper_model.transcribe(
audio,
task=Task.TRANSCRIBE,
language=language,
Expand Down
9 changes: 5 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 3ed59e8

Please sign in to comment.