diff --git a/llvm/include/llvm/IR/InstrTypes.h b/llvm/include/llvm/IR/InstrTypes.h index e8c2cba8418dc..41b6fcd2bd453 100644 --- a/llvm/include/llvm/IR/InstrTypes.h +++ b/llvm/include/llvm/IR/InstrTypes.h @@ -1058,6 +1058,16 @@ class CmpInst : public Instruction { static CmpInst *Create(OtherOps Op, Predicate predicate, Value *S1, Value *S2, const Twine &Name, BasicBlock *InsertAtEnd); + /// Construct a compare instruction, given the opcode, the predicate, + /// the two operands and the instruction to copy the flags from. Optionally + /// (if InstBefore is specified) insert the instruction into a BasicBlock + /// right before the specified instruction. The specified Instruction is + /// allowed to be a dereferenced end iterator. Create a CmpInst + static CmpInst *CreateWithFlags(OtherOps Op, Predicate Pred, Value *S1, + Value *S2, const Instruction *FlagsSource, + const Twine &Name = "", + Instruction *InsertBefore = nullptr); + /// Get the opcode casted to the right type OtherOps getOpcode() const { return static_cast(Instruction::getOpcode()); diff --git a/llvm/lib/IR/Instructions.cpp b/llvm/lib/IR/Instructions.cpp index 494d50f89e374..c9b160a127ced 100644 --- a/llvm/lib/IR/Instructions.cpp +++ b/llvm/lib/IR/Instructions.cpp @@ -4623,6 +4623,15 @@ CmpInst::Create(OtherOps Op, Predicate predicate, Value *S1, Value *S2, S1, S2, Name); } +CmpInst *CmpInst::CreateWithFlags(OtherOps Op, Predicate Pred, Value *S1, + Value *S2, const Instruction *FlagsSource, + const Twine &Name, + Instruction *InsertBefore) { + CmpInst *Inst = Create(Op, Pred, S1, S2, Name, InsertBefore); + Inst->copyIRFlags(FlagsSource); + return Inst; +} + void CmpInst::swapOperands() { if (ICmpInst *IC = dyn_cast(this)) IC->swapOperands(); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index c7f4fb17648c8..a6558ea558fea 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -487,7 +487,9 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { // extelt (cmp X, Y), Index --> cmp (extelt X, Index), (extelt Y, Index) Value *E0 = Builder.CreateExtractElement(X, Index); Value *E1 = Builder.CreateExtractElement(Y, Index); - return CmpInst::Create(cast(SrcVec)->getOpcode(), Pred, E0, E1); + Instruction *SrcInst = cast(SrcVec); + return CmpInst::CreateWithFlags(cast(SrcVec)->getOpcode(), Pred, + E0, E1, SrcInst); } if (auto *I = dyn_cast(SrcVec)) { diff --git a/llvm/test/Transforms/InstCombine/scalarization.ll b/llvm/test/Transforms/InstCombine/scalarization.ll index fe6dc526bd50e..5ab960ece54d9 100644 --- a/llvm/test/Transforms/InstCombine/scalarization.ll +++ b/llvm/test/Transforms/InstCombine/scalarization.ll @@ -341,6 +341,20 @@ define i1 @extractelt_vector_fcmp_constrhs_dynidx(<2 x float> %arg, i32 %idx) { ret i1 %ext } +define i1 @extractelt_vector_fcmp_copy_flags(<4 x float> %x, <4 x i1> %y) { +; CHECK-LABEL: @extractelt_vector_fcmp_copy_flags( +; CHECK-NEXT: [[TMP1:%.*]] = extractelement <4 x float> [[X:%.*]], i64 2 +; CHECK-NEXT: [[TMP2:%.*]] = fcmp nsz arcp oeq float [[TMP1]], 0.000000e+00 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement <4 x i1> [[Y:%.*]], i64 2 +; CHECK-NEXT: [[R:%.*]] = and i1 [[TMP2]], [[TMP3]] +; CHECK-NEXT: ret i1 [[R]] +; + %cmp = fcmp nsz arcp oeq <4 x float> %x, zeroinitializer + %and = and <4 x i1> %cmp, %y + %r = extractelement <4 x i1> %and, i32 2 + ret i1 %r +} + define i1 @extractelt_vector_fcmp_not_cheap_to_scalarize_multi_use(<2 x float> %arg0, <2 x float> %arg1, <2 x float> %arg2, i32 %idx) { ; ; CHECK-LABEL: @extractelt_vector_fcmp_not_cheap_to_scalarize_multi_use(