Signficantly generalize our ability to constant fold floating point intrinsics, inclu...
[oota-llvm.git] / lib / Analysis / ConstantFolding.cpp
index 2b7d3bd701d590fd768ae9dc0b7bddd0292c5d4c..e499c73566c649e21fcab9354914b496887352e9 100644 (file)
@@ -218,10 +218,10 @@ static Constant *FoldBitCast(Constant *C, Type *DestTy,
 /// from a global, return the global and the constant.  Because of
 /// constantexprs, this function is recursive.
 static bool IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV,
-                                       int64_t &Offset, const DataLayout &TD) {
+                                       APInt &Offset, const DataLayout &TD) {
   // Trivial case, constant is the global.
   if ((GV = dyn_cast<GlobalValue>(C))) {
-    Offset = 0;
+    Offset.clearAllBits();
     return true;
   }
 
@@ -235,34 +235,13 @@ static bool IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV,
     return IsConstantOffsetFromGlobal(CE->getOperand(0), GV, Offset, TD);
 
   // i32* getelementptr ([5 x i32]* @a, i32 0, i32 5)
-  if (CE->getOpcode() == Instruction::GetElementPtr) {
-    // Cannot compute this if the element type of the pointer is missing size
-    // info.
-    if (!cast<PointerType>(CE->getOperand(0)->getType())
-                 ->getElementType()->isSized())
-      return false;
-
+  if (GEPOperator *GEP = dyn_cast<GEPOperator>(CE)) {
     // If the base isn't a global+constant, we aren't either.
     if (!IsConstantOffsetFromGlobal(CE->getOperand(0), GV, Offset, TD))
       return false;
 
     // Otherwise, add any offset that our operands provide.
-    gep_type_iterator GTI = gep_type_begin(CE);
-    for (User::const_op_iterator i = CE->op_begin() + 1, e = CE->op_end();
-         i != e; ++i, ++GTI) {
-      ConstantInt *CI = dyn_cast<ConstantInt>(*i);
-      if (!CI) return false;  // Index isn't a simple constant?
-      if (CI->isZero()) continue;  // Not adding anything.
-
-      if (StructType *ST = dyn_cast<StructType>(*GTI)) {
-        // N = N + Offset
-        Offset += TD.getStructLayout(ST)->getElementOffset(CI->getZExtValue());
-      } else {
-        SequentialType *SQT = cast<SequentialType>(*GTI);
-        Offset += TD.getTypeAllocSize(SQT->getElementType())*CI->getSExtValue();
-      }
-    }
-    return true;
+    return GEP->accumulateConstantOffset(TD, Offset);
   }
 
   return false;
@@ -310,6 +289,10 @@ static bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset,
       C = FoldBitCast(C, Type::getInt32Ty(C->getContext()), TD);
       return ReadDataFromGlobal(C, ByteOffset, CurPtr, BytesLeft, TD);
     }
+    if (CFP->getType()->isHalfTy()){
+      C = FoldBitCast(C, Type::getInt16Ty(C->getContext()), TD);
+      return ReadDataFromGlobal(C, ByteOffset, CurPtr, BytesLeft, TD);
+    }
     return false;
   }
 
