Skip to content

Commit

Permalink
Relax assumptions on dimensionality in air.dma as with optimizations …
Browse files Browse the repository at this point in the history
…the number of dimensions may be different than memref rank (Xilinx#469)
  • Loading branch information
nirvedhmeshram authored Mar 1, 2024
1 parent a9bc6ed commit c4d79db
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 34 deletions.
34 changes: 0 additions & 34 deletions mlir/lib/Conversion/ConvertToAIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1060,30 +1060,13 @@ class AIRDmaToAIRChannelConversion
} else
return failure();

auto src_rank = src_type.getRank();
auto dst_rank = dst_type.getRank();

SmallVector<Value, 4> src_offsets = op.getSrcOffsets();
SmallVector<Value, 4> dst_offsets = op.getDstOffsets();
SmallVector<Value, 4> src_sizes = op.getSrcSizes();
SmallVector<Value, 4> dst_sizes = op.getDstSizes();
SmallVector<Value, 4> src_strides = op.getSrcStrides();
SmallVector<Value, 4> dst_strides = op.getDstStrides();

if (src_offsets.size()) {
if (src_sizes.size() != (unsigned)src_rank)
return failure();
if (src_strides.size() != (unsigned)src_rank)
return failure();
}

if (dst_offsets.size()) {
if (dst_sizes.size() != (unsigned)dst_rank)
return failure();
if (dst_strides.size() != (unsigned)dst_rank)
return failure();
}

std::set<Operation *> erased;
SmallVector<air::ChannelInterface, 1> externalGetPut;
SmallVector<air::ChannelInterface, 1> internalGetPut;
Expand Down Expand Up @@ -1472,30 +1455,13 @@ class AIRDemoteDmaToAIRHierarchyConversion
return failure(); // This pass is currently not able to promote in memory
// tier

auto src_rank = src_type.getRank();
auto dst_rank = dst_type.getRank();

SmallVector<Value, 4> src_offsets = op.getSrcOffsets();
SmallVector<Value, 4> dst_offsets = op.getDstOffsets();
SmallVector<Value, 4> src_sizes = op.getSrcSizes();
SmallVector<Value, 4> dst_sizes = op.getDstSizes();
SmallVector<Value, 4> src_strides = op.getSrcStrides();
SmallVector<Value, 4> dst_strides = op.getDstStrides();

if (src_offsets.size()) {
if (src_sizes.size() != (unsigned)src_rank)
return failure();
if (src_strides.size() != (unsigned)src_rank)
return failure();
}

if (dst_offsets.size()) {
if (dst_sizes.size() != (unsigned)dst_rank)
return failure();
if (dst_strides.size() != (unsigned)dst_rank)
return failure();
}

std::set<Operation *> erased;

{
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Transform/AIRDependency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ class AIRDependency
for (unsigned i = 0; i < sink_op_memcpy.getSrcStrides().size(); i++)
sink_op_scalar_ins.push_back(sink_op_memcpy.getSrcStrides()[i]);
if (sink_op_memcpy.getSrcOffsets().size()) {
numDimsSrc = sink_op_memcpy.getSrcOffsets().size();
for (unsigned i = 0; i < numDimsSrc; i++) {
src_indices.push_back(sink_op_memcpy.getSrcOffsets()[i]);
}
Expand All @@ -366,6 +367,7 @@ class AIRDependency
for (unsigned i = 0; i < sink_op_memcpy.getDstStrides().size(); i++)
sink_op_scalar_outs.push_back(sink_op_memcpy.getDstStrides()[i]);
if (sink_op_memcpy.getDstOffsets().size()) {
numDimsDst = sink_op_memcpy.getDstOffsets().size();
for (unsigned i = 0; i < numDimsDst; i++) {
dst_indices.push_back(sink_op_memcpy.getDstOffsets()[i]);
}
Expand Down Expand Up @@ -1018,6 +1020,7 @@ class AIRDependency
unsigned numDimsSrc =
memcpy.getSrcMemref().getType().cast<MemRefType>().getRank();
if (memcpy.getSrcOffsets().size()) {
numDimsSrc = memcpy.getSrcOffsets().size();
for (unsigned i = 0; i < numDimsSrc; i++) {
src_indices.push_back(memcpy.getSrcOffsets()[i]);
}
Expand All @@ -1034,6 +1037,7 @@ class AIRDependency
memcpy.getDstMemref().getType().cast<MemRefType>().getRank();
SmallVector<Value, 2> dst_indices;
if (memcpy.getDstOffsets().size()) {
numDimsDst = memcpy.getDstOffsets().size();
for (unsigned i = 0; i < numDimsDst; i++) {
dst_indices.push_back(memcpy.getDstOffsets()[i]);
}
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Util/Dependency.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1818,6 +1818,7 @@ void dependencyTracer::pushDepsAtCurrentScope(mlir::Value operand,
memcpy.getSrcMemref().getType().cast<MemRefType>().getRank();
SmallVector<Value, 2> src_indices;
if (memcpy.getSrcOffsets().size()) {
numDimsSrc = memcpy.getSrcOffsets().size();
for (unsigned i = 0; i < numDimsSrc; i++) {
src_indices.push_back(memcpy.getSrcOffsets()[i]);
}
Expand All @@ -1834,6 +1835,7 @@ void dependencyTracer::pushDepsAtCurrentScope(mlir::Value operand,
memcpy.getDstMemref().getType().cast<MemRefType>().getRank();
SmallVector<Value, 2> dst_indices;
if (memcpy.getDstOffsets().size()) {
numDimsDst = memcpy.getDstOffsets().size();
for (unsigned i = 0; i < numDimsDst; i++) {
dst_indices.push_back(memcpy.getDstOffsets()[i]);
}
Expand Down Expand Up @@ -2034,6 +2036,7 @@ void dependencyTracer::getPartialMemrefFromOp(
for (unsigned i = 0; i < sink_op_memcpy.getSrcStrides().size(); i++)
sink_op_scalar_ins.push_back(sink_op_memcpy.getSrcStrides()[i]);
if (sink_op_memcpy.getSrcOffsets().size()) {
numDimsSrc = sink_op_memcpy.getSrcOffsets().size();
for (unsigned i = 0; i < numDimsSrc; i++) {
src_indices.push_back(sink_op_memcpy.getSrcOffsets()[i]);
}
Expand All @@ -2058,6 +2061,7 @@ void dependencyTracer::getPartialMemrefFromOp(
for (unsigned i = 0; i < sink_op_memcpy.getDstStrides().size(); i++)
sink_op_scalar_outs.push_back(sink_op_memcpy.getDstStrides()[i]);
if (sink_op_memcpy.getDstOffsets().size()) {
numDimsDst = sink_op_memcpy.getDstOffsets().size();
for (unsigned i = 0; i < numDimsDst; i++) {
dst_indices.push_back(sink_op_memcpy.getDstOffsets()[i]);
}
Expand Down

0 comments on commit c4d79db

Please sign in to comment.