Skip to content

Commit

Permalink
[flang][Lower] Convert OMP Map and related functions to evaluate::Expr (
Browse files Browse the repository at this point in the history
#81626)

The related functions are `gatherDataOperandAddrAndBounds` and
`genBoundsOps`. The former is used in OpenACC as well, and it was
updated to pass evaluate::Expr instead of parser objects.

The difference in the test case comes from unfolded conversions of index
expressions, which are explicitly of type integer(kind=8).

Delete now unused `findRepeatableClause2` and `findClause2`.

Add `AsGenericExpr` that takes std::optional. It already returns
optional Expr. Making it accept an optional Expr as input would reduce
the number of necessary checks when handling frequent optional values in
evaluator.

[Clause representation 4/6]
  • Loading branch information
kparzysz committed Mar 20, 2024
1 parent 0177a95 commit 8411549
Show file tree
Hide file tree
Showing 8 changed files with 335 additions and 268 deletions.
8 changes: 8 additions & 0 deletions flang/include/flang/Evaluate/tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,14 @@ inline Expr<SomeType> AsGenericExpr(Expr<SomeType> &&x) { return std::move(x); }
std::optional<Expr<SomeType>> AsGenericExpr(DataRef &&);
std::optional<Expr<SomeType>> AsGenericExpr(const Symbol &);

// Propagate std::optional from input to output.
template <typename A>
std::optional<Expr<SomeType>> AsGenericExpr(std::optional<A> &&x) {
if (!x)
return std::nullopt;
return AsGenericExpr(std::move(*x));
}

template <typename A>
common::IfNoLvalue<Expr<SomeKind<ResultType<A>::category>>, A> AsCategoryExpr(
A &&x) {
Expand Down
389 changes: 234 additions & 155 deletions flang/lib/Lower/DirectivesCommon.h

Large diffs are not rendered by default.

54 changes: 35 additions & 19 deletions flang/lib/Lower/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,11 @@ getSymbolFromAccObject(const Fortran::parser::AccObject &accObject) {
Fortran::parser::GetLastName(arrayElement->base);
return *name.symbol;
}
if (const auto *component =
Fortran::parser::Unwrap<Fortran::parser::StructureComponent>(
*designator)) {
return *component->component.symbol;
}
} else if (const auto *name =
std::get_if<Fortran::parser::Name>(&accObject.u)) {
return *name->symbol;
Expand All @@ -286,17 +291,20 @@ genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
mlir::acc::DataClause dataClause, bool structured,
bool implicit, bool setDeclareAttr = false) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objectList.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
Fortran::semantics::MaybeExpr designator =
std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
stmtCtx, accObject, operandLocation,
asFortran, bounds,
/*treatIndexAsSection=*/true);
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
converter, builder, semanticsContext, stmtCtx, symbol, designator,
operandLocation, asFortran, bounds,
/*treatIndexAsSection=*/true);

