Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for unsigned types with mlir #3582

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,9 @@ bool is_pointwise_op_supported_by_mlir(const instruction& i)
type_t::fp8e4m3fn_type,
type_t::fp8e5m2_type,
type_t::int8_type,
type_t::uint8_type,
type_t::int32_type,
type_t::uint32_type,
type_t::bool_type};
// Preliminary type check.
if(not contains(allowed_types, result_type))
Expand Down
21 changes: 5 additions & 16 deletions src/targets/gpu/mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,14 +322,14 @@ struct mlir_program
result = mlirF64TypeGet(ctx.get());
else if(as.is_integral())
{
// Note: rocMLIR use signless integer type for tensors types. This
// will translate to signed implementation for current supported
// operations.
if(as.is_unsigned())
{
MIGRAPHX_THROW("Unsupported type: " + std::to_string(as.type_enum()));
result = mlirIntegerTypeUnsignedGet(ctx.get(), as.size() * 8);
}
else
{
result = mlirIntegerTypeSignedGet(ctx.get(), as.size() * 8);
}
result = mlirIntegerTypeGet(ctx.get(), as.size() * 8); // number of bits
}
else
MIGRAPHX_THROW("Unsupported type: " + std::to_string(as.type_enum()));
Expand Down Expand Up @@ -718,24 +718,13 @@ struct mlir_program
literal r = ins->get_literal();
auto sh = ins->get_shape();

// mlir works only with signed types. change uint4 to (int4 + unsigned-flag)
if(shape::is_unsigned(sh.type()) and ins->outputs()[0]->name() == "unpack_int4")
sh = ins->get_shape().with_type(shape::int8_type);

MlirType shaped_type = make_mlir_shaped(sh);
MlirType tensor_type = rocmlirMIXRShapedTypeAsTensor(shaped_type);
MlirAttribute mlir_value_attr =
mlirDenseElementsAttrRawBufferGet(tensor_type, r.get_shape().bytes(), r.data());
ops.add_attributes({{"value", mlir_value_attr}});
}

if(ins->name() == "unpack_int4")
{
auto sh = get_shape(ins);
ops.add_attributes(
{{"isUnsigned", shape::is_unsigned(sh.type())}}); // flag for uint4
}

if(ins->name() == "convolution" or ins->name() == "dot")
{
pp =
Expand Down
83 changes: 83 additions & 0 deletions test/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,89 @@ TEST_CASE(dequantizelinear_dot)
EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(unsigned_dequantizelinear_dot)
{
migraphx::program p1;
{
auto* mm = p1.get_main_module();

auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 3, 5}});
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::uint8_type, {2, 5, 2}});
auto scalelit =
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 2, 2}}));
auto zplit =
mm->add_literal(migraphx::generate_literal({migraphx::shape::uint8_type, {2, 2, 2}}));

auto unsqueeze1 =
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), scalelit);
auto broadcast1 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze1);
auto reshape1 =
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 6, 2}}}), broadcast1);
auto scale = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), reshape1);

auto unsqueeze2 =
mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), zplit);
auto broadcast2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze2);
auto reshape2 =
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 6, 2}}}), broadcast2);
auto zp = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}), reshape2);

auto dq = add_pointwise(
p1, "main:pointwise0", {y, scale, zp}, single_pointwise("dequantizelinear"));
auto dot = mm->add_instruction(migraphx::make_op("dot"), x, dq);
mm->add_return({dot});
}
run_pass(p1);
migraphx::program p2;
{
auto* mm = p2.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 3, 5}});
auto y = mm->add_parameter("y", migraphx::shape{migraphx::shape::uint8_type, {2, 5, 2}});
auto scalelit =
mm->add_literal(migraphx::generate_literal({migraphx::shape::float_type, {2, 2, 2}}));
auto zplit =
mm->add_literal(migraphx::generate_literal({migraphx::shape::uint8_type, {2, 2, 2}}));

auto fused = add_mlir(
p2,
"main:pointwise0:mlir_dot0",
{y, scalelit, zplit, x},
{"x0", "x1", "x2", "x3"},
[=](auto* pm, const auto& inputs) {
auto unsqueeze1 =
pm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), inputs[1]);
auto broadcast1 = pm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze1);
auto reshape1 = pm->add_instruction(
migraphx::make_op("reshape", {{"dims", {2, 6, 2}}}), broadcast1);
auto scale = pm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}),
reshape1);

auto unsqueeze2 =
pm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", {2}}}), inputs[2]);
auto broadcast2 = pm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", {2, 2, 3, 2}}}), unsqueeze2);
auto reshape2 = pm->add_instruction(
migraphx::make_op("reshape", {{"dims", {2, 6, 2}}}), broadcast2);
auto zp = pm->add_instruction(
migraphx::make_op("slice", {{"axes", {1}}, {"starts", {0}}, {"ends", {5}}}),
reshape2);

auto dq = pm->add_instruction(
migraphx::make_op("dequantizelinear"), inputs[0], scale, zp);
auto dot = pm->add_instruction(migraphx::make_op("dot"), inputs[3], dq);
return std::make_tuple(dot->get_operator(), dot);
});
mm->add_return({fused});
}
EXPECT(p1.sort() == p2.sort());
}

TEST_CASE(unpack_int4_dot)
{
migraphx::program p1;
Expand Down
Loading
Loading