-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP]add mluop cholesky #1146
base: master
Are you sure you want to change the base?
[WIP]add mluop cholesky #1146
Conversation
} | ||
|
||
__mlu_global__ void inverse_kernel(int batch, float* d_input, int ld_input, | ||
int stride_input, float* d_output, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.mlu文件中的一些共性问题:
1.关键步骤缺少不要的注释
2.sync_cluster存在多用的问题,建议对于每个sync和sync_cluster加上必要的注释说明同步了什么操作,目的是啥
3.调用cnnl代码的逻辑不要放到.mlu中,.mlu文件本质上是deivce上的函数,调用cnnl的接口是个host侧行为
4.不要自己创建cnrtQueue,统一使用外部传入的handle->queue
5.涉及到cnrtDim的,用policyFunc函数封装
6.变量命名不清晰,建议不要缩写名字,提升可读性
const int lda, int width, float* sram_buffer, | ||
float* dst) { | ||
int id = taskId % 4; | ||
int span = CPOTF_NB; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.mlu文件中的一些共性问题:
1.关键步骤缺少必要的注释
2.sync_cluster存在多用的问题,建议对于每个sync和sync_cluster加上必要的注释说明同步了什么操作,目的是啥
3.调用cnnl代码的逻辑不要放到.mlu中,.mlu文件本质上是deivce上的函数,调用cnnl的接口是个host侧行为
4.不要自己创建cnrtQueue,统一使用外部传入的handle->queue
5.涉及到cnrtDim的,用policyFunc函数封装
6.变量命名不清晰,建议不要缩写名字,提升可读性
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
*************************************************************************/ | ||
|
||
#include "cholesky.h" | ||
#define COMPLEX_OFFSET(A, off) (((float*)A) + (2 * (off))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的目的是啥,建议加上注释
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
kernels/cholesky/cholesky.h
Outdated
#define CLUSTER_NUM 1 | ||
#define M (TASK_NUM * POTF_NB) | ||
#define ZERO 0.0 | ||
#define SHARED_MEM_SIZE (((M * POTF_NB / TASK_NUM * 4) + (POTF_NB * POTF_NB))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议加上注释,说明空间是怎么使用的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
#define M (TASK_NUM * POTF_NB) | ||
#define ZERO 0.0 | ||
#define SHARED_MEM_SIZE (((M * POTF_NB / TASK_NUM * 4) + (POTF_NB * POTF_NB))) | ||
#define OFFSET_ROW(A, i, j) A + ((i) * (lda) + (j)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议加上注释,说明这些offset宏的目的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
测试报告最后也更新下吧,另外测试报告中性能部分建议测试下float/complex upper=false和upper=true的性能 |
随机测试时较多case会出现精度问题&coredum,建议使用下面的脚本生成下case并做自我测试
|
另外针对用户感知到的一些tensor信息,如下所列,支持的做下测试,不支持的可以参考下其他算子做好参数拦截 |
测试的generator代码中也有问题,会有下面的问题,也请修复下 |
请问具体是在哪个规模下出现了精度问题或者coredump问题呢,麻烦举出一些例子我优先复现然后修复 |
出错的case: |
kernels/cholesky/cholesky_union1.mlu
Outdated
if (batch == 1) { | ||
func_type = CNRT_FUNC_TYPE_UNION1; | ||
} else if (batch == 2) { | ||
func_type = CNRT_FUNC_TYPE_UNION2; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
板卡上不一定有这个类型,建议参考这里进行设置:
Line 191 in 5ae8c94
*k_type = mluop::runtime::getJobLimitCapabilityCnrtFuncType(handle); |
mlu-ops/kernels/fft/c2c_fft/c2c_fft_host.cpp
Line 1668 in 5ae8c94
int task_type = mluop::runtime::getJobLimitCapability(handle); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改为U1类型
kernels/cholesky/cholesky.cpp
Outdated
type_size * size_a * lda * ((uint64_t)batch_size - 16), | ||
CNRT_MEM_TRANS_DIR_DEV2DEV)); | ||
} else { | ||
CNRT_CHECK(cnrtMemcpy(d_output, workspace, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不建议使用cnrtMemcpy和cnrtMemset,cnrtQueueSync,会对上层使用mlu_graph有问题
建议cnrtMemcpy使用片上的__memcpy来替换
cnrtMemset使用片上设置数据来替换
cnrtQueueSync可以去掉,对于同一个queue来说,queue内的kernel调用(使用<<<>>>)是串行的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
kernels/cholesky/cholesky_union1.mlu
Outdated
func_type = CNRT_FUNC_TYPE_UNION4; | ||
carry_batch = 4; | ||
} else { | ||
func_type = CNRT_FUNC_TYPE_UNION8; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里要根据板卡的实际最大cluster数目来,这里写死了U8,有些板卡没有U8这个类型
可以参考这里的写法
mlu-ops/kernels/fft/rfft/rfft_host.cpp
Line 1032 in 662a162
int task_type = mluop::runtime::getJobLimitCapability(handle); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
其他类似的写死U8的地方也请一起修改下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
|
||
|
||
|
||
if (result_mul) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
结果验收上参考svd,需要验收输出结果L或者U, L@LT,以及output结果是下三角或者上三角,当前的处理,只处理了第一种方式,需要增加另外两种的测试。
关于结果这块的比较上:当前result_mul默认是false,只测试了结果的上下三角这块
同一个case还需要同时测试result_mul是true时,结果的还原性
另外还需要增加测试结果一定是上三角或者下三角的测试,这个可以参考https://github.com/pytorch/pytorch/blob/main/test/test_linalg.py#L622
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
另外generator的逻辑也麻烦根据上面的comments做下update
void cpu_compute(float* cpu_c, int n_, int ldda_, bool upper_, bool trans_, | ||
mluOpDataType_t type_) { | ||
if (trans_) { | ||
for (int64_t i = 0; i < n_; i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cpu计算过程这里加上关键的计算步骤吧,方便后续维护和阅读
if (parser_->device() != CPU) { | ||
if (result_mul) { | ||
for (int i = 0; i < batch_size_; i++) { | ||
if (type_ == MLUOP_DTYPE_FLOAT) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里做的trans和fill_zeo,设置1加下注释说明下意图,方便理解和后续维护
Thanks for your contribution and we appreciate it a lot. 🚀🚀
1. Motivation
Please describe your motivation and the goal you want to achieve through this pull request.
2. Modification
Please briefly describe what modification is made in this pull request, and indicate where to make the modification.
Are new test cases added? If so, please post the corresponding generator-PR link here.
3. Test Report
If you want to know how to do operator testing, you can see GTest-User-Guide-zh.
3.1 Modification Details
3.1.1 Accuracy Acceptance Standard
For static threshold standard details, see: MLU-OPS™ Accuracy Acceptance Standard.
3.1.2 Operator Scheme checklist
3.2 Accuracy Test
3.2.1 Accuracy Test
If you have checked the following items, please tick the relevant box.
3.2.2 Parameter Check
Test Point-1:
When a new operator is submitted, the test points are given and the test results are stated
. Acceptance Standard:Normal error
.Test Point-2:
Whether illegal parameters are passed
. Acceptance Standard:Normal error
.3.3 Performance Test
See MLU-OPS™ Performance Acceptance Standard for details.
Platform:MLU370
Platform:MLU590
3.4 Summary Analysis
Please give a brief overview here, if you want to note and summarize the content.