Skip to content

Commit

Permalink
use target slot analysis to simplify insert rotate pass
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Mar 18, 2024
1 parent cb54c78 commit 48ca56d
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 20 deletions.
6 changes: 3 additions & 3 deletions include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h
Expand Up @@ -122,9 +122,9 @@ class TargetSlotAnalysis
void visitOperation(Operation *op, ArrayRef<TargetSlotLattice *> operands,
ArrayRef<const TargetSlotLattice *> results) override;

void visitBranchOperand(OpOperand &operand) override {};
void visitCallOperand(OpOperand &operand) override {};
void setToExitState(TargetSlotLattice *lattice) override {};
void visitBranchOperand(OpOperand &operand) override{};
void visitCallOperand(OpOperand &operand) override{};
void setToExitState(TargetSlotLattice *lattice) override{};
};

} // namespace target_slot_analysis
Expand Down
25 changes: 19 additions & 6 deletions include/Dialect/TensorExt/Transforms/InsertRotate.td
Expand Up @@ -7,7 +7,12 @@ include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"
include "mlir/IR/PatternBase.td"

// TODO(#512): Support target slot selection when the downstream op is an insert.
// Get the target_slot attribute from an op, if it exists, or else
// return a zero index attribute.
def GetTargetSlotAttr : NativeCodeCall<
"$0.getOwner()->hasAttr(\"target_slot\")"
" ? llvm::cast<mlir::IntegerAttr>($0.getOwner()->getAttr(\"target_slot\"))"
" : $_builder.getIndexAttr(0)">;

// The patterns in this file are intended to align with the automatic-SIMD
// batching heuristics from the HECO project. See section 4.4 of
Expand All @@ -20,23 +25,31 @@ include "mlir/IR/PatternBase.td"
// canonicalization patterns will remove duplicated rotations.
foreach ArithOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in {
def InsertRotations_#ArithOp : Pattern<
(ArithOp
(ArithOp:$arithOp
(Tensor_ExtractOp $t1, (variadic $i1)),
(Tensor_ExtractOp $t2, (variadic $i2)),
$overflow),
[
(TensorExt_RotateOp:$r1 $t1, $i1),
(TensorExt_RotateOp:$r2 $t2, $i2),
(TensorExt_RotateOp:$r1 $t1,
(Arith_SubIOp $i1, (Arith_ConstantOp (GetTargetSlotAttr $arithOp)), DefOverflow)),
(TensorExt_RotateOp:$r2 $t2,
(Arith_SubIOp $i2, (Arith_ConstantOp (GetTargetSlotAttr $arithOp)), DefOverflow)),
(ArithOp:$opResult $r1, $r2, $overflow),
(Tensor_ExtractOp
$opResult,
(MakeSingleResultVariadic (Arith_ConstantOp ConstantAttr<IndexAttr, "0">))),
(MakeSingleResultVariadic
(Arith_ConstantOp (GetTargetSlotAttr $arithOp)))),
]
>;
}


// Pre-align the first op's operands to the index that the result is
// used for in a subsequent op.
// used for in a subsequent op. This is used to simplify the IR
// primarily when there is no specific slot target selected for an op. In
// that case, the above pattern will still replace extractions with
// rotations, and the simplifications will occur by replacing triples
// of rotations with pairs.
// TODO(#514): handle OuterOp with two different InnerOps on the LHS and RHS
foreach InnerOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in {
foreach OuterOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in {
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TensorExt/Transforms/BUILD
Expand Up @@ -69,6 +69,7 @@ cc_library(
],
deps = [
"@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
Expand Down
22 changes: 13 additions & 9 deletions lib/Dialect/TensorExt/Transforms/InsertRotate.cpp
Expand Up @@ -54,15 +54,19 @@ struct InsertRotate : impl::InsertRotateBase<InsertRotate> {
return;
}

LLVM_DEBUG({
getOperation()->walk([&](Operation *op) {
if (op->getNumResults() == 0) return;
auto *targetSlotLattice =
solver.lookupState<target_slot_analysis::TargetSlotLattice>(
op->getResult(0));
llvm::dbgs() << "Target slot for op " << *op << ": "
<< targetSlotLattice->getValue() << "\n";
});
// Annotate all arith ops with their target slot attribute, so that it can
// be matched in the DRR rules.
OpBuilder builder(context);
getOperation()->walk([&](Operation *op) {
if (op->getNumResults() == 0) return;
auto *targetSlotLattice =
solver.lookupState<target_slot_analysis::TargetSlotLattice>(
op->getResult(0));
if (targetSlotLattice->getValue().isInitialized()) {
op->setAttr(
"target_slot",
builder.getIndexAttr(targetSlotLattice->getValue().getValue()));
}
});

alignment::populateWithGenerated(patterns);
Expand Down
11 changes: 10 additions & 1 deletion lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp
Expand Up @@ -2,6 +2,7 @@

#include <cstdint>

#include "include/Dialect/Secret/IR/SecretOps.h"
#include "include/Dialect/TensorExt/IR/TensorExtOps.h"
#include "llvm/include/llvm/ADT/DenseSet.h" // from @llvm-project
#include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project
Expand Down Expand Up @@ -54,6 +55,9 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
llvm::TypeSwitch<Operation *, LogicalResult>(upstreamOpPtr)
.Case<arith::ConstantOp>(
[&](auto upstreamOp) { return success(); })
// Ignore generic ops
.template Case<secret::GenericOp>(
[&](auto upstreamOp) { return success(); })
.template Case<arith::AddIOp, arith::MulIOp>(
[&](auto upstreamOp) {
opCounts[upstreamOp->getName().getStringRef()]++;
Expand Down Expand Up @@ -129,7 +133,12 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
accessIndices.insert(accessIndex);
return success();
})
.Default([&](Operation *op) { return failure(); });
.Default([&](Operation *op) {
LLVM_DEBUG(llvm::dbgs() << "Not continuing because type switch "
"encountered unsupported op "
<< op->getName() << "\n");
return failure();
});

if (failed(result)) {
return;
Expand Down
6 changes: 5 additions & 1 deletion tests/simd/hamming_distance.mlir
@@ -1,5 +1,5 @@
// RUN: heir-opt --secretize=entry-function=hamming --wrap-generic --canonicalize --cse \
// RUN: --full-loop-unroll --insert-rotate --cse --canonicalize \
// RUN: --full-loop-unroll --cse --canonicalize --insert-rotate --cse --canonicalize \
// RUN: %s | FileCheck %s

// CHECK-LABEL: @hamming
Expand All @@ -10,10 +10,14 @@
// CHECK-NEXT: arith.addi
// CHECK-NEXT: tensor_ext.rotate
// CHECK-NEXT: arith.addi
// CHECK-NEXT: tensor_ext.rotate
// CHECK-NEXT: arith.addi
// CHECK-NEXT: tensor.extract
// CHECK-NEXT: secret.yield

// TODO(#521): support rotate-and-reduce when the input is already a series of incremental rotations,
// as this IR is currently lowered to 4-1 rotate operations to sum after doing (x-y)**2 in SIMD.

func.func @hamming(%arg0: tensor<4xi16> {secret.secret}, %arg1: tensor<4xi16> {secret.secret}) -> i16 {
%c0 = arith.constant 0 : index
%c0_si16 = arith.constant 0 : i16
Expand Down

0 comments on commit 48ca56d

Please sign in to comment.