Skip to content

Commit

Permalink
[acc] Add attribute for combined constructs (#80319)
Browse files Browse the repository at this point in the history
Combined constructs are decomposed into separate operations. However,
this does not adhere to `acc` dialect's goal to be able to regenerate
semantically equivalent clauses as user's intent. Thus, add an attribute
to keep track of the combined constructs.
  • Loading branch information
razvanlupusoru committed Mar 7, 2024
1 parent cfdfeb4 commit a435e1f
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 7 deletions.
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/OpenACC/OpenACC.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ static constexpr StringLiteral getRoutineInfoAttrName() {
return StringLiteral("acc.routine_info");
}

static constexpr StringLiteral getCombinedConstructsAttrName() {
return CombinedConstructsTypeAttr::name;
}

struct RuntimeCounters
: public mlir::SideEffects::Resource::Base<RuntimeCounters> {
mlir::StringRef getName() final { return "AccRuntimeCounters"; }
Expand Down
34 changes: 30 additions & 4 deletions mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,24 @@ def GangArgTypeArrayAttr :
let constBuilderCall = ?;
}

// Combined constructs enumerations
def OpenACC_KernelsLoop : I32EnumAttrCase<"KernelsLoop", 1, "kernels_loop">;
def OpenACC_ParallelLoop : I32EnumAttrCase<"ParallelLoop", 2, "parallel_loop">;
def OpenACC_SerialLoop : I32EnumAttrCase<"SerialLoop", 3, "serial_loop">;

def OpenACC_CombinedConstructsType : I32EnumAttr<"CombinedConstructsType",
"Differentiate between combined constructs",
[OpenACC_KernelsLoop, OpenACC_ParallelLoop, OpenACC_SerialLoop]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::acc";
}

def OpenACC_CombinedConstructsAttr : EnumAttr<OpenACC_Dialect,
OpenACC_CombinedConstructsType,
"combined_constructs"> {
let assemblyFormat = [{ ```<` $value `>` }];
}

// Define a resource for the OpenACC runtime counters.
def OpenACC_RuntimeCounters : Resource<"::mlir::acc::RuntimeCounters">;

Expand Down Expand Up @@ -933,7 +951,8 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
OptionalAttr<DefaultValueAttr>:$defaultAttr);
OptionalAttr<DefaultValueAttr>:$defaultAttr,
UnitAttr:$combined);

let regions = (region AnyRegion:$region);

Expand Down Expand Up @@ -993,6 +1012,7 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
}];

let assemblyFormat = [{
( `combined` `(` `loop` `)` $combined^)?
oilist(
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
| `async` `(` custom<DeviceTypeOperands>($asyncOperands,
Expand Down Expand Up @@ -1068,7 +1088,8 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
Variadic<OpenACC_PointerLikeTypeInterface>:$gangFirstPrivateOperands,
OptionalAttr<SymbolRefArrayAttr>:$firstprivatizations,
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
OptionalAttr<DefaultValueAttr>:$defaultAttr);
OptionalAttr<DefaultValueAttr>:$defaultAttr,
UnitAttr:$combined);

let regions = (region AnyRegion:$region);

Expand Down Expand Up @@ -1109,6 +1130,7 @@ def OpenACC_SerialOp : OpenACC_Op<"serial",
}];

let assemblyFormat = [{
( `combined` `(` `loop` `)` $combined^)?
oilist(
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
| `async` `(` custom<DeviceTypeOperands>($asyncOperands,
Expand Down Expand Up @@ -1182,7 +1204,8 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
Optional<I1>:$selfCond,
UnitAttr:$selfAttr,
Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
OptionalAttr<DefaultValueAttr>:$defaultAttr);
OptionalAttr<DefaultValueAttr>:$defaultAttr,
UnitAttr:$combined);

let regions = (region AnyRegion:$region);

Expand Down Expand Up @@ -1242,6 +1265,7 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
}];

