Skip to content

Commit

Permalink
[Feature](bangc-ops): add concat binary operator.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengleiZL committed Oct 25, 2023
1 parent 232ec2c commit 7840df0
Show file tree
Hide file tree
Showing 9 changed files with 596 additions and 100 deletions.
41 changes: 41 additions & 0 deletions bangc-ops/kernels/concat/concat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*************************************************************************
* Copyright (C) [2023] by Cambricon, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a
* copy of this software and associated documentation files (the
* "Software"), to deal in the Software without restriction, including
* without limitation the rights to use, copy, modify, merge, publish,
* distribute, sublicense, and/or sell copies of the Software, and to
* permit persons to whom the Software is furnished to do so, subject to
* the following conditions:
*
* The above copyright notice and this permission notice shall be included
* in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernels/kernel_wrapper/wrapper.h"

mluOpStatus_t MLUOP_WIN_API mluOpConcat(
mluOpHandle_t handle,
const int concat_num,
const int axis,
const mluOpTensorDescriptor_t inputs_desc[],
const void *const inputs[],
void *workspace,
size_t workspace_size,
const mluOpTensorDescriptor_t output_desc,
void *output) {
ConcatWrapper wrapper;
mluOpStatus_t ret = wrapper.invoke(handle, concat_num, axis, inputs_desc,
inputs, workspace, workspace_size,
output_desc, output);
return ret;
}

Binary file modified bangc-ops/kernels/kernel_wrapper/lib/libextops.a
Binary file not shown.
159 changes: 59 additions & 100 deletions bangc-ops/kernels/kernel_wrapper/wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@
#include "mlu_op.h"
#include "export_statement.h"

#define KERNEL_REGISTER(OP_NAME, PARAMS, ...) \
class OP_NAME##Wrapper { \
public: \
OP_NAME##Wrapper() {} \
~OP_NAME##Wrapper() {} \
mluOpStatus_t invoke(PARAMS); \
std::string op_name = #OP_NAME; \
};
#define KERNEL_REGISTER(OP_NAME, PARAMS, ...) \
class OP_NAME##Wrapper { \
public: \
OP_NAME##Wrapper() {} \
~OP_NAME##Wrapper() {} \
mluOpStatus_t invoke(PARAMS); \
std::string op_name = #OP_NAME; \
};

/* Kernel param types macro defination */

Expand Down Expand Up @@ -148,74 +148,38 @@
const mluOpTensorDescriptor_t, void *, void *, size_t

#define REDUCE_PARAM_TYPE \
mluOpHandle_t, const mluOpReduceDescriptor_t, void *, size_t, \
const void *, const mluOpTensorDescriptor_t, const void *, \
const size_t, void *, const void *, const mluOpTensorDescriptor_t, \
void *

#define ROIALIGNBACKWARD_PARAM_TYPE \
mluOpHandle_t, \
const float, \
const int, \
const bool, \
const mluOpTensorDescriptor_t, \
const void *, \
const mluOpTensorDescriptor_t, \
const void *, \
const mluOpTensorDescriptor_t, \
void *

#define ROIALIGNBACKWARD_V2_PARAM_TYPE \
mluOpHandle_t, \
const mluOpTensorDescriptor_t, \
const void *, \
const mluOpTensorDescriptor_t, \
const void *, \
const mluOpTensorDescriptor_t, \
const void *, \
const mluOpTensorDescriptor_t, \
const void *, \
const float, \
const int, \
const bool, \
const int, \
const mluOpTensorDescriptor_t, \
void *

#define ROIPOOLINGFORWARD_PARAM_TYPE \
mluOpHandle_t, \
mluOpPoolingMode_t, \
const mluOpTensorDescriptor_t, \
const void *, \
const mluOpTensorDescriptor_t, \
const void *, \
float, \
const mluOpTensorDescriptor_t, \
void *, \
int *
mluOpHandle_t, const mluOpReduceDescriptor_t, void *, size_t, const void *, \
const mluOpTensorDescriptor_t, const void *, const size_t, void *, \
const void *, const mluOpTensorDescriptor_t, void *

#define ROIALIGNBACKWARD_PARAM_TYPE \
mluOpHandle_t, const float, const int, const bool, \
const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, void *

#define ROIPOOLINGBACKWARD_PARAM_TYPE \
mluOpHandle_t, \
mluOpPoolingMode_t, \
const mluOpTensorDescriptor_t, \
const void *, \
const mluOpTensorDescriptor_t, \
const void *, \
const mluOpTensorDescriptor_t, \
const int *, \
const float, \
const mluOpTensorDescriptor_t, \
void *
#define ROIALIGNBACKWARD_V2_PARAM_TYPE \
mluOpHandle_t, const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, const void *, const float, const int, \
const bool, const int, const mluOpTensorDescriptor_t, void *

