Skip to content

Commit

Permalink
Merge branch 'develop' into RUF005-1
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo committed Aug 10, 2024
2 parents 24e6b04 + 2e4b1f4 commit e1f19c6
Show file tree
Hide file tree
Showing 1,313 changed files with 19,573 additions and 8,465 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ repos:
paddle/cinn/utils/registry.h
)$
# For Python files
- repo: https://github.com/psf/black.git
rev: 23.3.0
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.8.0
hooks:
- id: black
- repo: https://github.com/astral-sh/ruff-pre-commit
Expand Down
2 changes: 2 additions & 0 deletions cmake/cinn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ set(PUBLISH_LIBS ON)
if(PUBLISH_LIBS)
set(core_includes
"${core_includes};paddle/cinn/runtime/cuda/cinn_cuda_runtime_source.cuh")
set(core_includes
"${core_includes};paddle/cinn/runtime/hip/cinn_hip_runtime_source.h")
set(core_includes
"${core_includes};paddle/common/flags.h;paddle/utils/test_macros.h")
foreach(header ${core_includes})
Expand Down
1 change: 0 additions & 1 deletion cmake/external/json.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ set(JSON_INCLUDE_DIR ${JSON_PREFIX_DIR}/include)
set(SOURCE_DIR ${PADDLE_SOURCE_DIR}/third_party/nlohmann_json)
set(SOURCE_INCLUDE_DIR ${SOURCE_DIR}/include)

include_directories(${JSON_INCLUDE_DIR})
include_directories(${SOURCE_INCLUDE_DIR})