@@ -402,7 +385,9 @@ static Constant *FoldReinterpretLoadFromConstPtr(Constant *C,
     // that address spaces don't matter here since we're not going to result in
     // an actual new load.
     Type *MapTy;
-    if (LoadTy->isFloatTy())
+    if (LoadTy->isHalfTy())
+      MapTy = Type::getInt16PtrTy(C->getContext());
+    else if (LoadTy->isFloatTy())
       MapTy = Type::getInt32PtrTy(C->getContext());
     else if (LoadTy->isDoubleTy())
       MapTy = Type::getInt64PtrTy(C->getContext());
@@ -423,7 +408,7 @@ static Constant *FoldReinterpretLoadFromConstPtr(Constant *C,
   if (BytesLoaded > 32 || BytesLoaded == 0) return 0;
 
   GlobalValue *GVal;
-  int64_t Offset;
+  APInt Offset(TD.getPointerSizeInBits(), 0);
   if (!IsConstantOffsetFromGlobal(C, GVal, Offset, TD))
     return 0;
 
@@ -434,14 +419,15 @@ static Constant *FoldReinterpretLoadFromConstPtr(Constant *C,
 
   // If we're loading off the beginning of the global, some bytes may be valid,
   // but we don't try to handle this.
-  if (Offset < 0) return 0;
+  if (Offset.isNegative()) return 0;
 
   // If we're not accessing anything in this constant, the result is undefined.
-  if (uint64_t(Offset) >= TD.getTypeAllocSize(GV->getInitializer()->getType()))
+  if (Offset.getZExtValue() >=
+      TD.getTypeAllocSize(GV->getInitializer()->getType()))
     return UndefValue::get(IntType);
 
   unsigned char RawBytes[32] = {0};
-  if (!ReadDataFromGlobal(GV->getInitializer(), Offset, RawBytes,
+  if (!ReadDataFromGlobal(GV->getInitializer(), Offset.getZExtValue(), RawBytes,
                           BytesLoaded, TD))
     return 0;
 
@@ -565,13 +551,18 @@ static Constant *SymbolicallyEvaluateBinop(unsigned Opc, Constant *Op0,
   // constant.  This happens frequently when iterating over a global array.
   if (Opc == Instruction::Sub && TD) {
     GlobalValue *GV1, *GV2;
-    int64_t Offs1, Offs2;
+    unsigned PtrSize = TD->getPointerSizeInBits();
+    unsigned OpSize = TD->getTypeSizeInBits(Op0->getType());
+    APInt Offs1(PtrSize, 0), Offs2(PtrSize, 0);
 
     if (IsConstantOffsetFromGlobal(Op0, GV1, Offs1, *TD))
       if (IsConstantOffsetFromGlobal(Op1, GV2, Offs2, *TD) &&
           GV1 == GV2) {
         // (&GV+C1) - (&GV+C2) -> C1-C2, pointer arithmetic cannot overflow.
-        return ConstantInt::get(Op0->getType(), Offs1-Offs2);
+        // PtrToInt may change the bitwidth so we have convert to the right size
+        // first.
+        return ConstantInt::get(Op0->getType(), Offs1.zextOrTrunc(OpSize) -
+                                                Offs2.zextOrTrunc(OpSize));
       }
   }
 
@@ -1104,6 +1095,13 @@ Constant *llvm::ConstantFoldLoadThroughGEPIndices(Constant *C,
 bool
 llvm::canConstantFoldCallTo(const Function *F) {
   switch (F->getIntrinsicID()) {
+  case Intrinsic::fabs:
+  case Intrinsic::log:
+  case Intrinsic::log2:
+  case Intrinsic::log10:
+  case Intrinsic::exp:
+  case Intrinsic::exp2:
+  case Intrinsic::floor:
   case Intrinsic::sqrt:
   case Intrinsic::pow:
   case Intrinsic::powi:
@@ -1171,11 +1169,17 @@ static Constant *ConstantFoldFP(double (*NativeFP)(double), double V,
     return 0;
   }
 
+  if (Ty->isHalfTy()) {
+    APFloat APF(V);
+    bool unused;
+    APF.convert(APFloat::IEEEhalf, APFloat::rmNearestTiesToEven, &unused);
+    return ConstantFP::get(Ty->getContext(), APF);
+  }
   if (Ty->isFloatTy())
     return ConstantFP::get(Ty->getContext(), APFloat((float)V));
   if (Ty->isDoubleTy())
     return ConstantFP::get(Ty->getContext(), APFloat(V));
-  llvm_unreachable("Can only constant fold float/double");
+  llvm_unreachable("Can only constant fold half/float/double");
 }
 
 static Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double),
@@ -1187,11 +1191,17 @@ static Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double),
     return 0;
   }
 
+  if (Ty->isHalfTy()) {
+    APFloat APF(V);
+    bool unused;
+    APF.convert(APFloat::IEEEhalf, APFloat::rmNearestTiesToEven, &unused);
+    return ConstantFP::get(Ty->getContext(), APF);
+  }
   if (Ty->isFloatTy())
     return ConstantFP::get(Ty->getContext(), APFloat((float)V));
   if (Ty->isDoubleTy())
     return ConstantFP::get(Ty->getContext(), APFloat(V));
-  llvm_unreachable("Can only constant fold float/double");
+  llvm_unreachable("Can only constant fold half/float/double");
 }
 
 /// ConstantFoldConvertToInt - Attempt to an SSE floating point to integer
@@ -1243,7 +1253,7 @@ llvm::ConstantFoldCall(Function *F, ArrayRef<Constant *> Operands,
       if (!TLI)
         return 0;
 
-      if (!Ty->isFloatTy() && !Ty->isDoubleTy())
+      if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy())
         return 0;
 
       /// We only fold functions with finite arguments. Folding NaN and inf is
@@ -1256,8 +1266,36 @@ llvm::ConstantFoldCall(Function *F, ArrayRef<Constant *> Operands,
       /// the host native double versions.  Float versions are not called
       /// directly but for all these it is true (float)(f((double)arg)) ==
       /// f(arg).  Long double not supported yet.
-      double V = Ty->isFloatTy() ? (double)Op->getValueAPF().convertToFloat() :
-                                     Op->getValueAPF().convertToDouble();
+      double V;
+      if (Ty->isFloatTy())
+        V = Op->getValueAPF().convertToFloat();
+      else if (Ty->isDoubleTy())
+        V = Op->getValueAPF().convertToDouble();
+      else {
+        bool unused;
+        APFloat APF = Op->getValueAPF();
+        APF.convert(APFloat::IEEEdouble, APFloat::rmNearestTiesToEven, &unused);
+        V = APF.convertToDouble();
+      }
+
+      switch (F->getIntrinsicID()) {
+        default: break;
+        case Intrinsic::fabs:
+          return ConstantFoldFP(fabs, V, Ty);
+        case Intrinsic::log2:
+          return ConstantFoldFP(log2, V, Ty);
+        case Intrinsic::log:
+          return ConstantFoldFP(log, V, Ty);
+        case Intrinsic::log10:
+          return ConstantFoldFP(log10, V, Ty);
+        case Intrinsic::exp:
+          return ConstantFoldFP(exp, V, Ty);
+        case Intrinsic::exp2:
+          return ConstantFoldFP(exp2, V, Ty);
+        case Intrinsic::floor:
+          return ConstantFoldFP(floor, V, Ty);
+      }
+
       switch (Name[0]) {
       case 'a':
         if (Name == "acos" && TLI->has(LibFunc::acos))
@@ -1299,7 +1337,7 @@ llvm::ConstantFoldCall(Function *F, ArrayRef<Constant *> Operands,
         else if (Name == "log10" && V > 0 && TLI->has(LibFunc::log10))
           return ConstantFoldFP(log10, V, Ty);
         else if (F->getIntrinsicID() == Intrinsic::sqrt &&
-                 (Ty->isFloatTy() || Ty->isDoubleTy())) {
+                 (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy())) {
           if (V >= -0.0)
             return ConstantFoldFP(sqrt, V, Ty);
           else // Undefined
@@ -1337,7 +1375,7 @@ llvm::ConstantFoldCall(Function *F, ArrayRef<Constant *> Operands,
       case Intrinsic::ctpop:
         return ConstantInt::get(Ty, Op->getValue().countPopulation());
       case Intrinsic::convert_from_fp16: {
-        APFloat Val(Op->getValue());
+        APFloat Val(APFloat::IEEEhalf, Op->getValue());
 
         bool lost = false;
         APFloat::opStatus status =
@@ -1391,18 +1429,35 @@ llvm::ConstantFoldCall(Function *F, ArrayRef<Constant *> Operands,
 
   if (Operands.size() == 2) {
     if (ConstantFP *Op1 = dyn_cast<ConstantFP>(Operands[0])) {
-      if (!Ty->isFloatTy() && !Ty->isDoubleTy())
+      if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy())
         return 0;
-      double Op1V = Ty->isFloatTy() ?
-                      (double)Op1->getValueAPF().convertToFloat() :
-                      Op1->getValueAPF().convertToDouble();
+      double Op1V;
+      if (Ty->isFloatTy())
+        Op1V = Op1->getValueAPF().convertToFloat();
+      else if (Ty->isDoubleTy())
+        Op1V = Op1->getValueAPF().convertToDouble();
+      else {
+        bool unused;
+        APFloat APF = Op1->getValueAPF();
+        APF.convert(APFloat::IEEEdouble, APFloat::rmNearestTiesToEven, &unused);
+        Op1V = APF.convertToDouble();
+      }
+
       if (ConstantFP *Op2 = dyn_cast<ConstantFP>(Operands[1])) {
         if (Op2->getType() != Op1->getType())
           return 0;
 
-        double Op2V = Ty->isFloatTy() ?
-                      (double)Op2->getValueAPF().convertToFloat():
-                      Op2->getValueAPF().convertToDouble();
+        double Op2V;
+        if (Ty->isFloatTy())
+          Op2V = Op2->getValueAPF().convertToFloat();
+        else if (Ty->isDoubleTy())
+          Op2V = Op2->getValueAPF().convertToDouble();
+        else {
+          bool unused;
+          APFloat APF = Op2->getValueAPF();
+          APF.convert(APFloat::IEEEdouble, APFloat::rmNearestTiesToEven, &unused);
+          Op2V = APF.convertToDouble();
+        }
 
         if (F->getIntrinsicID() == Intrinsic::pow) {
           return ConstantFoldBinaryFP(pow, Op1V, Op2V, Ty);
@@ -1416,6 +1471,10 @@ llvm::ConstantFoldCall(Function *F, ArrayRef<Constant *> Operands,
         if (Name == "atan2" && TLI->has(LibFunc::atan2))
           return ConstantFoldBinaryFP(atan2, Op1V, Op2V, Ty);
       } else if (ConstantInt *Op2C = dyn_cast<ConstantInt>(Operands[1])) {
+        if (F->getIntrinsicID() == Intrinsic::powi && Ty->isHalfTy())
+          return ConstantFP::get(F->getContext(),
+                                 APFloat((float)std::pow((float)Op1V,
+                                                 (int)Op2C->getZExtValue())));
         if (F->getIntrinsicID() == Intrinsic::powi && Ty->isFloatTy())
           return ConstantFP::get(F->getContext(),
                                  APFloat((float)std::pow((float)Op1V,
@@ -1468,12 +1527,12 @@ llvm::ConstantFoldCall(Function *F, ArrayRef<Constant *> Operands,
           return ConstantStruct::get(cast<StructType>(F->getReturnType()), Ops);
         }
         case Intrinsic::cttz:
-          // FIXME: This should check for Op2 == 1, and become unreachable if
-          // Op1 == 0.
+          if (Op2->isOne() && Op1->isZero()) // cttz(0, 1) is undef.
+            return UndefValue::get(Ty);
           return ConstantInt::get(Ty, Op1->getValue().countTrailingZeros());
         case Intrinsic::ctlz:
-          // FIXME: This should check for Op2 == 1, and become unreachable if
-          // Op1 == 0.
+          if (Op2->isOne() && Op1->isZero()) // ctlz(0, 1) is undef.
+            return UndefValue::get(Ty);
           return ConstantInt::get(Ty, Op1->getValue().countLeadingZeros());
         }
       }