Skip to content

Commit

Permalink
add target slot analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Mar 12, 2024
1 parent 54fc80d commit fea6f58
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 5 deletions.
9 changes: 9 additions & 0 deletions include/Analysis/TargetSlotAnalysis/BUILD
@@ -0,0 +1,9 @@
# TargetSlotAnalysis analysis pass
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

exports_files(
["TargetSlotAnalysis.h"],
)
97 changes: 97 additions & 0 deletions include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h
@@ -0,0 +1,97 @@
#ifndef INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_
#define INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_

#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace target_slot_analysis {

class TargetSlot {
public:
TargetSlot() : value(std::nullopt) {}
TargetSlot(int64_t value) : value(value) {}
~TargetSlot() = default;

/// Whether the slot target is initialized. It can be uninitialized when the
/// state hasn't been set during the analysis.
bool isInitialized() const { return value.has_value(); }

/// Get a known slot target.
const int64_t &getValue() const {
assert(isInitialized());
return *value;
}

bool operator==(const TargetSlot &rhs) const { return value == rhs.value; }

/// Join two target slots.
static TargetSlot join(const TargetSlot &lhs, const TargetSlot &rhs) {
if (!lhs.isInitialized()) return rhs;
if (!rhs.isInitialized()) return lhs;
// If they are both initialized, use an arbitrary deterministic rule to
// select one. A more sophisticated analysis could try to determine which
// slot is more likely to lead to beneficial optimizations.
return TargetSlot{lhs.getValue() < rhs.getValue() ? lhs.getValue()
: rhs.getValue()};
}

void print(raw_ostream &os) const { os << value; }

private:
/// The target slot, if known.
std::optional<int64_t> value;

friend mlir::Diagnostic &operator<<(mlir::Diagnostic &diagnostic,
const TargetSlot &foo) {
if (foo.isInitialized()) {
return diagnostic << foo.getValue();
}
return diagnostic << "uninitialized";
}
};

inline raw_ostream &operator<<(raw_ostream &os, const TargetSlot &v) {
v.print(os);
return os;
}

class TargetSlotLattice : public dataflow::Lattice<TargetSlot> {
public:
using Lattice::Lattice;
};

/// An analysis that identifies a target slot for an SSA value in a program.
/// This is used by downstream passes to determine how to align rotations in
/// vectorized FHE schemes.
///
/// We use a backward dataflow analysis because the target slot propagates
/// backward from its final use to the arithmetic operations at which rotations
/// can be optimized.
class TargetSlotAnalysis
: public dataflow::SparseBackwardDataFlowAnalysis<TargetSlotLattice> {
public:
explicit TargetSlotAnalysis(DataFlowSolver &solver,
SymbolTableCollection &symbolTable)
: SparseBackwardDataFlowAnalysis(solver, symbolTable) {}
~TargetSlotAnalysis() override = default;
using SparseBackwardDataFlowAnalysis::SparseBackwardDataFlowAnalysis;

// Given the computed results of the operation, update its operand lattice
// values.
void visitOperation(Operation *op, ArrayRef<TargetSlotLattice *> operands,
ArrayRef<const TargetSlotLattice *> results) override;

void visitBranchOperand(OpOperand &operand) override{};
void visitCallOperand(OpOperand &operand) override{};
void setToExitState(TargetSlotLattice *lattice) override{};
};

} // namespace target_slot_analysis
} // namespace heir
} // namespace mlir

#endif // INCLUDE_ANALYSIS_TARGETSLOTANALYSIS_TARGETSLOTANALYSIS_H_
18 changes: 18 additions & 0 deletions lib/Analysis/TargetSlotAnalysis/BUILD
@@ -0,0 +1,18 @@
# TargetSlotAnalysis analysis pass
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "TargetSlotAnalysis",
srcs = ["TargetSlotAnalysis.cpp"],
hdrs = ["@heir//include/Analysis/TargetSlotAnalysis:TargetSlotAnalysis.h"],
deps = [
"@heir//lib/Dialect:Utils",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:TensorDialect",
],
)
55 changes: 55 additions & 0 deletions lib/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.cpp
@@ -0,0 +1,55 @@
#include "include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h"

