Skip to content

Commit

Permalink
Merge pull request #573 from j2kun:heco-examples
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 620288526
  • Loading branch information
Copybara-Service committed Mar 29, 2024
2 parents ef52e8d + 75f1276 commit 23a5dc6
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 13 deletions.
16 changes: 13 additions & 3 deletions include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h
Expand Up @@ -111,9 +111,16 @@ class TargetSlotLattice : public dataflow::Lattice<TargetSlot> {
class TargetSlotAnalysis
: public dataflow::SparseBackwardDataFlowAnalysis<TargetSlotLattice> {
public:
explicit TargetSlotAnalysis(DataFlowSolver &solver,
SymbolTableCollection &symbolTable)
: SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
explicit TargetSlotAnalysis(
DataFlowSolver &solver, SymbolTableCollection &symbolTable,
// The dataflow solver is a private member of the base analysis
// class, so if we want to access it we have to get it explicitly from
// the caller. It's required that this solver is pre-loaded with a
// SparseConstantPropagation analysis. I'd like a better way to do
// this: maybe pass a callback?
const DataFlowSolver *sccpAnalysis)
: SparseBackwardDataFlowAnalysis(solver, symbolTable),
sccpAnalysis(sccpAnalysis) {}
~TargetSlotAnalysis() override = default;
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;

Expand All @@ -125,6 +132,9 @@ class TargetSlotAnalysis
void visitBranchOperand(OpOperand &operand) override {};
void visitCallOperand(OpOperand &operand) override {};
void setToExitState(TargetSlotLattice *lattice) override {};

private:
const DataFlowSolver *sccpAnalysis;
};

} // namespace target_slot_analysis
Expand Down
44 changes: 39 additions & 5 deletions lib/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.cpp
@@ -1,8 +1,10 @@
#include "include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h"

