Skip to content

Commit

Permalink
[MINOR] Type inference for conv2d() forward pass
Browse files Browse the repository at this point in the history
This change adds type inference to the conv2d operation. It merely passes the input type of the matrix to the output type and sets the dimensions to the index/size type.
  • Loading branch information
corepointer committed Apr 26, 2024
1 parent 25f2a3c commit 39cd4e6
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
8 changes: 8 additions & 0 deletions src/ir/daphneir/DaphneInferTypesOpInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,14 @@ std::vector<Type> daphne::RecodeOp::inferTypes() {
return {resTy, dictTy};
}

std::vector<Type> daphne::Conv2DForwardOp::inferTypes() {
MLIRContext * ctx = getContext();
Type srcTy = getInput().getType().dyn_cast<daphne::MatrixType>();
Builder builder(ctx);

// output matrix of same type as input, height/width dimensions as size/index type
return {srcTy, builder.getIndexType(), builder.getIndexType()};
}
// ****************************************************************************
// Type inference function
// ****************************************************************************
Expand Down
3 changes: 2 additions & 1 deletion src/ir/daphneir/DaphneOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,8 @@ def Daphne_BiasAddForwardOp : Daphne_Op<"biasAddForward", [ CUDASupport ]> {
// Convolution
// ----------------------------------------------------------------------------

def Daphne_Conv2DForwardOp : Daphne_Op<"Convolution_Forward", [ CUDASupport ]> {
def Daphne_Conv2DForwardOp : Daphne_Op<"Convolution_Forward",
[ DeclareOpInterfaceMethods<InferTypesOpInterface>, CUDASupport ]> {
let arguments = (ins
MatrixOf<[FloatScalar]>:$input, MatrixOf<[FloatScalar]>:$filter, MatrixOf<[FloatScalar]>:$bias,
// input shape
Expand Down

0 comments on commit 39cd4e6

Please sign in to comment.