Skip to content

Commit

Permalink
Merge pull request #17 from cubxxw/feat/add-tts-whisper
Browse files Browse the repository at this point in the history
feat: add whisper git config?
  • Loading branch information
cubxxw authored Dec 30, 2024
2 parents c479892 + 829610d commit 8637976
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 12 deletions.
5 changes: 5 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,8 @@ VOICEFLOW_VOLCENGINE_TTS_TOKEN=''

# 语音服务端口配置
VOICEFLOW_SERVER_PORT=18080 # 语音服务端口

# Whisper 配置
# https://fireworks.ai/dashboard/models/getting-started
VOICEFLOW_WHISPER_API_KEY=''
VOICEFLOW_WHISPER_ENDPOINT="https://audio-turbo.us-virginia-1.direct.fireworks.ai/v1/audio/transcriptions"
19 changes: 16 additions & 3 deletions configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ minio:
storage_path: "voiceflow/audio/"

stt:
# 可选值:azure、 google、 local、 assemblyai-ws、 volcengine、 aws、 assemblyai
provider: assemblyai
# 可选值:azure、 google、 local、 assemblyai-ws、 volcengine、 aws、 assemblyai、 whisper-v3
provider: whisper-v3

tts:
# 可选值:azure、 google、 local、 volcengine
Expand Down Expand Up @@ -121,4 +121,17 @@ logging:
max_backups: 3
max_age: 28
compress: true
report_caller: true
report_caller: true

# 添加 Whisper 配置段
whisper:
api_key: "your-api-key"
endpoint: "https://audio-turbo.us-virginia-1.direct.fireworks.ai/v1/audio/transcriptions"
model: "accounts/fireworks/models/whisper-v3-turbo"
temperature: 0
vad_model: "silero"
max_retries: 3
timeout: 30
language: "auto"
task: "transcribe"
batch_size: 30
5 changes: 4 additions & 1 deletion internal/server/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ func InitServices() {
if err != nil {
logger.Fatalf("配置初始化失败: %v", err)
}
sttService = stt.NewService(cfg.STT.Provider)
sttService, err = stt.NewService(cfg.STT.Provider)
if err != nil {
logger.Fatalf("STT 服务初始化失败: %v", err)
}
ttsService = tts.NewService(cfg.TTS.Provider)
// llmService = llm.NewService(cfg.LLM.Provider)
storageService = storage.NewService()
Expand Down
21 changes: 13 additions & 8 deletions internal/stt/stt.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@
package stt

import (
"fmt"

assemblyai "github.com/telepace/voiceflow/internal/stt/assemblyai"
aaiws "github.com/telepace/voiceflow/internal/stt/assemblyai-ws"
"github.com/telepace/voiceflow/internal/stt/azure"
"github.com/telepace/voiceflow/internal/stt/google"
"github.com/telepace/voiceflow/internal/stt/local"
"github.com/telepace/voiceflow/internal/stt/volcengine"
"github.com/telepace/voiceflow/internal/stt/whisper"
"github.com/telepace/voiceflow/pkg/logger"
)

Expand All @@ -17,22 +20,24 @@ type Service interface {
}

// NewService 根据配置返回相应的 STT 服务实现
func NewService(provider string) Service {
func NewService(provider string) (Service, error) {
logger.Debugf("Using STT provider: %s", provider)
switch provider {
case "azure":
return azure.NewAzureSTT()
return azure.NewAzureSTT(), nil
case "google":
return google.NewGoogleSTT()
return google.NewGoogleSTT(), nil
case "assemblyai-ws":
return aaiws.NewAssemblyAI()
return aaiws.NewAssemblyAI(), nil
case "volcengine":
return volcengine.NewVolcengineSTT()
return volcengine.NewVolcengineSTT(), nil
case "local":
return local.NewLocalSTT()
return local.NewLocalSTT(), nil
case "assemblyai":
return assemblyai.NewAssemblyAI()
return assemblyai.NewAssemblyAI(), nil
case "whisper-v3":
return whisper.NewWhisperSTT(), nil
default:
return local.NewLocalSTT()
return nil, fmt.Errorf("未知的 STT 提供商: %s", provider)
}
}
110 changes: 110 additions & 0 deletions internal/stt/whisper/whisper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package whisper

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"

"github.com/telepace/voiceflow/pkg/config"
"github.com/telepace/voiceflow/pkg/logger"
)

type WhisperSTT struct {
apiKey string
endpoint string
model string
temperature float64
vadModel string
}

type WhisperResponse struct {
Text string `json:"text"`
Language string `json:"language,omitempty"`
Duration float64 `json:"duration,omitempty"`
}

func NewWhisperSTT() *WhisperSTT {
cfg, err := config.GetConfig()
if err != nil {
logger.Fatalf("配置初始化失败: %v", err)
}

return &WhisperSTT{
apiKey: cfg.Whisper.APIKey,
endpoint: cfg.Whisper.Endpoint,
model: cfg.Whisper.Model,
temperature: cfg.Whisper.Temperature,
vadModel: cfg.Whisper.VADModel,
}
}

func (w *WhisperSTT) Recognize(audioData []byte, audioURL string) (string, error) {
// 创建一个 buffer 来写入 multipart 数据
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)

// 写入音频文件
part, err := writer.CreateFormFile("file", "audio.mp3")
if err != nil {
return "", fmt.Errorf("创建表单文件失败: %v", err)
}
if _, err := io.Copy(part, bytes.NewReader(audioData)); err != nil {
return "", fmt.Errorf("写入音频数据失败: %v", err)
}

// 添加其他参数
writer.WriteField("model", w.model)
writer.WriteField("temperature", fmt.Sprintf("%f", w.temperature))
writer.WriteField("vad_model", w.vadModel)

// 关闭 multipart writer
if err := writer.Close(); err != nil {
return "", fmt.Errorf("关闭 writer 失败: %v", err)
}

// 创建请求
req, err := http.NewRequest("POST", w.endpoint, body)
if err != nil {
return "", fmt.Errorf("创建请求失败: %v", err)
}

// 设置请求头
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", w.apiKey))
req.Header.Set("Content-Type", writer.FormDataContentType())

// 发送请求
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("发送请求失败: %v", err)
}
defer resp.Body.Close()

