Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tts : add OuteTTS support #10784

Merged
merged 45 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
89eaf50
server : add "tokens" output
ggerganov Dec 16, 2024
06e8540
server : output embeddings for all tokens when pooling = none
ggerganov Dec 17, 2024
1b18b2d
server : be explicit about the pooling type in the tests
ggerganov Dec 17, 2024
e65556f
server : do not normalize embeddings when there is no pooling
ggerganov Dec 17, 2024
f169965
llama : add OuteTTS support (wip)
ggerganov Dec 10, 2024
ff2ea75
wip
ggerganov Dec 10, 2024
aac7e04
extract features
ggerganov Dec 10, 2024
6ef1409
first conv
ggerganov Dec 10, 2024
5296c96
group norm
ggerganov Dec 10, 2024
3d08d62
resnet conv
ggerganov Dec 10, 2024
13dd894
resnet
ggerganov Dec 10, 2024
3046fde
attn
ggerganov Dec 10, 2024
435cfd7
pos net
ggerganov Dec 10, 2024
b3ba05e
layer norm
ggerganov Dec 10, 2024
fe6dd5a
convnext
ggerganov Dec 11, 2024
839035d
head
ggerganov Dec 11, 2024
eb1b70f
hann window
ggerganov Dec 11, 2024
a1f08ad
fix n_embd + remove llama.cpp hacks
ggerganov Dec 11, 2024
e728cfd
compute hann window
ggerganov Dec 11, 2024
5a1c98e
fft
ggerganov Dec 11, 2024
e527971
spectrum processing
ggerganov Dec 11, 2024
191da33
clean-up
ggerganov Dec 11, 2024
b9a011e
tts : receive input text and generate codes
ggerganov Dec 11, 2024
db61391
clip : fix new conv name
ggerganov Dec 11, 2024
8329e85
tts : minor fix
ggerganov Dec 11, 2024
d4fa34b
tts : add header + minor fixes
ggerganov Dec 11, 2024
2221e54
tts : add matchematical constant
ggerganov Dec 11, 2024
906a0ed
tts : fix sampling + cut initial noise
ggerganov Dec 11, 2024
1d7c27c
tts : fixes
ggerganov Dec 11, 2024
3d54be4
tts : update default samplers
ggerganov Dec 16, 2024
befdcd2
tts : text pre-processing
ggerganov Dec 16, 2024
e70f140
tts : outetts-voc -> wavtokenizer-dec
ggerganov Dec 16, 2024
c096bbd
tts : remove hardcoded constants
ggerganov Dec 16, 2024
d1ef627
tts : fix tensor shapes
ggerganov Dec 16, 2024
980d631
llama : refactor wavtokenizer tensors
ggerganov Dec 16, 2024
35259e5
cont
ggerganov Dec 16, 2024
2033fb7
cont [no ci]
ggerganov Dec 16, 2024
824fa75
llama : update WavTokenizer to non-causal attn
ggerganov Dec 17, 2024
d291c74
llama : handle no-vocab detokenization
ggerganov Dec 16, 2024
5038abe
tts : add Python example for OuteTTS (wip)
ggerganov Dec 17, 2024
edb7896
tts : extend python example to generate spectrogram
ggerganov Dec 17, 2024
2a1a6f6
server : fix rebase artifacts
ggerganov Dec 18, 2024
29df666
tts : enable "return_tokens" in Python example
ggerganov Dec 18, 2024
a95191c
tts : minor fixes
ggerganov Dec 18, 2024
c0df192
common : support HF download for vocoder
ggerganov Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 44 additions & 16 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,29 +119,33 @@ std::string common_arg::to_string() {
// utils
//

static void common_params_handle_model_default(common_params & params) {
if (!params.hf_repo.empty()) {
static void common_params_handle_model_default(
std::string & model,
std::string & model_url,
std::string & hf_repo,
std::string & hf_file) {
if (!hf_repo.empty()) {
// short-hand to avoid specifying --hf-file -> default it to --model
if (params.hf_file.empty()) {
if (params.model.empty()) {
if (hf_file.empty()) {
if (model.empty()) {
throw std::invalid_argument("error: --hf-repo requires either --hf-file or --model\n");
}
params.hf_file = params.model;
} else if (params.model.empty()) {
hf_file = model;
} else if (model.empty()) {
// this is to avoid different repo having same file name, or same file name in different subdirs
std::string filename = params.hf_repo + "_" + params.hf_file;
std::string filename = hf_repo + "_" + hf_file;
// to make sure we don't have any slashes in the filename
string_replace_all(filename, "/", "_");
params.model = fs_get_cache_file(filename);
model = fs_get_cache_file(filename);
}
} else if (!params.model_url.empty()) {
if (params.model.empty()) {
auto f = string_split<std::string>(params.model_url, '#').front();
} else if (!model_url.empty()) {
if (model.empty()) {
auto f = string_split<std::string>(model_url, '#').front();
f = string_split<std::string>(f, '?').front();
params.model = fs_get_cache_file(string_split<std::string>(f, '/').back());
model = fs_get_cache_file(string_split<std::string>(f, '/').back());
}
} else if (params.model.empty()) {
params.model = DEFAULT_MODEL_PATH;
} else if (model.empty()) {
model = DEFAULT_MODEL_PATH;
}
}

Expand Down Expand Up @@ -276,7 +280,9 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n");
}

common_params_handle_model_default(params);
// TODO: refactor model params in a common struct
common_params_handle_model_default(params.model, params.model_url, params.hf_repo, params.hf_file);
common_params_handle_model_default(params.vocoder.model, params.vocoder.model_url, params.vocoder.hf_repo, params.vocoder.hf_file);

if (params.escape) {
string_process_escapes(params.prompt);
Expand Down Expand Up @@ -842,7 +848,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_sparam());
add_opt(common_arg(
{"--sampling-seq"}, "SEQUENCE",
{"--sampling-seq", "--sampler-seq"}, "SEQUENCE",
string_format("simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str()),
[](common_params & params, const std::string & value) {
params.sampling.samplers = common_sampler_types_from_chars(value);
Expand Down Expand Up @@ -1581,6 +1587,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.hf_file = value;
}
).set_env("LLAMA_ARG_HF_FILE"));
add_opt(common_arg(
{"-hfrv", "--hf-repo-v"}, "REPO",
"Hugging Face model repository for the vocoder model (default: unused)",
[](common_params & params, const std::string & value) {
params.vocoder.hf_repo = value;
}
).set_env("LLAMA_ARG_HF_REPO_V"));
add_opt(common_arg(
{"-hffv", "--hf-file-v"}, "FILE",
"Hugging Face model file for the vocoder model (default: unused)",
[](common_params & params, const std::string & value) {
params.vocoder.hf_file = value;
}
).set_env("LLAMA_ARG_HF_FILE_V"));
add_opt(common_arg(
{"-hft", "--hf-token"}, "TOKEN",
"Hugging Face access token (default: value from HF_TOKEN environment variable)",
Expand Down Expand Up @@ -2178,5 +2198,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_MODEL_DRAFT"));

add_opt(common_arg(
{"-mv", "--model-vocoder"}, "FNAME",
"vocoder model for audio generation (default: unused)",
[](common_params & params, const std::string & value) {
params.vocoder.model = value;
}
).set_examples({LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_SERVER}));

return ctx_arg;
}
7 changes: 4 additions & 3 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,7 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
#define CURL_MAX_RETRY 3
#define CURL_RETRY_DELAY_SECONDS 2

static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_attempts, int retry_delay_seconds) {
static bool curl_perform_with_retry(const std::string & url, CURL * curl, int max_attempts, int retry_delay_seconds) {
int remaining_attempts = max_attempts;

while (remaining_attempts > 0) {
Expand All @@ -1119,7 +1119,6 @@ static bool curl_perform_with_retry(const std::string& url, CURL* curl, int max_
}

static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) {

// Initialize libcurl
std::unique_ptr<CURL, decltype(&curl_easy_cleanup)> curl(curl_easy_init(), &curl_easy_cleanup);
if (!curl) {
Expand Down Expand Up @@ -1192,11 +1191,13 @@ static bool common_download_file(const std::string & url, const std::string & pa
std::string etag;
std::string last_modified;
};

common_load_model_from_url_headers headers;

{
typedef size_t(*CURLOPT_HEADERFUNCTION_PTR)(char *, size_t, size_t, void *);
auto header_callback = [](char * buffer, size_t /*size*/, size_t n_items, void * userdata) -> size_t {
common_load_model_from_url_headers *headers = (common_load_model_from_url_headers *) userdata;
common_load_model_from_url_headers * headers = (common_load_model_from_url_headers *) userdata;

static std::regex header_regex("([^:]+): (.*)\r\n");
static std::regex etag_regex("ETag", std::regex_constants::icase);
Expand Down
13 changes: 12 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ enum llama_example {
LLAMA_EXAMPLE_LLAVA,
LLAMA_EXAMPLE_LOOKUP,
LLAMA_EXAMPLE_PARALLEL,
LLAMA_EXAMPLE_TTS,

LLAMA_EXAMPLE_COUNT,
};
Expand Down Expand Up @@ -159,6 +160,7 @@ struct common_params_sampling {

struct common_params_speculative {
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading

int32_t n_ctx = 0; // draft context size
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
Expand All @@ -172,6 +174,14 @@ struct common_params_speculative {
std::string model = ""; // draft model for speculative decoding // NOLINT
};

struct common_params_vocoder {
std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT

std::string model = ""; // model path // NOLINT
std::string model_url = ""; // model url to download // NOLINT
};

struct common_params {
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 4096; // context size
Expand Down Expand Up @@ -214,8 +224,9 @@ struct common_params {
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings

struct common_params_sampling sampling;
struct common_params_sampling sampling;
struct common_params_speculative speculative;
struct common_params_vocoder vocoder;

std::string model = ""; // model path // NOLINT
std::string model_alias = ""; // model alias // NOLINT
Expand Down
59 changes: 52 additions & 7 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,17 +221,17 @@ def set_gguf_parameters(self):
self.gguf_writer.add_context_length(n_ctx)
logger.info(f"gguf: context length = {n_ctx}")

n_embd = self.find_hparam(["hidden_size", "n_embd"])
self.gguf_writer.add_embedding_length(n_embd)
logger.info(f"gguf: embedding length = {n_embd}")
if (n_embd := self.find_hparam(["hidden_size", "n_embd"], optional=True)) is not None:
self.gguf_writer.add_embedding_length(n_embd)
logger.info(f"gguf: embedding length = {n_embd}")

if (n_ff := self.find_hparam(["intermediate_size", "n_inner"], optional=True)) is not None:
self.gguf_writer.add_feed_forward_length(n_ff)
logger.info(f"gguf: feed forward length = {n_ff}")

n_head = self.find_hparam(["num_attention_heads", "n_head"])
self.gguf_writer.add_head_count(n_head)
logger.info(f"gguf: head count = {n_head}")
if (n_head := self.find_hparam(["num_attention_heads", "n_head"], optional=True)) is not None:
self.gguf_writer.add_head_count(n_head)
logger.info(f"gguf: head count = {n_head}")

if (n_head_kv := self.hparams.get("num_key_value_heads")) is not None:
self.gguf_writer.add_head_count_kv(n_head_kv)
Expand Down Expand Up @@ -296,7 +296,9 @@ def prepare_tensors(self):
break

for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
data = data_torch.squeeze().numpy()
# TODO: why do we squeeze here?
# data = data_torch.squeeze().numpy()
data = data_torch.numpy()

# if data ends up empty, it means data_torch was a scalar tensor -> restore
if len(data.shape) == 0:
Expand Down Expand Up @@ -324,6 +326,8 @@ def prepare_tensors(self):
gguf.MODEL_TENSOR.TIME_MIX_W2,
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W1,
gguf.MODEL_TENSOR.TIME_MIX_DECAY_W2,
gguf.MODEL_TENSOR.POSNET_NORM1,
gguf.MODEL_TENSOR.POSNET_NORM2,
)
)
or not new_name.endswith(".weight")
Expand Down Expand Up @@ -689,6 +693,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
return res
# Marker: End get_vocab_base_pre

def _set_vocab_none(self) -> None:
self.gguf_writer.add_tokenizer_model("none")

def _set_vocab_gpt2(self) -> None:
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
Expand Down Expand Up @@ -2027,6 +2034,44 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
yield name, data


@Model.register("WavTokenizerDec")
class WavTokenizerDecModel(Model):
model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

if \
name.endswith("codebook.cluster_size") or \
name.endswith("codebook.embed_avg") or \
name.endswith("codebook.inited"):
logger.debug(f"Skipping {name!r}")
return []

logger.info(f"{self.map_tensor_name(name)} -> {data_torch.shape}")

return [(self.map_tensor_name(name), data_torch)]

def set_vocab(self):
self._set_vocab_none()

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_vocab_size (self.hparams["vocab_size"])
self.gguf_writer.add_features_length (self.hparams["n_embd_features"])
self.gguf_writer.add_feed_forward_length(self.hparams["n_ff"])
self.gguf_writer.add_group_norm_eps (self.hparams["group_norm_epsilon"])
self.gguf_writer.add_group_norm_groups (self.hparams["group_norm_groups"])

self.gguf_writer.add_posnet_embedding_length(self.hparams["posnet"]["n_embd"])
self.gguf_writer.add_posnet_block_count (self.hparams["posnet"]["n_layer"])

self.gguf_writer.add_convnext_embedding_length(self.hparams["convnext"]["n_embd"])
self.gguf_writer.add_convnext_block_count (self.hparams["convnext"]["n_layer"])

self.gguf_writer.add_causal_attention(False)


@Model.register("Qwen2MoeForCausalLM")
class Qwen2MoeModel(Model):
model_arch = gguf.MODEL_ARCH.QWEN2MOE
Expand Down
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ else()
add_subdirectory(speculative)
add_subdirectory(speculative-simple)
add_subdirectory(tokenize)
add_subdirectory(tts)
add_subdirectory(gen-docs)
if (NOT GGML_BACKEND_DL)
# these examples use the backends directly and cannot be built with dynamic loading
Expand Down
6 changes: 3 additions & 3 deletions examples/llava/clip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
mlp_3 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_3, 1, 0, 2, 3));
mlp_3 = ggml_reshape_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]);
// stride = 1, padding = 1, bias is nullptr
block_1 = ggml_conv_depthwise_2d(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1);
block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1);

// layer norm
// // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
Expand Down Expand Up @@ -944,7 +944,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
// block_2
{
// stride = 2
block_1 = ggml_conv_depthwise_2d(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1);
block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1);

// block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
// layer norm
Expand Down Expand Up @@ -1005,7 +1005,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
// mlp_2 ne [24, 24, 2048, 1]
mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0);
// weight ne = [3, 3, 2048, 1]
struct ggml_tensor * peg_0 = ggml_conv_depthwise_2d(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
struct ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3));
peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b);
mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3));
Expand Down
5 changes: 5 additions & 0 deletions examples/tts/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set(TARGET llama-tts)
add_executable(${TARGET} tts.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)
Loading
Loading