LoopVectorizer: Preserve fast-math flags
[oota-llvm.git] / lib / Transforms / Vectorize / LoopVectorize.cpp
index 1ff347df6e7d59f13c3bf33c646f6d0ee6918af4..77633a562a6759549ea9c6fb77b40aa5ba0f7926 100644 (file)
@@ -433,11 +433,12 @@ public:
     InnerLoopVectorizer(OrigLoop, SE, LI, DT, DL, TLI, 1, UnrollFactor) { }
 
 private:
-  virtual void scalarizeInstruction(Instruction *Instr, bool IfPredicateStore = false);
-  virtual void vectorizeMemoryInstruction(Instruction *Instr);
-  virtual Value *getBroadcastInstrs(Value *V);
-  virtual Value *getConsecutiveVector(Value* Val, int StartIdx, bool Negate);
-  virtual Value *reverseVector(Value *Vec);
+  void scalarizeInstruction(Instruction *Instr,
+                            bool IfPredicateStore = false) override;
+  void vectorizeMemoryInstruction(Instruction *Instr) override;
+  Value *getBroadcastInstrs(Value *V) override;
+  Value *getConsecutiveVector(Value* Val, int StartIdx, bool Negate) override;
+  Value *reverseVector(Value *Vec) override;
 };
 
 /// \brief Look for a meaningful debug location on the instruction or it's
@@ -1020,7 +1021,7 @@ struct LoopVectorize : public FunctionPass {
 
   BlockFrequency ColdEntryFreq;
 
-  virtual bool runOnFunction(Function &F) {
+  bool runOnFunction(Function &F) override {
     SE = &getAnalysis<ScalarEvolution>();
     DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>();
     DL = DLP ? &DLP->getDataLayout() : 0;
@@ -1160,7 +1161,7 @@ struct LoopVectorize : public FunctionPass {
     return true;
   }
 
-  virtual void getAnalysisUsage(AnalysisUsage &AU) const {
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.addRequiredID(LoopSimplifyID);
     AU.addRequiredID(LCSSAID);
     AU.addRequired<BlockFrequencyInfo>();
@@ -2488,6 +2489,16 @@ static void cse(SmallVector<BasicBlock *, 4> &BBs) {
   }
 }
 
+/// \brief Adds a 'fast' flag to floating point operations.
+static Value *addFastMathFlag(Value *V) {
+  if (isa<FPMathOperator>(V)){
+    FastMathFlags Flags;
+    Flags.setUnsafeAlgebra();
+    cast<Instruction>(V)->setFastMathFlags(Flags);
+  }
+  return V;
+}
+
 void InnerLoopVectorizer::vectorizeLoop() {
   //===------------------------------------------------===//
   //
@@ -2631,9 +2642,10 @@ void InnerLoopVectorizer::vectorizeLoop() {
     setDebugLocFromInst(Builder, ReducedPartRdx);
     for (unsigned part = 1; part < UF; ++part) {
       if (Op != Instruction::ICmp && Op != Instruction::FCmp)
-        ReducedPartRdx = Builder.CreateBinOp((Instruction::BinaryOps)Op,
-                                             RdxParts[part], ReducedPartRdx,
-                                             "bin.rdx");
+        // Floating point operations had to be 'fast' to enable the reduction.
+        ReducedPartRdx = addFastMathFlag(
+            Builder.CreateBinOp((Instruction::BinaryOps)Op, RdxParts[part],
+                                ReducedPartRdx, "bin.rdx"));
       else
         ReducedPartRdx = createMinMaxOp(Builder, RdxDesc.MinMaxKind,
                                         ReducedPartRdx, RdxParts[part]);
@@ -2663,8 +2675,9 @@ void InnerLoopVectorizer::vectorizeLoop() {
                                     "rdx.shuf");
 
         if (Op != Instruction::ICmp && Op != Instruction::FCmp)
-          TmpVec = Builder.CreateBinOp((Instruction::BinaryOps)Op, TmpVec, Shuf,
-                                       "bin.rdx");
+          // Floating point operations had to be 'fast' to enable the reduction.
+          TmpVec = addFastMathFlag(Builder.CreateBinOp(
+              (Instruction::BinaryOps)Op, TmpVec, Shuf, "bin.rdx"));
         else
           TmpVec = createMinMaxOp(Builder, RdxDesc.MinMaxKind, TmpVec, Shuf);
       }
@@ -2998,6 +3011,10 @@ void InnerLoopVectorizer::vectorizeBlockInLoop(BasicBlock *BB, PhiVector *PV) {
         if (VecOp && isa<PossiblyExactOperator>(VecOp))
           VecOp->setIsExact(BinOp->isExact());
 
+        // Copy the fast-math flags.
+        if (VecOp && isa<FPMathOperator>(V))
+          VecOp->setFastMathFlags(it->getFastMathFlags());
+
         Entry[Part] = V;
       }
       break;