let assemblyFormat = [{
( `combined` `(` `loop` `)` $combined^)?
oilist(
`dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
| `async` `(` custom<DeviceTypeOperands>($asyncOperands,
Expand Down Expand Up @@ -1573,7 +1597,8 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",
Variadic<OpenACC_PointerLikeTypeInterface>:$privateOperands,
OptionalAttr<SymbolRefArrayAttr>:$privatizations,
Variadic<AnyType>:$reductionOperands,
OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes
OptionalAttr<SymbolRefArrayAttr>:$reductionRecipes,
OptionalAttr<OpenACC_CombinedConstructsAttr>:$combined
);

let results = (outs Variadic<AnyType>:$results);
Expand Down Expand Up @@ -1665,6 +1690,7 @@ def OpenACC_LoopOp : OpenACC_Op<"loop",

let hasCustomAssemblyFormat = 1;
let assemblyFormat = [{
custom<CombinedConstructsLoop>($combined)
oilist(
`gang` `` custom<GangClause>($gangOperands, type($gangOperands),
$gangOperandsArgType, $gangOperandsDeviceType,
Expand Down
51 changes: 51 additions & 0 deletions mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1283,6 +1283,50 @@ static void printDeviceTypeOperandsWithKeywordOnly(
p << ")";
}

static ParseResult
parseCombinedConstructsLoop(mlir::OpAsmParser &parser,
mlir::acc::CombinedConstructsTypeAttr &attr) {
if (succeeded(parser.parseOptionalKeyword("combined"))) {
if (parser.parseLParen())
return failure();
if (succeeded(parser.parseOptionalKeyword("kernels"))) {
attr = mlir::acc::CombinedConstructsTypeAttr::get(
parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
} else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
attr = mlir::acc::CombinedConstructsTypeAttr::get(
parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
} else if (succeeded(parser.parseOptionalKeyword("serial"))) {
attr = mlir::acc::CombinedConstructsTypeAttr::get(
parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
} else {
parser.emitError(parser.getCurrentLocation(),
"expected compute construct name");
return failure();
}
if (parser.parseRParen())
return failure();
}
return success();
}

static void
printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op,
mlir::acc::CombinedConstructsTypeAttr attr) {
if (attr) {
switch (attr.getValue()) {
case mlir::acc::CombinedConstructsType::KernelsLoop:
p << "combined(kernels)";
break;
case mlir::acc::CombinedConstructsType::ParallelLoop:
p << "combined(parallel)";
break;
case mlir::acc::CombinedConstructsType::SerialLoop:
p << "combined(serial)";
break;
};
}
}

//===----------------------------------------------------------------------===//
// SerialOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1851,6 +1895,13 @@ LogicalResult acc::LoopOp::verify() {
"reductions", false)))
return failure();

if (getCombined().has_value() &&
(getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
return emitError("unexpected combined constructs attribute");
}

// Check non-empty body().
if (getRegion().empty())
return emitError("expected non-empty body.");
Expand Down
40 changes: 40 additions & 0 deletions mlir/test/Dialect/OpenACC/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -738,3 +738,43 @@ func.func @acc_atomic_capture(%x: memref<i32>, %y: memref<i32>, %v: memref<i32>,
acc.terminator
}
}

// -----

func.func @acc_combined() {
// expected-error @below {{expected 'loop'}}
acc.parallel combined() {
}

return
}

// -----

func.func @acc_combined() {
// expected-error @below {{expected compute construct name}}
acc.loop combined(loop) {
}

return
}

// -----

func.func @acc_combined() {
// expected-error @below {{expected 'loop'}}
acc.parallel combined(parallel loop) {
}

return
}

// -----

func.func @acc_combined() {
// expected-error @below {{expected ')'}}
acc.loop combined(parallel loop) {
}

return
}
46 changes: 43 additions & 3 deletions mlir/test/Dialect/OpenACC/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1846,9 +1846,49 @@ func.func @acc_atomic_capture(%v: memref<i32>, %x: memref<i32>, %expr: i32) {

// -----

%c2 = arith.constant 2 : i32
%c1 = arith.constant 1 : i32
acc.parallel num_gangs({%c2 : i32} [#acc.device_type<default>], {%c1 : i32, %c1 : i32, %c1 : i32} [#acc.device_type<nvidia>]) {
// CHECK-LABEL: func.func @acc_num_gangs
func.func @acc_num_gangs() {
%c2 = arith.constant 2 : i32
%c1 = arith.constant 1 : i32
acc.parallel num_gangs({%c2 : i32} [#acc.device_type<default>], {%c1 : i32, %c1 : i32, %c1 : i32} [#acc.device_type<nvidia>]) {
}

return
}

// CHECK: acc.parallel num_gangs({%c2{{.*}} : i32} [#acc.device_type<default>], {%c1{{.*}} : i32, %c1{{.*}} : i32, %c1{{.*}} : i32} [#acc.device_type<nvidia>])

// -----

// CHECK-LABEL: func.func @acc_combined
func.func @acc_combined() {
acc.parallel combined(loop) {
acc.loop combined(parallel) {
acc.yield
}
acc.terminator
}

acc.kernels combined(loop) {
acc.loop combined(kernels) {
acc.yield
}
acc.terminator
}

acc.serial combined(loop) {
acc.loop combined(serial) {
acc.yield
}
acc.terminator
}

return
}

// CHECK: acc.parallel combined(loop)
// CHECK: acc.loop combined(parallel)
// CHECK: acc.kernels combined(loop)
// CHECK: acc.loop combined(kernels)
// CHECK: acc.serial combined(loop)
// CHECK: acc.loop combined(serial)

0 comments on commit a435e1f

Please sign in to comment.