#include "lib/Dialect/Utils.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project

#define DEBUG_TYPE "target-slot-analysis"

namespace mlir {
namespace heir {
namespace target_slot_analysis {

void TargetSlotAnalysis::visitOperation(
Operation *op, ArrayRef<TargetSlotLattice *> operands,
ArrayRef<const TargetSlotLattice *> results) {
llvm::TypeSwitch<Operation &>(*op)
.Case<tensor::InsertOp>([&](auto insertOp) {
LLVM_DEBUG({ llvm::dbgs() << "Visiting: " << *op << "\n"; });
auto insertIndexRes = get1DExtractionIndex<tensor::InsertOp>(insertOp);
// If the target slot can't be statically determined, we can't
// propagate anything through the IR.
if (failed(insertIndexRes)) return;

// The target slot propagates to the value inserted, which is the first
// positional argument
TargetSlotLattice *lattice = operands[0];
TargetSlot newSlot = TargetSlot{insertIndexRes.value()};
LLVM_DEBUG({
llvm::dbgs() << "Joining " << lattice->getValue() << " and "
<< newSlot << " --> "
<< TargetSlot::join(lattice->getValue(), newSlot)
<< "\n";
});
ChangeResult changed = lattice->join(newSlot);
propagateIfChanged(lattice, changed);
})
.Default([&](Operation &op) {
// By default, an op propagates its result target slots to all its
// operands.
for (const TargetSlotLattice *r : results) {
for (TargetSlotLattice *operand : operands) {
ChangeResult result = operand->join(*r);
propagateIfChanged(operand, result);
}
}
});
}

} // namespace target_slot_analysis
} // namespace heir
} // namespace mlir
1 change: 1 addition & 0 deletions lib/Dialect/TensorExt/Transforms/BUILD
Expand Up @@ -27,6 +27,7 @@ cc_library(
"@heir//include/Dialect/TensorExt/IR:canonicalize_inc_gen",
"@heir//include/Dialect/TensorExt/Transforms:insert_rotate_inc_gen",
"@heir//include/Dialect/TensorExt/Transforms:pass_inc_gen",
"@heir//lib/Analysis/TargetSlotAnalysis",
"@heir//lib/Dialect/TensorExt/IR:Dialect",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:IR",
Expand Down
42 changes: 37 additions & 5 deletions lib/Dialect/TensorExt/Transforms/InsertRotate.cpp
Expand Up @@ -2,14 +2,21 @@

#include <utility>

#include "include/Analysis/TargetSlotAnalysis/TargetSlotAnalysis.h"
#include "include/Dialect/TensorExt/IR/TensorExtOps.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/include/mlir/IR/Matchers.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project

#define DEBUG_TYPE "insert-rotate"

namespace mlir {
namespace heir {
namespace tensor_ext {
Expand All @@ -33,6 +40,31 @@ struct InsertRotate : impl::InsertRotateBase<InsertRotate> {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);

SymbolTableCollection symbolTable;
DataFlowSolver solver;
// These two upstream analyses are required dependencies for any sparse
// dataflow analysis, or else the analysis will be a no-op. Cf.
// https://github.com/llvm/llvm-project/issues/58922
solver.load<dataflow::DeadCodeAnalysis>();
solver.load<dataflow::SparseConstantPropagation>();
solver.load<target_slot_analysis::TargetSlotAnalysis>(symbolTable);
if (failed(solver.initializeAndRun(getOperation()))) {
getOperation()->emitOpError() << "Failed to run the analysis.\n";
signalPassFailure();
return;
}

LLVM_DEBUG({
getOperation()->walk([&](Operation *op) {
if (op->getNumResults() == 0) return;
auto *targetSlotLattice =
solver.lookupState<target_slot_analysis::TargetSlotLattice>(
op->getResult(0));
llvm::dbgs() << "Target slot for op " << *op << ": "
<< targetSlotLattice->getValue() << "\n";
});
});

alignment::populateWithGenerated(patterns);
canonicalization::populateWithGenerated(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
Expand Down

0 comments on commit fea6f58

Please sign in to comment.