Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] Add GatherND #22847

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ Do not modify directly.*
| Gather | ai.onnx(1-10,11-12,13+) | |
| GatherBlockQuantized | com.microsoft(1+) | |
| GatherElements | ai.onnx(11-12,13+) | |
| GatherND | ai.onnx(11-12,13+) | |
| Gelu | ai.onnx(20+); com.microsoft(1+) | |
| Gemm | ai.onnx(7-8,9-10,11-12,13+) | |
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
Expand Down
2 changes: 2 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { einsum, parseEinsumAttributes } from './ops/einsum';
import { expand } from './ops/expand';
import { fastGelu } from './ops/fast-gelu';
import { gather, parseGatherAttributes } from './ops/gather';
import { gatherND, parseGatherNDAttributes } from './ops/gather-nd';
import { gatherBlockQuantized, parseGatherBlockQuantizedAttributes } from './ops/gather-block-quantized';
import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements';
import { gemm, parseGemmAttributes } from './ops/gemm';
Expand Down Expand Up @@ -100,6 +101,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Gather', [gather, parseGatherAttributes]],
['GatherElements', [gatherElements, parseGatherElementsAttributes]],
['GatherBlockQuantized', [gatherBlockQuantized, parseGatherBlockQuantizedAttributes]],
['GatherND', [gatherND, parseGatherNDAttributes]],
['Gelu', [unaryOps.gelu]],
['Gemm', [gemm, parseGemmAttributes]],
['GlobalAveragePool', [pool.globalAveragePool, pool.parseGlobalAveragePoolAttributes]],
Expand Down
186 changes: 186 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/gather-nd.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import { DataType } from '../../../wasm-common';
import { TensorView } from '../../tensor-view';
import { ShapeUtil } from '../../util';
import { AttributeWithCacheKey } from '../attribute-with-cache-key';
import { ComputeContext, ProgramInputTensorInfoDependency, ProgramUniform } from '../types';

import { createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType } from './common';

export interface GatherNDAttributes extends AttributeWithCacheKey {
readonly batchDims: number;
}

const computeSliceOffsetsKernel = (
context: ComputeContext,
indicesData: TensorView,
sizesFromSliceDimsData: number[],
batchDims: number,
inputDims: readonly number[],
numSlices: number,
numSlicesPerBatch: number,
inputBatchStride: number,
numSliceDims: number,
) => {
const inputDependencies: ProgramInputTensorInfoDependency[] = ['type'];
const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: numSlices },
{ type: DataType.uint32, data: batchDims },
{ type: DataType.uint32, data: inputDims },
{ type: DataType.uint32, data: sizesFromSliceDimsData },
{ type: DataType.uint32, data: numSlicesPerBatch },
{ type: DataType.uint32, data: inputBatchStride },
{ type: DataType.uint32, data: numSliceDims },
];

const outputShape = [numSlices];
programUniforms.push(...createTensorShapeVariables(indicesData.dims, outputShape));

const getShaderSource = (shaderHelper: ShaderHelper) => {
const indices = inputVariable('indices_data', indicesData.dataType, indicesData.dims.length);
const output = outputVariable('input_slice_offsets_data', indicesData.dataType, 1, 1);
const variables = [indices, output];
const uniforms: UniformsArrayType = [
{ name: 'output_size', type: 'u32' },
{ name: 'batch_dims', type: 'u32' },
{ name: 'input_dims', type: 'u32', length: inputDims.length },
{ name: 'sizes_from_slice_dims_data', type: 'u32', length: sizesFromSliceDimsData.length },
{ name: 'num_slices_per_batch', type: 'u32' },
{ name: 'input_batch_stride', type: 'u32' },
{ name: 'num_slice_dims', type: 'u32' },
];
return `
${shaderHelper.registerUniforms(uniforms).declareVariables(...variables)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
let batch_idx = global_idx / uniforms.num_slices_per_batch;
let base_offset = batch_idx * uniforms.input_batch_stride;

let slice_indices_base_offset = global_idx * uniforms.num_slice_dims;
var relative_slice_offset = 0;
for (var dim_idx = 0u; dim_idx < uniforms.num_slice_dims; dim_idx ++) {
var index = i32(indices_data[dim_idx + slice_indices_base_offset].x);
let input_dim_idx = uniforms.batch_dims + dim_idx;
if (index < 0) {
${
inputDims.length === 1
? 'index += i32(uniforms.input_dims);'
: 'index += i32(uniforms.input_dims[input_dim_idx]);'
}
}
${
sizesFromSliceDimsData.length === 1
? 'relative_slice_offset += index * i32(uniforms.sizes_from_slice_dims_data);'
: 'relative_slice_offset += index * i32(uniforms.sizes_from_slice_dims_data[dim_idx]);'
}
}

input_slice_offsets_data[global_idx].x = base_offset + u32(relative_slice_offset);
}`;
};