#include "lib/Dialect/Utils.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
Expand All @@ -22,15 +24,47 @@ void TargetSlotAnalysis::visitOperation(
llvm::TypeSwitch<Operation &>(*op)
.Case<tensor::InsertOp>([&](auto insertOp) {
LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << *op << "\n"; });
auto insertIndexRes = get1DExtractionIndex<tensor::InsertOp>(insertOp);
auto insertIndices = insertOp.getIndices();
if (insertIndices.size() != 1) {
LLVM_DEBUG(llvm::dbgs() << "At " << insertOp
<< " can't handle >1D insertion index\n");
return;
}

Value insertIndexValue = insertOp.getIndices()[0];
const dataflow::Lattice<dataflow::ConstantValue> *insertIndexLattice =
sccpAnalysis
->lookupState<dataflow::Lattice<dataflow::ConstantValue>>(
insertIndexValue);

if (insertIndexLattice) {
LLVM_DEBUG(llvm::dbgs()
<< "At " << insertOp << " SCCP analysis gives lattice of "
<< *insertIndexLattice << "\n");
}

// If the target slot can't be statically determined, we can't
// propagate anything through the IR.
if (failed(insertIndexRes)) return;
if (!insertIndexLattice ||
insertIndexLattice->getValue().isUninitialized() ||
!insertIndexLattice->getValue().getConstantValue()) {
LLVM_DEBUG(
llvm::dbgs()
<< "At " << insertOp
<< " can't statically determine constant insertion index\n");
return;
}
Attribute insertIndexAttr =
insertIndexLattice->getValue().getConstantValue();
auto insertIndexIntAttr = insertIndexAttr.dyn_cast<IntegerAttr>();
assert(insertIndexIntAttr &&
"If 1D insertion index is constant, it must be integer");
int64_t insertIndexConst = insertIndexIntAttr.getInt();

// The target slot propagates to the value inserted, which is the first
// positional argument
TargetSlotLattice *lattice = operands[0];
TargetSlot newSlot = TargetSlot{insertIndexRes.value()};
TargetSlot newSlot = TargetSlot{insertIndexConst};
LLVM_DEBUG({
llvm::dbgs() << "Joining " << lattice->getValue() << " and "
<< newSlot << " --> "
Expand Down
30 changes: 25 additions & 5 deletions lib/Dialect/TensorExt/Transforms/InsertRotate.cpp
Expand Up @@ -42,27 +42,47 @@ struct InsertRotate : impl::InsertRotateBase<InsertRotate> {

SymbolTableCollection symbolTable;
DataFlowSolver solver;
// These two upstream analyses are required dependencies for any sparse
// dataflow analysis, or else the analysis will be a no-op. Cf.

// These two upstream analyses are required to be instantiated in any
// sparse dataflow analysis, or else the analysis will be a no-op. Cf.
// https://github.com/llvm/llvm-project/issues/58922
solver.load<dataflow::DeadCodeAnalysis>();
solver.load<dataflow::SparseConstantPropagation>();
solver.load<target_slot_analysis::TargetSlotAnalysis>(symbolTable);
if (failed(solver.initializeAndRun(getOperation()))) {
getOperation()->emitOpError() << "Failed to run the analysis.\n";
signalPassFailure();
return;
}

// We want to use the result of the sparse constant propagation from the
// first dataflow solver as an input to the target slot analysis. For some
// reason, actually running `--sccp` before this pass causes the IR to
// simplify away some operations that are needed to properly identify
// target slots. So the SparseConstantPropagation above is a simulated
// folding of arith operations, so as to identify when insertion indices
// are statically inferable.
//
// TODO(#572): find a better way to depend dataflow analyses on each other.
DataFlowSolver solver2;
solver2.load<dataflow::DeadCodeAnalysis>();
solver2.load<dataflow::SparseConstantPropagation>();
solver2.load<target_slot_analysis::TargetSlotAnalysis>(symbolTable,
&solver);
if (failed(solver2.initializeAndRun(getOperation()))) {
getOperation()->emitOpError() << "Failed to run the analysis.\n";
signalPassFailure();
return;
}

// Annotate all arith ops with their target slot attribute, so that it can
// be matched in the DRR rules.
OpBuilder builder(context);
getOperation()->walk([&](Operation *op) {
if (op->getNumResults() == 0) return;
auto *targetSlotLattice =
solver.lookupState<target_slot_analysis::TargetSlotLattice>(
solver2.lookupState<target_slot_analysis::TargetSlotLattice>(
op->getResult(0));
if (targetSlotLattice->getValue().isInitialized()) {
if (targetSlotLattice && targetSlotLattice->getValue().isInitialized()) {
op->setAttr(
"target_slot",
builder.getIndexAttr(targetSlotLattice->getValue().getValue()));
Expand Down
1 change: 1 addition & 0 deletions tests/heir_simd_vectorizer/box_blur_64x64.mlir
Expand Up @@ -29,6 +29,7 @@ module {
// CHECK-NEXT: secret.yield %[[v15]]
// CHECK-NEXT: } -> !secret.secret<tensor<4096xi16>>
// CHECK-NEXT: return %[[v0]]

func.func @box_blur(%arg0: tensor<4096xi16>) -> tensor<4096xi16> {
%c4096 = arith.constant 4096 : index
%c64 = arith.constant 64 : index
Expand Down
82 changes: 82 additions & 0 deletions tests/heir_simd_vectorizer/roberts_cross_4x4.mlir
@@ -0,0 +1,82 @@
// Ported from https://github.com/MarbleHE/HECO/blob/3e13744233ab0c09030a41ef98b4e061b6fa2eac/evaluation/benchmark/heco_input/robertscross_4x4.mlir

// RUN: heir-opt --secretize=entry-function=roberts_cross --wrap-generic --canonicalize --cse \
// RUN: --heir-simd-vectorizer %s | FileCheck %s

module{
// CHECK-LABEL: @roberts_cross
// CHECK-SAME: (%[[arg0:.*]]: !secret.secret<tensor<16xi16>>) -> !secret.secret<tensor<16xi16>> {
// CHECK-NEXT: %[[c15:.*]] = arith.constant 15 : index
// CHECK-NEXT: %[[c11:.*]] = arith.constant 11 : index
// CHECK-NEXT: secret.generic ins(%[[arg0]] : !secret.secret<tensor<16xi16>>) {
// CHECK-NEXT: ^bb0(%[[arg1:.*]]: tensor<16xi16>):
// CHECK-NEXT: %[[v1:.*]] = tensor_ext.rotate %[[arg1]], %[[c11]]
// CHECK-NEXT: %[[v2:.*]] = arith.subi %[[v1]], %[[arg1]]
// CHECK-NEXT: %[[v3:.*]] = tensor_ext.rotate %[[arg1]], %[[c15]]
// CHECK-NEXT: %[[v4:.*]] = arith.subi %[[v1]], %[[v3]]
// CHECK-NEXT: %[[v5:.*]] = arith.muli %[[v2]], %[[v2]]
// CHECK-NEXT: %[[v6:.*]] = arith.muli %[[v4]], %[[v4]]
// CHECK-NEXT: %[[v7:.*]] = arith.addi %[[v5]], %[[v6]]
func.func @roberts_cross(%img: tensor<16xi16>) -> tensor<16xi16> {
%c16 = arith.constant 16 : index
%c4 = arith.constant 4 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c-1 = arith.constant -1 : index

// Each point p = img[x][y], where x is row and y is column, in the new image will equal:
// (img[x-1][y-1] - img[x][y])^2 + (img[x-1][y] - img[x][y-1])^2
%r = affine.for %x = 0 to 4 iter_args(%imgx = %img) -> tensor<16xi16> {
%1 = affine.for %y = 0 to 4 iter_args(%imgy = %imgx) -> tensor<16xi16> {

// fetch img[x-1][y-1]
%4 = arith.addi %x, %c-1 : index
%5 = arith.muli %4, %c4 : index
%6 = arith.addi %y, %c-1 : index
%7 = arith.addi %5, %6 : index
%8 = arith.remui %7, %c16 : index
%9 = tensor.extract %img[%8] : tensor<16xi16>

// fetch img[x][y]
%10 = arith.muli %x, %c4 : index
%11 = arith.addi %10, %y : index
%12 = arith.remui %11, %c16 : index
%13 = tensor.extract %img[%12] : tensor<16xi16>

// subtract those two
%14 = arith.subi %9, %13 : i16

// fetch img[x-1][y]
%15 = arith.addi %x, %c-1 : index
%16 = arith.muli %15, %c4 : index
%17 = arith.addi %y, %c-1 : index
%18 = arith.addi %16, %17 : index
%19 = arith.remui %18, %c16 : index
%20 = tensor.extract %img[%19] : tensor<16xi16>

// fetch img[x][y-1]
%21 = arith.muli %x, %c4 : index
%22 = arith.addi %y, %c-1 : index
%23 = arith.addi %21, %22 : index
%24 = arith.remui %23, %c16 : index
%25 = tensor.extract %img[%24] : tensor<16xi16>

// subtract those two
%26 = arith.subi %20, %25 : i16

// square each difference
%27 = arith.muli %14, %14 : i16
%28 = arith.muli %26, %26 : i16

// add the squares
%29 = arith.addi %27, %28 : i16

// save to result[x][y]
%30 = tensor.insert %29 into %imgy[%12] : tensor<16xi16>
affine.yield %30: tensor<16xi16>
}
affine.yield %1 : tensor<16xi16>
}
return %r : tensor<16xi16>
}
}

0 comments on commit 23a5dc6

Please sign in to comment.