set(JSON_BuildTests
Expand Down
10 changes: 5 additions & 5 deletions cmake/phi.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ function(kernel_declare TARGET_LIST)
"(PD_REGISTER_KERNEL|PD_REGISTER_KERNEL_FOR_ALL_DTYPE|PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE|PD_REGISTER_KERNEL_FOR_ALL_BACKEND_DTYPE_EXCEPT_CUSTOM)\\([ \t\r\n]*[a-z0-9_]*,[[ \\\t\r\n\/]*[a-z0-9_]*]?[ \\\t\r\n]*[a-zA-Z_]*,[ \\\t\r\n]*[A-Z_]*"
first_registry
"${kernel_impl}")
if(DEFINED REDUCE_INFERENCE_LIB_SIZE)
if("${first_registry}" MATCHES ".*_grad,.*")
continue()
endif()
endif()
set(kernel_declare_id "")
while(NOT first_registry STREQUAL "")
string(REPLACE "${first_registry}" "" kernel_impl "${kernel_impl}")
Expand Down Expand Up @@ -162,11 +167,6 @@ function(kernel_declare TARGET_LIST)
endwhile()
# append kernel declare into declarations.h
if(NOT kernel_declare_id STREQUAL "")
if(DEFINED REDUCE_INFERENCE_LIB_SIZE)
if(${kernel_declare_id} MATCHES ".*_grad,.*")
continue()
endif()
endif()
file(APPEND ${kernel_declare_file} "${kernel_declare_id}\n")
endif()
endforeach()
Expand Down
10 changes: 8 additions & 2 deletions paddle/cinn/adt/equation_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,18 @@ EquationGraphTopoWalker<VT, FT> GetSubgraph(
};
const auto& VisitInputVariables =
[graph, IsSelected](FT function, const std::function<void(VT)>& Visit) {
CHECK(IsSelected(function));
PADDLE_ENFORCE_EQ(
IsSelected(function),
true,
phi::errors::PreconditionNotMet("The function must be selected."));
graph.VisitInputVariables(function, Visit);
};
const auto& VisitOutputVariables =
[graph, IsSelected](FT function, const std::function<void(VT)>& Visit) {
CHECK(IsSelected(function));
PADDLE_ENFORCE_EQ(
IsSelected(function),
true,
phi::errors::PreconditionNotMet("The function must be selected."));
graph.VisitOutputVariables(function, Visit);
};
return EquationGraphTopoWalker<VT, FT>(
Expand Down
28 changes: 14 additions & 14 deletions paddle/cinn/adt/generate_map_expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ bool CollectRewrittenReductionOpStmts(const OpStmt& op_stmt,
PADDLE_ENFORCE_EQ(
op.Has<const ::pir::Operation*>(),
true,
phi::errors::InvalidArgument(
::common::errors::InvalidArgument(
"The op should have a value of type ::pir::Operation*"));
if (GetOpPatternKind(op.Get<const ::pir::Operation*>()) ==
hlir::framework::OpPatternKind::kReduction) {
Expand Down Expand Up @@ -241,7 +241,7 @@ std::vector<std::shared_ptr<IGroup>> GenerateIGroups(
PADDLE_ENFORCE_EQ(
!op_stmts->empty(),
true,
phi::errors::InvalidArgument("The op_stmts should not be empty"));
::common::errors::InvalidArgument("The op_stmts should not be empty"));

PartitionIGroupOpStmts(op_stmts, [&](const auto& igroup_spec) {
ret.push_back(MakeIGroup(igroup_spec));
Expand Down Expand Up @@ -278,12 +278,12 @@ std::unordered_map<Variable, const Value> MakeSdIterator2Iterator(
std::unordered_map<Variable, const Value> ret{};

for (std::size_t i = 0; i < igroup.loop_iterators()->size(); ++i) {
PADDLE_ENFORCE_EQ(
ret.emplace(igroup.loop_iterators()->at(i),
igroup.loop_iterators()->at(i))
.second,
true,
phi::errors::InvalidArgument("The loop iterator should be unique"));
PADDLE_ENFORCE_EQ(ret.emplace(igroup.loop_iterators()->at(i),
igroup.loop_iterators()->at(i))
.second,
true,
::common::errors::InvalidArgument(
"The loop iterator should be unique"));
}

return ret;
Expand Down Expand Up @@ -344,10 +344,10 @@ LoopDescriptor4IterVarT MakeGetterLoopDescriptor4IterVar(
using Cache = std::unordered_map<Iterator, LoopDescriptor>;
const auto& sd_iter2sd = std::make_shared<Cache>();
for (std::size_t i = 0; i < loop_iters->size(); ++i) {
PADDLE_ENFORCE_EQ(
sd_iter2sd->emplace(loop_iters->at(i), sd->at(i)).second,
true,
phi::errors::InvalidArgument("The loop iterator should be unique"));
PADDLE_ENFORCE_EQ(sd_iter2sd->emplace(loop_iters->at(i), sd->at(i)).second,
true,
::common::errors::InvalidArgument(
"The loop iterator should be unique"));
}
return [sd_iter2sd](const auto& sd_iter) { return sd_iter2sd->at(sd_iter); };
}
Expand All @@ -359,7 +359,7 @@ TreeMerger<Stmt> MakeTreeMerger(const MapIr& map_ir) {
PADDLE_ENFORCE_EQ(
cache->emplace(op_stmt, map_ir.loop_iterators()).second,
true,
phi::errors::InvalidArgument("The op_stmt should be unique"));
::common::errors::InvalidArgument("The op_stmt should be unique"));
}

TreeMerger<Stmt> tree_merger{};
Expand All @@ -383,7 +383,7 @@ MapStmt<Stmt> MakeMapStmt(const MapIrList& map_irs) {
"The size of stmts should be 1, but got %d.", stmts->size()));
PADDLE_ENFORCE_EQ(stmts->at(0).Has<MapStmt<Stmt>>(),
true,
phi::errors::InvalidArgument(
::common::errors::InvalidArgument(
"The stmts should have a value of type MapStmt<Stmt>"));
return stmts->at(0).Get<MapStmt<Stmt>>();
}
Expand Down
13 changes: 7 additions & 6 deletions paddle/cinn/adt/get_sub_reshape_dim_ranges.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ GetSubReshapeDimRanges(const List<DimExpr>& lhs_dims,
PADDLE_ENFORCE_EQ(
!lhs_dims->empty(),
true,
phi::errors::InvalidArgument("Sorry,but lhs_dims is empty"));
PADDLE_ENFORCE_EQ(!rhs_dims->empty(),
true,
phi::errors::InvalidArgument("Sory,but rhs_dims is empty"));
::common::errors::InvalidArgument("Sorry,but lhs_dims is empty"));
PADDLE_ENFORCE_EQ(
!rhs_dims->empty(),
true,
::common::errors::InvalidArgument("Sory,but rhs_dims is empty"));
std::vector<std::pair<int, int>> lhs_ranges{};
std::vector<std::pair<int, int>> rhs_ranges{};
int lhs_start = 0;
Expand All @@ -59,7 +60,7 @@ GetSubReshapeDimRanges(const List<DimExpr>& lhs_dims,
PADDLE_ENFORCE_EQ(
dims->at(i).Has<std::int64_t>(),
true,
phi::errors::InvalidArgument("dims->at(i) is not int64_t"));
::common::errors::InvalidArgument("dims->at(i) is not int64_t"));
ret *= dims->at(i).Get<std::int64_t>();
}
return ret;
Expand Down Expand Up @@ -95,7 +96,7 @@ GetSubReshapeDimRanges(const List<DimExpr>& lhs_dims,
}
PADDLE_ENFORCE_EQ(lhs_end == lhs_dims->size() && rhs_end == rhs_dims->size(),
true,
phi::errors::InvalidArgument(
::common::errors::InvalidArgument(
"lhs_end is not equal to lhs_dims->size() and rhs_end "
"is not equal to rhs_dims->size()"));
if (lhs_start < lhs_end && rhs_start < rhs_end) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/adt/igroup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ std::shared_ptr<IndexExprInferContext> MakeIndexExprInferContext(
.emplace(anchor_iterators->at(i), anchor_iterators->at(i))
.second,
true,
phi::errors::InvalidArgument(
::common::errors::InvalidArgument(
"The element in anchor iterators failed to insert in anchor "
"iterator2value! Please check."));
}
Expand Down
6 changes: 3 additions & 3 deletions paddle/cinn/adt/igroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class IGroup final {
const List<Iterator>& loop_iterators() const {
PADDLE_ENFORCE_EQ(anchor_sd_equation_ctx_.has_value(),
true,
phi::errors::InvalidArgument(
::common::errors::InvalidArgument(
"The anchor_sd_equation_ctx_ has no value."));
return anchor_sd_equation_ctx_.value().sd_iterators();
}
Expand Down Expand Up @@ -128,7 +128,7 @@ class IGroup final {
PADDLE_ENFORCE_EQ(
index2tensor->emplace(index, tensor).second,
true,
phi::errors::InvalidArgument(
::common::errors::InvalidArgument(
"The index2tensor map has already contained the index."));
(*tensor2indexes)[tensor].emplace_back(index);
}
Expand All @@ -138,7 +138,7 @@ class IGroup final {
PADDLE_ENFORCE_EQ(
index2tensor->emplace(index, tensor).second,
true,
phi::errors::InvalidArgument(
::common::errors::InvalidArgument(
"The index2tensor map has already contained the index."));
(*tensor2indexes)[tensor].emplace_back(index);
}
Expand Down
12 changes: 6 additions & 6 deletions paddle/cinn/adt/inline_translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct InlineTranslator final {
static DstTree Call(const SrcTree& src_tree) {
PADDLE_ENFORCE_EQ((src_tree.template Has<MapT<SrcTree>>()),
true,
phi::errors::InvalidArgument(
::common::errors::InvalidArgument(
"src_tree.template should have <MapT<SrcTree>>()"));
const MapT<DstTree> dst_tree =
CallMap(src_tree.template Get<MapT<SrcTree>>());
Expand Down Expand Up @@ -102,7 +102,7 @@ struct InlineTranslator final {
const auto& [arg_tensor] = arg_leaf.tuple();
PADDLE_ENFORCE_EQ(producer_tensor == arg_tensor,
true,
phi::errors::InvalidArgument(
::common::errors::InvalidArgument(
"producer_tensor should be equal to arg_tensor"));
List<OpExpr> ret{};
ret->assign(op_call_children->begin(), op_call_children->end());
Expand All @@ -117,7 +117,7 @@ struct InlineTranslator final {
PADDLE_ENFORCE_EQ(
(consumer_tree.template Has<OpCallT<OpExpr>>()),
true,
phi::errors::InvalidArgument(
::common::errors::InvalidArgument(
"consumer_tree.template should have <OpCallT<OpExpr>>()"));
const auto& op_call = consumer_tree.template Get<OpCallT<OpExpr>>();
const auto& op_call_children =
Expand All @@ -126,7 +126,7 @@ struct InlineTranslator final {
PADDLE_ENFORCE_EQ(
(op_call_child.template Has<Load<TensorT>>()),
true,
phi::errors::InvalidArgument(
::common::errors::InvalidArgument(
"op_call_child.template should have <Load<TensorT>>()"));
}

Expand Down Expand Up @@ -181,7 +181,7 @@ struct InlineTranslator final {
index2dst_leaf.emplace(i, NaiveTranslateLeaf(*std::next(begin, i)))
.second,
true,
phi::errors::InvalidArgument(
::common::errors::InvalidArgument(
"index2dst_leaf.emplace should return true"));
}
// Inline dst leaves
Expand Down Expand Up @@ -215,7 +215,7 @@ struct InlineTranslator final {
static DstLeaf NaiveTranslateLeaf(const SrcTree& src_tree) {
PADDLE_ENFORCE_EQ(src_tree.template Has<SrcLeaf>(),
true,
phi::errors::InvalidArgument(
::common::errors::InvalidArgument(
"src_tree.template should have <SrcLeaf>()"));
const auto& [tensor, op_call] = src_tree.template Get<SrcLeaf>().tuple();
const List<Load<TensorT>>& src_loads =
Expand Down
8 changes: 6 additions & 2 deletions paddle/cinn/adt/m_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,16 @@ void CollectTensorIndexIterators(const TensorIndexExpr& tensor_index_expr,

void CollectTensorIndexIteratorsImpl(const Undefined& tensor_index_expr,
std::unordered_set<Iterator>* ret) {
PADDLE_THROW(::common::errors::Unimplemented("Not Implemented"));
PADDLE_THROW(::common::errors::Unimplemented(
"CollectTensorIndexIteratorsImpl is not implemented for Undefined tensor "
"index expression. Please check your input."));
}

void CollectTensorIndexIteratorsImpl(const Ok& ok,
std::unordered_set<Iterator>* ret) {
PADDLE_THROW(::common::errors::Unimplemented("Not Implemented"));
PADDLE_THROW(::common::errors::Unimplemented(
"CollectTensorIndexIteratorsImpl is not implemented for Ok state. Please "
"ensure the function is correctly called."));
}

void CollectTensorIndexIteratorsImpl(const Iterator& iterator,
Expand Down
7 changes: 6 additions & 1 deletion paddle/cinn/adt/map_expr_ctx.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ class MapExprCtx final {
::pir::Operation* node,
const std::vector<ir::LoweredFunc>& lowered_funcs) {
Node2LoweredFuncs* map = &node2lowered_funcs_;
CHECK(map->emplace(node, ir::ir_utils::IRCopy(lowered_funcs)).second);
PADDLE_ENFORCE_EQ(
map->emplace(node, ir::ir_utils::IRCopy(lowered_funcs)).second,
true,
::common::errors::InvalidArgument(
"Failed to emplace the node in the map. Ensure that the node is "
"valid and the operation is correct."));
}

const Node2LoweredFuncs& node2lowered_funcs() const {
Expand Down
6 changes: 3 additions & 3 deletions paddle/cinn/adt/naive_bidirection_equation_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ OpArgIndexes<std::optional<Index>> MakeOutMsgOpArgIndexes(
for (const auto& out_msg_in_index : *opt_out_msg_in_indexes) {
PADDLE_ENFORCE_EQ(out_msg_in_index.has_value(),
true,
phi::errors::InvalidArgument(
::common::errors::InvalidArgument(
"The out_msg_in_index should have value."));
out_msg_in_indexes->emplace_back(out_msg_in_index.value());
}
Expand Down Expand Up @@ -118,7 +118,7 @@ void NaiveBidirectionEquationGenerator::InitInMsgIndex2OutMsgIndex() {
this->in_msg_index2out_msg_index_.emplace(in_index, out_index)
.second,
true,
phi::errors::InvalidArgument(
::common::errors::InvalidArgument(
"The out_msg_index2in_msg_index_ map has already "
"contained the out_index."));
});
Expand Down Expand Up @@ -172,7 +172,7 @@ NaiveBidirectionEquationGenerator::MakeGetterOpStmt4OpPlaceHolder() const {
->emplace(fake_op_placeholders_->at(i), op_stmts_->at(i))
.second,
true,
phi::errors::InvalidArgument(
::common::errors::InvalidArgument(
"The fake_op_placeholder2op_stmt map has already contained the "
"fake_op_placeholder."));
}
Expand Down
20 changes: 10 additions & 10 deletions paddle/cinn/adt/naive_op_equation_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ void GenerateOpEquationsImpl(const ::pir::Operation* op_node,
hlir::framework::pir::CompatibleInfo::OpName(*op_node));
PADDLE_ENFORCE_EQ(generate_equations.Find(cinn_op),
true,
phi::errors::NotFound("generate_equations not found"));
::common::errors::NotFound("generate_equations not found"));
generate_equations[cinn_op](ctx);
}

Expand Down Expand Up @@ -127,13 +127,13 @@ GetArgStaticDimT MakeGetterArgStaticDim(const List<Tensor>& tensors) {
return [=](std::size_t tensor_idx,
std::size_t dim_idx) -> std::optional<std::int64_t> {
const auto& opt_expr = GetArgDim(tensors, tensor_idx, dim_idx);
PADDLE_ENFORCE_EQ(
opt_expr.has_value(),
true,
phi::errors::InvalidArgument("Sorry,but opt_expr don't has value"));
PADDLE_ENFORCE_EQ(opt_expr.has_value(),
true,
::common::errors::InvalidArgument(
"Sorry,but opt_expr don't has value"));
PADDLE_ENFORCE_EQ(opt_expr.value().Has<std::int64_t>(),
true,
phi::errors::InvalidArgument(
::common::errors::InvalidArgument(
"Sorry,but opt_expr should has value int64_t"));
return opt_expr.value().Get<std::int64_t>();
};
Expand Down Expand Up @@ -220,10 +220,10 @@ GenerateContext4LocalOpStmt(const List<OpStmt>& op_stmts) {

for (const auto& op_stmt : *op_stmts) {
const auto& ctx = MakeContextAndGenerateEquations(op_stmt);
PADDLE_ENFORCE_EQ(
op_stmt2equation_ctx->emplace(op_stmt, ctx).second,
true,
phi::errors::InvalidArgument("op_stmt2equation_ctx insert failed"));
PADDLE_ENFORCE_EQ(op_stmt2equation_ctx->emplace(op_stmt, ctx).second,
true,
::common::errors::InvalidArgument(
"op_stmt2equation_ctx insert failed"));
}

return [op_stmt2equation_ctx](const auto& op_stmt) {
Expand Down
5 changes: 4 additions & 1 deletion paddle/cinn/adt/print_utils/print_schedule_descriptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ std::string ToTxtString(const LoopDescriptor& loop_descriptor) {
*string += vectorize.iter_var_name();
},
[&](const Unroll& unroll) { *string += unroll.iter_var_name(); }};
CHECK(loop_size.Has<std::int64_t>());
PADDLE_ENFORCE_EQ(loop_size.Has<std::int64_t>(),
true,
::common::errors::InvalidArgument(
"The loop_size should have type int64_t."));
*string += "=0.." + std::to_string(loop_size.Get<std::int64_t>());
return ret;
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/adt/print_utils/print_schedule_descriptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#pragma once

#include <string>

#include "paddle/common/enforce.h"
namespace cinn::adt {

class LoopDescriptor;
Expand Down
Loading

0 comments on commit e1f19c6

Please sign in to comment.