Skip to content

Commit

Permalink
[AArch64] Add more complete support for BF16
Browse files Browse the repository at this point in the history
We can use a small amount of integer arithmetic to round FP32 to BF16
and extend BF16 to FP32.

While a number of operations still require promotion, this can be
reduced for some rather simple operations like abs, copysign, fneg but
these can be done in a follow-up.

A few neat optimizations are implemented:
- round-inexact-to-odd is used for F64 to BF16 rounding.
- quieting signaling NaNs for f32 -> bf16 tries to detect if a prior
  operation makes it unnecessary.
  • Loading branch information
majnemer committed Mar 3, 2024
1 parent 2435dcd commit 3dd6750
Show file tree
Hide file tree
Showing 13 changed files with 842 additions and 197 deletions.
5 changes: 3 additions & 2 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -1573,13 +1573,14 @@ class TargetLoweringBase {
assert((VT.isInteger() || VT.isFloatingPoint()) &&
"Cannot autopromote this type, add it with AddPromotedToType.");

uint64_t VTBits = VT.getScalarSizeInBits();
MVT NVT = VT;
do {
NVT = (MVT::SimpleValueType)(NVT.SimpleTy+1);
assert(NVT.isInteger() == VT.isInteger() && NVT != MVT::isVoid &&
"Didn't find type to promote to!");
} while (!isTypeLegal(NVT) ||
getOperationAction(Op, NVT) == Promote);
} while (VTBits >= NVT.getScalarSizeInBits() || !isTypeLegal(NVT) ||
getOperationAction(Op, NVT) == Promote);
return NVT;
}

Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/CodeGen/ValueTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ namespace llvm {
return changeExtendedVectorElementType(EltVT);
}

/// Return a VT for a type whose attributes match ourselves with the
/// exception of the element type that is chosen by the caller.
EVT changeElementType(EVT EltVT) const {
EltVT = EltVT.getScalarType();
return isVector() ? changeVectorElementType(EltVT) : EltVT;
}

