diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index d13136d252d2a..37eb0e0edc67c 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -163,6 +163,69 @@ export class WebNNBackend { return id; } + // Register WebNN Constant operands from external data. + public registerMLConstant( + externalFilePath: string, + dataOffset: number, + dataLength: number, + builder: MLGraphBuilder, + desc: MLOperandDescriptor, + mountedFiles: Map | undefined, + ): MLOperand { + // If available, "Module.MountedFiles" is a Map for all preloaded files. + if (!mountedFiles) { + throw new Error('External mounted files are not available.'); + } + + let filePath = externalFilePath; + if (externalFilePath.startsWith('./')) { + filePath = externalFilePath.substring(2); + } + const fileData = mountedFiles.get(filePath); + if (!fileData) { + throw new Error(`File with name ${filePath} not found in preloaded files.`); + } + + if (dataOffset + dataLength > fileData.byteLength) { + throw new Error('Out of bounds: data offset and length exceed the external file data size.'); + } + + const buffer = fileData.slice(dataOffset, dataOffset + dataLength).buffer; + let bufferView: ArrayBufferView; + switch (desc.dataType) { + case 'float32': + bufferView = new Float32Array(buffer); + break; + case 'float16': + bufferView = new Uint16Array(buffer); + break; + case 'int32': + bufferView = new Int32Array(buffer); + break; + case 'uint32': + bufferView = new Uint32Array(buffer); + break; + case 'int64': + bufferView = new BigInt64Array(buffer); + break; + case 'uint64': + bufferView = new BigUint64Array(buffer); + break; + case 'int8': + bufferView = new Int8Array(buffer); + break; + case 'uint8': + bufferView = new Uint8Array(buffer); + break; + default: + throw new Error(`Unsupported data type: ${desc.dataType} in creating WebNN Constant from external data.`); + } + + LOG_DEBUG('verbose', () => `[WebNN] registerMLConstant {dataType: ${desc.dataType}, shape: ${desc.shape}}}`); + + return builder.constant(desc, bufferView); + } + public flush(): void { // Unlike the WebGPU backend, the WebNN backend does not need to flush any pending operations. } diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 74c359881a1d7..2af9f95ad059e 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -165,37 +165,6 @@ Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Int4x2) DEFINE_INT4_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt4x2) -static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, - const std::filesystem::path& tensor_proto_dir, - std::basic_string& external_file_path, - onnxruntime::FileOffsetType& file_offset, - SafeInt& tensor_byte_size) { - ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto), - "Tensor does not have external data to read from."); - - ORT_RETURN_IF(!onnxruntime::utils::HasDataType(tensor_proto) || onnxruntime::utils::HasString(tensor_proto), - "External data type cannot be UNDEFINED or STRING."); - - std::unique_ptr external_data_info; - ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info)); - - const auto& location = external_data_info->GetRelPath(); - - external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location) - : (tensor_proto_dir / location); - - ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size)); - const size_t external_data_length = external_data_info->GetLength(); - ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size, - "TensorProto: ", tensor_proto.name(), - " external data size mismatch. Computed size: ", *&tensor_byte_size, - ", external_data.length: ", external_data_length); - - file_offset = external_data_info->GetOffset(); - - return Status::OK(); -} - // Read external data for tensor in unint8_t* form and return Status::OK() if the data is read successfully. // Uses the tensor_proto_dir to construct the full path for external data. If tensor_proto_dir == nullptr // then uses the current directory instead. @@ -261,6 +230,37 @@ Status TensorProtoToOrtValueImpl(const Env& env, const std::filesystem::path& mo namespace utils { +Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& tensor_proto_dir, + std::basic_string& external_file_path, + onnxruntime::FileOffsetType& file_offset, + SafeInt& tensor_byte_size) { + ORT_RETURN_IF_NOT(onnxruntime::utils::HasExternalData(tensor_proto), + "Tensor does not have external data to read from."); + + ORT_RETURN_IF(!onnxruntime::utils::HasDataType(tensor_proto) || onnxruntime::utils::HasString(tensor_proto), + "External data type cannot be UNDEFINED or STRING."); + + std::unique_ptr external_data_info; + ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info)); + + const auto& location = external_data_info->GetRelPath(); + + external_file_path = location == onnxruntime::utils::kTensorProtoMemoryAddressTag ? std::filesystem::path(location) + : (tensor_proto_dir / location); + + ORT_RETURN_IF_ERROR(onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &tensor_byte_size)); + const size_t external_data_length = external_data_info->GetLength(); + ORT_RETURN_IF_NOT(external_data_length == 0 || external_data_length == tensor_byte_size, + "TensorProto: ", tensor_proto.name(), + " external data size mismatch. Computed size: ", *&tensor_byte_size, + ", external_data.length: ", external_data_length); + + file_offset = external_data_info->GetOffset(); + + return Status::OK(); +} + void SetRawDataInTensorProto(ONNX_NAMESPACE::TensorProto& tensor_proto, std::string&& param) { tensor_proto.set_raw_data(std::move(param)); } diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 227ba0706197e..262f7adaca1cb 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -23,6 +23,20 @@ namespace onnxruntime { namespace utils { +/** + * This function is used to get the external data info from the given tensor proto. + * @param tensor_proto given initializer tensor + * @param tensor_proto_dir directory of the tensor proto file + * @param external_file_path output external file path + * @param file_offset output tensor offset + * @param tensor_byte_size output tensor byte size + * @returns Status::OK() if the function is executed successfully + */ +Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const std::filesystem::path& tensor_proto_dir, + std::basic_string& external_file_path, + onnxruntime::FileOffsetType& file_offset, + SafeInt& tensor_byte_size); /** * This function is used to convert the endianess of Tensor data. * Mostly, will be used in big endian system to support the model file diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index 8da255a288f17..fffe964e6aaf2 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -12,27 +12,6 @@ namespace onnxruntime { namespace webnn { - -// Shared functions. -bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node, - const logging::Logger& logger) { - for (const auto* node_arg : node.InputDefs()) { - const auto& input_name(node_arg->Name()); - if (!Contains(initializers, input_name)) - continue; - - const auto& tensor = *initializers.at(input_name); - if (tensor.has_data_location() && - tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { - LOGS(logger, VERBOSE) << "Initializer [" << input_name - << "] with external data location are not currently supported"; - return true; - } - } - - return false; -} - // Add operator related. Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node, @@ -58,10 +37,6 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons if (!HasSupportedOutputsImpl(node, wnn_limits, logger)) return false; - // We do not support external initializers for now. - if (HasExternalInitializer(initializers, node, logger)) - return false; - if (!HasSupportedOpSet(node, logger)) return false; diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 044baa738e8c4..8a7fea0cde431 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -112,56 +112,73 @@ Status ModelBuilder::RegisterInitializers() { auto num_elements = SafeInt(Product(shape)); emscripten::val view = emscripten::val::undefined(); std::byte* tensor_ptr = nullptr; - if (tensor.has_raw_data()) { - tensor_ptr = reinterpret_cast(const_cast(tensor.raw_data().c_str())); + + if (utils::HasExternalData(tensor)) { + // Create WebNN Constant from external data. + std::basic_string external_file_path; + onnxruntime::FileOffsetType data_offset; + SafeInt tensor_byte_size; + ORT_RETURN_IF_ERROR(utils::GetExternalDataInfo( + tensor, graph_viewer_.ModelPath(), external_file_path, data_offset, tensor_byte_size)); + + auto jsepRegisterMLConstant = emscripten::val::module_property("jsepRegisterMLConstant"); + operand = jsepRegisterMLConstant(emscripten::val(external_file_path), + static_cast(data_offset), + static_cast(tensor_byte_size), + wnn_builder_, + desc); } else { - // Store temporary unpacked_tensor. - unpacked_tensors_.push_back({}); - std::vector& unpacked_tensor = unpacked_tensors_.back(); - ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor)); - tensor_ptr = reinterpret_cast(unpacked_tensor.data()); - } - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_BOOL: - case ONNX_NAMESPACE::TensorProto_DataType_UINT8: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT8: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT32: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - case ONNX_NAMESPACE::TensorProto_DataType_UINT64: - view = emscripten::val{emscripten::typed_memory_view(num_elements, - reinterpret_cast(tensor_ptr))}; - break; - default: - break; + if (tensor.has_raw_data()) { + tensor_ptr = reinterpret_cast(const_cast(tensor.raw_data().c_str())); + } else { + // Store temporary unpacked_tensor. + unpacked_tensors_.push_back({}); + std::vector& unpacked_tensor = unpacked_tensors_.back(); + ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor)); + tensor_ptr = reinterpret_cast(unpacked_tensor.data()); + } + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: + view = emscripten::val{emscripten::typed_memory_view(num_elements, + reinterpret_cast(tensor_ptr))}; + break; + default: + break; + } + + // Wasm memory grow will cause all array buffers reallocation, which will be treated as detached + // buffers in JS side. Simply create a copy to fix it. + operand = wnn_builder_.call("constant", desc, view.call("slice")); } - - // Wasm memory grow will cause all array buffers reallocation, which will be treated as detached - // buffers in JS side. Simply create a copy to fix it. - operand = wnn_builder_.call("constant", desc, view.call("slice")); } else { // TODO: support other type. return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 68332d07a9782..78d60326dd0a8 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -235,5 +235,10 @@ Module['jsepInit'] = (name, params) => { Module['jsepRegisterMLTensor'] = (tensor, dataType, shape) => { return backend['registerMLTensor'](tensor, dataType, shape); } + + Module.jsepRegisterMLConstant = (externalFilePath, dataOffset, dataLength, builder, desc) => { + return backend['registerMLConstant']( + externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles); + } } };