From 6d586330d9b6d2a2f487d7985e5d586e22b1e47a Mon Sep 17 00:00:00 2001 From: chqy99 <1216494776@qq.com> Date: Tue, 19 Nov 2024 14:11:06 +0800 Subject: [PATCH 1/4] [Feature](mluOpRoiAlignRotatedForward): bin_cycle vector --- .../roi_align_rotated/roi_align_rotated.md | 6 +- .../roi_align_rotated_forward_vector.md | 435 +++++++++++++ .../roi_align_rotated/roi_align_rotated.cpp | 53 +- kernels/roi_align_rotated/roi_align_rotated.h | 15 +- .../roi_align_rotated_block.mlu | 4 +- .../roi_align_rotated_forward_vector.mlu | 605 ++++++++++++++++++ kernels/utils/common.h | 16 +- 7 files changed, 1095 insertions(+), 39 deletions(-) create mode 100644 docs/design_docs/roi_align_rotated/roi_align_rotated_forward_vector.md create mode 100644 kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu diff --git a/docs/design_docs/roi_align_rotated/roi_align_rotated.md b/docs/design_docs/roi_align_rotated/roi_align_rotated.md index 81dfbdf9a..2053a5e27 100644 --- a/docs/design_docs/roi_align_rotated/roi_align_rotated.md +++ b/docs/design_docs/roi_align_rotated/roi_align_rotated.md @@ -37,7 +37,7 @@ roi_align_rotated算子应用于FOTS网络结构中,以双线性插值的方 |是否需要支持原位 | 否 | | 是否需要支持stride机制 | 否 | | 是否需要支持广播 | 否 | -| 0元素检查是否直接返回 | 是 | | +| 0元素检查是否直接返回 | 是 | ### 1.2 算子功能和应用场景描述 @@ -208,7 +208,7 @@ mluOpStatus_t MLUOP_WIN_API mluOpRoiAlignRotatedBackward(mluOpHandle_t handle, |---------|---------|---------|---------| | input | output | input | output | |---------|---------|---------|---------| - + ``` 与前向类似,反向计算时把空间均分为4部分,保持每次处理的input和output的大小相同。 ### 3.4 性能优化设计 @@ -252,7 +252,7 @@ RoiAlign类的算子是IO瓶颈,在一个bin中需要处理多个采样点, 6、output的HW维度需要分别与参数中的pooled_height和pooled_width保持一致。 反向: - + 1、指针为空防呆; 2、0元素检查防呆,VLOG(5)打印信息; diff --git a/docs/design_docs/roi_align_rotated/roi_align_rotated_forward_vector.md b/docs/design_docs/roi_align_rotated/roi_align_rotated_forward_vector.md new file mode 100644 index 000000000..ccd3e4611 --- /dev/null +++ b/docs/design_docs/roi_align_rotated/roi_align_rotated_forward_vector.md @@ -0,0 +1,435 @@ +# roi_align_rotated_forward 向量化实现设计方案 + +* #### 文档基本信息 + +| 算子名称 | roi_align_rotated_forward 向量化实现 | +| ---------- | ---------------------------------- | +| 编制人/日期 | 陈其阳/2024-11-4 | + +* #### 修改记录 +| 版本号 | 修订人 | 修订日期 | 修订描述 | +| ----- | ------ | ------- | ------- | +| V 1.0 | 陈其阳 | 2024-11-4 | 首次提交 | + +* #### 内容描述 + +本文档为`roi_align_rotated_forward`算子向量化实现的设计文档,包括需求分析、接口设计、方案设计、性能优化记录。 + +## 1 需求分析 + +### 1.1 算子需求分析 + +| 算子功能简介| 以双线性插值的方式提取非整数大小且带有旋转的roi的特征图| +|-------------|--------------------------------------------------------------| +| 需求来源 | mmcv | +| 应用网络 | FOTS | +| 输入数据类型| half, float | +| 输入Shape | input1: [batch, hi, wi, channels]; input2: [roi_nums, 6] | +| 输入Layout | input1: NHWC; input2: ARRAY | +| 输出数据类型| half, float | +| 输出Shape | [roi_nums, ho, wo, channels] | +| 输出Layout | NHWC | +|是否含有dim/axis等类似语义的参数且该参数支持负数/其他特殊处理 | 否| +|是否含有labels/index等类似语义的参数且该参数支持负数/界外情况/其他特殊处理 | 否| +|是否需要支持原位 | 否 | +| 是否需要支持stride机制 | 否 | +| 是否需要支持广播 | 否 | +| 0元素检查是否直接返回 | 是 | + +### 1.2 算子功能和应用场景描述 + +在FOTS网络中,roi_align_rotated算子用于统一检测和识别到端到端的pipeline中,输入检测分支中得到的带有旋转角度的bounding boxes,提取对应的特征图用于后续的识别。 + +![通道的二维展开](fots_framework.png) + +### 1.3 算子输入输出参数要求 +#### 1.3.1 roi_align_rotated_forward + +| 参数 | 语义 | 类型(输入/输出) | 支持类型 | 物理布局 | 规模限制 | +| ------------- | ---- | ----------------- | ----------- | -------- | -------- | +| handle | MLUOP句柄,保存运行的上下文信息 | 输入 | | / | 无 | +| features_desc | 输入特征图数据的描述信息 | 输入 | | / | features的维度必须是4 | +| features | 输入数据,指向输入特征图数据的mlu首地址 | 输入 | half, float | NHWC | 无 | +| rois_desc | roi数据的描述信息 | 输入 | | / | rois的维度必须是2,且第二维的大小必须是6 | +| rois | 输入数据,指向rois的mlu地址 | 输入 | half, float | ARRAY | 无 | +| pooled_height | 输出output的height | 输入 | int | / | 无 | +| pooled_width | 输出output的width | 输入 | int | / | 无 | +| sample_ratio | 一个bin的采样率 | 输入 | int | / | 无 | +| spatial_scale | rois在feature map上的缩放比例 | 输入 | float | / | 无 | +| aligned | 决定rois中的像素是否需要偏移 | 输入 | bool | / | 无 | +| clockwise | 是否顺时针旋转 | 输入 | bool | / | 无 | +| output_desc | 输出数据的描述信息 | 输入 | | / | output的维度必须是4,且第一维大小与rois的第一维大小一致,第二维大小与pooled_height一致,第三维大小与pooled_width一致,第四维大小与features的第四维大小一致 | +| output | 指向输出数据的mlu首地址 | 输出 | half, float | NHWC | 无 | + +### 1.4 算子限制 +#### 1.4.1 roi_align_rotated_forward +- rois是一个二维的Tensor,其中第一维与output的第一维相同,最后一维必须等于6。每个roi包含(batch_id,x,y, w, h, θ),其中,x和y表示的是roi中心点的坐标,w和h分别是roi的宽和高,θ表示边框逆时针旋转的角度。 + +- rois中batch_id的值在[0, batch-1]范围内,其中batch是features的第一维的大小,rois中参数x,y,w和h与spatial_scale的乘积值不能超过参数类型可表示的范围;rois中包含NaN和infinity数据时,只有x和y支持infinity数据,其它都不支持。 + +- output的最高维与rois的最高维相等,最后一维大小与features的最后一维相等。 + +- features, rois, output数据类型要相同。 + +### 1.5 验收标准 + +#### 1.5.1 精度验收标准 + +- 采用动态阈值: + diff=[diff1, diff2], threshold_rate=[10, 10]。 + +## 2 算子接口设计 + +### 2.1 参考接口 + +- MMCV +```c++ +// forward +template +__global__ void roi_align_rotated_forward_cuda_kernel( + const int nthreads, const scalar_t *bottom_data, + const scalar_t *bottom_rois, const scalar_t spatial_scale, + const int sample_num, const bool aligned, const bool clockwise, + const int channels, const int height, const int width, + const int pooled_height, const int pooled_width, scalar_t *top_data); +``` + +### 2.2 接口设计 + +```c++ +// forward +mluOpStatus_t MLUOP_WIN_API mluOpRoiAlignRotatedForward(mluOpHandle_t handle, + const mluOpTensorDescriptor_t features_desc, + const void *features, + const mluOpTensorDescriptor_t rois_desc, + const void *rois, + const int pooled_height, + const int pooled_width, + const int sample_ratio, + const float spatial_scale, + const bool aligned, + const bool clockwise, + const mluOpTensorDescriptor_t output_desc, + void *output); +``` + +## 3 实现方案设计 + +### 3.1 实现方案 +在支持 gather.vector 后,相比于一次处理单个输出点,一次处理多个输出点的方案有更高的 IO 效率,能够显著提升性能。
+ + +该算子的MMCV 实现有 6 层循环,分别是[roi_nums, pooled_height, pooled_width, channels, roi_bin_grid_h, roi_bin_grid_w]。
+其中[roi_nums, pooled_height, pooled_width, channels]的一个点(bin)对应一个线程,每个线程中的循环是[roi_bin_grid_h, roi_bin_grid_w]。
+该算子涉及大量的坐标计算,坐标计算与 channels 维度无关,因此将该维度放置于最后处理,可以复用坐标信息。
+注:有考虑过 [roi_bin_grid_h, roi_bin_grid_w] 的顺序放置在 [pooled_height, pooled_width] 之前,但是单个 bin 内有一个累加计算,这个累加计算拆给多个单元去计算,就需要 atomic_add,且控制逻辑会更复杂。
+ + +MLU 将6层循环分成4个部分,第一循环[roi_nums, pooled_height, pooled_width],第二循环[channels](channels累加结果进行缓存),第三循环[roi_bin_grid_h, roi_bin_grid_w],第四循环[channels_cache](复用坐标信息)。
+ + +(1)第一循环 +- 核间拆分,记第一循环中总数量为 n1,使用最简单拆分逻辑即可。 +```c++ + for (uint32_t i = taskId; i < n1; i+=taskDim) +``` +- [roi_info 计算](#321-roi_info-计算),roi_info 信息只与 roi_idx 有关,若roi_idx 与上次相同时,则跳过该步。
+ + +(2)第二循环 +- 由于 NRAM 大小限制,设置 channels 缓存量为1024,如果 channels 大于 1024,会进行循环。记 channels 缓存空间为 output_channels。
+- 在第三循环开始,对 output_channels 刷 0。
+- 在第三循环结束,将 output_channels 存到 output 。(前两层循环与 output 一一对应)
+注:channels 在坐标计算外循环,会导致坐标重复计算;channels 过大时,gather.vector 需要多次,IO 开销增大。即 channels 过大时,向量化实现的性能会下降。
+ + +(3)第三循环(向量化的重点) +- h_idx, w_idx 序列构造
+roi_bin_grid_h,roi_bin_grid_w 对应两个方向的采样率。
+sample_ratio>0 时,采样率都等于 sample_ratio,通常值大小在 2~9 之间。
+sample_ratio<0 时,采样率为 roi_height(width)/pooled_height(width),参考 roi_align,采样率范围通常在30以内。
+根据采样率,向上选择序列长度(包括:8,16,32),序列长度记为 bin_order_num。为方便起见,h,w方向处理的数量一致。如果采样率超过32,则需要循环处理。
+使用vv.index构造自增序列(从 .5f 开始,步长为 1),该序列会缓存,方便多次构建二维序列。
+二维序列 w_idx 构造,相当于长度为 roi_bin_grid_w 自增序列复制 roi_bin_grid_h 次,使用 stride=0 的 memcpy2d 可以实现。
+二维序列 h_idx 构造,相当于长度为 roi_bin_grid_h 自增序列每个位置扩展至 roi_bin_grid_w 次,使用 __extension 实现。
+- 计算 x, y 序列
+先计算 h_idx_in_bin, w_idx_in_bin。
+h_idx_in_bin = roi_start_h + ph * bin_size_h + (h_idx + .5f) * bin_size_h / roi_bin_grid_h;
+w_idx_in_bin = roi_start_w + pw * bin_size_w + (w_idx + .5f) * bin_size_w / roi_bin_grid_w;
+注:除法部分不能用乘法代替,否则坐标计算出现误差,而后续又有筛选操作,会使得精度严重下降。
+y = h_idx_in_bin * cosscalar_theta - w_idx_in_bin * sinscalar_theta + roi_center_h;
+x = h_idx_in_bin * sinscalar_theta + w_idx_in_bin * cosscalar_theta + roi_center_w;
+- 筛选有效点
+if (y < -1.0 || y > height || x < -1.0 || x > width) return 0;
+mmcv 有上行处理,无效点的val为0,+= 0 等价于不做处理。
+bangC 实现为 [筛选有效点](#322-筛选有效点)。
+注:mmcv和MLU实现时,nan 都不会被筛选出去。
+- 计算双线性插值四个点的坐标(pos)和权值(w)
+[双线性插值前半部](#323-双线性插值),由于需要处理边界情况,使用更为直接的标量处理。
+- 坐标去重,权重相加
+根据坐标在GDRAM上取值是对性能影响最大的一步,通常来说有坐标重复的点,去重后可以减少IO次数,进而提升性能。
+去重的逻辑较为复杂也不便向量化实现,在[双线性插值后半部](#323-双线性插值)中实现。
+注:经测试,去重后点的数量(unique_num)一般变少,有时能取得几倍的性能提升。然而,权重提前相加,改变了计算顺序,使得 nan/inf 不对齐。
+ + +(4)第四循环 +- 在[channels_cache]中循环,记单次最大取 max_once_c 个 channel。
+在不支持 gather.vector 的机器上,一次最大只能取一个 channel,max_once_c = 1。
+在支持 gather.vector 的机器上,unique * max_once_c 要不超过 NRAM 空间限制,详细见[拆分](#33-拆分)。
+- 从 input 中根据坐标信息去取值(对性能影响最大的一步)
+pos 需要变成字节偏移,即乘以 channels * sizeof(T)。
+input 可能不按 64B 对齐,需要做对齐处理,pos 还需加上 input 对齐的偏移。
+- 计算双线性插值结果 +取数后 v([unique_num, once_c]),w([unique_num]),要进行广播乘法(目前只能先转置再调用__bang_cycle_mul),得到 val([unique_num, once_c])。
+使用 __bang_sumpool 对 val 做累加得到 val_sum([once_c]) 。val_sum 加到 output_channels 中。
+注:mmcv 的累加顺序一定是从前往后,而 sumpool 累加顺序不是,会使得精度有偏差,inf/nan 无法对齐。
+ + +### 3.2 伪代码实现 + +#### 3.2.1 roi_info 计算 +```c++ +template +__mlu_func__ void getRoiInfo(const T *rois_dram, int roi_idx, + const mluOpRoiAlignRotatedParams ¶ms, + int &roi_batch_ind, T &roi_center_h, + T &roi_center_w, T &bin_size_h, T &bin_size_w, + int &roi_bin_grid_h, int &roi_bin_grid_w, + T &roi_start_h, T &roi_start_w, T &cos_theta, + T &sin_theta, T &count) { + const T *roi_info = rois_dram + roi_idx * ROI_OFFSET; + roi_batch_ind = (int)roi_info[0]; + T offset = params.aligned ? (T)0.5 : (T)0.0; + roi_center_w = roi_info[1] * (T)params.spatial_scale - offset; + roi_center_h = roi_info[2] * (T)params.spatial_scale - offset; + T roi_width = roi_info[3] * (T)params.spatial_scale; + T roi_height = roi_info[4] * (T)params.spatial_scale; + T theta = roi_info[5]; + if (params.clockwise) { + theta = -(theta); + } + if (!params.aligned) { + roi_width = fmaxf(roi_width, (T)1.0); + roi_height = fmaxf(roi_height, (T)1.0); + } + + bin_size_h = roi_height / (T)params.pooled_height; + bin_size_w = roi_width / (T)params.pooled_width; + + if constexpr (sr_gt0) { + roi_bin_grid_h = params.sample_ratio; + roi_bin_grid_w = params.sample_ratio; + } else { + if constexpr (std::is_same::value) { + roi_bin_grid_h = __half2int_up(bin_size_h); + roi_bin_grid_w = __half2int_up(bin_size_w); + } else { + roi_bin_grid_h = __float2int_up(bin_size_h); + roi_bin_grid_w = __float2int_up(bin_size_w); + } + } + + roi_start_h = roi_height / (T)-2.0; + roi_start_w = roi_width / (T)-2.0; + + if constexpr (std::is_same::value) { + cos_theta = __cn_scalar_cos_f16(theta); + sin_theta = __cn_scalar_sin_f16(theta); + } else { + cos_theta = __cn_scalar_cos_f32(theta); + sin_theta = __cn_scalar_sin_f32(theta); + } + + count = fmaxf(T(roi_bin_grid_h * roi_bin_grid_w), (T)1.0); +} +``` + +#### 3.2.2 筛选有效点 +```c++ +template +__mlu_func__ void selectValidPoint(const int height, const int width, T *nram_y, + T *nram_x, const uint32_t deal_num, T *aux1, + T *aux2, T *aux3, uint32_t &valid_num) { + // y < -1.0 + __bang_lt_scalar(aux1, nram_y, (T)-1, deal_num); + // || y > height + __bang_gt_scalar(aux2, nram_y, (T)height, deal_num); + __bang_or(aux3, aux1, aux2, deal_num); + // || x < -1 + __bang_lt_scalar(aux1, nram_x, (T)-1, deal_num); + __bang_or(aux3, aux3, aux1, deal_num); + // || x > width + __bang_gt_scalar(aux2, nram_x, (T)width, deal_num); + __bang_or(aux3, aux3, aux2, deal_num); + __bang_not(aux3, aux3, deal_num); + __bang_filter(nram_y, nram_y, aux3, deal_num); + valid_num = __bang_filter(nram_x, nram_x, aux3, deal_num); +} +``` + +#### 3.2.3 双线性插值 +```c++ +template +__mlu_func__ void bilinearInterpolatePosWeight( + const int height, const int width, T *nram_y, T *nram_x, + const uint32_t valid_num, uint32_t *pos1, uint32_t *pos2, uint32_t *pos3, + uint32_t *pos4, T *w1, T *w2, T *w3, T *w4, uint32_t &unique_num) { + for (uint32_t i = 0; i < valid_num; ++i) { + T y = nram_y[i]; + T x = nram_x[i]; + + if (y <= 0) y = 0; + if (x <= 0) x = 0; + + int y_low, x_low, y_high, x_high; + if constexpr (std::is_same::value) { + y_low = __half2int(y); + x_low = __half2int(x); + } else { + y_low = __float2int(y); + x_low = __float2int(x); + } + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + pos1[i] = y_low * width + x_low; + pos2[i] = y_low * width + x_high; + pos3[i] = y_high * width + x_low; + pos4[i] = y_high * width + x_high; + w1[i] = hy * hx; + w2[i] = hy * lx; + w3[i] = ly * hx; + w4[i] = ly * lx; + } + // unique + unique_num = 0; + for (int i = 0; i < valid_num; ++i) { + if (w1[i] < 0) { + continue; + } + for (int j = i + 1; j < valid_num; ++j) { + if (pos1[i] == pos1[j]) { + // Points at the same position + w1[i] += w1[j]; + w2[i] += w2[j]; + w3[i] += w3[j]; + w4[i] += w4[j]; + w1[j] = -1; + } + } + if (unique_num != i) { + pos1[unique_num] = pos1[i]; + pos2[unique_num] = pos2[i]; + pos3[unique_num] = pos3[i]; + pos4[unique_num] = pos4[i]; + w1[unique_num] = w1[i]; + w2[unique_num] = w2[i]; + w3[unique_num] = w3[i]; + w4[unique_num] = w4[i]; + } + unique_num += 1; + } +} +``` + +### 3.3 拆分 + +#### 3.3.1 核间拆分 +目前使用简单拆分的逻辑: +```c++ +uint32_t n1 = rois_num * pooled_height * pooled_width; +for (uint32_t i = taskId; i < n1; i+=taskDim) { + ... +} +``` +总逻辑分为 6 个维度,如果针对特定规模的话,可以采用别的拆分方法,目前没有这样的需求。 + +#### 3.3.2 核内拆分 +记原始自增序列长度为 bin_order_num,二维扩展后长度为 bin_hw_order_num。
+bin_hw_order_num = bin_order_num ^ 2。
+ + +固定NRAM空间划分为: +| name | size | 用途 | +| ------ | ------ | ---------- | +| order | 128 | 起点 0.5,步长为 1 的自增序列 | +| output_channels | sizeof(T) * 1024 | 存储 output 结果 | +| bin_h | sizeof(T) * bin_hw_order_num | h_idx 二维序列 | +| bin_w | sizeof(T) * bin_hw_order_num | w_idx 二维序列 | +| y | sizeof(T) * bin_hw_order_num | y 序列 | +| x | sizeof(T) * bin_hw_order_num | x 序列 | +| w1 | sizeof(T) * bin_hw_order_num | 计算 y,x 时的缓冲区。w1 权重 | +| w2 | sizeof(T) * bin_hw_order_num | 计算 y,x 时的缓冲区。w2 权重 | +| w3 | sizeof(T) * bin_hw_order_num | 计算 y,x 时的缓冲区。w3 权重 | +| w4 | sizeof(T) * bin_hw_order_num | w4 权重 | +| pos1 | sizeof(uint) * bin_hw_order_num | pos1 坐标 | +| pos2 | sizeof(uint) * bin_hw_order_num | pos2 坐标 | +| pos3 | sizeof(uint) * bin_hw_order_num | pos3 坐标 | +| pos4 | sizeof(uint) * bin_hw_order_num | pos4 坐标 | + + +剩余空间对齐均分为三份 vi, vi_t, val,记空间大小为 max_v_size。
+其中 vi 复用多次,最终的 val_sum 也存储于 vi 中。
+此时 max_once_c = max_v_size / unique_num / sizeof(T)。
+以float 类型为例: +- 若 bin_order_num 为 32,固定的 size 为 53376, max_vi_size 为 113280。 +unique_num 最大可到 bin_hw_order_num(1024),此时 max_once_c = 27。 +- 若 bin_order_num 为 8,固定的 size 为 7296, max_vi_size 为 128640。 +unique_num 最大可到 bin_hw_order_num(64),此时 max_once_c = 502。 + + +### 3.4 性能优化设计 +1.向量化加速。 +2.减少重复计算,例如:roi_info 计算,bin_h、bin_w 二维序列构建等。 +3.使用 fuse.nram 融合三条以上的乘加法。 +4.双线性插值坐标进行查重,减少 IO 的数量。 + + +### 3.5 可维护性设计 + +1、每个函数都有相应的注释,表明该函数的功能以及参数信息; + +2、算子对应的feature提交,bug修复等,均应记录在对应的维护表中; + +3、在kernel入口处应用参数检查,log打印,kernel出口应有相应的debug信息打印; + +4、不支持的数据类型与物理布局应有相应的检查报错; + +### 3.6 算子防呆检查 + 1、指针为空防呆; + + 2、0元素检查防呆,VLOG(5)打印信息; + + 3、features和output必须为4维,rois必须要为2维,且rois的第二维大小必须是6; + + 4、features和output的layout必须相同,且都为NHWC; + + 5、output和rois的第一维必须相等,features和output的第四维必须相等; + + 6、output的HW维度需要分别与参数中的pooled_height和pooled_width保持一致。 + + +## 4 算子性能/精度问题 & 优化记录 + +### 4.1 当前存在问题的规模说明 + +暂无 + +### 4.2 已经过优化的规模说明 + +1.适用于向量化的规模:(sample_ratio >=3 || sample_ratio <= 0) && channels < 1024 diff --git a/kernels/roi_align_rotated/roi_align_rotated.cpp b/kernels/roi_align_rotated/roi_align_rotated.cpp index 3b4e8a95c..28765af0e 100644 --- a/kernels/roi_align_rotated/roi_align_rotated.cpp +++ b/kernels/roi_align_rotated/roi_align_rotated.cpp @@ -96,13 +96,6 @@ mluOpStatus_t MLUOP_WIN_API mluOpRoiAlignRotatedForward( return MLUOP_STATUS_BAD_PARAM; } - const int channel = features_desc->dims[3]; - const int width = features_desc->dims[2]; - const int height = features_desc->dims[1]; - const int batch = features_desc->dims[0]; - const int rois_nums = rois_desc->dims[0]; - mluOpDataType_t data_type = features_desc->dtype; - PARAM_CHECK_GT(API, pooled_height, 0); PARAM_CHECK_GT(API, pooled_width, 0); PARAM_CHECK_GE(API, spatial_scale, 0); @@ -117,10 +110,15 @@ mluOpStatus_t MLUOP_WIN_API mluOpRoiAlignRotatedForward( PARAM_CHECK(API, output != nullptr); PARAM_CHECK(API, rois != nullptr); + const int channel = features_desc->dims[3]; + const int width = features_desc->dims[2]; + const int height = features_desc->dims[1]; + const int batch = features_desc->dims[0]; + const int rois_nums = rois_desc->dims[0]; VLOG(5) << "pool_height: " << pooled_height << ",pool_width: " << pooled_width << ",channel: " << channel << ",roi nums: " << rois_nums << "."; VLOG(5) << "batch: " << batch << ",height: " << height << ",width: " << width - << "."; + << ",sample_ratio: " << sample_ratio << "."; if (MLUOP_GEN_CASE_ON_NEW) { GEN_CASE_START("roi_align_rotated_forward", "ROI_ALIGN_ROTATED_FORWARD"); @@ -142,18 +140,33 @@ mluOpStatus_t MLUOP_WIN_API mluOpRoiAlignRotatedForward( clockwise); GEN_CASE_TEST_PARAM_NEW(true, true, false, 0.003, 0.003, 0); } - mluOpRoiAlignRotatedParams roiAlignRotatedParams{pooled_height, pooled_width, - sample_ratio, spatial_scale, - aligned, clockwise}; + mluOpRoiAlignRotatedParams roiAlignRotatedParams{ + aligned, clockwise, pooled_height, + pooled_width, sample_ratio, spatial_scale}; + cnrtDim3_t k_dim; cnrtFunctionType_t k_type; policyFunc(handle, rois_nums * pooled_height * pooled_width, &k_dim, &k_type); - VLOG(5) << "[mluOpRoiAlignRotatedForward] launch kernel policyFunc[" - << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << "]."; - CHECK_RETURN(API, KernelRoiAlignRotatedForward( - k_dim, k_type, handle->queue, features_desc->dtype, - features, rois, batch, height, width, channel, - rois_nums, roiAlignRotatedParams, output)); + + uint32_t sample_ratio_split = 3, channels_split = 1024; + if (handle->arch >= MLUOP_MLU590 && channel <= channels_split && + (sample_ratio >= sample_ratio_split || sample_ratio <= 0) + ) { + VLOG(5) << "[mluOpRoiAlignRotatedForwardVector] launch kernel policyFunc[" + << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << "]."; + CHECK_RETURN(API, KernelRoiAlignRotatedForwardVector( + k_dim, k_type, handle->queue, features_desc->dtype, + features, rois, batch, height, width, channel, + rois_nums, roiAlignRotatedParams, output)); + } else { + VLOG(5) << "[mluOpRoiAlignRotatedForward] launch kernel policyFunc[" + << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << "]."; + CHECK_RETURN(API, KernelRoiAlignRotatedForward( + k_dim, k_type, handle->queue, features_desc->dtype, + features, rois, batch, height, width, channel, + rois_nums, roiAlignRotatedParams, output)); + } + VLOG(5) << "Kernel KernelRoiAlignRotatedForward."; GEN_CASE_END(); return MLUOP_STATUS_SUCCESS; @@ -259,9 +272,9 @@ mluOpStatus_t MLUOP_WIN_API mluOpRoiAlignRotatedBackward( clockwise); GEN_CASE_TEST_PARAM_NEW(true, true, false, 0.003, 0.003, 0); } - mluOpRoiAlignRotatedParams roiAlignRotatedParams{pooled_height, pooled_width, - sample_ratio, spatial_scale, - aligned, clockwise}; + mluOpRoiAlignRotatedParams roiAlignRotatedParams{ + aligned, clockwise, pooled_height, + pooled_width, sample_ratio, spatial_scale}; cnrtDim3_t k_dim; cnrtFunctionType_t k_type; diff --git a/kernels/roi_align_rotated/roi_align_rotated.h b/kernels/roi_align_rotated/roi_align_rotated.h index 14fdf2eea..5fc00ee78 100644 --- a/kernels/roi_align_rotated/roi_align_rotated.h +++ b/kernels/roi_align_rotated/roi_align_rotated.h @@ -26,22 +26,29 @@ #include "mlu_op.h" struct mluOpRoiAlignRotatedParams { + bool aligned; + bool clockwise; int pooled_height; int pooled_width; int sample_ratio; float spatial_scale; - bool aligned; - bool clockwise; }; -mluOpStatus_t MLUOP_WIN_API KernelRoiAlignRotatedForward( +mluOpStatus_t KernelRoiAlignRotatedForward( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + mluOpDataType_t d_type, const void *features, const void *rois, + const int batch, const int height, const int width, const int channel, + const int rois_num, const mluOpRoiAlignRotatedParams rroiAlignParams, + void *output); + +mluOpStatus_t KernelRoiAlignRotatedForwardVector( cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, mluOpDataType_t d_type, const void *features, const void *rois, const int batch, const int height, const int width, const int channel, const int rois_num, const mluOpRoiAlignRotatedParams rroiAlignParams, void *output); -mluOpStatus_t MLUOP_WIN_API KernelRoiAlignRotatedBackward( +mluOpStatus_t KernelRoiAlignRotatedBackward( cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, mluOpDataType_t d_type, const void *top_grad, const void *rois, const int batch, const int height, const int width, const int channel, diff --git a/kernels/roi_align_rotated/roi_align_rotated_block.mlu b/kernels/roi_align_rotated/roi_align_rotated_block.mlu index 684de42df..b28a133db 100644 --- a/kernels/roi_align_rotated/roi_align_rotated_block.mlu +++ b/kernels/roi_align_rotated/roi_align_rotated_block.mlu @@ -449,7 +449,7 @@ __mlu_global__ void roiAlignRotatedBackward( } } -mluOpStatus_t MLUOP_WIN_API KernelRoiAlignRotatedForward( +mluOpStatus_t KernelRoiAlignRotatedForward( cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, mluOpDataType_t d_type, const void *features, const void *rois, const int batch, const int height, const int width, const int channel, @@ -474,7 +474,7 @@ mluOpStatus_t MLUOP_WIN_API KernelRoiAlignRotatedForward( return MLUOP_STATUS_SUCCESS; } -mluOpStatus_t MLUOP_WIN_API KernelRoiAlignRotatedBackward( +mluOpStatus_t KernelRoiAlignRotatedBackward( cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, mluOpDataType_t d_type, const void *top_grad, const void *rois, const int batch, const int height, const int width, const int channel, diff --git a/kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu b/kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu new file mode 100644 index 000000000..0e5ffc64c --- /dev/null +++ b/kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu @@ -0,0 +1,605 @@ +/************************************************************************* + * Copyright (C) [2024] by Cambricon, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + * + * The above copyright notice and this permission notice shall be included + * in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "roi_align_rotated.h" + +#include "core/logging.h" +#include "kernels/debug.h" +#include "kernels/kernel.h" +#include "kernels/utils/common.h" +#include "mlu.h" + +__nram__ int8_t nram_buffer[MAX_NRAM_SIZE]; + +#define ROI_OFFSET 6 + +template +__mlu_func__ void getRoiInfo(const T *rois_dram, int roi_idx, + const mluOpRoiAlignRotatedParams ¶ms, + int &roi_batch_ind, T &roi_center_h, + T &roi_center_w, T &bin_size_h, T &bin_size_w, + int &roi_bin_grid_h, int &roi_bin_grid_w, + T &roi_start_h, T &roi_start_w, T &cos_theta, + T &sin_theta, T &count) { + const T *roi_info = rois_dram + roi_idx * ROI_OFFSET; + roi_batch_ind = (int)roi_info[0]; + T offset = params.aligned ? (T)0.5 : (T)0.0; + roi_center_w = roi_info[1] * (T)params.spatial_scale - offset; + roi_center_h = roi_info[2] * (T)params.spatial_scale - offset; + T roi_width = roi_info[3] * (T)params.spatial_scale; + T roi_height = roi_info[4] * (T)params.spatial_scale; + T theta = roi_info[5]; + if (params.clockwise) { + theta = -(theta); + } + if (!params.aligned) { + roi_width = fmaxf(roi_width, (T)1.0); + roi_height = fmaxf(roi_height, (T)1.0); + } + + bin_size_h = roi_height / (T)params.pooled_height; + bin_size_w = roi_width / (T)params.pooled_width; + + if constexpr (sr_gt0) { + roi_bin_grid_h = params.sample_ratio; + roi_bin_grid_w = params.sample_ratio; + } else { + if constexpr (std::is_same::value) { + roi_bin_grid_h = __half2int_up(bin_size_h); + roi_bin_grid_w = __half2int_up(bin_size_w); + } else { + roi_bin_grid_h = __float2int_up(bin_size_h); + roi_bin_grid_w = __float2int_up(bin_size_w); + } + } + + roi_start_h = roi_height / (T)-2.0; + roi_start_w = roi_width / (T)-2.0; + + if constexpr (std::is_same::value) { + cos_theta = __cn_scalar_cos_f16(theta); + sin_theta = __cn_scalar_sin_f16(theta); + } else { + cos_theta = __cn_scalar_cos_f32(theta); + sin_theta = __cn_scalar_sin_f32(theta); + } + + count = fmaxf(T(roi_bin_grid_h * roi_bin_grid_w), (T)1.0); +} + +template +__mlu_func__ void getXYorder(T *bh_order, T *bw_order, T h_offset, T w_offset, + uint32_t deal_num, T *nram_y, T *nram_x, T *aux1, + T *aux2, T *aux3, T bin_size_h, T bin_size_w, + int roi_bin_grid_h, int roi_bin_grid_w, T h_bias, + T w_bias, T roi_center_h, T roi_center_w, + T cos_theta, T sin_theta) { + // h_idx_in_bin = (bh_order + h_offset)* bin_size_h / + // roi_bin_grid_h + h_bias + // w_idx_in_bin = (bw_order + w_offset)* bin_size_w / + // roi_bin_grid_w + w_bias + __bang_add_scalar(aux1, bh_order, h_offset, deal_num); + __bang_mul_scalar(aux1, aux1, bin_size_h, deal_num); + __bang_add_scalar(aux2, bw_order, w_offset, deal_num); + __bang_mul_scalar(aux2, aux2, bin_size_w, deal_num); + // Coordinate calculation requires high precision. + // must use div + if constexpr (std::is_same::value) { + __asm__ volatile( + "div.scalar.nram.f16 [%[dst]], [%[src0]], " + "%[src1], %[num];\n\t" ::[dst] "r"(aux1), + [ src0 ] "r"(aux1), [ src1 ] "r"((T)roi_bin_grid_h), + [ num ] "r"(deal_num)); + __asm__ volatile( + "div.scalar.nram.f16 [%[dst]], [%[src0]], " + "%[src1], %[num];\n\t" ::[dst] "r"(aux2), + [ src0 ] "r"(aux2), [ src1 ] "r"((T)roi_bin_grid_w), + [ num ] "r"(deal_num)); + } else { + __asm__ volatile( + "div.scalar.nram.f32 [%[dst]], [%[src0]], " + "%[src1], %[num];\n\t" ::[dst] "r"(aux1), + [ src0 ] "r"(aux1), [ src1 ] "r"((T)roi_bin_grid_h), + [ num ] "r"(deal_num)); + __asm__ volatile( + "div.scalar.nram.f32 [%[dst]], [%[src0]], " + "%[src1], %[num];\n\t" ::[dst] "r"(aux2), + [ src0 ] "r"(aux2), [ src1 ] "r"((T)roi_bin_grid_w), + [ num ] "r"(deal_num)); + } + __bang_add_scalar(aux1, aux1, h_bias, deal_num); + __bang_add_scalar(aux2, aux2, w_bias, deal_num); + + // y = h_idx_in_bin * cos_theta - + // w_idx_in_bin * sin_theta + roi_center_h + // x = h_idx_in_bin * sin_theta + + // w_idx_in_bin * cos_theta + roi_center_w + if constexpr (std::is_same::value) { + // calu y + __bang_mul_scalar(aux3, aux2, sin_theta, deal_num); + __asm__ volatile( + "fuse.nram.f16 [%[dst]], %[num], [%[src0]], .mul(%[cos_v]), " + ".sub([%[src1]]), .add(%[rh]);\n\t" ::[dst] "r"(nram_y), + [ num ] "r"(deal_num), [ src0 ] "r"(aux1), [ cos_v ] "r"(cos_theta), + [ src1 ] "r"(aux3), [ rh ] "r"(roi_center_h)); + // calu x + __bang_mul_scalar(aux3, aux2, cos_theta, deal_num); + __asm__ volatile( + "fuse.nram.f16 [%[dst]], %[num], [%[src0]], .mul(%[sin_v]), " + ".add([%[src1]]), .add(%[rw]);\n\t" ::[dst] "r"(nram_x), + [ num ] "r"(deal_num), [ src0 ] "r"(aux1), [ sin_v ] "r"(sin_theta), + [ src1 ] "r"(aux3), [ rw ] "r"(roi_center_w)); + } else { + // calu y + __bang_mul_scalar(aux3, aux2, sin_theta, deal_num); + __asm__ volatile( + "fuse.nram.f32 [%[dst]], %[num], [%[src0]], .mul(%[cos_v]), " + ".sub([%[src1]]), .add(%[rh]);\n\t" ::[dst] "r"(nram_y), + [ num ] "r"(deal_num), [ src0 ] "r"(aux1), [ cos_v ] "r"(cos_theta), + [ src1 ] "r"(aux3), [ rh ] "r"(roi_center_h)); + // calu x + __bang_mul_scalar(aux3, aux2, cos_theta, deal_num); + __asm__ volatile( + "fuse.nram.f32 [%[dst]], %[num], [%[src0]], .mul(%[sin_v]), " + ".add([%[src1]]), .add(%[rw]);\n\t" ::[dst] "r"(nram_x), + [ num ] "r"(deal_num), [ src0 ] "r"(aux1), [ sin_v ] "r"(sin_theta), + [ src1 ] "r"(aux3), [ rw ] "r"(roi_center_w)); + } +} + +template +__mlu_func__ void selectValidPoint(const int height, const int width, T *nram_y, + T *nram_x, const uint32_t deal_num, T *aux1, + T *aux2, T *aux3, uint32_t &valid_num) { + // y < -1.0 + __bang_lt_scalar(aux1, nram_y, (T)-1, deal_num); + // || y > height + __bang_gt_scalar(aux2, nram_y, (T)height, deal_num); + __bang_or(aux3, aux1, aux2, deal_num); + // || x < -1 + __bang_lt_scalar(aux1, nram_x, (T)-1, deal_num); + __bang_or(aux3, aux3, aux1, deal_num); + // || x > width + __bang_gt_scalar(aux2, nram_x, (T)width, deal_num); + __bang_or(aux3, aux3, aux2, deal_num); + __bang_not(aux3, aux3, deal_num); + __bang_filter(nram_y, nram_y, aux3, deal_num); + valid_num = __bang_filter(nram_x, nram_x, aux3, deal_num); +} + +template +__mlu_func__ void bilinearInterpolatePosWeight( + const int height, const int width, T *nram_y, T *nram_x, + const uint32_t valid_num, uint32_t *pos1, uint32_t *pos2, uint32_t *pos3, + uint32_t *pos4, T *w1, T *w2, T *w3, T *w4, uint32_t &unique_num) { + for (uint32_t i = 0; i < valid_num; ++i) { + T y = nram_y[i]; + T x = nram_x[i]; + + if (y <= 0) y = 0; + if (x <= 0) x = 0; + + int y_low, x_low, y_high, x_high; + if constexpr (std::is_same::value) { + y_low = __half2int(y); + x_low = __half2int(x); + } else { + y_low = __float2int(y); + x_low = __float2int(x); + } + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + // w1 cache x_low, w3 cache x_high + // pos1 cache y_low, pos2 cache y_high + ((uint32_t *)w1)[i] = x_low; + ((uint32_t *)w3)[i] = x_high; + pos1[i] = y_low; + pos3[i] = y_high; + + // nram_y cache ly, nram_x cache lx + // T ly = y - y_low; + // T lx = x - x_low; + nram_y[i] = y - y_low; + nram_x[i] = x - x_low; + } + // pos2 = y_low * width; + __bang_mul_scalar(pos2, pos1, width, valid_num); + // pos1 = y_low * width + x_low; + __bang_add(pos1, pos2, (uint32_t *)w1, valid_num); + // pos2 = y_low * width + x_high; + __bang_add(pos2, pos2, (uint32_t *)w3, valid_num); + + // pos3 = y_high * width; + __bang_mul_scalar(pos4, pos3, width, valid_num); + // pos3 = y_high * width + x_low; + __bang_add(pos3, pos4, (uint32_t *)w1, valid_num); + // pos4 = y_high * width + x_high; + __bang_add(pos4, pos4, (uint32_t *)w3, valid_num); + + // w3 cache hy, w4 cache hx + // T hy = 1. - ly, hx = 1. - lx; + __bang_mul_scalar(w3, nram_y, -1.0, valid_num); + __bang_add_scalar(w3, w3, 1.0, valid_num); + __bang_mul_scalar(w4, nram_x, -1.0, valid_num); + __bang_add_scalar(w4, w4, 1.0, valid_num); + + // w1[i] = hy * hx; + __bang_mul(w1, w3, w4, valid_num); + // w2[i] = hy * lx; + __bang_mul(w2, w3, nram_x, valid_num); + // w3[i] = ly * hx; + __bang_mul(w3, nram_y, w4, valid_num); + // w4[i] = ly * lx; + __bang_mul(w4, nram_y, nram_x, valid_num); + + // unique + unique_num = 0; + for (int i = 0; i < valid_num; ++i) { + if (w1[i] < 0) { + continue; + } + for (int j = i + 1; j < valid_num; ++j) { + if (pos1[i] == pos1[j]) { + // Points at the same position + w1[i] += w1[j]; + w2[i] += w2[j]; + w3[i] += w3[j]; + w4[i] += w4[j]; + w1[j] = -1; + } + } + if (unique_num != i) { + pos1[unique_num] = pos1[i]; + pos2[unique_num] = pos2[i]; + pos3[unique_num] = pos3[i]; + pos4[unique_num] = pos4[i]; + w1[unique_num] = w1[i]; + w2[unique_num] = w2[i]; + w3[unique_num] = w3[i]; + w4[unique_num] = w4[i]; + } + unique_num += 1; + } +} + +template +__mlu_func__ void handleChannels(const T *input, uint32_t deal_channels, + uint32_t valid_num, T *w1, T *w2, T *w3, T *w4, + uint32_t *pos1, uint32_t *pos2, uint32_t *pos3, + uint32_t *pos4, int32_t &pos_offset, T *v, + T *v_t, T *val) { + // gather dst,src,offset must 64B align + int32_t align_offset = ((int64_t)input & 0x3f); + if (align_offset != 0) { + input = (T *)((int64_t)input - align_offset); + } + if (align_offset != pos_offset) { + __bang_add_scalar(pos1, pos1, align_offset - pos_offset, valid_num); + __bang_add_scalar(pos2, pos2, align_offset - pos_offset, valid_num); + __bang_add_scalar(pos3, pos3, align_offset - pos_offset, valid_num); + __bang_add_scalar(pos4, pos4, align_offset - pos_offset, valid_num); + pos_offset = align_offset; + } + uint32_t hwc_num = deal_channels * valid_num; + + __gather(v, input, pos1, deal_channels * sizeof(T), GDRAM2NRAM, + deal_channels * sizeof(T), valid_num); + if (deal_channels == 1) { + __bang_mul(val, v, w1, valid_num); + } else { + __bang_transpose(val, v, valid_num, deal_channels); + __bang_cycle_mul(val, val, w1, hwc_num, valid_num); + } + + __gather(v, input, pos2, deal_channels * sizeof(T), GDRAM2NRAM, + deal_channels * sizeof(T), valid_num); + if (deal_channels == 1) { + __bang_mul(v_t, v, w2, valid_num); + } else { + __bang_transpose(v_t, v, valid_num, deal_channels); + __bang_cycle_mul(v_t, v_t, w2, hwc_num, valid_num); + } + __bang_add(val, val, v_t, hwc_num); + + __gather(v, input, pos3, deal_channels * sizeof(T), GDRAM2NRAM, + deal_channels * sizeof(T), valid_num); + if (deal_channels == 1) { + __bang_mul(v_t, v, w3, valid_num); + } else { + __bang_transpose(v_t, v, valid_num, deal_channels); + __bang_cycle_mul(v_t, v_t, w3, hwc_num, valid_num); + } + __bang_add(val, val, v_t, hwc_num); + + __gather(v, input, pos4, deal_channels * sizeof(T), GDRAM2NRAM, + deal_channels * sizeof(T), valid_num); + if (deal_channels == 1) { + __bang_mul(v_t, v, w4, valid_num); + } else { + __bang_transpose(v_t, v, valid_num, deal_channels); + __bang_cycle_mul(v_t, v_t, w4, hwc_num, valid_num); + } + __bang_add(val, val, v_t, hwc_num); + + if (deal_channels != 1) { + __bang_transpose(v_t, val, deal_channels, valid_num); + __bang_sumpool(v, v_t, deal_channels, 1, valid_num, 1, valid_num, + valid_num, 1); + } else { + __bang_sumpool(v, val, deal_channels, 1, valid_num, 1, valid_num, + valid_num, 1); + } +} + +template +__mlu_global__ void roiAlignRotatedForward( + const T *input_dram, const T *rois_dram, const int batch, const int height, + const int width, const uint32_t channels, const int rois_num, + const mluOpRoiAlignRotatedParams params, T *output_dram) { +#if __BANG_ARCH__ >= 592 + if (coreId == 0x80) { + return; + } + /* + Cache + | | order | output_channels | + | size | 128 | 1024 * sizeof(T) | + */ + T *order = (T *)nram_buffer; + // Construct order(0.5,1.5,2.5,...,31.5) + uint32_t bin_order_num = 8; + if (params.sample_ratio > 16) { + bin_order_num = 32; + } else if (params.sample_ratio > 8) { + bin_order_num = 16; + } + + __mlu_op_gen_stage_index(order, bin_order_num, (T)0.5); + const uint32_t bin_hw_order_num = bin_order_num * bin_order_num; + + const uint32_t cache_channels_num = 1024; + // use for store + T *output_channels = order + NFU_ALIGN_SIZE / sizeof(T); + /* + bilinear_interpolate, coordinates can be reused + | | | volatile | + | name | bin_h,bin_w,y,x,w1,w2,w3,w4 | p1,p2,p3,p4 | v1,v2,v3 | + | type | T | uint32_t | T | + | num | bin_hw_order_num | | + */ + T *bin_h_order = output_channels + cache_channels_num; + T *bin_w_order = bin_h_order + bin_hw_order_num; + T *nram_y = bin_w_order + bin_hw_order_num; + T *nram_x = nram_y + bin_hw_order_num; + T *nram_w1 = nram_x + bin_hw_order_num; + T *nram_w2 = nram_w1 + bin_hw_order_num; + T *nram_w3 = nram_w2 + bin_hw_order_num; + T *nram_w4 = nram_w3 + bin_hw_order_num; + uint32_t *nram_pos1 = (uint32_t *)(nram_w4 + bin_hw_order_num); + uint32_t *nram_pos2 = nram_pos1 + bin_hw_order_num; + uint32_t *nram_pos3 = nram_pos2 + bin_hw_order_num; + uint32_t *nram_pos4 = nram_pos3 + bin_hw_order_num; + + uint32_t fixed_size = NFU_ALIGN_SIZE + cache_channels_num * sizeof(T) + + 8 * bin_hw_order_num * sizeof(T) + + 4 * bin_hw_order_num * sizeof(uint32_t); + uint32_t max_v_size = + FLOOR_ALIGN((MAX_NRAM_SIZE - fixed_size) / 3, NFU_ALIGN_SIZE); + + T *nram_vi = (T *)(nram_pos4 + bin_hw_order_num); + T *nram_vi_t = nram_vi + max_v_size / sizeof(T); + T *nram_val = nram_vi_t + max_v_size / sizeof(T); + + // If dynamic creation of sequences + bool construct_order = true; + // bin_order only needs to be calu once + if constexpr (sr_gt0) { + if (params.sample_ratio < bin_order_num) { + construct_order = false; + // construct bin_w_idx in bin_loop + __memcpy_async(bin_w_order, order, params.sample_ratio * sizeof(T), + NRAM2NRAM, params.sample_ratio * sizeof(T), 0, + params.sample_ratio - 1); + // construct bin_h_idx in bin_loop + __bang_write_value(nram_w1, params.sample_ratio, + (uint8_t)params.sample_ratio); + __extension(bin_h_order, order, (uint8_t *)nram_w1, sizeof(T), NRAM2NRAM, + params.sample_ratio); + } + } + + // roi_info only needs to be calculated once + uint32_t last_roi_idx = -1; + int roi_batch_ind = 0, roi_bin_grid_h = 0, roi_bin_grid_w = 0; + T roi_center_h = 0, roi_center_w = 0, bin_size_h = 0, bin_size_w = 0, + roi_start_h = 0, roi_start_w = 0, cos_theta = 0, sin_theta = 0, count = 0; + uint32_t pooled_num = params.pooled_height * params.pooled_width; + + // loop order [roi, pooled_height, pooled_width] + for (uint32_t bin_i = taskId; bin_i < rois_num * pooled_num; + bin_i += taskDim) { + uint32_t roi_idx = bin_i / pooled_num; + if (roi_idx != last_roi_idx) { + getRoiInfo(rois_dram, roi_idx, params, roi_batch_ind, + roi_center_h, roi_center_w, bin_size_h, bin_size_w, + roi_bin_grid_h, roi_bin_grid_w, roi_start_h, + roi_start_w, cos_theta, sin_theta, count); + last_roi_idx = roi_idx; + } + + uint32_t ph = bin_i % pooled_num / params.pooled_width; + uint32_t pw = bin_i % params.pooled_width; + T h_bias = roi_start_h + ph * bin_size_h; + T w_bias = roi_start_w + pw * bin_size_w; + + // channels max cache 1024 + for (uint32_t c_cache_i = 0; c_cache_i < channels; + c_cache_i += cache_channels_num) { + uint32_t cur_cache_c = cache_channels_num; + if (cur_cache_c + c_cache_i > channels) { + cur_cache_c = channels - c_cache_i; + } + __bang_write_zero(output_channels, cur_cache_c); + + for (uint32_t h_idx = 0; h_idx < roi_bin_grid_h; h_idx += bin_order_num) { + uint32_t deal_bin_h = bin_order_num; + if (h_idx + deal_bin_h > roi_bin_grid_h) { + deal_bin_h = roi_bin_grid_h - h_idx; + } + for (uint32_t w_idx = 0; w_idx < roi_bin_grid_w; + w_idx += bin_order_num) { + uint32_t deal_bin_w = bin_order_num; + if (w_idx + deal_bin_w > roi_bin_grid_w) { + deal_bin_w = roi_bin_grid_w - w_idx; + } + uint32_t deal_num = deal_bin_h * deal_bin_w; + + if (construct_order) { + // construct bin_w_idx in bin_loop + __memcpy_async(bin_w_order, order, deal_bin_w * sizeof(T), + NRAM2NRAM, deal_bin_w * sizeof(T), 0, + deal_bin_h - 1); + // construct bin_h_idx in bin_loop + __bang_write_value(nram_w1, deal_bin_h, (uint8_t)deal_bin_w); + __extension(bin_h_order, order, (uint8_t *)nram_w1, sizeof(T), + NRAM2NRAM, deal_bin_h); + } + + // getXYorder + getXYorder(bin_h_order, bin_w_order, (T)h_idx, (T)w_idx, deal_num, + nram_y, nram_x, nram_w1, nram_w2, nram_w3, bin_size_h, + bin_size_w, roi_bin_grid_h, roi_bin_grid_w, h_bias, w_bias, + roi_center_h, roi_center_w, cos_theta, sin_theta); + + uint32_t valid_num = 0; + selectValidPoint(height, width, nram_y, nram_x, deal_num, nram_w1, + nram_w2, nram_w3, valid_num); + if (valid_num == 0) { + continue; + } + + // bilinearInterpolate + uint32_t unique_num = 0; + bilinearInterpolatePosWeight(height, width, nram_y, nram_x, valid_num, + nram_pos1, nram_pos2, nram_pos3, + nram_pos4, nram_w1, nram_w2, nram_w3, + nram_w4, unique_num); + + __bang_mul_scalar(nram_pos1, nram_pos1, channels * sizeof(T), + unique_num); + __bang_mul_scalar(nram_pos2, nram_pos2, channels * sizeof(T), + unique_num); + __bang_mul_scalar(nram_pos3, nram_pos3, channels * sizeof(T), + unique_num); + __bang_mul_scalar(nram_pos4, nram_pos4, channels * sizeof(T), + unique_num); + + int pos_offset = 0; + uint32_t max_once_c = max_v_size / sizeof(T) / unique_num; + // Same coordinates in different channels + for (uint32_t ci = 0; ci < cur_cache_c; ci += max_once_c) { + uint32_t cur_c = max_once_c; + if (cur_c + ci > cur_cache_c) { + cur_c = cur_cache_c - ci; + } + handleChannels( + input_dram + roi_batch_ind * height * width * channels + + c_cache_i + ci, + cur_c, unique_num, nram_w1, nram_w2, nram_w3, nram_w4, + nram_pos1, nram_pos2, nram_pos3, nram_pos4, pos_offset, + nram_vi, nram_vi_t, nram_val); + __bang_add(output_channels + ci, output_channels + ci, nram_vi, + cur_c); + } + } + } + if constexpr (std::is_same::value) { + __asm__ volatile( + "div.scalar.nram.f16 [%[dst]], [%[src0]], " + "%[src1], %[num];\n\t" ::[dst] "r"(output_channels), + [ src0 ] "r"(output_channels), [ src1 ] "r"((T)count), + [ num ] "r"(cur_cache_c)); + } else { + __asm__ volatile( + "div.scalar.nram.f32 [%[dst]], [%[src0]], " + "%[src1], %[num];\n\t" ::[dst] "r"(output_channels), + [ src0 ] "r"(output_channels), [ src1 ] "r"((T)count), + [ num ] "r"(cur_cache_c)); + } + __memcpy(output_dram + bin_i * channels + c_cache_i, output_channels, + cur_cache_c * sizeof(T), NRAM2GDRAM); + } + } +#endif +} + +mluOpStatus_t KernelRoiAlignRotatedForwardVector( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + mluOpDataType_t d_type, const void *features, const void *rois, + const int batch, const int height, const int width, const int channel, + const int rois_num, const mluOpRoiAlignRotatedParams rroiAlignParams, + void *output) { + switch (d_type) { + case MLUOP_DTYPE_FLOAT: { + if (rroiAlignParams.sample_ratio > 0) { + KERNEL_CHECK(roiAlignRotatedForward + <<>>( + (float *)features, (float *)rois, batch, height, width, + channel, rois_num, rroiAlignParams, (float *)output)); + } else { + KERNEL_CHECK(roiAlignRotatedForward + <<>>( + (float *)features, (float *)rois, batch, height, width, + channel, rois_num, rroiAlignParams, (float *)output)); + } + } break; + case MLUOP_DTYPE_HALF: { + if (rroiAlignParams.sample_ratio > 0) { + KERNEL_CHECK(roiAlignRotatedForward + <<>>( + (half *)features, (half *)rois, batch, height, width, + channel, rois_num, rroiAlignParams, (half *)output)); + } else { + KERNEL_CHECK(roiAlignRotatedForward + <<>>( + (half *)features, (half *)rois, batch, height, width, + channel, rois_num, rroiAlignParams, (half *)output)); + } + } break; + default: + break; + } + return MLUOP_STATUS_SUCCESS; +} diff --git a/kernels/utils/common.h b/kernels/utils/common.h index f4ca0f2af..bceb8ccd4 100644 --- a/kernels/utils/common.h +++ b/kernels/utils/common.h @@ -534,23 +534,19 @@ __mlu_vector__ void __mlu_op_arange_vv_(T *dst_nram, T start_index, T step) { const uint32_t vv_num = __vv_get_length() / sizeof(T); #if _BANG_ARCH_ <= 592 - if (std::is_same::value) { - MLUOP_ARANGE_VV_IMPL(vv_uint16, vv_num, dst_nram, start_index, step); - } else if (std::is_same::value) { - MLUOP_ARANGE_VV_IMPL(vv_int16, vv_num, dst_nram, start_index, step); - } else if (std::is_same::value) { + if constexpr(std::is_same::value) { MLUOP_ARANGE_VV_IMPL(vv_uint32, vv_num, dst_nram, start_index, step); - } else if (std::is_same::value) { + } else if constexpr(std::is_same::value) { MLUOP_ARANGE_VV_IMPL(vv_int32, vv_num, dst_nram, start_index, step); } #endif // if _BANG_ARCH_ <= 592 - if (std::is_same::value) { + if constexpr(std::is_same::value) { MLUOP_ARANGE_VV_IMPL(vv_uint16, vv_num, dst_nram, start_index, step); - } else if (std::is_same::value) { + } else if constexpr(std::is_same::value) { MLUOP_ARANGE_VV_IMPL(vv_int16, vv_num, dst_nram, start_index, step); - } else if (std::is_same::value) { + } else if constexpr(std::is_same::value) { MLUOP_ARANGE_VV_IMPL(vv_float, vv_num, dst_nram, start_index, step); - } else if (std::is_same::value) { + } else if constexpr(std::is_same::value) { MLUOP_ARANGE_VV_IMPL(vv_half, vv_num, dst_nram, start_index, step); } return; From 273400c39031b006317f4dbc6feeb34ee00e2e17 Mon Sep 17 00:00:00 2001 From: chqy99 <1216494776@qq.com> Date: Tue, 19 Nov 2024 18:56:09 +0800 Subject: [PATCH 2/4] [Feature](mluOpRoiAlignRotatedForward): improve perf --- .../roi_align_rotated_forward_vector.md | 7 +- .../roi_align_rotated_forward_vector.mlu | 148 ++++++++---------- 2 files changed, 69 insertions(+), 86 deletions(-) diff --git a/docs/design_docs/roi_align_rotated/roi_align_rotated_forward_vector.md b/docs/design_docs/roi_align_rotated/roi_align_rotated_forward_vector.md index ccd3e4611..233aefd2f 100644 --- a/docs/design_docs/roi_align_rotated/roi_align_rotated_forward_vector.md +++ b/docs/design_docs/roi_align_rotated/roi_align_rotated_forward_vector.md @@ -71,6 +71,8 @@ - features, rois, output数据类型要相同。 +- 仅支持 5xx 以上系列 + ### 1.5 验收标准 #### 1.5.1 精度验收标准 @@ -177,11 +179,12 @@ bangC 实现为 [筛选有效点](#322-筛选有效点)。
在不支持 gather.vector 的机器上,一次最大只能取一个 channel,max_once_c = 1。
在支持 gather.vector 的机器上,unique * max_once_c 要不超过 NRAM 空间限制,详细见[拆分](#33-拆分)。
- 从 input 中根据坐标信息去取值(对性能影响最大的一步)
+双线性插值的四个点构成四个向量,将四个坐标向量合并成一个,将四个权重向量合并成一个。这种做法在 gather 时可以同时处理更多的数据,更容易将 outstanding 打满,有约 10%~20% 性能收益。
pos 需要变成字节偏移,即乘以 channels * sizeof(T)。
input 可能不按 64B 对齐,需要做对齐处理,pos 还需加上 input 对齐的偏移。
- 计算双线性插值结果 -取数后 v([unique_num, once_c]),w([unique_num]),要进行广播乘法(目前只能先转置再调用__bang_cycle_mul),得到 val([unique_num, once_c])。
-使用 __bang_sumpool 对 val 做累加得到 val_sum([once_c]) 。val_sum 加到 output_channels 中。
+取数后 v([unique_num, once_c]),w([unique_num]),要进行广播乘法(目前只能先转置再调用__bang_cycle_mul),得到 v * w([unique_num, once_c])。
+使用 __bang_sumpool 对 v * w 做累加得到 val_sum([once_c]) 。val_sum 加到 output_channels 中。
注:mmcv 的累加顺序一定是从前往后,而 sumpool 累加顺序不是,会使得精度有偏差,inf/nan 无法对齐。
diff --git a/kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu b/kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu index 0e5ffc64c..893bc9563 100644 --- a/kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu +++ b/kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu @@ -190,7 +190,7 @@ template __mlu_func__ void bilinearInterpolatePosWeight( const int height, const int width, T *nram_y, T *nram_x, const uint32_t valid_num, uint32_t *pos1, uint32_t *pos2, uint32_t *pos3, - uint32_t *pos4, T *w1, T *w2, T *w3, T *w4, uint32_t &unique_num) { + uint32_t *pos4, T *w1, T *w2, T *w3, T *w4) { for (uint32_t i = 0; i < valid_num; ++i) { T y = nram_y[i]; T x = nram_x[i]; @@ -263,9 +263,14 @@ __mlu_func__ void bilinearInterpolatePosWeight( __bang_mul(w3, nram_y, w4, valid_num); // w4[i] = ly * lx; __bang_mul(w4, nram_y, nram_x, valid_num); +} +template +__mlu_func__ void getUniquePos(const uint32_t valid_num, uint32_t *pos1, + uint32_t *pos2, uint32_t *pos3, uint32_t *pos4, + T *w1, T *w2, T *w3, T *w4, + uint32_t &unique_num) { // unique - unique_num = 0; for (int i = 0; i < valid_num; ++i) { if (w1[i] < 0) { continue; @@ -278,6 +283,10 @@ __mlu_func__ void bilinearInterpolatePosWeight( w3[i] += w3[j]; w4[i] += w4[j]; w1[j] = -1; + } else { + // if the dot is not at the same position, + // follow-up dots are not at the same positon + break; } } if (unique_num != i) { @@ -296,71 +305,29 @@ __mlu_func__ void bilinearInterpolatePosWeight( template __mlu_func__ void handleChannels(const T *input, uint32_t deal_channels, - uint32_t valid_num, T *w1, T *w2, T *w3, T *w4, - uint32_t *pos1, uint32_t *pos2, uint32_t *pos3, - uint32_t *pos4, int32_t &pos_offset, T *v, - T *v_t, T *val) { + uint32_t vec_num, T *w, uint32_t *pos, + int32_t &pos_offset, T *val, T *v_t) { // gather dst,src,offset must 64B align int32_t align_offset = ((int64_t)input & 0x3f); if (align_offset != 0) { input = (T *)((int64_t)input - align_offset); } if (align_offset != pos_offset) { - __bang_add_scalar(pos1, pos1, align_offset - pos_offset, valid_num); - __bang_add_scalar(pos2, pos2, align_offset - pos_offset, valid_num); - __bang_add_scalar(pos3, pos3, align_offset - pos_offset, valid_num); - __bang_add_scalar(pos4, pos4, align_offset - pos_offset, valid_num); + __bang_add_scalar(pos, pos, align_offset - pos_offset, vec_num); pos_offset = align_offset; } - uint32_t hwc_num = deal_channels * valid_num; - - __gather(v, input, pos1, deal_channels * sizeof(T), GDRAM2NRAM, - deal_channels * sizeof(T), valid_num); - if (deal_channels == 1) { - __bang_mul(val, v, w1, valid_num); - } else { - __bang_transpose(val, v, valid_num, deal_channels); - __bang_cycle_mul(val, val, w1, hwc_num, valid_num); - } - - __gather(v, input, pos2, deal_channels * sizeof(T), GDRAM2NRAM, - deal_channels * sizeof(T), valid_num); - if (deal_channels == 1) { - __bang_mul(v_t, v, w2, valid_num); - } else { - __bang_transpose(v_t, v, valid_num, deal_channels); - __bang_cycle_mul(v_t, v_t, w2, hwc_num, valid_num); - } - __bang_add(val, val, v_t, hwc_num); - - __gather(v, input, pos3, deal_channels * sizeof(T), GDRAM2NRAM, - deal_channels * sizeof(T), valid_num); - if (deal_channels == 1) { - __bang_mul(v_t, v, w3, valid_num); - } else { - __bang_transpose(v_t, v, valid_num, deal_channels); - __bang_cycle_mul(v_t, v_t, w3, hwc_num, valid_num); - } - __bang_add(val, val, v_t, hwc_num); - - __gather(v, input, pos4, deal_channels * sizeof(T), GDRAM2NRAM, - deal_channels * sizeof(T), valid_num); - if (deal_channels == 1) { - __bang_mul(v_t, v, w4, valid_num); - } else { - __bang_transpose(v_t, v, valid_num, deal_channels); - __bang_cycle_mul(v_t, v_t, w4, hwc_num, valid_num); - } - __bang_add(val, val, v_t, hwc_num); + uint32_t hwc_num = deal_channels * vec_num; + __gather(val, input, pos, deal_channels * sizeof(T), GDRAM2NRAM, + deal_channels * sizeof(T), vec_num); if (deal_channels != 1) { - __bang_transpose(v_t, val, deal_channels, valid_num); - __bang_sumpool(v, v_t, deal_channels, 1, valid_num, 1, valid_num, - valid_num, 1); + __bang_transpose(v_t, val, vec_num, deal_channels); + __bang_cycle_mul(v_t, v_t, w, hwc_num, vec_num); + __bang_transpose(val, v_t, deal_channels, vec_num); } else { - __bang_sumpool(v, val, deal_channels, 1, valid_num, 1, valid_num, - valid_num, 1); + __bang_mul(val, val, w, vec_num); } + __bang_sumpool(v_t, val, deal_channels, 1, vec_num, 1, vec_num, vec_num, 1); } template @@ -394,10 +361,10 @@ __mlu_global__ void roiAlignRotatedForward( T *output_channels = order + NFU_ALIGN_SIZE / sizeof(T); /* bilinear_interpolate, coordinates can be reused - | | | volatile | - | name | bin_h,bin_w,y,x,w1,w2,w3,w4 | p1,p2,p3,p4 | v1,v2,v3 | - | type | T | uint32_t | T | - | num | bin_hw_order_num | | + | | | volatile | + | name | bin_h,bin_w,y,x,w1,w2,w3,w4 | p1,p2,p3,p4 | val, val_t | + | type | T | uint32_t | T | + | num | bin_hw_order_num | | */ T *bin_h_order = output_channels + cache_channels_num; T *bin_w_order = bin_h_order + bin_hw_order_num; @@ -416,11 +383,10 @@ __mlu_global__ void roiAlignRotatedForward( 8 * bin_hw_order_num * sizeof(T) + 4 * bin_hw_order_num * sizeof(uint32_t); uint32_t max_v_size = - FLOOR_ALIGN((MAX_NRAM_SIZE - fixed_size) / 3, NFU_ALIGN_SIZE); + FLOOR_ALIGN((MAX_NRAM_SIZE - fixed_size) / 2, NFU_ALIGN_SIZE); - T *nram_vi = (T *)(nram_pos4 + bin_hw_order_num); - T *nram_vi_t = nram_vi + max_v_size / sizeof(T); - T *nram_val = nram_vi_t + max_v_size / sizeof(T); + T *nram_val = (T *)(nram_pos4 + bin_hw_order_num); + T *nram_val_t = nram_val + max_v_size / sizeof(T); // If dynamic creation of sequences bool construct_order = true; @@ -511,36 +477,50 @@ __mlu_global__ void roiAlignRotatedForward( } // bilinearInterpolate - uint32_t unique_num = 0; - bilinearInterpolatePosWeight(height, width, nram_y, nram_x, valid_num, - nram_pos1, nram_pos2, nram_pos3, - nram_pos4, nram_w1, nram_w2, nram_w3, - nram_w4, unique_num); + bilinearInterpolatePosWeight( + height, width, nram_y, nram_x, valid_num, nram_pos1, nram_pos2, + nram_pos3, nram_pos4, nram_w1, nram_w2, nram_w3, nram_w4); + // pos de-duplication + uint32_t unique_num = 0; + getUniquePos(valid_num, nram_pos1, nram_pos2, nram_pos3, nram_pos4, + nram_w1, nram_w2, nram_w3, nram_w4, unique_num); + + // Combine Four Discrete Data Sets into One + if (bin_hw_order_num != unique_num) { + __sync(); + __memcpy_async(nram_w1 + unique_num, nram_w2, + unique_num * sizeof(T), NRAM2NRAM); + __memcpy_async(nram_pos1 + unique_num, nram_pos2, + unique_num * sizeof(uint32_t), NRAM2NRAM); + __memcpy_async(nram_w1 + 2 * unique_num, nram_w3, + unique_num * sizeof(T), NRAM2NRAM); + __memcpy_async(nram_pos1 + 2 * unique_num, nram_pos3, + unique_num * sizeof(uint32_t), NRAM2NRAM); + __memcpy_async(nram_w1 + 3 * unique_num, nram_w4, + unique_num * sizeof(T), NRAM2NRAM); + __memcpy_async(nram_pos1 + 3 * unique_num, nram_pos4, + unique_num * sizeof(uint32_t), NRAM2NRAM); + __sync(); + } + uint32_t vec_num = 4 * unique_num; __bang_mul_scalar(nram_pos1, nram_pos1, channels * sizeof(T), - unique_num); - __bang_mul_scalar(nram_pos2, nram_pos2, channels * sizeof(T), - unique_num); - __bang_mul_scalar(nram_pos3, nram_pos3, channels * sizeof(T), - unique_num); - __bang_mul_scalar(nram_pos4, nram_pos4, channels * sizeof(T), - unique_num); + vec_num); int pos_offset = 0; - uint32_t max_once_c = max_v_size / sizeof(T) / unique_num; - // Same coordinates in different channels + uint32_t max_once_c = max_v_size / sizeof(T) / vec_num; + + // Same coordinates in different channels for (uint32_t ci = 0; ci < cur_cache_c; ci += max_once_c) { uint32_t cur_c = max_once_c; if (cur_c + ci > cur_cache_c) { cur_c = cur_cache_c - ci; } - handleChannels( - input_dram + roi_batch_ind * height * width * channels + - c_cache_i + ci, - cur_c, unique_num, nram_w1, nram_w2, nram_w3, nram_w4, - nram_pos1, nram_pos2, nram_pos3, nram_pos4, pos_offset, - nram_vi, nram_vi_t, nram_val); - __bang_add(output_channels + ci, output_channels + ci, nram_vi, + handleChannels(input_dram + c_cache_i + ci + + roi_batch_ind * height * width * channels, + cur_c, vec_num, nram_w1, nram_pos1, pos_offset, + nram_val, nram_val_t); + __bang_add(output_channels + ci, output_channels + ci, nram_val_t, cur_c); } } From 63caf1a6f3ae30adc6f3b9764513bfec3aa6ca7e Mon Sep 17 00:00:00 2001 From: chqy99 <1216494776@qq.com> Date: Mon, 25 Nov 2024 10:30:37 +0800 Subject: [PATCH 3/4] [Feature](mluOpRoiAlignRotatedForward): package div_scalar --- .../roi_align_rotated_forward_vector.mlu | 57 +++++++------------ 1 file changed, 21 insertions(+), 36 deletions(-) diff --git a/kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu b/kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu index 893bc9563..217d4df28 100644 --- a/kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu +++ b/kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu @@ -32,6 +32,23 @@ __nram__ int8_t nram_buffer[MAX_NRAM_SIZE]; #define ROI_OFFSET 6 +template +__mlu_func__ void mluopDivScalar(T *dst, T *src, T value, uint32_t num) { + if constexpr (std::is_same::value) { + __asm__ volatile( + "div.scalar.nram.f16 [%[dst]], [%[src0]], " + "%[src1], %[num];\n\t" ::[dst] "r"(dst), + [ src0 ] "r"(src), [ src1 ] "r"(value), + [ num ] "r"(num)); + } else { + __asm__ volatile( + "div.scalar.nram.f32 [%[dst]], [%[src0]], " + "%[src1], %[num];\n\t" ::[dst] "r"(dst), + [ src0 ] "r"(src), [ src1 ] "r"(value), + [ num ] "r"(num)); + } +} + template __mlu_func__ void getRoiInfo(const T *rois_dram, int roi_idx, const mluOpRoiAlignRotatedParams ¶ms, @@ -103,29 +120,8 @@ __mlu_func__ void getXYorder(T *bh_order, T *bw_order, T h_offset, T w_offset, __bang_mul_scalar(aux2, aux2, bin_size_w, deal_num); // Coordinate calculation requires high precision. // must use div - if constexpr (std::is_same::value) { - __asm__ volatile( - "div.scalar.nram.f16 [%[dst]], [%[src0]], " - "%[src1], %[num];\n\t" ::[dst] "r"(aux1), - [ src0 ] "r"(aux1), [ src1 ] "r"((T)roi_bin_grid_h), - [ num ] "r"(deal_num)); - __asm__ volatile( - "div.scalar.nram.f16 [%[dst]], [%[src0]], " - "%[src1], %[num];\n\t" ::[dst] "r"(aux2), - [ src0 ] "r"(aux2), [ src1 ] "r"((T)roi_bin_grid_w), - [ num ] "r"(deal_num)); - } else { - __asm__ volatile( - "div.scalar.nram.f32 [%[dst]], [%[src0]], " - "%[src1], %[num];\n\t" ::[dst] "r"(aux1), - [ src0 ] "r"(aux1), [ src1 ] "r"((T)roi_bin_grid_h), - [ num ] "r"(deal_num)); - __asm__ volatile( - "div.scalar.nram.f32 [%[dst]], [%[src0]], " - "%[src1], %[num];\n\t" ::[dst] "r"(aux2), - [ src0 ] "r"(aux2), [ src1 ] "r"((T)roi_bin_grid_w), - [ num ] "r"(deal_num)); - } + mluopDivScalar(aux1, aux1, (T)roi_bin_grid_h, deal_num); + mluopDivScalar(aux2, aux2, (T)roi_bin_grid_w, deal_num); __bang_add_scalar(aux1, aux1, h_bias, deal_num); __bang_add_scalar(aux2, aux2, w_bias, deal_num); @@ -525,19 +521,8 @@ __mlu_global__ void roiAlignRotatedForward( } } } - if constexpr (std::is_same::value) { - __asm__ volatile( - "div.scalar.nram.f16 [%[dst]], [%[src0]], " - "%[src1], %[num];\n\t" ::[dst] "r"(output_channels), - [ src0 ] "r"(output_channels), [ src1 ] "r"((T)count), - [ num ] "r"(cur_cache_c)); - } else { - __asm__ volatile( - "div.scalar.nram.f32 [%[dst]], [%[src0]], " - "%[src1], %[num];\n\t" ::[dst] "r"(output_channels), - [ src0 ] "r"(output_channels), [ src1 ] "r"((T)count), - [ num ] "r"(cur_cache_c)); - } + mluopDivScalar(output_channels, output_channels, (T)count, + cur_cache_c); __memcpy(output_dram + bin_i * channels + c_cache_i, output_channels, cur_cache_c * sizeof(T), NRAM2GDRAM); } From 84fb471e024c05e1544dc628e81840bd4a3ccffb Mon Sep 17 00:00:00 2001 From: chqy99 <1216494776@qq.com> Date: Mon, 25 Nov 2024 11:08:59 +0800 Subject: [PATCH 4/4] [Feature](mluOpRoiAlignRotatedForward): fix comments --- .../roi_align_rotated_forward_vector.mlu | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu b/kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu index 217d4df28..d226df82c 100644 --- a/kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu +++ b/kernels/roi_align_rotated/roi_align_rotated_forward_vector.mlu @@ -130,14 +130,14 @@ __mlu_func__ void getXYorder(T *bh_order, T *bw_order, T h_offset, T w_offset, // x = h_idx_in_bin * sin_theta + // w_idx_in_bin * cos_theta + roi_center_w if constexpr (std::is_same::value) { - // calu y + // calculate y __bang_mul_scalar(aux3, aux2, sin_theta, deal_num); __asm__ volatile( "fuse.nram.f16 [%[dst]], %[num], [%[src0]], .mul(%[cos_v]), " ".sub([%[src1]]), .add(%[rh]);\n\t" ::[dst] "r"(nram_y), [ num ] "r"(deal_num), [ src0 ] "r"(aux1), [ cos_v ] "r"(cos_theta), [ src1 ] "r"(aux3), [ rh ] "r"(roi_center_h)); - // calu x + // calculate x __bang_mul_scalar(aux3, aux2, cos_theta, deal_num); __asm__ volatile( "fuse.nram.f16 [%[dst]], %[num], [%[src0]], .mul(%[sin_v]), " @@ -145,14 +145,14 @@ __mlu_func__ void getXYorder(T *bh_order, T *bw_order, T h_offset, T w_offset, [ num ] "r"(deal_num), [ src0 ] "r"(aux1), [ sin_v ] "r"(sin_theta), [ src1 ] "r"(aux3), [ rw ] "r"(roi_center_w)); } else { - // calu y + // calculate y __bang_mul_scalar(aux3, aux2, sin_theta, deal_num); __asm__ volatile( "fuse.nram.f32 [%[dst]], %[num], [%[src0]], .mul(%[cos_v]), " ".sub([%[src1]]), .add(%[rh]);\n\t" ::[dst] "r"(nram_y), [ num ] "r"(deal_num), [ src0 ] "r"(aux1), [ cos_v ] "r"(cos_theta), [ src1 ] "r"(aux3), [ rh ] "r"(roi_center_h)); - // calu x + // calculate x __bang_mul_scalar(aux3, aux2, cos_theta, deal_num); __asm__ volatile( "fuse.nram.f32 [%[dst]], %[num], [%[src0]], .mul(%[sin_v]), " @@ -218,7 +218,7 @@ __mlu_func__ void bilinearInterpolatePosWeight( } // w1 cache x_low, w3 cache x_high - // pos1 cache y_low, pos2 cache y_high + // pos1 cache y_low, pos3 cache y_high ((uint32_t *)w1)[i] = x_low; ((uint32_t *)w3)[i] = x_high; pos1[i] = y_low; @@ -237,7 +237,7 @@ __mlu_func__ void bilinearInterpolatePosWeight( // pos2 = y_low * width + x_high; __bang_add(pos2, pos2, (uint32_t *)w3, valid_num); - // pos3 = y_high * width; + // pos4 = y_high * width; __bang_mul_scalar(pos4, pos3, width, valid_num); // pos3 = y_high * width + x_low; __bang_add(pos3, pos4, (uint32_t *)w1, valid_num); @@ -251,13 +251,13 @@ __mlu_func__ void bilinearInterpolatePosWeight( __bang_mul_scalar(w4, nram_x, -1.0, valid_num); __bang_add_scalar(w4, w4, 1.0, valid_num); - // w1[i] = hy * hx; + // w1 = hy * hx; __bang_mul(w1, w3, w4, valid_num); - // w2[i] = hy * lx; + // w2 = hy * lx; __bang_mul(w2, w3, nram_x, valid_num); - // w3[i] = ly * hx; + // w3 = ly * hx; __bang_mul(w3, nram_y, w4, valid_num); - // w4[i] = ly * lx; + // w4 = ly * lx; __bang_mul(w4, nram_y, nram_x, valid_num); }