return context.compute(
{
name: 'computeSliceOffsetsKernel',
shaderCache: { hint: `${inputDims.length === 1}_${sizesFromSliceDimsData.length === 1}`, inputDependencies },
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: context.inputs[1].dataType }],
dispatchGroup: { x: Math.ceil(numSlices / 64) },
programUniforms,
}),
getShaderSource,
},
{ inputs: [indicesData], outputs: [-1] },
)[0];
};

const createGatherNDProgramInfo = (context: ComputeContext, attributes: GatherNDAttributes) => {
const inputs = context.inputs;
const inputShape = inputs[0].dims;
const inputType = inputs[0].dataType;
const indicesShape = inputs[1].dims;
const numSliceDims = indicesShape[indicesShape.length - 1];
const numSlices = ShapeUtil.sizeToDimension(indicesShape, indicesShape.length - 1);
const sliceSize = ShapeUtil.sizeFromDimension(inputShape, attributes.batchDims + numSliceDims);
const numBatches = ShapeUtil.sizeToDimension(inputShape, attributes.batchDims);
const inputBatchStride = ShapeUtil.sizeFromDimension(inputShape, attributes.batchDims);
const numSlicesPerBatch = numSlices / numBatches;
const sizesFromSliceDims = new Array(numSliceDims);
{
let runningProduct = sliceSize;
for (let i = 0; i < numSliceDims; ++i) {
sizesFromSliceDims[numSliceDims - 1 - i] = runningProduct;
runningProduct *= inputShape[attributes.batchDims + numSliceDims - 1 - i];
}
}

const inputSliceOffsets = computeSliceOffsetsKernel(
context,
inputs[1],
sizesFromSliceDims,
attributes.batchDims,
inputShape,
numSlices,
numSlicesPerBatch,
inputBatchStride,
numSliceDims,
);

const lastIndicesDimension = attributes.batchDims + numSliceDims;
if (lastIndicesDimension > inputShape.length) {
throw new Error('last dimension of indices must not be larger than rank of input tensor');
}

const outputShape = indicesShape.slice(0, -1).concat(inputShape.slice(lastIndicesDimension));
const outputSize = ShapeUtil.size(outputShape);

const programUniforms: ProgramUniform[] = [
{ type: DataType.uint32, data: outputSize },
{ type: DataType.uint32, data: sliceSize },
...createTensorShapeVariables(inputs[0].dims, inputSliceOffsets.dims, outputShape),
];

const getShaderSource = (shaderHelper: ShaderHelper) => {
const input = inputVariable('data', inputs[0].dataType, inputs[0].dims.length);
const indices = inputVariable('slice_offsets', inputSliceOffsets.dataType, inputSliceOffsets.dims.length);

const output = outputVariable('output', inputs[0].dataType, outputShape.length);
return `
${shaderHelper
.registerUniform('output_size', 'u32')
.registerUniform('slice_size', 'u32')
.declareVariables(input, indices, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
let slice_offset = slice_offsets[global_idx / uniforms.slice_size].x;
output[global_idx] = data[u32(slice_offset) + global_idx % uniforms.slice_size];
}`;
};
context.compute(
{
name: 'GatherND',
shaderCache: { hint: attributes.cacheKey, inputDependencies: ['rank', 'rank'] },
getRunData: () => ({
outputs: [{ dims: outputShape, dataType: inputType }],
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
programUniforms,
}),
getShaderSource,
},
{ inputs: [inputs[0], inputSliceOffsets] },
);
};

