From e026d09506cb41534b60d7c43370d6a03d1fb92e Mon Sep 17 00:00:00 2001 From: Xinwei Xiong <3293172751NSS@gmail.com> Date: Wed, 25 Dec 2024 12:45:12 +0700 Subject: [PATCH] feat: add common version mange --- internal/stt/assemblyai/assemblyai.go | 95 ++++++++++++--------------- 1 file changed, 42 insertions(+), 53 deletions(-) diff --git a/internal/stt/assemblyai/assemblyai.go b/internal/stt/assemblyai/assemblyai.go index 943b827..0dd6dab 100644 --- a/internal/stt/assemblyai/assemblyai.go +++ b/internal/stt/assemblyai/assemblyai.go @@ -45,28 +45,46 @@ func (s *STT) Recognize(audioData []byte, audioURL string) (string, error) { func (s *STT) transcribeFromURL(audioURL string) (string, error) { ctx := context.Background() - // 第一次尝试��使用语言检测 - params := s.buildParams() + // 第一次尝试:启用语言检测 + params := &aai.TranscriptOptionalParams{ + LanguageDetection: aai.Bool(true), + LanguageConfidenceThreshold: aai.Float64(0.1), // 设置较低的初始阈值 + Punctuate: aai.Bool(s.cfg.AssemblyAI.Punctuate), + FormatText: aai.Bool(s.cfg.AssemblyAI.FormatText), + SpeechThreshold: aai.Float64(s.cfg.AssemblyAI.SpeechThreshold), + Multichannel: aai.Bool(s.cfg.AssemblyAI.Multichannel), + } + transcript, err := s.client.Transcripts.TranscribeFromURL(ctx, audioURL, params) if err != nil { // 检查是否是语言置信度错误 if s.isLanguageConfidenceError(err) && s.cfg.AssemblyAI.DefaultLanguageCode != "" { - // 使用默认语言重试 - logger.Infof("语言置信度低于阈值 %.2f,使用默认语言 %s 重试", - s.cfg.AssemblyAI.LanguageConfidenceThreshold, + logger.Infof("第一次尝试失败(语言置信度低), 使用默认语言 %s 重试", s.cfg.AssemblyAI.DefaultLanguageCode) - // 构建新的参数,使用默认语言 - params = s.buildParamsWithDefaultLanguage() - transcript, err = s.client.Transcripts.TranscribeFromURL(ctx, audioURL, params) + // 第二次尝试:禁用语言检测,使用固定语言 + retryParams := &aai.TranscriptOptionalParams{ + LanguageDetection: aai.Bool(false), // 明确禁用语言检测 + LanguageCode: aai.TranscriptLanguageCode(s.cfg.AssemblyAI.DefaultLanguageCode), + // 基础参数 + Punctuate: aai.Bool(s.cfg.AssemblyAI.Punctuate), + FormatText: aai.Bool(s.cfg.AssemblyAI.FormatText), + SpeechThreshold: aai.Float64(s.cfg.AssemblyAI.SpeechThreshold), + Multichannel: aai.Bool(s.cfg.AssemblyAI.Multichannel), + // 不再设置 LanguageConfidenceThreshold + } + + // 记录重试请求参数 + logger.Debugf("重试请求参数: %+v", retryParams) + + transcript, err = s.client.Transcripts.TranscribeFromURL(ctx, audioURL, retryParams) if err != nil { - return "", fmt.Errorf("语言置信度低于阈值 %.2f,使用默认语言 %s 重试失败: %v", - s.cfg.AssemblyAI.LanguageConfidenceThreshold, - s.cfg.AssemblyAI.DefaultLanguageCode, - err) + return "", fmt.Errorf("使用默认语言 %s 重试失败: %v", + s.cfg.AssemblyAI.DefaultLanguageCode, err) } + } else { + return "", fmt.Errorf("转录请求失败: %v", err) } - return "", fmt.Errorf("转录请求失败: %v", err) } // 使用指数退避策略,轮询转录状态 @@ -132,34 +150,20 @@ func (s *STT) buildParams() *aai.TranscriptOptionalParams { aaiCfg := s.cfg.AssemblyAI params := &aai.TranscriptOptionalParams{ - // 将 string 转换为 SpeechModel 类型 - SpeechModel: aai.SpeechModel(aaiCfg.Model), - LanguageDetection: aai.Bool(aaiCfg.LanguageDetection), - LanguageConfidenceThreshold: aai.Float64(aaiCfg.LanguageConfidenceThreshold), - Punctuate: aai.Bool(aaiCfg.Punctuate), - FormatText: aai.Bool(aaiCfg.FormatText), - Disfluencies: aai.Bool(aaiCfg.Disfluencies), - FilterProfanity: aai.Bool(aaiCfg.FilterProfanity), - AudioStartFrom: aai.Int64(aaiCfg.AudioStartFrom), - AudioEndAt: aai.Int64(aaiCfg.AudioEndAt), - SpeechThreshold: aai.Float64(aaiCfg.SpeechThreshold), - Multichannel: aai.Bool(aaiCfg.Multichannel), + SpeechModel: aai.SpeechModel(aaiCfg.Model), + Punctuate: aai.Bool(aaiCfg.Punctuate), + FormatText: aai.Bool(aaiCfg.FormatText), + SpeechThreshold: aai.Float64(aaiCfg.SpeechThreshold), + Multichannel: aai.Bool(aaiCfg.Multichannel), } - // 如果设置了固定的 language_code,则禁用语言检测并指定语言 - if aaiCfg.LanguageCode != "" { - params.LanguageDetection = aai.Bool(false) - params.LanguageCode = aai.TranscriptLanguageCode(aaiCfg.LanguageCode) - } - - // 如果配置了词汇增强 + // 词汇增强设置 if len(aaiCfg.WordBoost) > 0 { params.WordBoost = aaiCfg.WordBoost - // 将 string 转换为 TranscriptBoostParam 类型 params.BoostParam = aai.TranscriptBoostParam(aaiCfg.BoostParam) } - // 如果配置了自定义拼写 + // 自定义拼写设置 if len(aaiCfg.CustomSpelling) > 0 { var customSpellings []aai.TranscriptCustomSpelling for _, cs := range aaiCfg.CustomSpelling { @@ -174,24 +178,9 @@ func (s *STT) buildParams() *aai.TranscriptOptionalParams { return params } -// 新增:检查是否是语言置信度错误 +// isLanguageConfidenceError 优化错误检测逻辑 func (s *STT) isLanguageConfidenceError(err error) bool { - return strings.Contains(err.Error(), "below the requested confidence threshold value") -} - -// 新增:使用默认语言构建参数 -func (s *STT) buildParamsWithDefaultLanguage() *aai.TranscriptOptionalParams { - // 不再调用 s.buildParams(),防止里面带了 threshold - // 自己手动指定二次请求想要的字段 - return &aai.TranscriptOptionalParams{ - LanguageDetection: aai.Bool(false), - LanguageCode: aai.TranscriptLanguageCode(s.cfg.AssemblyAI.DefaultLanguageCode), - Punctuate: aai.Bool(true), - FormatText: aai.Bool(true), - SpeechThreshold: aai.Float64(s.cfg.AssemblyAI.SpeechThreshold), - Multichannel: aai.Bool(s.cfg.AssemblyAI.Multichannel), - AudioStartFrom: aai.Int64(s.cfg.AssemblyAI.AudioStartFrom), - AudioEndAt: aai.Int64(s.cfg.AssemblyAI.AudioEndAt), - BoostParam: aai.TranscriptBoostParam(s.cfg.AssemblyAI.BoostParam), - } + errMsg := err.Error() + return strings.Contains(errMsg, "below the requested confidence threshold") || + strings.Contains(errMsg, "confidence threshold value") }