Skip to content

Commit

Permalink
Merge branch 'koi-convolution-fixes' into 'master'
Browse files Browse the repository at this point in the history
CRFModel (Koi codepath): update needed for fixes in koi window kernel

See merge request machine-learning/dorado!616
  • Loading branch information
GKolling committed Oct 13, 2023
2 parents 02b383e + c6f03ef commit 6be61d0
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 21 deletions.
2 changes: 1 addition & 1 deletion cmake/Koi.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ endfunction()

if(CMAKE_SYSTEM_NAME STREQUAL "Linux" OR WIN32)

set(KOI_VERSION 0.3.9)
set(KOI_VERSION 0.4.0)
if(BUILD_KOI_FROM_SOURCE)
message(STATUS "Building Koi from source")
set(KOI_DIR "${DORADO_3RD_PARTY_DOWNLOAD}/koi")
Expand Down
45 changes: 25 additions & 20 deletions dorado/nn/CRFModel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,23 +263,26 @@ struct ConvolutionImpl : Module {
int64_t batch_size = wm.current_sizes[0];
int64_t chunk_size_in = wm.current_sizes[1];
int64_t chunk_size_out = chunk_size_in / stride;
if (next_layer_is_lstm) {
switch (auto lstm_mode = get_cuda_lstm_mode(0, out_size)) {
if (next_layer_is_lstm || in_size > 16) {
// For conv2 with in_size > 16 we can use the same codepath as QUANTISED_NTC
LstmMode lstm_mode =
next_layer_is_lstm ? get_cuda_lstm_mode(0, out_size) : LstmMode::QUANTISED_NTC;
switch (lstm_mode) {
case LstmMode::CUTLASS_TNC_I8:
wm.reserve({chunk_size_out, batch_size, in_size, window_size}, torch::kF16);
wm.reserve({chunk_size_out, batch_size, window_size, in_size}, torch::kF16);
wm.reserve({chunk_size_out * batch_size, out_size}, torch::kF16);
wm.reserve({chunk_size_out + 3, batch_size, out_size}, torch::kI8);
break;
case LstmMode::QUANTISED_NTC:
wm.reserve({batch_size, chunk_size_out, in_size, window_size}, torch::kF16);
wm.reserve({batch_size, chunk_size_out, window_size, in_size}, torch::kF16);
wm.reserve({batch_size, chunk_size_out, out_size}, torch::kF16);
break;
case LstmMode::CUTLASS_TNC_F16:
wm.reserve({chunk_size_out, batch_size, in_size, window_size}, torch::kF16);
wm.reserve({chunk_size_out, batch_size, window_size, in_size}, torch::kF16);
wm.reserve({chunk_size_out + 3, batch_size, out_size}, torch::kF16);
break;
case LstmMode::CUBLAS_TN2C:
wm.reserve({chunk_size_out, batch_size, in_size, window_size}, torch::kF16);
wm.reserve({chunk_size_out, batch_size, window_size, in_size}, torch::kF16);
wm.reserve({chunk_size_out + 1, batch_size, 2, out_size}, torch::kF16);
break;
default:
Expand All @@ -300,27 +303,29 @@ struct ConvolutionImpl : Module {
int chunk_size_out = chunk_size_in / stride;

// TODO: make device weights permanent?
auto w_device = conv->weight.view({out_size, in_size * window_size})
.t()
// conv->weight is [C_out, C_in, W], we want [W, C_in, C_out]
auto w_device = conv->weight.permute({2, 1, 0})
.contiguous()
.to(in.options())
.contiguous();
.view({window_size * in_size, out_size});
auto b_device = conv->bias.to(in.options());

if (next_layer_is_lstm) {
auto lstm_mode = get_cuda_lstm_mode(0, out_size);
torch::Tensor ntcw_mat, tncw_mat;
if (next_layer_is_lstm || in_size > 16) {
// For conv2 with in_size > 16 we can use the same codepath as QUANTISED_NTC
LstmMode lstm_mode =
next_layer_is_lstm ? get_cuda_lstm_mode(0, out_size) : LstmMode::QUANTISED_NTC;
torch::Tensor ntwc_mat, tnwc_mat;
if (lstm_mode == LstmMode::QUANTISED_NTC) {
ntcw_mat = wm.next({batch_size, chunk_size_out, in_size, window_size}, torch::kF16);
ntwc_mat = wm.next({batch_size, chunk_size_out, in_size, window_size}, torch::kF16);
} else {
tncw_mat = wm.next({chunk_size_out, batch_size, in_size, window_size}, torch::kF16);
ntcw_mat = tncw_mat.transpose(0, 1);
tnwc_mat = wm.next({chunk_size_out, batch_size, in_size, window_size}, torch::kF16);
ntwc_mat = tnwc_mat.transpose(0, 1);
}
host_window_ntcw_f16(stream, in.stride(0), in.stride(1), in.stride(2), batch_size,
chunk_size_in, in_size, window_size, stride, ntcw_mat.stride(0),
ntcw_mat.stride(1), ntcw_mat.stride(2), ntcw_mat.stride(3),
in.data_ptr(), ntcw_mat.data_ptr());
host_window_ntwc_f16(stream, batch_size, chunk_size_in, in_size, window_size, stride,
ntwc_mat.stride(0), ntwc_mat.stride(1), in.data_ptr(),
ntwc_mat.data_ptr());

auto mm_in = wm.current.view({-1, in_size * window_size});
auto mm_in = wm.current.view({-1, window_size * in_size});
torch::Tensor mm_out, out;
if (lstm_mode == LstmMode::QUANTISED_NTC) {
// Output is [N, T_out, C_out], F16
Expand Down

0 comments on commit 6be61d0

Please sign in to comment.