Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use target slot analysis to simplify insert-rotate #530

Merged
merged 1 commit into from Mar 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 19 additions & 6 deletions include/Dialect/TensorExt/Transforms/InsertRotate.td
Expand Up @@ -7,7 +7,12 @@ include "mlir/Dialect/Arith/IR/ArithOps.td"
include "mlir/Dialect/Tensor/IR/TensorOps.td"
include "mlir/IR/PatternBase.td"

// TODO(#512): Support target slot selection when the downstream op is an insert.
// Get the target_slot attribute from an op, if it exists, or else
// return a zero index attribute.
def GetTargetSlotAttr : NativeCodeCall<
"$0.getOwner()->hasAttr(\"target_slot\")"
" ? llvm::cast<mlir::IntegerAttr>($0.getOwner()->getAttr(\"target_slot\"))"
" : $_builder.getIndexAttr(0)">;

// The patterns in this file are intended to align with the automatic-SIMD
// batching heuristics from the HECO project. See section 4.4 of
Expand All @@ -20,23 +25,31 @@ include "mlir/IR/PatternBase.td"
// canonicalization patterns will remove duplicated rotations.
foreach ArithOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in {
def InsertRotations_#ArithOp : Pattern<
(ArithOp
(ArithOp:$arithOp
(Tensor_ExtractOp $t1, (variadic $i1)),
(Tensor_ExtractOp $t2, (variadic $i2)),
$overflow),
[
(TensorExt_RotateOp:$r1 $t1, $i1),
(TensorExt_RotateOp:$r2 $t2, $i2),
(TensorExt_RotateOp:$r1 $t1,
(Arith_SubIOp $i1, (Arith_ConstantOp (GetTargetSlotAttr $arithOp)), DefOverflow)),
(TensorExt_RotateOp:$r2 $t2,
(Arith_SubIOp $i2, (Arith_ConstantOp (GetTargetSlotAttr $arithOp)), DefOverflow)),
(ArithOp:$opResult $r1, $r2, $overflow),
(Tensor_ExtractOp
$opResult,
(MakeSingleResultVariadic (Arith_ConstantOp ConstantAttr<IndexAttr, "0">))),
(MakeSingleResultVariadic
(Arith_ConstantOp (GetTargetSlotAttr $arithOp)))),
]
>;
}


// Pre-align the first op's operands to the index that the result is
// used for in a subsequent op.
// used for in a subsequent op. This is used to simplify the IR
// primarily when there is no specific slot target selected for an op. In
// that case, the above pattern will still replace extractions with
// rotations, and the simplifications will occur by replacing triples
// of rotations with pairs.
// TODO(#514): handle OuterOp with two different InnerOps on the LHS and RHS
foreach InnerOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in {
foreach OuterOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in {
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TensorExt/Transforms/BUILD
Expand Up @@ -69,6 +69,7 @@ cc_library(
],
deps = [
"@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen",
"@heir//lib/Dialect/Secret/IR:Dialect",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
Expand Down
22 changes: 13 additions & 9 deletions lib/Dialect/TensorExt/Transforms/InsertRotate.cpp
Expand Up @@ -54,15 +54,19 @@ struct InsertRotate : impl::InsertRotateBase<InsertRotate> {
return;
}

LLVM_DEBUG({
getOperation()->walk([&](Operation *op) {
if (op->getNumResults() == 0) return;
auto *targetSlotLattice =
solver.lookupState<target_slot_analysis::TargetSlotLattice>(
op->getResult(0));
llvm::dbgs() << "Target slot for op " << *op << ": "
<< targetSlotLattice->getValue() << "\n";
});
// Annotate all arith ops with their target slot attribute, so that it can
// be matched in the DRR rules.
OpBuilder builder(context);
getOperation()->walk([&](Operation *op) {
if (op->getNumResults() == 0) return;
auto *targetSlotLattice =
solver.lookupState<target_slot_analysis::TargetSlotLattice>(
op->getResult(0));
if (targetSlotLattice->getValue().isInitialized()) {
op->setAttr(
"target_slot",
builder.getIndexAttr(targetSlotLattice->getValue().getValue()));
}
});

alignment::populateWithGenerated(patterns);
Expand Down
11 changes: 10 additions & 1 deletion lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp
Expand Up @@ -2,6 +2,7 @@

#include <cstdint>

#include "include/Dialect/Secret/IR/SecretOps.h"
#include "include/Dialect/TensorExt/IR/TensorExtOps.h"
#include "llvm/include/llvm/ADT/DenseSet.h" // from @llvm-project
#include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project
Expand Down Expand Up @@ -54,6 +55,9 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
llvm::TypeSwitch<Operation *, LogicalResult>(upstreamOpPtr)
.Case<arith::ConstantOp>(
[&](auto upstreamOp) { return success(); })
// Ignore generic ops
.template Case<secret::GenericOp>(
[&](auto upstreamOp) { return success(); })
AlexanderViand-Intel marked this conversation as resolved.
Show resolved Hide resolved
.template Case<arith::AddIOp, arith::MulIOp>(
[&](auto upstreamOp) {
opCounts[upstreamOp->getName().getStringRef()]++;
Expand Down Expand Up @@ -129,7 +133,12 @@ struct RotateAndReduce : impl::RotateAndReduceBase<RotateAndReduce> {
accessIndices.insert(accessIndex);
return success();
})
.Default([&](Operation *op) { return failure(); });
.Default([&](Operation *op) {
LLVM_DEBUG(llvm::dbgs() << "Not continuing because type switch "
"encountered unsupported op "
<< op->getName() << "\n");
return failure();
});

if (failed(result)) {
return;
Expand Down
6 changes: 5 additions & 1 deletion tests/simd/hamming_distance.mlir
@@ -1,5 +1,5 @@
// RUN: heir-opt --secretize=entry-function=hamming --wrap-generic --canonicalize --cse \
// RUN: --full-loop-unroll --insert-rotate --cse --canonicalize \
// RUN: --full-loop-unroll --cse --canonicalize --insert-rotate --cse --canonicalize \
j2kun marked this conversation as resolved.
Show resolved Hide resolved
// RUN: %s | FileCheck %s

// CHECK-LABEL: @hamming
Expand All @@ -10,10 +10,14 @@
// CHECK-NEXT: arith.addi
// CHECK-NEXT: tensor_ext.rotate
// CHECK-NEXT: arith.addi
// CHECK-NEXT: tensor_ext.rotate
// CHECK-NEXT: arith.addi
// CHECK-NEXT: tensor.extract
// CHECK-NEXT: secret.yield

// TODO(#521): support rotate-and-reduce when the input is already a series of incremental rotations,
// as this IR is currently lowered to 4-1 rotate operations to sum after doing (x-y)**2 in SIMD.

func.func @hamming(%arg0: tensor<4xi16> {secret.secret}, %arg1: tensor<4xi16> {secret.secret}) -> i16 {
%c0 = arith.constant 0 : index
%c0_si16 = arith.constant 0 : i16
Expand Down