/// Return the type converted to an equivalently sized integer or vector
/// with integer element type. Similar to changeVectorElementTypeToInteger,
/// but also handles scalars.
Expand Down
290 changes: 201 additions & 89 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ enum NodeType : unsigned {
FCMLEz,
FCMLTz,

// Round wide FP to narrow FP with inexact results to odd.
FCVTXN,

// Vector across-lanes addition
// Only the lower result lane is defined.
SADDV,
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/AArch64InstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -7547,7 +7547,7 @@ class BaseSIMDCmpTwoScalar<bit U, bits<2> size, bits<2> size2, bits<5> opcode,
let mayRaiseFPException = 1, Uses = [FPCR] in
class SIMDInexactCvtTwoScalar<bits<5> opcode, string asm>
: I<(outs FPR32:$Rd), (ins FPR64:$Rn), asm, "\t$Rd, $Rn", "",
[(set (f32 FPR32:$Rd), (int_aarch64_sisd_fcvtxn (f64 FPR64:$Rn)))]>,
[(set (f32 FPR32:$Rd), (AArch64fcvtxn (f64 FPR64:$Rn)))]>,
Sched<[WriteVd]> {
bits<5> Rd;
bits<5> Rn;
Expand Down
37 changes: 37 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,11 @@ def AArch64fcmgtz: SDNode<"AArch64ISD::FCMGTz", SDT_AArch64fcmpz>;
def AArch64fcmlez: SDNode<"AArch64ISD::FCMLEz", SDT_AArch64fcmpz>;
def AArch64fcmltz: SDNode<"AArch64ISD::FCMLTz", SDT_AArch64fcmpz>;

def AArch64fcvtxn_n: SDNode<"AArch64ISD::FCVTXN", SDTFPRoundOp>;
def AArch64fcvtxn: PatFrags<(ops node:$Rn),
[(f32 (int_aarch64_sisd_fcvtxn (f64 node:$Rn))),
(f32 (AArch64fcvtxn_n (f64 node:$Rn)))]>;

def AArch64bici: SDNode<"AArch64ISD::BICi", SDT_AArch64vecimm>;
def AArch64orri: SDNode<"AArch64ISD::ORRi", SDT_AArch64vecimm>;

Expand Down Expand Up @@ -1276,6 +1281,9 @@ def BFMLALTIdx : SIMDBF16MLALIndex<1, "bfmlalt", int_aarch64_neon_bfmlalt>;
def BFCVTN : SIMD_BFCVTN;
def BFCVTN2 : SIMD_BFCVTN2;

def : Pat<(v4bf16 (any_fpround (v4f32 V128:$Rn))),
(EXTRACT_SUBREG (BFCVTN V128:$Rn), dsub)>;

// Vector-scalar BFDOT:
// The second source operand of the 64-bit variant of BF16DOTlane is a 128-bit
// register (the instruction uses a single 32-bit lane from it), so the pattern
Expand All @@ -1296,6 +1304,8 @@ def : Pat<(v2f32 (int_aarch64_neon_bfdot

let Predicates = [HasNEONorSME, HasBF16] in {
def BFCVT : BF16ToSinglePrecision<"bfcvt">;
// Round FP32 to BF16.
def : Pat<(bf16 (any_fpround (f32 FPR32:$Rn))), (BFCVT $Rn)>;
}

// ARMv8.6A AArch64 matrix multiplication
Expand Down Expand Up @@ -4648,6 +4658,22 @@ let Predicates = [HasFullFP16] in {
//===----------------------------------------------------------------------===//

defm FCVT : FPConversion<"fcvt">;
// Helper to get bf16 into fp32.
def cvt_bf16_to_fp32 :
OutPatFrag<(ops node:$Rn),
(f32 (COPY_TO_REGCLASS
(i32 (UBFMWri
(i32 (COPY_TO_REGCLASS (INSERT_SUBREG (f32 (IMPLICIT_DEF)),
node:$Rn, hsub), GPR32)),
(i64 (i32shift_a (i64 16))),
(i64 (i32shift_b (i64 16))))),
FPR32))>;
// Pattern for bf16 -> fp32.
def : Pat<(f32 (any_fpextend (bf16 FPR16:$Rn))),
(cvt_bf16_to_fp32 FPR16:$Rn)>;
// Pattern for bf16 -> fp64.
def : Pat<(f64 (any_fpextend (bf16 FPR16:$Rn))),
(FCVTDSr (f32 (cvt_bf16_to_fp32 FPR16:$Rn)))>;

//===----------------------------------------------------------------------===//
// Floating point single operand instructions.
Expand Down Expand Up @@ -5002,6 +5028,9 @@ defm FCVTNU : SIMDTwoVectorFPToInt<1,0,0b11010, "fcvtnu",int_aarch64_neon_fcvtnu
defm FCVTN : SIMDFPNarrowTwoVector<0, 0, 0b10110, "fcvtn">;
def : Pat<(v4i16 (int_aarch64_neon_vcvtfp2hf (v4f32 V128:$Rn))),
(FCVTNv4i16 V128:$Rn)>;
//def : Pat<(concat_vectors V64:$Rd,
// (v4bf16 (any_fpround (v4f32 V128:$Rn)))),
// (FCVTNv8bf16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Rd, dsub), V128:$Rn)>;
def : Pat<(concat_vectors V64:$Rd,
(v4i16 (int_aarch64_neon_vcvtfp2hf (v4f32 V128:$Rn)))),
(FCVTNv8i16 (INSERT_SUBREG (IMPLICIT_DEF), V64:$Rd, dsub), V128:$Rn)>;
Expand Down Expand Up @@ -5686,6 +5715,11 @@ defm USQADD : SIMDTwoScalarBHSDTied< 1, 0b00011, "usqadd",
def : Pat<(v1i64 (AArch64vashr (v1i64 V64:$Rn), (i32 63))),
(CMLTv1i64rz V64:$Rn)>;

// Round FP64 to BF16.
let Predicates = [HasNEONorSME, HasBF16] in
def : Pat<(bf16 (any_fpround (f64 FPR64:$Rn))),
(BFCVT (FCVTXNv1i64 $Rn))>;

def : Pat<(v1i64 (int_aarch64_neon_fcvtas (v1f64 FPR64:$Rn))),
(FCVTASv1i64 FPR64:$Rn)>;
def : Pat<(v1i64 (int_aarch64_neon_fcvtau (v1f64 FPR64:$Rn))),
Expand Down Expand Up @@ -7698,6 +7732,9 @@ def : Pat<(v4i32 (anyext (v4i16 V64:$Rn))), (USHLLv4i16_shift V64:$Rn, (i32 0))>
def : Pat<(v2i64 (sext (v2i32 V64:$Rn))), (SSHLLv2i32_shift V64:$Rn, (i32 0))>;
def : Pat<(v2i64 (zext (v2i32 V64:$Rn))), (USHLLv2i32_shift V64:$Rn, (i32 0))>;
def : Pat<(v2i64 (anyext (v2i32 V64:$Rn))), (USHLLv2i32_shift V64:$Rn, (i32 0))>;
// Vector bf16 -> fp32 is implemented morally as a zext + shift.
def : Pat<(v4f32 (any_fpextend (v4bf16 V64:$Rn))),
(USHLLv4i16_shift V64:$Rn, (i32 16))>;
// Also match an extend from the upper half of a 128 bit source register.
def : Pat<(v8i16 (anyext (v8i8 (extract_high_v16i8 (v16i8 V128:$Rn)) ))),
(USHLLv16i8_shift V128:$Rn, (i32 0))>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1022,13 +1022,17 @@ void applyLowerVectorFCMP(MachineInstr &MI, MachineRegisterInfo &MRI,

bool Invert = false;
AArch64CC::CondCode CC, CC2 = AArch64CC::AL;
if (Pred == CmpInst::Predicate::FCMP_ORD && IsZero) {
if ((Pred == CmpInst::Predicate::FCMP_ORD ||
Pred == CmpInst::Predicate::FCMP_UNO) &&
IsZero) {
// The special case "fcmp ord %a, 0" is the canonical check that LHS isn't
// NaN, so equivalent to a == a and doesn't need the two comparisons an
// "ord" normally would.
// Similarly, "fcmp uno %a, 0" is the canonical check that LHS is NaN and is
// thus equivalent to a != a.
RHS = LHS;
IsZero = false;
CC = AArch64CC::EQ;
CC = Pred == CmpInst::Predicate::FCMP_ORD ? AArch64CC::EQ : AArch64CC::NE;
} else
changeVectorFCMPPredToAArch64CC(Pred, CC, CC2, Invert);

Expand Down
8 changes: 4 additions & 4 deletions llvm/test/Analysis/CostModel/AArch64/reduce-fadd.ll
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ define void @strict_fp_reductions() {
; CHECK-NEXT: Cost Model: Found an estimated cost of 28 for instruction: %fadd_v8f32 = call float @llvm.vector.reduce.fadd.v8f32(float 0.000000e+00, <8 x float> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 6 for instruction: %fadd_v2f64 = call double @llvm.vector.reduce.fadd.v2f64(double 0.000000e+00, <2 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f64 = call double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f8 = call bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR0000, <4 x bfloat> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 18 for instruction: %fadd_v4f8 = call bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR0000, <4 x bfloat> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 20 for instruction: %fadd_v4f128 = call fp128 @llvm.vector.reduce.fadd.v4f128(fp128 undef, <4 x fp128> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret void
;
Expand All @@ -24,7 +24,7 @@ define void @strict_fp_reductions() {
; FP16-NEXT: Cost Model: Found an estimated cost of 28 for instruction: %fadd_v8f32 = call float @llvm.vector.reduce.fadd.v8f32(float 0.000000e+00, <8 x float> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 6 for instruction: %fadd_v2f64 = call double @llvm.vector.reduce.fadd.v2f64(double 0.000000e+00, <2 x double> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f64 = call double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f8 = call bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR0000, <4 x bfloat> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 18 for instruction: %fadd_v4f8 = call bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR0000, <4 x bfloat> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 20 for instruction: %fadd_v4f128 = call fp128 @llvm.vector.reduce.fadd.v4f128(fp128 undef, <4 x fp128> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret void
;
Expand Down Expand Up @@ -72,7 +72,7 @@ define void @fast_fp_reductions() {
; CHECK-NEXT: Cost Model: Found an estimated cost of 5 for instruction: %fadd_v4f64_reassoc = call reassoc double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %fadd_v7f64 = call fast double @llvm.vector.reduce.fadd.v7f64(double 0.000000e+00, <7 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 15 for instruction: %fadd_v9f64_reassoc = call reassoc double @llvm.vector.reduce.fadd.v9f64(double 0.000000e+00, <9 x double> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 6 for instruction: %fadd_v4f8 = call reassoc bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR8000, <4 x bfloat> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %fadd_v4f8 = call reassoc bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR8000, <4 x bfloat> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f128 = call reassoc fp128 @llvm.vector.reduce.fadd.v4f128(fp128 undef, <4 x fp128> undef)
; CHECK-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret void
;
Expand All @@ -95,7 +95,7 @@ define void @fast_fp_reductions() {
; FP16-NEXT: Cost Model: Found an estimated cost of 5 for instruction: %fadd_v4f64_reassoc = call reassoc double @llvm.vector.reduce.fadd.v4f64(double 0.000000e+00, <4 x double> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 9 for instruction: %fadd_v7f64 = call fast double @llvm.vector.reduce.fadd.v7f64(double 0.000000e+00, <7 x double> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 15 for instruction: %fadd_v9f64_reassoc = call reassoc double @llvm.vector.reduce.fadd.v9f64(double 0.000000e+00, <9 x double> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 6 for instruction: %fadd_v4f8 = call reassoc bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR8000, <4 x bfloat> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 10 for instruction: %fadd_v4f8 = call reassoc bfloat @llvm.vector.reduce.fadd.v4bf16(bfloat 0xR8000, <4 x bfloat> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 12 for instruction: %fadd_v4f128 = call reassoc fp128 @llvm.vector.reduce.fadd.v4f128(fp128 undef, <4 x fp128> undef)
; FP16-NEXT: Cost Model: Found an estimated cost of 0 for instruction: ret void
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,18 +321,15 @@ body: |
bb.0:
liveins: $q0, $q1
; Should be inverted. Needs two compares.
; CHECK-LABEL: name: uno_zero
; CHECK: liveins: $q0, $q1
; CHECK-NEXT: {{ $}}
; CHECK-NEXT: %lhs:_(<2 x s64>) = COPY $q0
; CHECK-NEXT: [[FCMGEZ:%[0-9]+]]:_(<2 x s64>) = G_FCMGEZ %lhs
; CHECK-NEXT: [[FCMLTZ:%[0-9]+]]:_(<2 x s64>) = G_FCMLTZ %lhs
; CHECK-NEXT: [[OR:%[0-9]+]]:_(<2 x s64>) = G_OR [[FCMLTZ]], [[FCMGEZ]]
; CHECK-NEXT: [[FCMEQ:%[0-9]+]]:_(<2 x s64>) = G_FCMEQ %lhs, %lhs(<2 x s64>)
; CHECK-NEXT: [[C:%[0-9]+]]:_(s64) = G_CONSTANT i64 -1
; CHECK-NEXT: [[BUILD_VECTOR:%[0-9]+]]:_(<2 x s64>) = G_BUILD_VECTOR [[C]](s64), [[C]](s64)
; CHECK-NEXT: [[XOR:%[0-9]+]]:_(<2 x s64>) = G_XOR [[OR]], [[BUILD_VECTOR]]
; CHECK-NEXT: [[XOR:%[0-9]+]]:_(<2 x s64>) = G_XOR [[FCMEQ]], [[BUILD_VECTOR]]
; CHECK-NEXT: $q0 = COPY [[XOR]](<2 x s64>)
; CHECK-NEXT: RET_ReallyLR implicit $q0
%lhs:_(<2 x s64>) = COPY $q0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,7 @@ entry:
define <8 x bfloat> @insertzero_v4bf16(<4 x bfloat> %a) {
; CHECK-LABEL: insertzero_v4bf16:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: movi d4, #0000000000000000
; CHECK-NEXT: movi d5, #0000000000000000
; CHECK-NEXT: movi d6, #0000000000000000
; CHECK-NEXT: movi d7, #0000000000000000
; CHECK-NEXT: fmov d0, d0
; CHECK-NEXT: ret
entry:
%shuffle.i = shufflevector <4 x bfloat> %a, <4 x bfloat> zeroinitializer, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
Expand Down

1 comment on commit 3dd6750

@davemgreen
Copy link
Collaborator

@davemgreen davemgreen commented on 3dd6750 Mar 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi.

Where was the review for this? It is much larger than anything I would expect to be committed without review, seems to include several different things that I would expect to be committed separately and some of it still looks quite messy.

Thanks
Dave

Please sign in to comment.