Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Generalize cggi-straight-line-vectorizer
Supports any elementwise mappable ops, so we can apply it to the BGV path as well.
- Loading branch information
Showing
14 changed files
with
263 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,5 +33,4 @@ gentbl_cc_library( | |
exports_files([ | ||
"Passes.h", | ||
"SetDefaultParameters.h", | ||
"StraightLineVectorizer.h", | ||
]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# StraightLineVectorizer tablegen and headers. | ||
|
||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") | ||
|
||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
exports_files([ | ||
"StraightLineVectorizer.h", | ||
]) | ||
|
||
gentbl_cc_library( | ||
name = "pass_inc_gen", | ||
tbl_outs = [ | ||
( | ||
[ | ||
"-gen-pass-decls", | ||
"-name=StraightLineVectorizer", | ||
], | ||
"StraightLineVectorizer.h.inc", | ||
), | ||
( | ||
["-gen-pass-doc"], | ||
"StraightLineVectorizerPasses.md", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "StraightLineVectorizer.td", | ||
deps = [ | ||
"@llvm-project//mlir:OpBaseTdFiles", | ||
"@llvm-project//mlir:PassBaseTdFiles", | ||
], | ||
) |
18 changes: 18 additions & 0 deletions
18
include/Transforms/StraightLineVectorizer/StraightLineVectorizer.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#ifndef INCLUDE_TRANSFORMS_STRAIGHTLINEVECTORIZER_STRAIGHTLINEVECTORIZER_H_ | ||
#define INCLUDE_TRANSFORMS_STRAIGHTLINEVECTORIZER_STRAIGHTLINEVECTORIZER_H_ | ||
|
||
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
|
||
#define GEN_PASS_DECL | ||
#include "include/Transforms/StraightLineVectorizer/StraightLineVectorizer.h.inc" | ||
|
||
#define GEN_PASS_REGISTRATION | ||
#include "include/Transforms/StraightLineVectorizer/StraightLineVectorizer.h.inc" | ||
|
||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // INCLUDE_TRANSFORMS_STRAIGHTLINEVECTORIZER_STRAIGHTLINEVECTORIZER_H_ |
15 changes: 15 additions & 0 deletions
15
include/Transforms/StraightLineVectorizer/StraightLineVectorizer.td
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#ifndef INCLUDE_TRANSFORMS_STRAIGHTLINEVECTORIZER_STRAIGHTLINEVECTORIZER_TD_ | ||
#define INCLUDE_TRANSFORMS_STRAIGHTLINEVECTORIZER_STRAIGHTLINEVECTORIZER_TD_ | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
def StraightLineVectorizer : Pass<"straight-line-vectorize"> { | ||
let summary = "A vectorizer for straight line programs."; | ||
let description = [{ | ||
This pass ignores control flow and only vectorizes straight-line programs | ||
within a given region. | ||
}]; | ||
let dependentDialects = []; | ||
} | ||
|
||
#endif // INCLUDE_TRANSFORMS_STRAIGHTLINEVECTORIZER_STRAIGHTLINEVECTORIZER_TD_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
cc_library( | ||
name = "StraightLineVectorizer", | ||
srcs = ["StraightLineVectorizer.cpp"], | ||
hdrs = [ | ||
"@heir//include/Transforms/StraightLineVectorizer:StraightLineVectorizer.h", | ||
], | ||
deps = [ | ||
"@heir//include/Graph", | ||
"@heir//include/Transforms/StraightLineVectorizer:pass_inc_gen", | ||
"@llvm-project//llvm:Support", | ||
"@llvm-project//mlir:Analysis", | ||
"@llvm-project//mlir:ArithDialect", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:Pass", | ||
"@llvm-project//mlir:Support", | ||
"@llvm-project//mlir:TensorDialect", | ||
"@llvm-project//mlir:Transforms", | ||
], | ||
) |
164 changes: 164 additions & 0 deletions
164
lib/Transforms/StraightLineVectorizer/StraightLineVectorizer.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
#include "include/Transforms/StraightLineVectorizer/StraightLineVectorizer.h" | ||
|
||
#include "include/Graph/Graph.h" | ||
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project | ||
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project | ||
#include "mlir/include/mlir/Analysis/SliceAnalysis.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project | ||
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/Passes.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/TopologicalSortUtils.h" // from @llvm-project | ||
|
||
#define DEBUG_TYPE "straight-line-vectorizer" | ||
|
||
namespace mlir { | ||
namespace heir { | ||
|
||
#define GEN_PASS_DEF_STRAIGHTLINEVECTORIZER | ||
#include "include/Transforms/StraightLineVectorizer/StraightLineVectorizer.h.inc" | ||
|
||
/// Returns true if the two operations can be combined into a single vectorized | ||
/// operation. | ||
bool areCompatible(Operation *lhs, Operation *rhs) { | ||
if (lhs->getName() != rhs->getName() || | ||
lhs->getDialect() != rhs->getDialect() || | ||
lhs->getResultTypes() != rhs->getResultTypes()) { | ||
return false; | ||
} | ||
return OpTrait::hasElementwiseMappableTraits(lhs); | ||
} | ||
|
||
bool tryVectorizeBlock(Block *block) { | ||
graph::Graph<Operation *> graph; | ||
for (auto &op : block->getOperations()) { | ||
if (!op.hasTrait<OpTrait::Elementwise>()) { | ||
continue; | ||
} | ||
|
||
graph.addVertex(&op); | ||
SetVector<Operation *> backwardSlice; | ||
BackwardSliceOptions options; | ||
options.omitBlockArguments = true; | ||
|
||
getBackwardSlice(&op, &backwardSlice, options); | ||
for (auto *upstreamDep : backwardSlice) { | ||
// An edge from upstreamDep to `op` means that upstreamDep must be | ||
// computed before `op`. | ||
graph.addEdge(upstreamDep, &op); | ||
} | ||
} | ||
|
||
if (graph.empty()) { | ||
return false; | ||
} | ||
|
||
auto result = graph.sortGraphByLevels(); | ||
assert(succeeded(result) && | ||
"Only possible failure is a cycle in the SSA graph!"); | ||
auto levels = result.value(); | ||
|
||
LLVM_DEBUG({ | ||
llvm::dbgs() | ||
<< "Found operations to vectorize. In topo-sorted level order:\n"; | ||
int level_num = 0; | ||
for (const auto &level : levels) { | ||
llvm::dbgs() << "\nLevel " << level_num++ << ":\n"; | ||
for (auto op : level) { | ||
llvm::dbgs() << " - " << *op << "\n"; | ||
} | ||
} | ||
}); | ||
|
||
bool madeReplacement = false; | ||
for (const auto &level : levels) { | ||
DenseMap<Operation *, SmallVector<Operation *, 4>> compatibleOps; | ||
for (auto *op : level) { | ||
bool foundCompatible = false; | ||
for (auto &[key, bucket] : compatibleOps) { | ||
if (areCompatible(key, op)) { | ||
compatibleOps[key].push_back(op); | ||
foundCompatible = true; | ||
} | ||
} | ||
if (!foundCompatible) { | ||
compatibleOps[op].push_back(op); | ||
} | ||
} | ||
LLVM_DEBUG(llvm::dbgs() | ||
<< "Partitioned level of size " << level.size() << " into " | ||
<< compatibleOps.size() << " groups of compatible ops\n"); | ||
|
||
for (auto &[key, bucket] : compatibleOps) { | ||
if (bucket.size() < 2) { | ||
continue; | ||
} | ||
|
||
LLVM_DEBUG({ | ||
llvm::dbgs() << "Vectorizing ops:\n"; | ||
for (auto op : bucket) { | ||
llvm::dbgs() << " - " << *op << "\n"; | ||
} | ||
}); | ||
|
||
OpBuilder builder(bucket.back()); | ||
// relies on CGGI ops having a single result type | ||
Type elementType = key->getResultTypes()[0]; | ||
RankedTensorType tensorType = RankedTensorType::get( | ||
{static_cast<int64_t>(bucket.size())}, elementType); | ||
|
||
SmallVector<Value, 4> vectorizedOperands; | ||
for (int operandIndex = 0; operandIndex < key->getNumOperands(); | ||
++operandIndex) { | ||
SmallVector<Value, 4> operands; | ||
operands.reserve(bucket.size()); | ||
for (auto *op : bucket) { | ||
operands.push_back(op->getOperand(operandIndex)); | ||
} | ||
auto fromElementsOp = builder.create<tensor::FromElementsOp>( | ||
key->getLoc(), tensorType, operands); | ||
vectorizedOperands.push_back(fromElementsOp.getResult()); | ||
} | ||
|
||
Operation *vectorizedOp = builder.clone(*key); | ||
vectorizedOp->setOperands(vectorizedOperands); | ||
vectorizedOp->getResult(0).setType(tensorType); | ||
|
||
int bucketIndex = 0; | ||
for (auto *op : bucket) { | ||
auto extractionIndex = builder.create<arith::ConstantOp>( | ||
op->getLoc(), builder.getIndexAttr(bucketIndex)); | ||
auto extractOp = builder.create<tensor::ExtractOp>( | ||
op->getLoc(), elementType, vectorizedOp->getResult(0), | ||
extractionIndex.getResult()); | ||
op->replaceAllUsesWith(ValueRange{extractOp.getResult()}); | ||
bucketIndex++; | ||
} | ||
|
||
for (auto *op : bucket) { | ||
op->erase(); | ||
} | ||
madeReplacement = true; | ||
} | ||
} | ||
|
||
return madeReplacement; | ||
} | ||
|
||
struct StraightLineVectorizer | ||
: impl::StraightLineVectorizerBase<StraightLineVectorizer> { | ||
using StraightLineVectorizerBase::StraightLineVectorizerBase; | ||
|
||
void runOnOperation() override { | ||
getOperation()->walk<WalkOrder::PreOrder>([&](Block *block) { | ||
if (tryVectorizeBlock(block)) { | ||
sortTopologically(block); | ||
} | ||
}); | ||
} | ||
}; | ||
|
||
} // namespace heir | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters