-
Notifications
You must be signed in to change notification settings - Fork 616
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[tuner] add an iree-opt pass to strip configuration from executable s…
…ources (#19069) This PR aims to address the first task in nod-ai/shark-ai#453: adding an iree-opt pass that removes configuration from executable sources. The corresponding test is also included to ensure its correct functionality. --------- Signed-off-by: Bangtian Liu <[email protected]>
- Loading branch information
1 parent
300e0c3
commit 55b998a
Showing
7 changed files
with
142 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
compiler/src/iree/compiler/Codegen/Common/StripCompilationInfoPass.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
// Copyright 2024 The IREE Authors | ||
// | ||
// Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
||
#include "iree/compiler/Codegen/Common/Passes.h" | ||
#include "mlir/Transforms/WalkPatternRewriteDriver.h" | ||
|
||
#define DEBUG_TYPE "iree-codegen-strip-compilation-info" | ||
|
||
namespace mlir::iree_compiler { | ||
|
||
#define GEN_PASS_DEF_STRIPCOMPILATIONINFOPASS | ||
#include "iree/compiler/Codegen/Common/Passes.h.inc" | ||
|
||
namespace { | ||
|
||
struct StripFuncOpTranslationInfo final | ||
: OpInterfaceRewritePattern<mlir::FunctionOpInterface> { | ||
using OpInterfaceRewritePattern::OpInterfaceRewritePattern; | ||
LogicalResult matchAndRewrite(mlir::FunctionOpInterface funcOp, | ||
PatternRewriter &rewriter) const final { | ||
if (!getTranslationInfo(funcOp)) | ||
return failure(); | ||
|
||
rewriter.modifyOpInPlace(funcOp, [&]() { | ||
// If the function has translation info, erase it. | ||
eraseTranslationInfo(funcOp); | ||
}); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
struct StripLinalgOpCompilationInfo final | ||
: OpInterfaceRewritePattern<linalg::LinalgOp> { | ||
using OpInterfaceRewritePattern::OpInterfaceRewritePattern; | ||
LogicalResult matchAndRewrite(linalg::LinalgOp linalgOp, | ||
PatternRewriter &rewriter) const final { | ||
if (!getCompilationInfo(linalgOp) && !getLoweringConfig(linalgOp)) | ||
return failure(); | ||
|
||
rewriter.modifyOpInPlace(linalgOp, [&]() { | ||
if (getCompilationInfo(linalgOp)) { | ||
// Erase the compilation info configuration if it exists. | ||
eraseCompilationInfo(linalgOp); | ||
} | ||
if (getLoweringConfig(linalgOp)) { | ||
// Erase the lowering configuration from root operation if it | ||
// exists. | ||
eraseLoweringConfig(linalgOp); | ||
} | ||
}); | ||
|
||
return success(); | ||
} | ||
}; | ||
|
||
struct StripCompilationInfoPass final | ||
: impl::StripCompilationInfoPassBase<StripCompilationInfoPass> { | ||
void runOnOperation() override { | ||
MLIRContext *ctx = &getContext(); | ||
RewritePatternSet patterns(ctx); | ||
patterns.add<StripFuncOpTranslationInfo>(ctx); | ||
patterns.add<StripLinalgOpCompilationInfo>(ctx); | ||
walkAndApplyPatterns(getOperation(), std::move(patterns)); | ||
} | ||
}; | ||
} // namespace | ||
} // namespace mlir::iree_compiler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
62 changes: 62 additions & 0 deletions
62
compiler/src/iree/compiler/Codegen/Common/test/strip_compilation_info.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
// RUN: iree-opt --split-input-file --iree-codegen-strip-compilation-info %s | FileCheck %s | ||
|
||
#translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute workgroup_size = [64, 1, 1] subgroup_size = 64> | ||
func.func @main() attributes {translation_info = #translation_info} { | ||
return | ||
} | ||
|
||
// CHECK-LABEL: func.func @main | ||
// CHECK-NOT: iree_codegen.translation_info | ||
// CHECK-NOT: LLVMGPUVectorDistribute | ||
|
||
// ----- | ||
|
||
#pipeline_layout = #hal.pipeline.layout<bindings = [ | ||
#hal.pipeline.binding<storage_buffer> | ||
]> | ||
hal.executable private @strip_main { | ||
hal.executable.variant public @strip_main target(#hal.executable.target<"", "", {}>) { | ||
hal.executable.export public @entry_point layout(#pipeline_layout) | ||
builtin.module { | ||
func.func @fn1() attributes {translation_info = #iree_codegen.translation_info<None subgroup_size = 32>} { | ||
return | ||
} | ||
func.func @fn2() attributes {translation_info = #iree_codegen.translation_info<None subgroup_size = 32>} { | ||
return | ||
} | ||
} | ||
} | ||
} | ||
|
||
// CHECK-LABEL: hal.executable private @strip_main | ||
// CHECK: @fn1 | ||
// CHECK-NOT: translation_info = | ||
// CHECK: @fn2 | ||
// CHECK-NOT: translation_info = | ||
// CHECK: return | ||
|
||
// ----- | ||
|
||
#config = #iree_codegen.lowering_config<tile_sizes = [[128, 256], [16, 16]]> | ||
#translation = #iree_codegen.translation_info<None workgroup_size = [16, 8, 1] subgroup_size = 64> | ||
#compilation = #iree_codegen.compilation_info<lowering_config = #config, translation_info = #translation> | ||
func.func @matmul_128x1024x256(%lhs : tensor<128x256xf32>, %rhs: tensor<256x1024xf32>, %init: tensor<128x1024xf32>) -> tensor<128x1024xf32> { | ||
%result = linalg.matmul {compilation_info = #compilation} ins(%lhs, %rhs : tensor<128x256xf32>, tensor<256x1024xf32>) outs(%init : tensor<128x1024xf32>) -> tensor<128x1024xf32> | ||
return %result : tensor<128x1024xf32> | ||
} | ||
|
||
// CHECK-LABEL: func.func @matmul_128x1024x256 | ||
// CHECK-NOT: iree_codegen.translation_info | ||
// CHECK-NOT: iree_codegen.lowering_config | ||
// CHECK-NOT: iree_codegen.compilation_info | ||
|
||
// ----- | ||
|
||
#config = #iree_codegen.lowering_config<tile_sizes = [[128, 256], [16, 16]]> | ||
func.func @matmul_128x1024x256_1(%lhs : tensor<128x256xf32>, %rhs: tensor<256x1024xf32>, %init: tensor<128x1024xf32>) -> tensor<128x1024xf32> { | ||
%result = linalg.matmul {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[128, 256], [16, 16]]>} ins(%lhs, %rhs : tensor<128x256xf32>, tensor<256x1024xf32>) outs(%init : tensor<128x1024xf32>) -> tensor<128x1024xf32> | ||
return %result : tensor<128x1024xf32> | ||
} | ||
|
||
// CHECK-LABEL: func.func @matmul_128x1024x256_1 | ||
// CHECK-NOT: iree_codegen.lowering_config |