Skip to content

Commit

Permalink
[tuner] add an iree-opt pass to strip configuration from executable s…
Browse files Browse the repository at this point in the history
…ources (#19069)

This PR aims to address the first task in
nod-ai/shark-ai#453: adding an iree-opt
pass that removes configuration from executable sources. The
corresponding test is also included to ensure its correct functionality.

---------

Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu authored Nov 11, 2024
1 parent 300e0c3 commit 55b998a
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ iree_compiler_cc_library(
"RemoveSingleIterationLoop.cpp",
"ReplaceSlowMinMaxOps.cpp",
"SplitFullPartialTransferPass.cpp",
"StripCompilationInfoPass.cpp",
"TensorDynamicDimAnalysis.cpp",
"TensorToVectorVectorizePad.cpp",
"TestExecutablePreprocessing.cpp",
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ iree_cc_library(
"RemoveSingleIterationLoop.cpp"
"ReplaceSlowMinMaxOps.cpp"
"SplitFullPartialTransferPass.cpp"
"StripCompilationInfoPass.cpp"
"TensorDynamicDimAnalysis.cpp"
"TensorToVectorVectorizePad.cpp"
"TestExecutablePreprocessing.cpp"
Expand Down
5 changes: 5 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,11 @@ def RemoveSingleIterationLoopPass :
let summary = "Remove distributed loop with single iteration.";
}

def StripCompilationInfoPass :
Pass<"iree-codegen-strip-compilation-info", "">{
let summary = "Remove all the the lowering configuration and translation info attributes.";
}

// TODO: Replace with upstream: https://github.com/iree-org/iree/issues/18759
def IREELoopInvariantCodeMotionPass :
Pass<"iree-loop-invariant-code-motion", ""> {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright 2024 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/Passes.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-codegen-strip-compilation-info"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_STRIPCOMPILATIONINFOPASS
#include "iree/compiler/Codegen/Common/Passes.h.inc"

namespace {

struct StripFuncOpTranslationInfo final
: OpInterfaceRewritePattern<mlir::FunctionOpInterface> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(mlir::FunctionOpInterface funcOp,
PatternRewriter &rewriter) const final {
if (!getTranslationInfo(funcOp))
return failure();

rewriter.modifyOpInPlace(funcOp, [&]() {
// If the function has translation info, erase it.
eraseTranslationInfo(funcOp);
});

return success();
}
};

struct StripLinalgOpCompilationInfo final
: OpInterfaceRewritePattern<linalg::LinalgOp> {
using OpInterfaceRewritePattern::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp,
PatternRewriter &rewriter) const final {
if (!getCompilationInfo(linalgOp) && !getLoweringConfig(linalgOp))
return failure();

rewriter.modifyOpInPlace(linalgOp, [&]() {
if (getCompilationInfo(linalgOp)) {
// Erase the compilation info configuration if it exists.
eraseCompilationInfo(linalgOp);
}
if (getLoweringConfig(linalgOp)) {
// Erase the lowering configuration from root operation if it
// exists.
eraseLoweringConfig(linalgOp);
}
});

return success();
}
};

struct StripCompilationInfoPass final
: impl::StripCompilationInfoPassBase<StripCompilationInfoPass> {
void runOnOperation() override {
MLIRContext *ctx = &getContext();
RewritePatternSet patterns(ctx);
patterns.add<StripFuncOpTranslationInfo>(ctx);
patterns.add<StripLinalgOpCompilationInfo>(ctx);
walkAndApplyPatterns(getOperation(), std::move(patterns));
}
};
} // namespace
} // namespace mlir::iree_compiler
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/test/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ iree_lit_test_suite(
"remove_trivial_loops.mlir",
"repeated_matcher_use.mlir",
"replace_slow_min_max_ops.mlir",
"strip_compilation_info.mlir",
"test_partitionable_loops_interface.mlir",
"tile_and_distribute_to_workgroups_func_scope.mlir",
"tile_and_distribute_to_workgroups.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ iree_lit_test_suite(
"remove_trivial_loops.mlir"
"repeated_matcher_use.mlir"
"replace_slow_min_max_ops.mlir"
"strip_compilation_info.mlir"
"test_partitionable_loops_interface.mlir"
"tile_and_distribute_to_workgroups.mlir"
"tile_and_distribute_to_workgroups_func_scope.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// RUN: iree-opt --split-input-file --iree-codegen-strip-compilation-info %s | FileCheck %s

#translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64, 1, 1] subgroup_size = 64>
func.func @main() attributes {translation_info = #translation_info} {
return
}

// CHECK-LABEL: func.func @main
// CHECK-NOT: iree_codegen.translation_info
// CHECK-NOT: LLVMGPUVectorDistribute

// -----

#pipeline_layout = #hal.pipeline.layout<bindings = [
#hal.pipeline.binding<storage_buffer>
]>
hal.executable private @strip_main {
hal.executable.variant public @strip_main target(#hal.executable.target<"", "", {}>) {
hal.executable.export public @entry_point layout(#pipeline_layout)
builtin.module {
func.func @fn1() attributes {translation_info = #iree_codegen.translation_info<None subgroup_size = 32>} {
return
}
func.func @fn2() attributes {translation_info = #iree_codegen.translation_info<None subgroup_size = 32>} {
return
}
}
}
}

// CHECK-LABEL: hal.executable private @strip_main
// CHECK: @fn1
// CHECK-NOT: translation_info =
// CHECK: @fn2
// CHECK-NOT: translation_info =
// CHECK: return

// -----

#config = #iree_codegen.lowering_config<tile_sizes = [[128, 256], [16, 16]]>
#translation = #iree_codegen.translation_info<None workgroup_size = [16, 8, 1] subgroup_size = 64>
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation>
func.func @matmul_128x1024x256(%lhs : tensor<128x256xf32>, %rhs: tensor<256x1024xf32>, %init: tensor<128x1024xf32>) -> tensor<128x1024xf32> {
%result = linalg.matmul {compilation_info = #compilation} ins(%lhs, %rhs : tensor<128x256xf32>, tensor<256x1024xf32>) outs(%init : tensor<128x1024xf32>) -> tensor<128x1024xf32>
return %result : tensor<128x1024xf32>
}

// CHECK-LABEL: func.func @matmul_128x1024x256
// CHECK-NOT: iree_codegen.translation_info
// CHECK-NOT: iree_codegen.lowering_config
// CHECK-NOT: iree_codegen.compilation_info

// -----

#config = #iree_codegen.lowering_config<tile_sizes = [[128, 256], [16, 16]]>
func.func @matmul_128x1024x256_1(%lhs : tensor<128x256xf32>, %rhs: tensor<256x1024xf32>, %init: tensor<128x1024xf32>) -> tensor<128x1024xf32> {
%result = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[128, 256], [16, 16]]>} ins(%lhs, %rhs : tensor<128x256xf32>, tensor<256x1024xf32>) outs(%init : tensor<128x1024xf32>) -> tensor<128x1024xf32>
return %result : tensor<128x1024xf32>
}

// CHECK-LABEL: func.func @matmul_128x1024x256_1
// CHECK-NOT: iree_codegen.lowering_config

0 comments on commit 55b998a

Please sign in to comment.