// If the input value is optional and is not a descriptor, we use the
// rawInput directly.
Expand All @@ -321,16 +329,19 @@ static void genDeclareDataOperandOperations(
llvm::SmallVectorImpl<mlir::Value> &dataOperands,
mlir::acc::DataClause dataClause, bool structured, bool implicit) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objectList.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
Fortran::semantics::MaybeExpr designator =
std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
stmtCtx, accObject, operandLocation,
asFortran, bounds);
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
converter, builder, semanticsContext, stmtCtx, symbol, designator,
operandLocation, asFortran, bounds);
EntryOp op = createDataEntryOp<EntryOp>(
builder, operandLocation, info.addr, asFortran, bounds, structured,
implicit, dataClause, info.addr.getType());
Expand All @@ -339,8 +350,7 @@ static void genDeclareDataOperandOperations(
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(info.addr.getType()))) {
mlir::OpBuilder modBuilder(builder.getModule().getBodyRegion());
modBuilder.setInsertionPointAfter(builder.getFunction());
std::string prefix =
converter.mangleName(getSymbolFromAccObject(accObject));
std::string prefix = converter.mangleName(symbol);
createDeclareAllocFuncWithArg<EntryOp>(
modBuilder, builder, operandLocation, info.addr.getType(), prefix,
asFortran, dataClause);
Expand Down Expand Up @@ -770,16 +780,19 @@ genPrivatizations(const Fortran::parser::AccObjectList &objectList,
llvm::SmallVectorImpl<mlir::Value> &dataOperands,
llvm::SmallVector<mlir::Attribute> &privatizations) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objectList.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
Fortran::semantics::MaybeExpr designator =
std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
stmtCtx, accObject, operandLocation,
asFortran, bounds);
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
converter, builder, semanticsContext, stmtCtx, symbol, designator,
operandLocation, asFortran, bounds);
RecipeOp recipe;
mlir::Type retTy = getTypeFromBounds(bounds, info.addr.getType());
if constexpr (std::is_same_v<RecipeOp, mlir::acc::PrivateRecipeOp>) {
Expand Down Expand Up @@ -1340,16 +1353,19 @@ genReductions(const Fortran::parser::AccObjectListWithReduction &objectList,
const auto &op =
std::get<Fortran::parser::AccReductionOperator>(objectList.t);
mlir::acc::ReductionOperator mlirOp = getReductionOperator(op);
Fortran::evaluate::ExpressionAnalyzer ea{semanticsContext};
for (const auto &accObject : objects.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
mlir::Location operandLocation = genOperandLocation(converter, accObject);
Fortran::semantics::Symbol &symbol = getSymbolFromAccObject(accObject);
Fortran::semantics::MaybeExpr designator =
std::visit([&](auto &&s) { return ea.Analyze(s); }, accObject.u);
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
Fortran::parser::AccObject, mlir::acc::DataBoundsOp,
mlir::acc::DataBoundsType>(converter, builder, semanticsContext,
stmtCtx, accObject, operandLocation,
asFortran, bounds);
mlir::acc::DataBoundsOp, mlir::acc::DataBoundsType>(
converter, builder, semanticsContext, stmtCtx, symbol, designator,
operandLocation, asFortran, bounds);

mlir::Type reductionTy = fir::unwrapRefType(info.addr.getType());
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(reductionTy))
Expand Down
44 changes: 20 additions & 24 deletions flang/lib/Lower/OpenMP/ClauseProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -818,65 +818,61 @@ bool ClauseProcessor::processMap(
llvm::SmallVectorImpl<const Fortran::semantics::Symbol *> *mapSymbols)
const {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
return findRepeatableClause2<ClauseTy::Map>(
[&](const ClauseTy::Map *mapClause,
return findRepeatableClause<omp::clause::Map>(
[&](const omp::clause::Map &clause,
const Fortran::parser::CharBlock &source) {
using Map = omp::clause::Map;
mlir::Location clauseLocation = converter.genLocation(source);
const auto &oMapType =
std::get<std::optional<Fortran::parser::OmpMapType>>(
mapClause->v.t);
const auto &oMapType = std::get<std::optional<Map::MapType>>(clause.t);
llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
// If the map type is specified, then process it else Tofrom is the
// default.
if (oMapType) {
const Fortran::parser::OmpMapType::Type &mapType =
std::get<Fortran::parser::OmpMapType::Type>(oMapType->t);
const Map::MapType::Type &mapType =
std::get<Map::MapType::Type>(oMapType->t);
switch (mapType) {
case Fortran::parser::OmpMapType::Type::To:
case Map::MapType::Type::To:
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
break;
case Fortran::parser::OmpMapType::Type::From:
case Map::MapType::Type::From:
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
break;
case Fortran::parser::OmpMapType::Type::Tofrom:
case Map::MapType::Type::Tofrom:
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
break;
case Fortran::parser::OmpMapType::Type::Alloc:
case Fortran::parser::OmpMapType::Type::Release:
case Map::MapType::Type::Alloc:
case Map::MapType::Type::Release:
// alloc and release is the default map_type for the Target Data
// Ops, i.e. if no bits for map_type is supplied then alloc/release
// is implicitly assumed based on the target directive. Default
// value for Target Data and Enter Data is alloc and for Exit Data
// it is release.
break;
case Fortran::parser::OmpMapType::Type::Delete:
case Map::MapType::Type::Delete:
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
}

if (std::get<std::optional<Fortran::parser::OmpMapType::Always>>(
oMapType->t))
if (std::get<std::optional<Map::MapType::Always>>(oMapType->t))
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
} else {
mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
}

for (const Fortran::parser::OmpObject &ompObject :
std::get<Fortran::parser::OmpObjectList>(mapClause->v.t).v) {
for (const omp::Object &object : std::get<omp::ObjectList>(clause.t)) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;

Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
Fortran::parser::OmpObject, mlir::omp::MapBoundsOp,
mlir::omp::MapBoundsType>(
converter, firOpBuilder, semaCtx, stmtCtx, ompObject,
clauseLocation, asFortran, bounds, treatIndexAsSection);
mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
converter, firOpBuilder, semaCtx, stmtCtx, *object.id(),
object.ref(), clauseLocation, asFortran, bounds,
treatIndexAsSection);

auto origSymbol =
converter.getSymbolAddress(*getOmpObjectSymbol(ompObject));
auto origSymbol = converter.getSymbolAddress(*object.id());
mlir::Value symAddr = info.addr;
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
symAddr = origSymbol;
Expand All @@ -899,7 +895,7 @@ bool ClauseProcessor::processMap(
mapSymLocs->push_back(symAddr.getLoc());

if (mapSymbols)
mapSymbols->push_back(getOmpObjectSymbol(ompObject));
mapSymbols->push_back(object.id());
}
});
}
Expand Down
59 changes: 11 additions & 48 deletions flang/lib/Lower/OpenMP/ClauseProcessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,6 @@ class ClauseProcessor {
/// Utility to find a clause within a range in the clause list.
template <typename T>
static ClauseIterator findClause(ClauseIterator begin, ClauseIterator end);
template <typename T>
static ClauseIterator2 findClause2(ClauseIterator2 begin,
ClauseIterator2 end);

/// Return the first instance of the given clause found in the clause list or
/// `nullptr` if not present. If more than one instance is expected, use
Expand All @@ -179,10 +176,6 @@ class ClauseProcessor {
bool findRepeatableClause(
std::function<void(const T &, const Fortran::parser::CharBlock &source)>
callbackFn) const;
template <typename T>
bool findRepeatableClause2(
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
callbackFn) const;

/// Set the `result` to a new `mlir::UnitAttr` if the clause is present.
template <typename T>
Expand All @@ -198,32 +191,31 @@ template <typename T>
bool ClauseProcessor::processMotionClauses(
Fortran::lower::StatementContext &stmtCtx,
llvm::SmallVectorImpl<mlir::Value> &mapOperands) {
return findRepeatableClause2<T>(
[&](const T *motionClause, const Fortran::parser::CharBlock &source) {
return findRepeatableClause<T>(
[&](const T &clause, const Fortran::parser::CharBlock &source) {
mlir::Location clauseLocation = converter.genLocation(source);
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();

static_assert(std::is_same_v<T, ClauseProcessor::ClauseTy::To> ||
std::is_same_v<T, ClauseProcessor::ClauseTy::From>);
static_assert(std::is_same_v<T, omp::clause::To> ||
std::is_same_v<T, omp::clause::From>);

// TODO Support motion modifiers: present, mapper, iterator.
constexpr llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
std::is_same_v<T, ClauseProcessor::ClauseTy::To>
std::is_same_v<T, omp::clause::To>
? llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO
: llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;

for (const Fortran::parser::OmpObject &ompObject : motionClause->v.v) {
for (const omp::Object &object : clause.v) {
llvm::SmallVector<mlir::Value> bounds;
std::stringstream asFortran;
Fortran::lower::AddrAndBoundsInfo info =
Fortran::lower::gatherDataOperandAddrAndBounds<
Fortran::parser::OmpObject, mlir::omp::MapBoundsOp,
mlir::omp::MapBoundsType>(
converter, firOpBuilder, semaCtx, stmtCtx, ompObject,
clauseLocation, asFortran, bounds, treatIndexAsSection);
mlir::omp::MapBoundsOp, mlir::omp::MapBoundsType>(
converter, firOpBuilder, semaCtx, stmtCtx, *object.id(),
object.ref(), clauseLocation, asFortran, bounds,
treatIndexAsSection);

auto origSymbol =
converter.getSymbolAddress(*getOmpObjectSymbol(ompObject));
auto origSymbol = converter.getSymbolAddress(*object.id());
mlir::Value symAddr = info.addr;
if (origSymbol && fir::isTypeWithDescriptor(origSymbol.getType()))
symAddr = origSymbol;
Expand Down Expand Up @@ -273,17 +265,6 @@ ClauseProcessor::findClause(ClauseIterator begin, ClauseIterator end) {
return end;
}

template <typename T>
ClauseProcessor::ClauseIterator2
ClauseProcessor::findClause2(ClauseIterator2 begin, ClauseIterator2 end) {
for (ClauseIterator2 it = begin; it != end; ++it) {
if (std::get_if<T>(&it->u))
return it;
}

return end;
}

template <typename T>
const T *ClauseProcessor::findUniqueClause(
const Fortran::parser::CharBlock **source) const {
Expand Down Expand Up @@ -314,24 +295,6 @@ bool ClauseProcessor::findRepeatableClause(
return found;
}

template <typename T>
bool ClauseProcessor::findRepeatableClause2(
std::function<void(const T *, const Fortran::parser::CharBlock &source)>
callbackFn) const {
bool found = false;
ClauseIterator2 nextIt, endIt = clauses2.v.end();
for (ClauseIterator2 it = clauses2.v.begin(); it != endIt; it = nextIt) {
nextIt = findClause2<T>(it, endIt);

if (nextIt != endIt) {
callbackFn(&std::get<T>(nextIt->u), nextIt->source);
found = true;
++nextIt;
}
}
return found;
}

template <typename T>
bool ClauseProcessor::markClauseOccurrence(mlir::UnitAttr &result) const {
if (findUniqueClause<T>()) {
Expand Down
7 changes: 2 additions & 5 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -930,11 +930,8 @@ static OpTy genTargetEnterExitDataUpdateOp(
cp.processNowait(nowaitAttr);

if constexpr (std::is_same_v<OpTy, mlir::omp::TargetUpdateOp>) {
cp.processMotionClauses<Fortran::parser::OmpClause::To>(stmtCtx,
mapOperands);
cp.processMotionClauses<Fortran::parser::OmpClause::From>(stmtCtx,
mapOperands);

cp.processMotionClauses<clause::To>(stmtCtx, mapOperands);
cp.processMotionClauses<clause::From>(stmtCtx, mapOperands);
} else {
cp.processMap(currentLocation, directive, stmtCtx, mapOperands);
}
Expand Down
2 changes: 1 addition & 1 deletion flang/test/Lower/OpenACC/acc-bounds.f90
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ subroutine acc_optional_data3(a, n)
! CHECK: fir.result %c0{{.*}} : index
! CHECK: }
! CHECK: %[[BOUNDS:.*]] = acc.bounds lowerbound(%c0{{.*}} : index) upperbound(%{{.*}} : index) extent(%{{.*}} : index) stride(%[[STRIDE]] : index) startIdx(%c1 : index) {strideInBytes = true}
! CHECK: %[[NOCREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref<!fir.array<?xf32>>) bounds(%14) -> !fir.ref<!fir.array<?xf32>> {name = "a(1:n)"}
! CHECK: %[[NOCREATE:.*]] = acc.nocreate varPtr(%[[DECL_A]]#1 : !fir.ref<!fir.array<?xf32>>) bounds(%[[BOUNDS]]) -> !fir.ref<!fir.array<?xf32>> {name = "a(1:n)"}
! CHECK: acc.data dataOperands(%[[NOCREATE]] : !fir.ref<!fir.array<?xf32>>) {

end module

0 comments on commit 8411549

Please sign in to comment.