diff --git a/src/neural/xla/hlo_builder.cc b/src/neural/xla/hlo_builder.cc index 873d2d50a8..24eaccb2d1 100644 --- a/src/neural/xla/hlo_builder.cc +++ b/src/neural/xla/hlo_builder.cc @@ -290,12 +290,18 @@ HloFlow HloBuilder::Gather(HloFlow input, HloFlow indices, } HloFlow HloBuilder::Dot(HloFlow lhs, HloFlow rhs, - const pblczero::XlaDotDimensionNumbers& dn) { + const pblczero::XlaDotDimensionNumbers& dn, + const pblczero::XlaShapeProto::Type out_type) { HloTensorType lhs_shape(lhs->shape()); HloTensorType rhs_shape(rhs->shape()); - HloTensorType new_shape(lhs_shape.GetElementType()); - if (lhs_shape.GetElementType() != rhs_shape.GetElementType()) { - throw Exception("Dot operands must have the same element type"); + HloTensorType new_shape( + out_type == pblczero::XlaShapeProto::PRIMITIVE_TYPE_INVALID + ? lhs_shape.GetElementType() + : out_type); + if (lhs_shape.GetElementType() != rhs_shape.GetElementType() && + out_type == pblczero::XlaShapeProto::PRIMITIVE_TYPE_INVALID) { + throw Exception( + "Dot operands must have the same element type or explicit output type"); } if (dn.lhs_batch_dimensions_size() != dn.rhs_batch_dimensions_size()) { throw Exception("Dot batch dimensions must have the same size"); @@ -471,6 +477,18 @@ HloFlow HloBuilder::Slice( return flow; } +HloFlow HloBuilder::RoundNearestEven(HloFlow input) { + return MakeInstruction("round-nearest-even", input->shape(), {input}); +} + +HloFlow HloBuilder::Clamp(HloFlow min, HloFlow input, HloFlow max) { + if (min->shape().dimensions() != input->shape().dimensions() || + input->shape().dimensions() != max->shape().dimensions()) { + throw Exception("Clamp operands must have the same shape"); + } + return MakeInstruction("clamp", input->shape(), {min, input, max}); +} + namespace { // Go over all "parameter" instructions of the computation and assign // "parameter_number" field with increasing numbers. diff --git a/src/neural/xla/hlo_builder.h b/src/neural/xla/hlo_builder.h index 1211446765..30a9110adf 100644 --- a/src/neural/xla/hlo_builder.h +++ b/src/neural/xla/hlo_builder.h @@ -115,7 +115,9 @@ class HloBuilder { HloFlow Maximum(HloFlow lhs, HloFlow rhs); HloFlow Reshape(HloFlow input, const HloTensorType& new_shape); HloFlow Dot(HloFlow lhs, HloFlow rhs, - const pblczero::XlaDotDimensionNumbers& dimension_numbers); + const pblczero::XlaDotDimensionNumbers& dimension_numbers, + const pblczero::XlaShapeProto::Type out_type = + pblczero::XlaShapeProto::PRIMITIVE_TYPE_INVALID); HloFlow Slice( HloFlow input, const std::vector& slice); @@ -140,6 +142,8 @@ class HloBuilder { // Direction is one of "EQ", "NE", "LT", "LE", "GT", "GE". HloFlow Compare(HloFlow lhs, HloFlow rhs, std::string_view direction); HloFlow Select(HloFlow condition, HloFlow on_true, HloFlow on_false); + HloFlow RoundNearestEven(HloFlow input); + HloFlow Clamp(HloFlow min, HloFlow input, HloFlow max); // Insert a computation into the module, under given name. Dependent // computations are also merged into the module. HloComputation AddComputation(std::string_view name, diff --git a/src/neural/xla/onnx2hlo.cc b/src/neural/xla/onnx2hlo.cc index a33923de25..a74275ee80 100644 --- a/src/neural/xla/onnx2hlo.cc +++ b/src/neural/xla/onnx2hlo.cc @@ -29,6 +29,7 @@ #include #include +#include #include #include @@ -40,6 +41,7 @@ #include "utils/exception.h" #include "utils/fp16_utils.h" #include "utils/fp8_utils.h" +#include "utils/string.h" namespace lczero { namespace { @@ -97,7 +99,7 @@ void FetchMutableForType(pblczero::XlaLiteralProto* literal, break; default: throw Exception( - "Unsupported type for constant input " + + "Unsupported type for mutable input " + pblczero::XlaShapeProto::Type_Name(literal->shape().element_type())); } } @@ -145,6 +147,13 @@ void LiteralOutInOpDifferentTypes(pblczero::XlaLiteralProto* dst, }); } +size_t GetLiteralByteSize(const pblczero::XlaLiteralProto& literal) { + size_t res; + FetchConstForType(literal, literal.shape().element_type(), + [&](const auto& x) { res = x.size() * sizeof(x[0]); }); + return res; +} + pblczero::XlaLiteralProto ConstOpConvert( const pblczero::XlaLiteralProto& input, const pblczero::XlaShapeProto::Type& to_type) { @@ -447,12 +456,18 @@ pblczero::XlaLiteralProto OnnxTensorToXlaLiteral( case pblczero::TensorProto::FLOAT8E5M2: literal.set_f8e5m2s(tensor.raw_data()); break; + case pblczero::TensorProto::FLOAT8E4M3FN: + literal.set_f8e4m3fns(tensor.raw_data()); + break; case pblczero::TensorProto::INT64: convert(tensor.raw_data(), literal.mutable_s64s()); break; case pblczero::TensorProto::INT32: convert(tensor.raw_data(), literal.mutable_s32s()); break; + case pblczero::TensorProto::INT8: + literal.set_s8s(tensor.raw_data()); + break; default: throw Exception("Cannot convert ONNX tensor to XLA literal for type " + pblczero::XlaShapeProto::Type_Name( @@ -468,9 +483,13 @@ class Onnx2HloConverter { onnx_op_to_builder_["BatchNormalization"] = &Onnx2HloConverter::OpBatchNormalization; onnx_op_to_builder_["Cast"] = &Onnx2HloConverter::OpCast; + onnx_op_to_builder_["Clip"] = &Onnx2HloConverter::OpClip; onnx_op_to_builder_["Concat"] = &Onnx2HloConverter::OpConcat; onnx_op_to_builder_["Conv"] = &Onnx2HloConverter::OpConv; + onnx_op_to_builder_["DequantizeLinear"] = + &Onnx2HloConverter::OpDequantizeLinear; onnx_op_to_builder_["Div"] = &Onnx2HloConverter::OpDiv; + onnx_op_to_builder_["Einsum"] = &Onnx2HloConverter::OpEinsum; onnx_op_to_builder_["Gather"] = &Onnx2HloConverter::OpGather; onnx_op_to_builder_["GlobalAveragePool"] = &Onnx2HloConverter::OpGlobalAveragePool; @@ -480,8 +499,11 @@ class Onnx2HloConverter { &Onnx2HloConverter::OpLayerNormalization; onnx_op_to_builder_["Max"] = &Onnx2HloConverter::OpMax; onnx_op_to_builder_["MatMul"] = &Onnx2HloConverter::OpMatMul; + onnx_op_to_builder_["MatMulInteger"] = &Onnx2HloConverter::OpMatMulInteger; onnx_op_to_builder_["Mish"] = &Onnx2HloConverter::OpMish; onnx_op_to_builder_["Mul"] = &Onnx2HloConverter::OpMul; + onnx_op_to_builder_["QuantizeLinear"] = + &Onnx2HloConverter::OpQuantizeLinear; onnx_op_to_builder_["Reciprocal"] = &Onnx2HloConverter::OpReciprocal; onnx_op_to_builder_["ReduceMean"] = &Onnx2HloConverter::OpReduceMean; onnx_op_to_builder_["ReduceProd"] = &Onnx2HloConverter::OpReduceProd; @@ -489,6 +511,7 @@ class Onnx2HloConverter { &Onnx2HloConverter::OpReduceSumSquare; onnx_op_to_builder_["Relu"] = &Onnx2HloConverter::OpRelu; onnx_op_to_builder_["Reshape"] = &Onnx2HloConverter::OpReshape; + onnx_op_to_builder_["Round"] = &Onnx2HloConverter::OpRound; onnx_op_to_builder_["Selu"] = &Onnx2HloConverter::OpSelu; onnx_op_to_builder_["Sigmoid"] = &Onnx2HloConverter::OpSigmoid; onnx_op_to_builder_["Shape"] = &Onnx2HloConverter::OpShape; @@ -500,6 +523,7 @@ class Onnx2HloConverter { onnx_op_to_builder_["Squeeze"] = &Onnx2HloConverter::OpSqueeze; onnx_op_to_builder_["Sub"] = &Onnx2HloConverter::OpSub; onnx_op_to_builder_["Tanh"] = &Onnx2HloConverter::OpTanh; + onnx_op_to_builder_["Tile"] = &Onnx2HloConverter::OpTile; onnx_op_to_builder_["Transpose"] = &Onnx2HloConverter::OpTranspose; onnx_op_to_builder_["Unsqueeze"] = &Onnx2HloConverter::OpUnsqueeze; } @@ -657,7 +681,12 @@ class Onnx2HloConverter { bool AllInputsConstant(const pblczero::NodeProto& node) { for (const auto& input : node.input()) { const std::string name(input); - if (initializers_.count(name)) continue; + if (auto tensor = initializers_.find(name); + tensor != initializers_.end() && + tensor->second->raw_data().size() <= + options_.max_inline_constant_size) { + continue; + } if (auto iter = onnx_name_to_hlo_flow_.find(name); iter != onnx_name_to_hlo_flow_.end() && iter->second->opcode() == "constant") { @@ -1045,11 +1074,12 @@ class Onnx2HloConverter { GetAttributeAs(node, "to")); const auto hlo_type = OnnxTypeToXlaType(onnx_type); if (input->shape().element_type() == hlo_type) return {input}; - // Only convert constants of int64 to int32 as that's what TF does. if (AllInputsConstant(node) && CanConvertConstant(hlo_type) && CanConvertConstant(input->shape().element_type())) { - return {builder_.Constant( - ConstOpConvert(*GetConstantInput(node, 0), hlo_type))}; + auto literal = ConstOpConvert(*GetConstantInput(node, 0), hlo_type); + if (GetLiteralByteSize(literal) <= options_.max_inline_constant_size) { + return {builder_.Constant(literal)}; + } } return {builder_.Convert(input, hlo_type)}; } @@ -1124,7 +1154,10 @@ class Onnx2HloConverter { for (size_t i = 0; i < node.input_size(); ++i) { constants.push_back(*GetConstantInput(node, i)); } - return {builder_.Constant(ConstOpConcat(constants, axis))}; + auto literal = ConstOpConcat(constants, axis); + if (GetLiteralByteSize(literal) <= options_.max_inline_constant_size) { + return {builder_.Constant(literal)}; + } } std::vector inputs; for (size_t i = 0; i < node.input_size(); ++i) { @@ -1381,6 +1414,11 @@ class Onnx2HloConverter { } std::vector OpMatMul(const pblczero::NodeProto& node) { + return MakeMatMul(node, pblczero::XlaShapeProto::PRIMITIVE_TYPE_INVALID); + } + + std::vector MakeMatMul(const pblczero::NodeProto& node, + const pblczero::XlaShapeProto::Type type) { CheckKnownAttributes(node, 2, {}); auto* lhs = GetInput(node, 0); auto* rhs = GetInput(node, 1); @@ -1397,7 +1435,7 @@ class Onnx2HloConverter { dn.add_lhs_batch_dimensions(i); dn.add_rhs_batch_dimensions(i); } - return {builder_.Dot(lhs, rhs, dn)}; + return {builder_.Dot(lhs, rhs, dn, type)}; } std::vector OpGlobalAveragePool(const pblczero::NodeProto& node) { @@ -1492,6 +1530,179 @@ class Onnx2HloConverter { return {builder_.Multiply(flow, input)}; } + std::vector OpEinsum(const pblczero::NodeProto& node) { + CheckKnownAttributes(node, std::numeric_limits::max(), + {"equation"}); + std::string equation(GetAttribute(node, "equation")->s()); + std::vector inputs; + for (size_t i = 0; i < node.input_size(); ++i) { + inputs.push_back(GetInput(node, i)); + } + if (inputs.size() != 2) { + throw Exception("Only 2 inputs supported"); + } + auto pos_comma = equation.find(','); + auto pos_arrow = equation.find("->"); + std::string eq_a = Trim(equation.substr(0, pos_comma)); + std::string eq_b = + Trim(equation.substr(pos_comma + 1, pos_arrow - pos_comma - 1)); + std::string eq_out = Trim(equation.substr(pos_arrow + 2)); + auto has_dups = [](const auto& s) { + return std::set(s.begin(), s.end()).size() < s.size(); + }; + if (eq_a.empty() || eq_b.empty() || eq_out.empty() || + equation.find('.') != std::string::npos || has_dups(eq_a) || + has_dups(eq_b) || has_dups(eq_out)) { + throw Exception("Unsupportred equation: " + equation); + } + + pblczero::XlaDotDimensionNumbers dn; + std::string batch; + std::string non_contracting; + for (size_t i = 0; i < eq_a.size(); i++) { + auto j = eq_b.find(eq_a[i]); + auto k = eq_out.find(eq_a[i]); + if (j != std::string::npos && k != std::string::npos) { + // Batch dimension. + dn.add_lhs_batch_dimensions(i); + dn.add_rhs_batch_dimensions(j); + batch += eq_a[i]; + } else if (j != std::string::npos && k == std::string::npos) { + // Contracting dimension. + dn.add_lhs_contracting_dimensions(i); + dn.add_rhs_contracting_dimensions(j); + } else if (j == std::string::npos && k != std::string::npos) { + // Non contracting dimension. + non_contracting += eq_a[i]; + } else { + throw Exception("LHS dimension " + eq_a.substr(i, 1) + " not used"); + } + } + for (size_t j = 0; j < eq_b.size(); j++) { + auto i = eq_a.find(eq_b[j]); + auto k = eq_out.find(eq_b[j]); + if (i == std::string::npos && k != std::string::npos) { + // Non contracting dimension. + non_contracting += eq_b[j]; + } else if (i == std::string::npos && k == std::string::npos) { + throw Exception("RHS dimension " + eq_b.substr(j, 1) + " not used"); + } + } + std::string out = batch + non_contracting; + std::vector perm; + for (size_t k = 0; k < eq_out.size(); k++) { + auto l = out.find(eq_out[k]); + if (l != std::string::npos) { + perm.push_back(l); + } else { + throw Exception("Unknown output dimension " + eq_out.substr(k, 1)); + } + } + auto flow = builder_.Dot(inputs[0], inputs[1], dn); + return {builder_.Transpose(flow, perm)}; + } + + std::vector OpTile(const pblczero::NodeProto& node) { + CheckKnownAttributes(node, 2, {}); + auto* input = GetInput(node, 0); + const auto repeats = *GetConstantInputAsVec(node, 1); + std::vector shape; + HloTensorType input_shape(input->shape()); + if (input_shape.Rank() != repeats.size()) { + throw Exception("Incompatible shapes"); + } + for (size_t i = 0; i < repeats.size(); ++i) { + shape.push_back(input_shape.GetDimension(i) * repeats[i]); + } + return {DoBroadcast(input, shape)}; + } + + std::vector OpRound(const pblczero::NodeProto& node) { + CheckKnownAttributes(node, 1, {}); + auto* input = GetInput(node, 0); + return {builder_.RoundNearestEven(input)}; + } + + std::vector OpClip(const pblczero::NodeProto& node) { + CheckKnownAttributes(node, 3, {}); + auto* input = GetInput(node, 0); + auto* min = GetInput(node, 1); + auto* max = GetInput(node, 2); + HloTensorType shape(input->shape()); + min = builder_.Broadcast(min, shape, {}); + max = builder_.Broadcast(max, shape, {}); + return {builder_.Clamp(min, input, max)}; + } + + std::vector OpMatMulInteger(const pblczero::NodeProto& node) { + return MakeMatMul(node, pblczero::XlaShapeProto::S32); + } + + std::vector OpQuantizeLinear(const pblczero::NodeProto& node) { + CheckKnownAttributes(node, 3, {"saturate"}); + auto* input = GetInput(node, 0); + auto* scale = GetInput(node, 1); + auto* zero_point = GetInput(node, 2, true); + bool saturate = + GetOptionalAttributeAs(node, "saturate").value_or(true); + HloTensorType shape(input->shape()); + scale = builder_.Broadcast(scale, shape, {}); + auto flow = builder_.Divide(input, scale); + const auto in_type = input->shape().element_type(); + auto out_type = pblczero::XlaShapeProto::U8; + if (zero_point) { + out_type = zero_point->shape().element_type(); + zero_point = builder_.Convert(zero_point, in_type); + zero_point = builder_.Broadcast(zero_point, shape, {}); + flow = builder_.Add(flow, zero_point); + } + switch (out_type) { + case pblczero::XlaShapeProto::S8: { + auto* min = MakeScalar(-128, in_type); + auto* max = MakeScalar(127, in_type); + min = builder_.Broadcast(min, shape, {}); + max = builder_.Broadcast(max, shape, {}); + flow = builder_.Clamp(min, flow, max); + } break; + case pblczero::XlaShapeProto::U8: { + auto* min = MakeScalar(0, in_type); + auto* max = MakeScalar(255, in_type); + min = builder_.Broadcast(min, shape, {}); + max = builder_.Broadcast(max, shape, {}); + flow = builder_.Clamp(min, flow, max); + } break; + case pblczero::XlaShapeProto::F8E4M3FN: { + if (!saturate) break; + auto* min = MakeScalar(-448, in_type); + auto* max = MakeScalar(448, in_type); + min = builder_.Broadcast(min, shape, {}); + max = builder_.Broadcast(max, shape, {}); + flow = builder_.Clamp(min, flow, max); + } break; + default: + throw Exception("Unsupported quantization type: " + + pblczero::XlaShapeProto::Type_Name(out_type)); + } + return {builder_.Convert(flow, out_type)}; + } + + std::vector OpDequantizeLinear(const pblczero::NodeProto& node) { + CheckKnownAttributes(node, 3, {}); + auto* input = GetInput(node, 0); + auto* scale = GetInput(node, 1); + auto* zero_point = GetInput(node, 2, true); + const auto type = scale->shape().element_type(); + auto flow = builder_.Convert(input, type); + HloTensorType shape(flow->shape()); + if (zero_point) { + zero_point = builder_.Convert(zero_point, type); + zero_point = builder_.Broadcast(zero_point, shape, {}); + flow = builder_.Subtract(flow, zero_point); + } + scale = builder_.Broadcast(scale, shape, {}); + return {builder_.Multiply(flow, scale)}; + } + ///////////////////////////////////////////////////////////////////////////// // Helper computations ///////////////////////////////////////////////////////////////////////////// @@ -1557,6 +1768,12 @@ class Onnx2HloConverter { sizeof(f8e5m2)); literal.set_f8e5m2s(f8e5m2_view); } break; + case pblczero::XlaShapeProto::F8E4M3FN: { + uint8_t f8e4m3fn = FP32toFP8E4M3FN(value); + std::string_view f8e4m3fn_view(reinterpret_cast(&f8e4m3fn), + sizeof(f8e4m3fn)); + literal.set_f8e4m3fns(f8e4m3fn_view); + } break; case pblczero::XlaShapeProto::S32: literal.add_s32s(value); break; diff --git a/src/utils/protomessage.cc b/src/utils/protomessage.cc index 5433b2ed7a..a44a6c878a 100644 --- a/src/utils/protomessage.cc +++ b/src/utils/protomessage.cc @@ -39,7 +39,6 @@ uint64_t ReadFixed(const std::uint8_t** iter, size_t size, } void WriteFixed(uint64_t value, size_t size, std::string* out) { - out->reserve(out->size() + size); for (size_t i = 0; i < size; ++i) { out->push_back(static_cast(static_cast(value))); value /= 256;