From eddb716ab976001c00127d69c96c4ace671f04e5 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 13 Mar 2024 15:00:08 -0700 Subject: [PATCH] use target slot analysis to simplify insert rotate pass --- .../TensorExt/Transforms/InsertRotate.td | 25 ++++++++++++++----- lib/Dialect/TensorExt/Transforms/BUILD | 1 + .../TensorExt/Transforms/InsertRotate.cpp | 22 +++++++++------- .../TensorExt/Transforms/RotateAndReduce.cpp | 11 +++++++- tests/simd/hamming_distance.mlir | 6 ++++- 5 files changed, 48 insertions(+), 17 deletions(-) diff --git a/include/Dialect/TensorExt/Transforms/InsertRotate.td b/include/Dialect/TensorExt/Transforms/InsertRotate.td index 7e977f1fb..b9e681d26 100644 --- a/include/Dialect/TensorExt/Transforms/InsertRotate.td +++ b/include/Dialect/TensorExt/Transforms/InsertRotate.td @@ -7,7 +7,12 @@ include "mlir/Dialect/Arith/IR/ArithOps.td" include "mlir/Dialect/Tensor/IR/TensorOps.td" include "mlir/IR/PatternBase.td" -// TODO(#512): Support target slot selection when the downstream op is an insert. +// Get the target_slot attribute from an op, if it exists, or else +// return a zero index attribute. +def GetTargetSlotAttr : NativeCodeCall< + "$0.getOwner()->hasAttr(\"target_slot\")" + " ? llvm::cast($0.getOwner()->getAttr(\"target_slot\"))" + " : $_builder.getIndexAttr(0)">; // The patterns in this file are intended to align with the automatic-SIMD // batching heuristics from the HECO project. See section 4.4 of @@ -20,23 +25,31 @@ include "mlir/IR/PatternBase.td" // canonicalization patterns will remove duplicated rotations. foreach ArithOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in { def InsertRotations_#ArithOp : Pattern< - (ArithOp + (ArithOp:$arithOp (Tensor_ExtractOp $t1, (variadic $i1)), (Tensor_ExtractOp $t2, (variadic $i2)), $overflow), [ - (TensorExt_RotateOp:$r1 $t1, $i1), - (TensorExt_RotateOp:$r2 $t2, $i2), + (TensorExt_RotateOp:$r1 $t1, + (Arith_SubIOp $i1, (Arith_ConstantOp (GetTargetSlotAttr $arithOp)), DefOverflow)), + (TensorExt_RotateOp:$r2 $t2, + (Arith_SubIOp $i2, (Arith_ConstantOp (GetTargetSlotAttr $arithOp)), DefOverflow)), (ArithOp:$opResult $r1, $r2, $overflow), (Tensor_ExtractOp $opResult, - (MakeSingleResultVariadic (Arith_ConstantOp ConstantAttr))), + (MakeSingleResultVariadic + (Arith_ConstantOp (GetTargetSlotAttr $arithOp)))), ] >; } + // Pre-align the first op's operands to the index that the result is -// used for in a subsequent op. +// used for in a subsequent op. This is used to simplify the IR +// primarily when there is no specific slot target selected for an op. In +// that case, the above pattern will still replace extractions with +// rotations, and the simplifications will occur by replacing triples +// of rotations with pairs. // TODO(#514): handle OuterOp with two different InnerOps on the LHS and RHS foreach InnerOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in { foreach OuterOp = [Arith_AddIOp, Arith_SubIOp, Arith_MulIOp] in { diff --git a/lib/Dialect/TensorExt/Transforms/BUILD b/lib/Dialect/TensorExt/Transforms/BUILD index 9e534b610..59f01d8b3 100644 --- a/lib/Dialect/TensorExt/Transforms/BUILD +++ b/lib/Dialect/TensorExt/Transforms/BUILD @@ -69,6 +69,7 @@ cc_library( ], deps = [ "@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen", + "@heir//lib/Dialect/Secret/IR:Dialect", "@heir//lib/Dialect/TensorExt/IR:Dialect", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", diff --git a/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp b/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp index 373980db7..816d6b5b5 100644 --- a/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp +++ b/lib/Dialect/TensorExt/Transforms/InsertRotate.cpp @@ -54,15 +54,19 @@ struct InsertRotate : impl::InsertRotateBase { return; } - LLVM_DEBUG({ - getOperation()->walk([&](Operation *op) { - if (op->getNumResults() == 0) return; - auto *targetSlotLattice = - solver.lookupState( - op->getResult(0)); - llvm::dbgs() << "Target slot for op " << *op << ": " - << targetSlotLattice->getValue() << "\n"; - }); + // Annotate all arith ops with their target slot attribute, so that it can + // be matched in the DRR rules. + OpBuilder builder(context); + getOperation()->walk([&](Operation *op) { + if (op->getNumResults() == 0) return; + auto *targetSlotLattice = + solver.lookupState( + op->getResult(0)); + if (targetSlotLattice->getValue().isInitialized()) { + op->setAttr( + "target_slot", + builder.getIndexAttr(targetSlotLattice->getValue().getValue())); + } }); alignment::populateWithGenerated(patterns); diff --git a/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp b/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp index fe5808271..7c14c4187 100644 --- a/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp +++ b/lib/Dialect/TensorExt/Transforms/RotateAndReduce.cpp @@ -2,6 +2,7 @@ #include +#include "include/Dialect/Secret/IR/SecretOps.h" #include "include/Dialect/TensorExt/IR/TensorExtOps.h" #include "llvm/include/llvm/ADT/DenseSet.h" // from @llvm-project #include "llvm/include/llvm/ADT/StringRef.h" // from @llvm-project @@ -54,6 +55,9 @@ struct RotateAndReduce : impl::RotateAndReduceBase { llvm::TypeSwitch(upstreamOpPtr) .Case( [&](auto upstreamOp) { return success(); }) + // Ignore generic ops + .template Case( + [&](auto upstreamOp) { return success(); }) .template Case( [&](auto upstreamOp) { opCounts[upstreamOp->getName().getStringRef()]++; @@ -129,7 +133,12 @@ struct RotateAndReduce : impl::RotateAndReduceBase { accessIndices.insert(accessIndex); return success(); }) - .Default([&](Operation *op) { return failure(); }); + .Default([&](Operation *op) { + LLVM_DEBUG(llvm::dbgs() << "Not continuing because type switch " + "encountered unsupported op " + << op->getName() << "\n"); + return failure(); + }); if (failed(result)) { return; diff --git a/tests/simd/hamming_distance.mlir b/tests/simd/hamming_distance.mlir index 10a8dfb94..f8aa39500 100644 --- a/tests/simd/hamming_distance.mlir +++ b/tests/simd/hamming_distance.mlir @@ -1,5 +1,5 @@ // RUN: heir-opt --secretize=entry-function=hamming --wrap-generic --canonicalize --cse \ -// RUN: --full-loop-unroll --insert-rotate --cse --canonicalize \ +// RUN: --full-loop-unroll --cse --canonicalize --insert-rotate --cse --canonicalize \ // RUN: %s | FileCheck %s // CHECK-LABEL: @hamming @@ -10,10 +10,14 @@ // CHECK-NEXT: arith.addi // CHECK-NEXT: tensor_ext.rotate // CHECK-NEXT: arith.addi +// CHECK-NEXT: tensor_ext.rotate // CHECK-NEXT: arith.addi // CHECK-NEXT: tensor.extract // CHECK-NEXT: secret.yield +// TODO(#521): support rotate-and-reduce when the input is already a series of incremental rotations, +// as this IR is currently lowered to 4-1 rotate operations to sum after doing (x-y)**2 in SIMD. + func.func @hamming(%arg0: tensor<4xi16> {secret.secret}, %arg1: tensor<4xi16> {secret.secret}) -> i16 { %c0 = arith.constant 0 : index %c0_si16 = arith.constant 0 : i16