#define ROIPOOLINGFORWARD_PARAM_TYPE \
mluOpHandle_t, mluOpPoolingMode_t, const mluOpTensorDescriptor_t, \
const void *, const mluOpTensorDescriptor_t, const void *, float, \
const mluOpTensorDescriptor_t, void *, int *

#define ROIPOOLINGBACKWARD_PARAM_TYPE \
mluOpHandle_t, mluOpPoolingMode_t, const mluOpTensorDescriptor_t, \
const void *, const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, const int *, const float, \
const mluOpTensorDescriptor_t, void *

#define TRANSFORM_PARAM_TYPE \
mluOpHandle_t, \
const mluOpPointerMode_t, \
const void *, \
const mluOpTensorDescriptor_t, \
const void *, \
const void *, \
const mluOpTensorDescriptor_t, \
void *
#define TRANSFORM_PARAM_TYPE \
mluOpHandle_t, const mluOpPointerMode_t, const void *, \
const mluOpTensorDescriptor_t, const void *, const void *, \
const mluOpTensorDescriptor_t, void *

#define SYNCBATCHNORMSTATS_PARAM_TYPE \
mluOpHandle_t, const mluOpTensorDescriptor_t, const void *, const float, \
Expand Down Expand Up @@ -273,31 +237,25 @@
const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, void *

#define SYNCBATCHNORMBACKWARDELEMT_V2_PARAM_TYPE \
mluOpHandle_t, const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, \
const void *, const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, const void *, \
#define SYNCBATCHNORMBACKWARDELEMT_V2_PARAM_TYPE \
mluOpHandle_t, const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, const void *, \
const mluOpTensorDescriptor_t, void *diff_x

#define TRANSFORM_PARAM_TYPE \
mluOpHandle_t, \
const mluOpPointerMode_t, \
const void *, \
const mluOpTensorDescriptor_t, \
const void *, \
const void *, \
const mluOpTensorDescriptor_t, \
void *
#define STRIDEDSLICE_PARAM_TYPE \
mluOpHandle_t, const mluOpTensorDescriptor_t, const void *, const int *, \
const int *, const int *, const mluOpTensorDescriptor_t, void *

#define STRIDEDSLICE_PARAM_TYPE \
mluOpHandle_t, const mluOpTensorDescriptor_t, const void *, \
const int *, const int *, const int *, \
const mluOpTensorDescriptor_t, void *
#define CONCAT_PARAM_TYPE \
mluOpHandle_t, const int, const int, const mluOpTensorDescriptor_t[], \
const void *const *, void *, size_t, const mluOpTensorDescriptor_t, \
void *

/* Kernel register */
KERNEL_REGISTER(addN, ADDN_PARAM_TYPE);
Expand Down Expand Up @@ -331,13 +289,14 @@ KERNEL_REGISTER(SyncBatchNormGatherStatsWithCounts,
SYNCBATCHNORMGATHERSTATSWITHCOUNTS_PARAM_TYPE);
KERNEL_REGISTER(SyncBatchNormElemt, SYNCBATCHNORMELEMT_PARAM_TYPE);
KERNEL_REGISTER(SyncBatchnormBackwardReduce,
SYNCBATCHNORMBACKWADREDUCE_PARAM_TYPE);
SYNCBATCHNORMBACKWADREDUCE_PARAM_TYPE);
KERNEL_REGISTER(SyncBatchnormBackwardReduceV2,
SYNCBATCHNORMBACKWADREDUCE_V2_PARAM_TYPE);
SYNCBATCHNORMBACKWADREDUCE_V2_PARAM_TYPE);
KERNEL_REGISTER(SyncBatchNormBackwardElemt,
SYNCBATCHNORMBACKWARDELEMT_PARAM_TYPE);
SYNCBATCHNORMBACKWARDELEMT_PARAM_TYPE);
KERNEL_REGISTER(SyncBatchNormBackwardElemtV2,
SYNCBATCHNORMBACKWARDELEMT_V2_PARAM_TYPE);
SYNCBATCHNORMBACKWARDELEMT_V2_PARAM_TYPE);
KERNEL_REGISTER(transform, TRANSFORM_PARAM_TYPE);
KERNEL_REGISTER(StridedSlice, STRIDEDSLICE_PARAM_TYPE);
KERNEL_REGISTER(Concat, CONCAT_PARAM_TYPE);
#endif // KERNELS_KERNEL_WRAPPER_WRAPPER_H
Loading

0 comments on commit 7840df0

Please sign in to comment.