From e5101859b48ba429c6bea9a5af370fa984ea1a4a Mon Sep 17 00:00:00 2001 From: Georg Kolling Date: Fri, 29 Sep 2023 16:20:06 +0100 Subject: [PATCH 1/2] [skip-CI] CRFModel.cpp (Koi codepath): update needed for fixes in koi window kernel. Support conv2 with in_size > 16. Needs Koi version bump --- dorado/nn/CRFModel.cpp | 45 +++++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/dorado/nn/CRFModel.cpp b/dorado/nn/CRFModel.cpp index 632915248..57202322d 100644 --- a/dorado/nn/CRFModel.cpp +++ b/dorado/nn/CRFModel.cpp @@ -263,23 +263,26 @@ struct ConvolutionImpl : Module { int64_t batch_size = wm.current_sizes[0]; int64_t chunk_size_in = wm.current_sizes[1]; int64_t chunk_size_out = chunk_size_in / stride; - if (next_layer_is_lstm) { - switch (auto lstm_mode = get_cuda_lstm_mode(0, out_size)) { + if (next_layer_is_lstm || in_size > 16) { + // For conv2 with in_size > 16 we can use the same codepath as QUANTISED_NTC + LstmMode lstm_mode = + next_layer_is_lstm ? get_cuda_lstm_mode(0, out_size) : LstmMode::QUANTISED_NTC; + switch (lstm_mode) { case LstmMode::CUTLASS_TNC_I8: - wm.reserve({chunk_size_out, batch_size, in_size, window_size}, torch::kF16); + wm.reserve({chunk_size_out, batch_size, window_size, in_size}, torch::kF16); wm.reserve({chunk_size_out * batch_size, out_size}, torch::kF16); wm.reserve({chunk_size_out + 3, batch_size, out_size}, torch::kI8); break; case LstmMode::QUANTISED_NTC: - wm.reserve({batch_size, chunk_size_out, in_size, window_size}, torch::kF16); + wm.reserve({batch_size, chunk_size_out, window_size, in_size}, torch::kF16); wm.reserve({batch_size, chunk_size_out, out_size}, torch::kF16); break; case LstmMode::CUTLASS_TNC_F16: - wm.reserve({chunk_size_out, batch_size, in_size, window_size}, torch::kF16); + wm.reserve({chunk_size_out, batch_size, window_size, in_size}, torch::kF16); wm.reserve({chunk_size_out + 3, batch_size, out_size}, torch::kF16); break; case LstmMode::CUBLAS_TN2C: - wm.reserve({chunk_size_out, batch_size, in_size, window_size}, torch::kF16); + wm.reserve({chunk_size_out, batch_size, window_size, in_size}, torch::kF16); wm.reserve({chunk_size_out + 1, batch_size, 2, out_size}, torch::kF16); break; default: @@ -300,27 +303,29 @@ struct ConvolutionImpl : Module { int chunk_size_out = chunk_size_in / stride; // TODO: make device weights permanent? - auto w_device = conv->weight.view({out_size, in_size * window_size}) - .t() + // conv->weight is [C_out, C_in, W], we want [W, C_in, C_out] + auto w_device = conv->weight.permute({2, 1, 0}) + .contiguous() .to(in.options()) - .contiguous(); + .view({window_size * in_size, out_size}); auto b_device = conv->bias.to(in.options()); - if (next_layer_is_lstm) { - auto lstm_mode = get_cuda_lstm_mode(0, out_size); - torch::Tensor ntcw_mat, tncw_mat; + if (next_layer_is_lstm || in_size > 16) { + // For conv2 with in_size > 16 we can use the same codepath as QUANTISED_NTC + LstmMode lstm_mode = + next_layer_is_lstm ? get_cuda_lstm_mode(0, out_size) : LstmMode::QUANTISED_NTC; + torch::Tensor ntwc_mat, tnwc_mat; if (lstm_mode == LstmMode::QUANTISED_NTC) { - ntcw_mat = wm.next({batch_size, chunk_size_out, in_size, window_size}, torch::kF16); + ntwc_mat = wm.next({batch_size, chunk_size_out, in_size, window_size}, torch::kF16); } else { - tncw_mat = wm.next({chunk_size_out, batch_size, in_size, window_size}, torch::kF16); - ntcw_mat = tncw_mat.transpose(0, 1); + tnwc_mat = wm.next({chunk_size_out, batch_size, in_size, window_size}, torch::kF16); + ntwc_mat = tnwc_mat.transpose(0, 1); } - host_window_ntcw_f16(stream, in.stride(0), in.stride(1), in.stride(2), batch_size, - chunk_size_in, in_size, window_size, stride, ntcw_mat.stride(0), - ntcw_mat.stride(1), ntcw_mat.stride(2), ntcw_mat.stride(3), - in.data_ptr(), ntcw_mat.data_ptr()); + host_window_ntwc_f16(stream, batch_size, chunk_size_in, in_size, window_size, stride, + ntwc_mat.stride(0), ntwc_mat.stride(1), in.data_ptr(), + ntwc_mat.data_ptr()); - auto mm_in = wm.current.view({-1, in_size * window_size}); + auto mm_in = wm.current.view({-1, window_size * in_size}); torch::Tensor mm_out, out; if (lstm_mode == LstmMode::QUANTISED_NTC) { // Output is [N, T_out, C_out], F16 From c6f03efeaca2f5701419f94d31b4319de658c050 Mon Sep 17 00:00:00 2001 From: Georg Kolling Date: Thu, 12 Oct 2023 11:20:47 +0100 Subject: [PATCH 2/2] Bump Koi version to 0.4.0 --- cmake/Koi.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/Koi.cmake b/cmake/Koi.cmake index 20147b60d..5122dd589 100644 --- a/cmake/Koi.cmake +++ b/cmake/Koi.cmake @@ -20,7 +20,7 @@ endfunction() if(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR WIN32) - set(KOI_VERSION 0.3.9) + set(KOI_VERSION 0.4.0) if(BUILD_KOI_FROM_SOURCE) message(STATUS "Building Koi from source") set(KOI_DIR "${DORADO_3RD_PARTY_DOWNLOAD}/koi")