export const gatherND = (context: ComputeContext, attributes: GatherNDAttributes): void => {
createGatherNDProgramInfo(context, attributes);
};

export const parseGatherNDAttributes = (attributes: Record<string, unknown>): GatherNDAttributes => {
const batchDims = attributes.batch_dims as number;
return {
batchDims,
cacheKey: ``,
};
};
147 changes: 147 additions & 0 deletions js/web/test/data/ops/gathernd.jsonc
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
[
{
"name": "GatherND int32",
"operator": "GatherND",
"attributes": [],
"cases": [
{
"name": "data[4] indices[]",
"inputs": [
{
"data": [100, 101, 102, 777, 778, 779, 1000, 1001, 1002],
"dims": [9],
"type": "int32"
},
{
"data": [0, 4, 8],
"dims": [3, 1],
"type": "int64"
}
],
"outputs": [
{
"data": [100, 778, 1002],
"dims": [3],
"type": "int32"
}
]
}
]
},
{
"name": "GatherND float32",
"operator": "GatherND",
"attributes": [],
"cases": [
{
"name": "data[4] indices[]",
"inputs": [
{
"data": [100.1, 101.2, 102.3, 777.4, 778.5, 779.6, 1000.7, 1001.8, 1002.9],
"dims": [9],
"type": "float32"
},
{
"data": [0, 4, 8],
"dims": [3, 1],
"type": "int64"
}
],
"outputs": [
{
"data": [100.0999984741211, 778.5, 1002.9000244140625],
"dims": [3],
"type": "float32"
}
]
}
]
},
{
"name": "GatherND int32 [2 2 2], batch_dims",
"operator": "GatherND",
"attributes": [{ "name": "batch_dims", "data": 1, "type": "int" }],
"cases": [
{
"name": "data[4] indices[]",
"inputs": [
{
"data": [0, 1, 2, 3, 4, 5, 6, 7],
"dims": [2, 2, 2],
"type": "int32"
},
{
"data": [1, 0],
"dims": [2, 1],
"type": "int64"
}
],
"outputs": [
{
"data": [2, 3, 4, 5],
"dims": [2, 2],
"type": "int32"
}
]
}
]
},
{
"name": "GatherND float16",
"operator": "GatherND",
"attributes": [],
"cases": [
{
"name": "data[4] indices[]",
"inputs": [
{
"data": [100.1, 101.2, 102.3, 777.4, 778.5, 779.6, 1000.7, 1001.8, 1002.9],
"dims": [9],
"type": "float16"
},
{
"data": [0, 4, 8],
"dims": [3, 1],
"type": "int64"
}
],
"outputs": [
{
"data": [100.0999984741211, 778.5, 1002.9000244140625],
"dims": [3],
"type": "float16"
}
]
}
]
},
{
"name": "GatherND uint32 [2 2 2], batch_dims",
"operator": "GatherND",
"attributes": [{ "name": "batch_dims", "data": 1, "type": "int" }],
"cases": [
{
"name": "data[4] indices[]",
"inputs": [
{
"data": [0, 1, 2, 3, 4, 5, 6, 7],
"dims": [2, 2, 2],
"type": "uint32"
},
{
"data": [1, 0],
"dims": [2, 1],
"type": "int64"
}
],
"outputs": [
{
"data": [2, 3, 4, 5],
"dims": [2, 2],
"type": "uint32"
}
]
}
]
}
]
1 change: 1 addition & 0 deletions js/web/test/suite-test-list.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -1365,6 +1365,7 @@
"gather.jsonc",
"gather-block-quantized.jsonc",
"gather-elements.jsonc",
"gathernd.jsonc",
"gemm.jsonc",
"global-average-pool.jsonc",
"greater.jsonc",
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/providers/js/js_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gat
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, GatherElements);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherElements);

class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, GatherND);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, GatherND);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherND);

class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 9, Slice);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Slice);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Slice);
Expand Down Expand Up @@ -667,6 +671,10 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, GatherElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherElements)>,

BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, GatherND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, GatherND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, GatherND)>,

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Resize)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, Resize)>,
Expand Down
Loading