From 9903564d4b978fe2e6b2c1519dc690b5a4abcab9 Mon Sep 17 00:00:00 2001 From: liudongwei Date: Wed, 7 Aug 2024 18:14:34 +0800 Subject: [PATCH] add Volcengine TTS/ASR Plugin and add tts_parallel config --- robot/AI.py | 58 +++++ robot/ASR.py | 26 ++- robot/Conversation.py | 2 +- robot/TTS.py | 28 ++- robot/sdk/VolcengineSpeech.py | 388 ++++++++++++++++++++++++++++++++++ static/default.yml | 53 +++-- 6 files changed, 535 insertions(+), 20 deletions(-) create mode 100644 robot/sdk/VolcengineSpeech.py diff --git a/robot/AI.py b/robot/AI.py index 98499735..d3b44261 100644 --- a/robot/AI.py +++ b/robot/AI.py @@ -506,6 +506,64 @@ def chat(self, texts, _): logger.critical("Tongyi robot failed to response for %r", msg, exc_info=True) return "抱歉, Tongyi回答失败" +class CozeRobot(AbstractRobot): + SLUG = "coze" + + def __init__(self, botid, token, **kwargs): + super(self.__class__, self).__init__() + self.botid = botid + self.token = token + self.userid = str(get_mac())[:32] + + @classmethod + def get_config(cls): + return config.get("coze", {}) + + def chat(self, texts, parsed=None): + """ + 使用coze聊天 + + Arguments: + texts -- user input, typically speech, to be parsed by a module + """ + msg = "".join(texts) + msg = utils.stripPunctuation(msg) + try: + url = "https://api.coze.cn/open_api/v2/chat" + + body = { + "conversation_id": "123", + "bot_id": self.botid, + "user": self.userid, + "query": msg, + "stream": False + } + headers = { + "Authorization": "Bearer " + self.token, + "Content-Type": "application/json", + "Accept": "*/*", + "Host": "api.coze.cn", + "Connection": "keep-alive" + } + r = requests.post(url, headers=headers, json=body) + respond = json.loads(r.text) + result = "" + logger.info(f"{self.SLUG} 回答:{respond}") + if "messages" in respond: + for m in respond["messages"]: + if m["type"] == "answer": + result = m["content"].replace("\n", "").replace("\r", "") + else: + result = "抱歉,扣子回答失败" + if result == "": + result = "抱歉,扣子回答失败" + logger.info(f"{self.SLUG} 回答:{result}") + return result + except Exception: + logger.critical( + "Tuling robot failed to response for %r", msg, exc_info=True + ) + return "抱歉, 扣子回答失败" def get_unknown_response(): """ diff --git a/robot/ASR.py b/robot/ASR.py index fbda2768..4a282137 100755 --- a/robot/ASR.py +++ b/robot/ASR.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- import json from aip import AipSpeech -from .sdk import TencentSpeech, AliSpeech, XunfeiSpeech, BaiduSpeech, FunASREngine +from .sdk import TencentSpeech, AliSpeech, XunfeiSpeech, BaiduSpeech, FunASREngine, VolcengineSpeech from . import utils, config from robot import logging from abc import ABCMeta, abstractmethod @@ -267,6 +267,30 @@ def transcribe(self, fp): logger.critical(f"{self.SLUG} 语音识别出错了", stack_info=True) return "" +class VolcengineASR(AbstractASR): + """ + VolcengineASR 实时语音转写服务软件包 + """ + + SLUG = "volcengine-asr" + + def __init__(self, **kargs): + super(self.__class__, self).__init__() + self.volcengine_asr = VolcengineSpeech.VolcengineASR(**kargs) + + @classmethod + def get_config(cls): + return config.get("volcengine-asr", {}) + + def transcribe(self, fp): + result = self.volcengine_asr.execute(fp) + if result: + logger.info(f"{self.SLUG} 语音识别到了:{result}") + return result + else: + logger.critical(f"{self.SLUG} 语音识别出错了", stack_info=True) + return "" + def get_engine_by_slug(slug=None): """ Returns: diff --git a/robot/Conversation.py b/robot/Conversation.py index 5cc96f41..318afae5 100644 --- a/robot/Conversation.py +++ b/robot/Conversation.py @@ -304,7 +304,7 @@ def _tts(self, lines, cache, onCompleted=None): pattern = r"http[s]?://.+" logger.info("_tts") with self.tts_lock: - with ThreadPoolExecutor(max_workers=5) as pool: + with ThreadPoolExecutor(max_workers=config.get("tts_parallel", 5)) as pool: all_task = [] index = 0 for line in lines: diff --git a/robot/TTS.py b/robot/TTS.py index c55e30d1..61c048d0 100644 --- a/robot/TTS.py +++ b/robot/TTS.py @@ -17,7 +17,7 @@ from pypinyin import lazy_pinyin from pydub import AudioSegment from abc import ABCMeta, abstractmethod -from .sdk import TencentSpeech, AliSpeech, XunfeiSpeech, atc, VITSClient +from .sdk import TencentSpeech, AliSpeech, XunfeiSpeech, atc, VITSClient, VolcengineSpeech import requests from xml.etree import ElementTree @@ -469,6 +469,32 @@ def get_engine_by_slug(slug=None): logger.info(f"使用 {engine.SLUG} TTS 引擎") return engine.get_instance() +class VolcengineTTS(AbstractTTS): + """ + VolcengineTTS 语音合成 + """ + + SLUG = "volcengine-tts" + + def __init__(self, appid, token, cluster, voice_type, **args): + super(self.__class__, self).__init__() + self.engine = VolcengineSpeech.VolcengineTTS(appid=appid, token=token, cluster=cluster, voice_type=voice_type) + + @classmethod + def get_config(cls): + # Try to get ali_yuyin config from config + return config.get("volcengine-tts", {}) + + def get_speech(self, text): + result = self.engine.execute(text) + if result is None: + logger.critical(f"{self.SLUG} 合成失败!", stack_info=True) + else: + tmpfile = os.path.join(constants.TEMP_PATH, uuid.uuid4().hex + ".mp3") + with open(tmpfile, "wb") as f: + f.write(result) + logger.info(f"{self.SLUG} 语音合成成功,合成路径:{tmpfile}") + return tmpfile def get_engines(): def get_subclasses(cls): diff --git a/robot/sdk/VolcengineSpeech.py b/robot/sdk/VolcengineSpeech.py new file mode 100644 index 00000000..32aa0beb --- /dev/null +++ b/robot/sdk/VolcengineSpeech.py @@ -0,0 +1,388 @@ +#coding=utf-8 + +""" +requires Python 3.6 or later + +pip install asyncio +pip install websockets +""" + +import asyncio +import base64 +from cProfile import run +import gzip +import hmac +import json +import requests +import logging +import os +from typing_extensions import Self +import uuid +import wave +from enum import Enum +from hashlib import sha256 +from io import BytesIO +from typing import List +from urllib.parse import urlparse +import time +import websockets +from robot import config + + +audio_format = "wav" # wav 或者 mp3,根据实际音频格式设置 + +PROTOCOL_VERSION = 0b0001 +DEFAULT_HEADER_SIZE = 0b0001 + +PROTOCOL_VERSION_BITS = 4 +HEADER_BITS = 4 +MESSAGE_TYPE_BITS = 4 +MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4 +MESSAGE_SERIALIZATION_BITS = 4 +MESSAGE_COMPRESSION_BITS = 4 +RESERVED_BITS = 8 + +# Message Type: +CLIENT_FULL_REQUEST = 0b0001 +CLIENT_AUDIO_ONLY_REQUEST = 0b0010 +SERVER_FULL_RESPONSE = 0b1001 +SERVER_ACK = 0b1011 +SERVER_ERROR_RESPONSE = 0b1111 + +# Message Type Specific Flags +NO_SEQUENCE = 0b0000 # no check sequence +POS_SEQUENCE = 0b0001 +NEG_SEQUENCE = 0b0010 +NEG_SEQUENCE_1 = 0b0011 + +# Message Serialization +NO_SERIALIZATION = 0b0000 +JSON = 0b0001 +THRIFT = 0b0011 +CUSTOM_TYPE = 0b1111 + +# Message Compression +NO_COMPRESSION = 0b0000 +GZIP = 0b0001 +CUSTOM_COMPRESSION = 0b1111 + + +def generate_header( + version=PROTOCOL_VERSION, + message_type=CLIENT_FULL_REQUEST, + message_type_specific_flags=NO_SEQUENCE, + serial_method=JSON, + compression_type=GZIP, + reserved_data=0x00, + extension_header=bytes() +): + """ + protocol_version(4 bits), header_size(4 bits), + message_type(4 bits), message_type_specific_flags(4 bits) + serialization_method(4 bits) message_compression(4 bits) + reserved (8bits) 保留字段 + header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) ) + """ + header = bytearray() + header_size = int(len(extension_header) / 4) + 1 + header.append((version << 4) | header_size) + header.append((message_type << 4) | message_type_specific_flags) + header.append((serial_method << 4) | compression_type) + header.append(reserved_data) + header.extend(extension_header) + return header + + +def generate_full_default_header(): + return generate_header() + + +def generate_audio_default_header(): + return generate_header( + message_type=CLIENT_AUDIO_ONLY_REQUEST + ) + + +def generate_last_audio_default_header(): + return generate_header( + message_type=CLIENT_AUDIO_ONLY_REQUEST, + message_type_specific_flags=NEG_SEQUENCE + ) + +def parse_response(res): + """ + protocol_version(4 bits), header_size(4 bits), + message_type(4 bits), message_type_specific_flags(4 bits) + serialization_method(4 bits) message_compression(4 bits) + reserved (8bits) 保留字段 + header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) ) + payload 类似与http 请求体 + """ + protocol_version = res[0] >> 4 + header_size = res[0] & 0x0f + message_type = res[1] >> 4 + message_type_specific_flags = res[1] & 0x0f + serialization_method = res[2] >> 4 + message_compression = res[2] & 0x0f + reserved = res[3] + header_extensions = res[4:header_size * 4] + payload = res[header_size * 4:] + result = {} + payload_msg = None + payload_size = 0 + if message_type == SERVER_FULL_RESPONSE: + payload_size = int.from_bytes(payload[:4], "big", signed=True) + payload_msg = payload[4:] + elif message_type == SERVER_ACK: + seq = int.from_bytes(payload[:4], "big", signed=True) + result['seq'] = seq + if len(payload) >= 8: + payload_size = int.from_bytes(payload[4:8], "big", signed=False) + payload_msg = payload[8:] + elif message_type == SERVER_ERROR_RESPONSE: + code = int.from_bytes(payload[:4], "big", signed=False) + result['code'] = code + payload_size = int.from_bytes(payload[4:8], "big", signed=False) + payload_msg = payload[8:] + if payload_msg is None: + return result + if message_compression == GZIP: + payload_msg = gzip.decompress(payload_msg) + if serialization_method == JSON: + payload_msg = json.loads(str(payload_msg, "utf-8")) + elif serialization_method != NO_SERIALIZATION: + payload_msg = str(payload_msg, "utf-8") + result['payload_msg'] = payload_msg + result['payload_size'] = payload_size + return result + + +def read_wav_info(data: bytes = None): + with BytesIO(data) as _f: + wave_fp = wave.open(_f, 'rb') + nchannels, sampwidth, framerate, nframes = wave_fp.getparams()[:4] + wave_bytes = wave_fp.readframes(nframes) + return nchannels, sampwidth, framerate, nframes, len(wave_bytes) + +class AudioType(Enum): + LOCAL = 1 # 使用本地音频文件 + +class AsrWsClient: + def __init__(self, audio_path, cluster, **kwargs): + """ + :param config: config + """ + self.audio_path = audio_path + self.cluster = cluster + self.success_code = 1000 # success code, default is 1000 + self.seg_duration = int(kwargs.get("seg_duration", 15000)) + self.nbest = int(kwargs.get("nbest", 1)) + self.appid = kwargs.get("appid", "") + self.token = kwargs.get("token", "") + self.ws_url = kwargs.get("ws_url", "wss://openspeech.bytedance.com/api/v2/asr") + self.uid = kwargs.get("uid", "streaming_asr_demo") + self.workflow = kwargs.get("workflow", "audio_in,resample,partition,vad,fe,decode,itn,nlu_punctuate") + self.show_language = kwargs.get("show_language", False) + self.show_utterances = kwargs.get("show_utterances", False) + self.result_type = kwargs.get("result_type", "full") + self.format = kwargs.get("format", "wav") + self.rate = kwargs.get("sample_rate", 16000) + self.language = kwargs.get("language", "zh-CN") + self.bits = kwargs.get("bits", 16) + self.channel = kwargs.get("channel", 1) + self.codec = kwargs.get("codec", "raw") + self.audio_type = kwargs.get("audio_type", AudioType.LOCAL) + self.secret = kwargs.get("secret", "access_secret") + self.auth_method = kwargs.get("auth_method", "token") + self.mp3_seg_size = int(kwargs.get("mp3_seg_size", 10000)) + + def construct_request(self, reqid): + req = { + 'app': { + 'appid': self.appid, + 'cluster': self.cluster, + 'token': self.token, + }, + 'user': { + 'uid': self.uid + }, + 'request': { + 'reqid': reqid, + 'nbest': self.nbest, + 'workflow': self.workflow, + 'show_language': self.show_language, + 'show_utterances': self.show_utterances, + 'result_type': self.result_type, + "sequence": 1 + }, + 'audio': { + 'format': self.format, + 'rate': self.rate, + 'language': self.language, + 'bits': self.bits, + 'channel': self.channel, + 'codec': self.codec + } + } + return req + + @staticmethod + def slice_data(data: bytes, chunk_size: int): + """ + slice data + :param data: wav data + :param chunk_size: the segment size in one request + :return: segment data, last flag + """ + data_len = len(data) + offset = 0 + while offset + chunk_size < data_len: + yield data[offset: offset + chunk_size], False + offset += chunk_size + else: + yield data[offset: data_len], True + + def _real_processor(self, request_params: dict) -> dict: + pass + + def token_auth(self): + return {'Authorization': 'Bearer; {}'.format(self.token)} + + def signature_auth(self, data): + header_dicts = { + 'Custom': 'auth_custom', + } + + url_parse = urlparse(self.ws_url) + input_str = 'GET {} HTTP/1.1\n'.format(url_parse.path) + auth_headers = 'Custom' + for header in auth_headers.split(','): + input_str += '{}\n'.format(header_dicts[header]) + input_data = bytearray(input_str, 'utf-8') + input_data += data + mac = base64.urlsafe_b64encode( + hmac.new(self.secret.encode('utf-8'), input_data, digestmod=sha256).digest()) + header_dicts['Authorization'] = 'HMAC256; access_token="{}"; mac="{}"; h="{}"'.format(self.token, + str(mac, 'utf-8'), auth_headers) + return header_dicts + + async def segment_data_processor(self, wav_data: bytes, segment_size: int): + reqid = str(uuid.uuid4()) + # 构建 full client request,并序列化压缩 + request_params = self.construct_request(reqid) + payload_bytes = str.encode(json.dumps(request_params)) + payload_bytes = gzip.compress(payload_bytes) + full_client_request = bytearray(generate_full_default_header()) + full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) + full_client_request.extend(payload_bytes) # payload + header = None + if self.auth_method == "token": + header = self.token_auth() + elif self.auth_method == "signature": + header = self.signature_auth(full_client_request) + async with websockets.connect(self.ws_url, extra_headers=header, max_size=1000000000) as ws: + # 发送 full client request + await ws.send(full_client_request) + res = await ws.recv() + result = parse_response(res) + if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code: + return result + for seq, (chunk, last) in enumerate(AsrWsClient.slice_data(wav_data, segment_size), 1): + # if no compression, comment this line + payload_bytes = gzip.compress(chunk) + audio_only_request = bytearray(generate_audio_default_header()) + if last: + audio_only_request = bytearray(generate_last_audio_default_header()) + audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) + audio_only_request.extend(payload_bytes) # payload + # 发送 audio-only client request + await ws.send(audio_only_request) + res = await ws.recv() + result = parse_response(res) + if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code: + return result + return result + + async def execute(self): + with open(self.audio_path, mode="rb") as _f: + data = _f.read() + audio_data = bytes(data) + if self.format == "mp3": + segment_size = self.mp3_seg_size + return await self.segment_data_processor(audio_data, segment_size) + if self.format != "wav": + raise Exception("format should in wav or mp3") + nchannels, sampwidth, framerate, nframes, wav_len = read_wav_info( + audio_data) + size_per_sec = nchannels * sampwidth * framerate + segment_size = int(size_per_sec * self.seg_duration / 1000) + return await self.segment_data_processor(audio_data, segment_size) + +class VolcengineASR(object): + def __init__(self, **kwargs) -> None: + self.appid = kwargs['appid'] + self.token = kwargs['token'] + self.cluster = kwargs['cluster'] + + def execute(self, path): + """ + :param audio_item: {"path": "xxx"} + :return: + """ + audio_type = AudioType.LOCAL + text = "" + asr_http_client = AsrWsClient( + audio_path=path, + cluster= self.cluster, + appid = self.appid, + token = self.token, + audio_type=audio_type, + ) + try: + result = asyncio.run(asr_http_client.execute()) + if result['payload_msg']['code'] == 1000: + text = result["payload_msg"]["result"][0]["text"] + except Exception as e: + text = "" + return text + +class VolcengineTTS(object): + def __init__(self, appid, token, cluster, voice_type) -> None: + self.appid, self.token, self.cluster, self.voice_type = appid, token, cluster, voice_type + + def execute(self, text): + api_url = "https://openspeech.bytedance.com/api/v1/tts" + header = {"Authorization": f"Bearer;{self.token}"} + request_json = { + "app": { + "appid": self.appid, + "token": self.token, + "cluster": self.cluster + }, + "user": { + "uid": "388808087185088" + }, + "audio": { + "voice_type": self.voice_type, + "encoding": "mp3", + "speed_ratio": 1.0, + "volume_ratio": 1.0, + "pitch_ratio": 1.0, + }, + "request": { + "reqid": str(uuid.uuid4()), + "text": text, + "text_type": "plain", + "operation": "query", + "with_frontend": 1, + "frontend_type": "unitTson" + } + } + try: + resp = requests.post(api_url, json.dumps(request_json), headers=header) + if "data" in resp.json(): + data = resp.json()["data"] + return base64.b64decode(data) + except Exception as e: + e.with_traceback() + return None \ No newline at end of file diff --git a/static/default.yml b/static/default.yml index 686c68b1..14c8b915 100755 --- a/static/default.yml +++ b/static/default.yml @@ -91,26 +91,28 @@ lru_cache: # 语音合成服务配置 # 可选值: -# han-tts - HanTTS -# baidu-tts - 百度语音合成 -# xunfei-tts - 讯飞语音合成 -# ali-tts - 阿里语音合成 -# tencent-tts - 腾讯云语音合成 -# azure-tts - 微软语音合成 -# mac-tts - macOS 系统自带TTS(mac 系统推荐) -# edge-tts - 基于 Edge 的 TTS(推荐) -# VITS - 基于 VITS 的AI语音合成 +# han-tts - HanTTS +# baidu-tts - 百度语音合成 +# xunfei-tts - 讯飞语音合成 +# ali-tts - 阿里语音合成 +# tencent-tts - 腾讯云语音合成 +# azure-tts - 微软语音合成 +# mac-tts - macOS 系统自带TTS(mac 系统推荐) +# edge-tts - 基于 Edge 的 TTS(推荐) +# VITS - 基于 VITS 的AI语音合成 +# volcengine-tts - 火山引擎语音合成 tts_engine: edge-tts # 语音识别服务配置 # 可选值: -# baidu-asr - 百度在线语音识别 -# xunfei-asr - 讯飞语音识别 -# ali-asr - 阿里语音识别 -# tencent-asr - 腾讯云语音识别(推荐) -# azure-asr - 微软语音识别 -# openai - OpenAI Whisper -# fun-asr - 达摩院FunASR语音识别 +# baidu-asr - 百度在线语音识别 +# xunfei-asr - 讯飞语音识别 +# ali-asr - 阿里语音识别 +# tencent-asr - 腾讯云语音识别(推荐) +# azure-asr - 微软语音识别 +# openai - OpenAI Whisper +# fun-asr - 达摩院FunASR语音识别 +# volcengine-asr - 火山引擎语音识别 asr_engine: baidu-asr # 百度语音服务 @@ -162,6 +164,13 @@ tencent_yuyin: voiceType: 0 # 0: 女声1;1:男生1;2:男生2 language: 1 # 1: 中文;2:英文 +volcengine-asr: + # 在官网https://www.volcengine.com/ 语音技术申请,有免费额度 + appid: '' + token: '' + cluster: '' + + # 达摩院FunASR实时语音转写服务软件包 fun_asr: # 导出模型流程:https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/runtime/python/libtorch#export-the-model @@ -199,6 +208,16 @@ edge-tts: # 中文推荐 `zh` 开头的音色 voice: zh-CN-XiaoxiaoNeural +# 火山引擎TTS服务 +volcengine-tts: + # 在官网https://www.volcengine.com/ 语音技术申请,有免费额度 + # 音色列表voice_type有几十种,链接在https://www.volcengine.com/docs/6561/97465 + # 免费额度QPS限制2,需要设置tts_parallel=2 + appid: "" + token: "" + cluster: "" + voice_type: "" + # 基于 VITS 的AI语音合成 VITS: # 需要自行搭建vits-simple-api服务器:https://github.com/Artrajz/vits-simple-api @@ -362,4 +381,4 @@ weather: enable: false key: '心知天气 API Key' - +tts_parallel: 5