diff --git a/src/ir/daphneir/DaphneInferTypesOpInterface.cpp b/src/ir/daphneir/DaphneInferTypesOpInterface.cpp index 290ee4818..3160a1bf7 100644 --- a/src/ir/daphneir/DaphneInferTypesOpInterface.cpp +++ b/src/ir/daphneir/DaphneInferTypesOpInterface.cpp @@ -544,6 +544,14 @@ std::vector daphne::RecodeOp::inferTypes() { return {resTy, dictTy}; } +std::vector daphne::Conv2DForwardOp::inferTypes() { + MLIRContext * ctx = getContext(); + Type srcTy = getInput().getType().dyn_cast(); + Builder builder(ctx); + + // output matrix of same type as input, height/width dimensions as size/index type + return {srcTy, builder.getIndexType(), builder.getIndexType()}; +} // **************************************************************************** // Type inference function // **************************************************************************** diff --git a/src/ir/daphneir/DaphneOps.td b/src/ir/daphneir/DaphneOps.td index 35b437916..55bcfc144 100644 --- a/src/ir/daphneir/DaphneOps.td +++ b/src/ir/daphneir/DaphneOps.td @@ -820,7 +820,8 @@ def Daphne_BiasAddForwardOp : Daphne_Op<"biasAddForward", [ CUDASupport ]> { // Convolution // ---------------------------------------------------------------------------- -def Daphne_Conv2DForwardOp : Daphne_Op<"Convolution_Forward", [ CUDASupport ]> { +def Daphne_Conv2DForwardOp : Daphne_Op<"Convolution_Forward", + [ DeclareOpInterfaceMethods, CUDASupport ]> { let arguments = (ins MatrixOf<[FloatScalar]>:$input, MatrixOf<[FloatScalar]>:$filter, MatrixOf<[FloatScalar]>:$bias, // input shape