// 检查响应状态
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("API 请求失败,状态码: %d,响应: %s",
resp.StatusCode, string(bodyBytes))
}

// 解析响应
var result WhisperResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("解析响应失败: %v", err)
}

logger.Infof("语音识别完成,语言: %s, 时长: %.2f秒",
result.Language, result.Duration)

return result.Text, nil
}

func (w *WhisperSTT) StreamRecognize(ctx context.Context, audioDataChan <-chan []byte,
transcriptChan chan<- string) error {
// Whisper V3 Turbo 目前不支持流式处理
return fmt.Errorf("Whisper V3 Turbo 不支持流式处理")
}
14 changes: 14 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,19 @@ type AWSConfig struct {
Region string `yaml:"region"`
}

type WhisperConfig struct {
APIKey string `mapstructure:"api_key"`
Endpoint string `mapstructure:"endpoint"`
Model string `mapstructure:"model"`
Temperature float64 `mapstructure:"temperature"`
VADModel string `mapstructure:"vad_model"`
MaxRetries int `mapstructure:"max_retries"`
Timeout int `mapstructure:"timeout"`
Language string `mapstructure:"language"` // 可选的指定语言
Task string `mapstructure:"task"` // transcribe 或 translate
BatchSize int `mapstructure:"batch_size"` // 音频分段大小(秒)
}

type Config struct {
Server struct {
Port int
Expand Down Expand Up @@ -129,6 +142,7 @@ type Config struct {
Compress bool `mapstructure:"compress"`
ReportCaller bool `mapstructure:"report_caller"`
}
Whisper WhisperConfig `mapstructure:"whisper"`
}

var (
Expand Down

0 comments on commit 8637976

Please sign in to comment.