diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 4376bb323cc..e9cf9eccade 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -366,7 +366,9 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i) type_t::fp8e4m3fn_type, type_t::fp8e5m2_type, type_t::int8_type, + type_t::uint8_type, type_t::int32_type, + type_t::uint32_type, type_t::bool_type}; // Preliminary type check. if(not contains(allowed_types, result_type)) diff --git a/src/targets/gpu/mlir.cpp b/src/targets/gpu/mlir.cpp index e96c6a5d7dd..2907f7eb160 100644 --- a/src/targets/gpu/mlir.cpp +++ b/src/targets/gpu/mlir.cpp @@ -322,14 +322,14 @@ struct mlir_program result = mlirF64TypeGet(ctx.get()); else if(as.is_integral()) { - // Note: rocMLIR use signless integer type for tensors types. This - // will translate to signed implementation for current supported - // operations. if(as.is_unsigned()) { - MIGRAPHX_THROW("Unsupported type: " + std::to_string(as.type_enum())); + result = mlirIntegerTypeUnsignedGet(ctx.get(), as.size() * 8); + } + else + { + result = mlirIntegerTypeSignedGet(ctx.get(), as.size() * 8); } - result = mlirIntegerTypeGet(ctx.get(), as.size() * 8); // number of bits } else MIGRAPHX_THROW("Unsupported type: " + std::to_string(as.type_enum())); @@ -718,10 +718,6 @@ struct mlir_program literal r = ins->get_literal(); auto sh = ins->get_shape(); - // mlir works only with signed types. change uint4 to (int4 + unsigned-flag) - if(shape::is_unsigned(sh.type()) and ins->outputs()[0]->name() == "unpack_int4") - sh = ins->get_shape().with_type(shape::int8_type); - MlirType shaped_type = make_mlir_shaped(sh); MlirType tensor_type = rocmlirMIXRShapedTypeAsTensor(shaped_type); MlirAttribute mlir_value_attr = @@ -729,13 +725,6 @@ struct mlir_program ops.add_attributes({{"value", mlir_value_attr}}); } - if(ins->name() == "unpack_int4") - { - auto sh = get_shape(ins); - ops.add_attributes( - {{"isUnsigned", shape::is_unsigned(sh.type())}}); // flag for uint4 - } - if(ins->name() == "convolution" or ins->name() == "dot") { pp = diff --git a/test/gpu/fuse_mlir.cpp b/test/gpu/fuse_mlir.cpp index 06034f77976..37fbf76da01 100644 --- a/test/gpu/fuse_mlir.cpp +++ b/test/gpu/fuse_mlir.cpp @@ -528,6 +528,89 @@ TEST_CASE(dequantizelinear_dot) EXPECT(p1.sort() == p2.sort()); } +TEST_CASE(unsigned_dequantizelinear_dot) +{ + migraphx::program p1; + { + auto* mm = p1.get_main_module(); + + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 3, 5}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::uint8_type, {2, 5, 2}}); + auto scalelit = + mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 2, 2}})); + auto zplit = + mm->add_literal(migraphx::generate_literal({migraphx::shape::uint8_type, {2, 2, 2}})); + + auto unsqueeze1 = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), scalelit); + auto broadcast1 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze1); + auto reshape1 = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 6, 2}}}), broadcast1); + auto scale = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), reshape1); + + auto unsqueeze2 = + mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), zplit); + auto broadcast2 = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze2); + auto reshape2 = + mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 6, 2}}}), broadcast2); + auto zp = mm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), reshape2); + + auto dq = add_pointwise( + p1, "main:pointwise0", {y, scale, zp}, single_pointwise("dequantizelinear")); + auto dot = mm->add_instruction(migraphx::make_op("dot"), x, dq); + mm->add_return({dot}); + } + run_pass(p1); + migraphx::program p2; + { + auto* mm = p2.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 3, 5}}); + auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::uint8_type, {2, 5, 2}}); + auto scalelit = + mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 2, 2}})); + auto zplit = + mm->add_literal(migraphx::generate_literal({migraphx::shape::uint8_type, {2, 2, 2}})); + + auto fused = add_mlir( + p2, + "main:pointwise0:mlir_dot0", + {y, scalelit, zplit, x}, + {"x0", "x1", "x2", "x3"}, + [=](auto* pm, const auto& inputs) { + auto unsqueeze1 = + pm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), inputs[1]); + auto broadcast1 = pm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze1); + auto reshape1 = pm->add_instruction( + migraphx::make_op("reshape", {{"dims", {2, 6, 2}}}), broadcast1); + auto scale = pm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), + reshape1); + + auto unsqueeze2 = + pm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), inputs[2]); + auto broadcast2 = pm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze2); + auto reshape2 = pm->add_instruction( + migraphx::make_op("reshape", {{"dims", {2, 6, 2}}}), broadcast2); + auto zp = pm->add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), + reshape2); + + auto dq = pm->add_instruction( + migraphx::make_op("dequantizelinear"), inputs[0], scale, zp); + auto dot = pm->add_instruction(migraphx::make_op("dot"), inputs[3], dq); + return std::make_tuple(dot->get_operator(), dot); + }); + mm->add_return({fused}); + } + EXPECT(p1.sort() == p2.sort()); +} + TEST_CASE(unpack_int4_dot) { migraphx::program p1; diff --git a/test/gpu/mlir.cpp b/test/gpu/mlir.cpp index 5453f60004a..108bb941fd8 100644 --- a/test/gpu/mlir.cpp +++ b/test/gpu/mlir.cpp @@ -308,10 +308,10 @@ TEST_CASE(quant_dot_add) { std::string mlir_output = R"__migraphx__( module { - func.func @mlir_quant_dot_add(%arg0: !migraphx.shaped<1x5x4xi8, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xi8, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi32, 15x3x1>) -> !migraphx.shaped<1x5x3xi32, 15x3x1> attributes ${attrs} { - %0 = migraphx.quant_dot %arg0, %arg1 : <1x5x4xi8, 20x4x1>, <1x4x3xi8, 12x3x1> -> <1x5x3xi32, 15x3x1> - %1 = migraphx.add %0, %arg2 : <1x5x3xi32, 15x3x1>, <1x5x3xi32, 15x3x1> -> <1x5x3xi32, 15x3x1> - return %1 : !migraphx.shaped<1x5x3xi32, 15x3x1> + func.func @mlir_quant_dot_add(%arg0: !migraphx.shaped<1x5x4xsi8, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xsi8, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xsi32, 15x3x1>) -> !migraphx.shaped<1x5x3xsi32, 15x3x1> attributes ${attrs} { + %0 = migraphx.quant_dot %arg0, %arg1 : <1x5x4xsi8, 20x4x1>, <1x4x3xsi8, 12x3x1> -> <1x5x3xsi32, 15x3x1> + %1 = migraphx.add %0, %arg2 : <1x5x3xsi32, 15x3x1>, <1x5x3xsi32, 15x3x1> -> <1x5x3xsi32, 15x3x1> + return %1 : !migraphx.shaped<1x5x3xsi32, 15x3x1> } } )__migraphx__"; @@ -395,11 +395,11 @@ TEST_CASE(conv_int8_dequantize_quantize) { std::string mlir_output = R"__migraphx__( module { - func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0: !migraphx.shaped<2x8x3x3xi8, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xi8, 128x16x4x1>, %arg2: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg3: !migraphx.shaped<1x2x2x2xi32, 8x4x2x1>) -> !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> attributes ${attrs} { - %0 = migraphx.quant_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xi8, 128x16x4x1>, <2x8x3x3xi8, 72x9x3x1> -> <1x2x2x2xi32, 8x4x2x1> - %1 = migraphx.dequantizelinear %0, %arg2, %arg3 : <1x2x2x2xi32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1> - %2 = migraphx.quantizelinear %1, %arg2, %arg3 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> -> <1x2x2x2xi32, 8x4x2x1> - return %2 : !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> + func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0: !migraphx.shaped<2x8x3x3xsi8, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xsi8, 128x16x4x1>, %arg2: !migraphx.shaped<1x2x2x2xf32, 8x4x2x1>, %arg3: !migraphx.shaped<1x2x2x2xsi32, 8x4x2x1>) -> !migraphx.shaped<1x2x2x2xsi32, 8x4x2x1> attributes ${attrs} { + %0 = migraphx.quant_convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xsi8, 128x16x4x1>, <2x8x3x3xsi8, 72x9x3x1> -> <1x2x2x2xsi32, 8x4x2x1> + %1 = migraphx.dequantizelinear %0, %arg2, %arg3 : <1x2x2x2xsi32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xsi32, 8x4x2x1> -> <1x2x2x2xf32, 8x4x2x1> + %2 = migraphx.quantizelinear %1, %arg2, %arg3 : <1x2x2x2xf32, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xsi32, 8x4x2x1> -> <1x2x2x2xsi32, 8x4x2x1> + return %2 : !migraphx.shaped<1x2x2x2xsi32, 8x4x2x1> } } )__migraphx__"; @@ -458,9 +458,9 @@ TEST_CASE(dot_where) { std::string mlir_output = R"__migraphx__( module { - func.func @mlir_dot_where(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xi8, 15x3x1>, %arg3: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes ${attrs} { + func.func @mlir_dot_where(%arg0: !migraphx.shaped<1x5x4xf32, 20x4x1>, %arg1: !migraphx.shaped<1x4x3xf32, 12x3x1>, %arg2: !migraphx.shaped<1x5x3xsi8, 15x3x1>, %arg3: !migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped<1x5x3xf32, 15x3x1> attributes ${attrs} { %0 = migraphx.dot %arg0, %arg1 : <1x5x4xf32, 20x4x1>, <1x4x3xf32, 12x3x1> -> <1x5x3xf32, 15x3x1> - %1 = migraphx.where %arg2, %0, %arg3 : <1x5x3xi8, 15x3x1>, <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1> + %1 = migraphx.where %arg2, %0, %arg3 : <1x5x3xsi8, 15x3x1>, <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> -> <1x5x3xf32, 15x3x1> return %1 : !migraphx.shaped<1x5x3xf32, 15x3x1> } } @@ -487,11 +487,11 @@ module { TEST_CASE(int4_unpack_ir) { std::string mlir_output = R"__migraphx__( -module { - func.func @mlir_unpack_int4(%arg0: !migraphx.shaped<2x1xi8, 1x1>) -> !migraphx.shaped<2x2xi8, 2x1> attributes ${attrs} { - %0 = migraphx.unpack %arg0 {axis = 1 : i64, isUnsigned = false} : <2x1xi8, 1x1> -> <2x2xi8, 2x1> - return %0 : !migraphx.shaped<2x2xi8, 2x1> - } +module { + func.func @mlir_unpack_int4(%arg0: !migraphx.shaped<2x1xsi8, 1x1>) -> !migraphx.shaped<2x2xsi8, 2x1> attributes ${attrs} { + %0 = migraphx.unpack %arg0 {axis = 1 : i64} : <2x1xsi8, 1x1> -> <2x2xsi8, 2x1> + return %0 : !migraphx.shaped<2x2xsi8, 2x1> + } } )__migraphx__"; migraphx::module m; @@ -513,12 +513,12 @@ module { TEST_CASE(int4_unpack_conv) { std::string mlir_output = R"__migraphx__( - module { - func.func @mlir_unpack_int4_quant_convolution(%arg0: !migraphx.shaped<2x8x2x1xi8, 16x2x1x1>, %arg1: !migraphx.shaped<1x8x4x4xi8, 128x16x4x1>) -> !migraphx.shaped<1x2x3x3xi32, 18x9x3x1> attributes ${attrs} { - %0 = migraphx.unpack %arg0 {axis = 3 : i64, isUnsigned = false} : <2x8x2x1xi8, 16x2x1x1> -> <2x8x2x2xi8, 32x4x2x1> - %1 = migraphx.quant_convolution %arg1, %0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xi8, 128x16x4x1>, <2x8x2x2xi8, 32x4x2x1> -> <1x2x3x3xi32, 18x9x3x1> - return %1 : !migraphx.shaped<1x2x3x3xi32, 18x9x3x1> - } +module { + func.func @mlir_unpack_int4_quant_convolution(%arg0: !migraphx.shaped<2x8x2x1xsi8, 16x2x1x1>, %arg1: !migraphx.shaped<1x8x4x4xsi8, 128x16x4x1>) -> !migraphx.shaped<1x2x3x3xsi32, 18x9x3x1> attributes ${attrs} { + %0 = migraphx.unpack %arg0 {axis = 3 : i64} : <2x8x2x1xsi8, 16x2x1x1> -> <2x8x2x2xsi8, 32x4x2x1> + %1 = migraphx.quant_convolution %arg1, %0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xsi8, 128x16x4x1>, <2x8x2x2xsi8, 32x4x2x1> -> <1x2x3x3xsi32, 18x9x3x1> + return %1 : !migraphx.shaped<1x2x3x3xsi32, 18x9x3x1> + } } )__migraphx__"; migraphx::module m; @@ -537,4 +537,116 @@ TEST_CASE(int4_unpack_conv) EXPECT(verify_mlir(m)); } +TEST_CASE(int4_unpack_dequantizelinear) +{ + std::string mlir_output = R"__migraphx__( +module { + func.func @mlir_unsqueeze_reshape_slice_unsqueeze_reshape_slice_unpack_int4_dequantizelinear_dot(%arg0: !migraphx.shaped<2x3x5xf32, 15x5x1>, %arg1: !migraphx.shaped<2x5x1xsi8, 5x1x1>, %arg2: !migraphx.shaped<2x2x2xf32, 4x2x1>, %arg3: !migraphx.shaped<2x2x2xsi8, 4x2x1>) -> !migraphx.shaped<2x3x2xf32, 6x2x1> attributes ${attrs} { + %0 = migraphx.reshape %arg2 {dims = [2, 2, 1, 2]} : <2x2x2xf32, 4x2x1> -> <2x2x1x2xf32, 4x2x2x1> + %1 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [2, 2, 3, 2]} : <2x2x1x2xf32, 4x2x2x1> -> <2x2x3x2xf32, 4x2x0x1> + %2 = migraphx.reshape %1 {dims = [2, 6, 2]} : <2x2x3x2xf32, 4x2x0x1> -> <2x6x2xf32, 12x2x1> + %3 = migraphx.slice %2 {axes = [1], ends = [5], starts = [0]} : <2x6x2xf32, 12x2x1> -> <2x5x2xf32, 12x2x1> + %4 = migraphx.reshape %arg3 {dims = [2, 2, 1, 2]} : <2x2x2xsi8, 4x2x1> -> <2x2x1x2xsi8, 4x2x2x1> + %5 = migraphx.multibroadcast %4 {out_dyn_dims = [], out_lens = [2, 2, 3, 2]} : <2x2x1x2xsi8, 4x2x2x1> -> <2x2x3x2xsi8, 4x2x0x1> + %6 = migraphx.reshape %5 {dims = [2, 6, 2]} : <2x2x3x2xsi8, 4x2x0x1> -> <2x6x2xsi8, 12x2x1> + %7 = migraphx.slice %6 {axes = [1], ends = [5], starts = [0]} : <2x6x2xsi8, 12x2x1> -> <2x5x2xsi8, 12x2x1> + %8 = migraphx.unpack %arg1 {axis = 2 : i64} : <2x5x1xsi8, 5x1x1> -> <2x5x2xsi8, 10x2x1> + %9 = migraphx.dequantizelinear %8, %3, %7 : <2x5x2xsi8, 10x2x1>, <2x5x2xf32, 12x2x1>, !migraphx.shaped<2x5x2xsi8, 12x2x1> -> <2x5x2xf32, 10x2x1> + %10 = migraphx.dot %arg0, %9 : <2x3x5xf32, 15x5x1>, <2x5x2xf32, 10x2x1> -> <2x3x2xf32, 6x2x1> + return %10 : !migraphx.shaped<2x3x2xf32, 6x2x1> + } +} +)__migraphx__"; + migraphx::module m; + auto x0 = m.add_parameter("x0", migraphx::shape{migraphx::shape::float_type, {2, 3, 5}}); + auto x1 = m.add_parameter("x1", migraphx::shape{migraphx::shape::int8_type, {2, 5, 1}}); + auto x2 = m.add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}}); + auto x3 = m.add_parameter("x3", migraphx::shape{migraphx::shape::int8_type, {2, 2, 2}}); + + auto unsqueeze1 = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), x2); + auto broadcast1 = m.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze1); + auto reshape1 = + m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 6, 2}}}), broadcast1); + auto scale = m.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), reshape1); + + auto unsqueeze2 = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), x3); + auto broadcast2 = m.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze2); + auto reshape2 = + m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 6, 2}}}), broadcast2); + auto zp = m.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), reshape2); + + auto unpack = m.add_instruction(migraphx::make_op("unpack_int4"), x1); + auto dq = m.add_instruction(migraphx::make_op("dequantizelinear"), unpack, scale, zp); + auto dot = m.add_instruction(migraphx::make_op("dot"), x0, dq); + m.add_return({dot}); + auto s = migraphx::gpu::dump_mlir(m); + // Skip test if MLIR is not enabled + if(s.empty()) + return; + auto mlir_output_with_attrs = + migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); + CHECK(encode(s) == encode(mlir_output_with_attrs)); + EXPECT(verify_mlir(m)); +} + +TEST_CASE(uint4_unpack_dequantizelinear) +{ + std::string mlir_output = R"__migraphx__( +module { + func.func @mlir_unsqueeze_reshape_slice_unsqueeze_reshape_slice_unpack_int4_dequantizelinear_dot(%arg0: !migraphx.shaped<2x3x5xf32, 15x5x1>, %arg1: !migraphx.shaped<2x5x1xui8, 5x1x1>, %arg2: !migraphx.shaped<2x2x2xf32, 4x2x1>, %arg3: !migraphx.shaped<2x2x2xui8, 4x2x1>) -> !migraphx.shaped<2x3x2xf32, 6x2x1> attributes ${attrs} { + %0 = migraphx.reshape %arg2 {dims = [2, 2, 1, 2]} : <2x2x2xf32, 4x2x1> -> <2x2x1x2xf32, 4x2x2x1> + %1 = migraphx.multibroadcast %0 {out_dyn_dims = [], out_lens = [2, 2, 3, 2]} : <2x2x1x2xf32, 4x2x2x1> -> <2x2x3x2xf32, 4x2x0x1> + %2 = migraphx.reshape %1 {dims = [2, 6, 2]} : <2x2x3x2xf32, 4x2x0x1> -> <2x6x2xf32, 12x2x1> + %3 = migraphx.slice %2 {axes = [1], ends = [5], starts = [0]} : <2x6x2xf32, 12x2x1> -> <2x5x2xf32, 12x2x1> + %4 = migraphx.reshape %arg3 {dims = [2, 2, 1, 2]} : <2x2x2xui8, 4x2x1> -> <2x2x1x2xui8, 4x2x2x1> + %5 = migraphx.multibroadcast %4 {out_dyn_dims = [], out_lens = [2, 2, 3, 2]} : <2x2x1x2xui8, 4x2x2x1> -> <2x2x3x2xui8, 4x2x0x1> + %6 = migraphx.reshape %5 {dims = [2, 6, 2]} : <2x2x3x2xui8, 4x2x0x1> -> <2x6x2xui8, 12x2x1> + %7 = migraphx.slice %6 {axes = [1], ends = [5], starts = [0]} : <2x6x2xui8, 12x2x1> -> <2x5x2xui8, 12x2x1> + %8 = migraphx.unpack %arg1 {axis = 2 : i64} : <2x5x1xui8, 5x1x1> -> <2x5x2xui8, 10x2x1> + %9 = migraphx.dequantizelinear %8, %3, %7 : <2x5x2xui8, 10x2x1>, <2x5x2xf32, 12x2x1>, !migraphx.shaped<2x5x2xui8, 12x2x1> -> <2x5x2xf32, 10x2x1> + %10 = migraphx.dot %arg0, %9 : <2x3x5xf32, 15x5x1>, <2x5x2xf32, 10x2x1> -> <2x3x2xf32, 6x2x1> + return %10 : !migraphx.shaped<2x3x2xf32, 6x2x1> + } +} +)__migraphx__"; + migraphx::module m; + auto x0 = m.add_parameter("x0", migraphx::shape{migraphx::shape::float_type, {2, 3, 5}}); + auto x1 = m.add_parameter("x1", migraphx::shape{migraphx::shape::uint8_type, {2, 5, 1}}); + auto x2 = m.add_parameter("x2", migraphx::shape{migraphx::shape::float_type, {2, 2, 2}}); + auto x3 = m.add_parameter("x3", migraphx::shape{migraphx::shape::uint8_type, {2, 2, 2}}); + + auto unsqueeze1 = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), x2); + auto broadcast1 = m.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze1); + auto reshape1 = + m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 6, 2}}}), broadcast1); + auto scale = m.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), reshape1); + + auto unsqueeze2 = m.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), x3); + auto broadcast2 = m.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze2); + auto reshape2 = + m.add_instruction(migraphx::make_op("reshape", {{"dims", {2, 6, 2}}}), broadcast2); + auto zp = m.add_instruction( + migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), reshape2); + + auto unpack = m.add_instruction(migraphx::make_op("unpack_int4"), x1); + auto dq = m.add_instruction(migraphx::make_op("dequantizelinear"), unpack, scale, zp); + auto dot = m.add_instruction(migraphx::make_op("dot"), x0, dq); + m.add_return({dot}); + auto s = migraphx::gpu::dump_mlir(m); + // Skip test if MLIR is not enabled + if(s.empty()) + return; + auto mlir_output_with_attrs = + migraphx::interpolate_string(mlir_output, {{"attrs", get_attrs()}}); + CHECK(encode(s) == encode(mlir_output_with_attrs)); + EXPECT(verify_mlir(m)); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); }