Build the correct range for loops with unusual bounds. Fix from Jay Foad.
[oota-llvm.git] / lib / Analysis / ConstantFolding.cpp
index 6fd8ff8b56318e6dfded7359aed960a469740348..6c828fa0042bf10bbfaefec099c425af5abac9bd 100644 (file)
@@ -19,6 +19,7 @@
 #include "llvm/Instructions.h"
 #include "llvm/Intrinsics.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringMap.h"
 #include "llvm/Target/TargetData.h"
 #include "llvm/Support/GetElementPtrTypeIterator.h"
 #include "llvm/Support/MathExtras.h"
@@ -72,8 +73,8 @@ static bool IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV,
         // N = N + Offset
         Offset += TD.getStructLayout(ST)->getElementOffset(CI->getZExtValue());
       } else {
-        const SequentialType *ST = cast<SequentialType>(*GTI);
-        Offset += TD.getTypeSize(ST->getElementType())*CI->getSExtValue();
+        const SequentialType *SQT = cast<SequentialType>(*GTI);
+        Offset += TD.getTypeSize(SQT->getElementType())*CI->getSExtValue();
       }
     }
     return true;
@@ -216,6 +217,23 @@ Constant *llvm::ConstantFoldInstOperands(const Instruction* I,
   case Instruction::FCmp:
     return ConstantExpr::getCompare(cast<CmpInst>(I)->getPredicate(), Ops[0], 
                                     Ops[1]);
+  case Instruction::PtrToInt:
+    // If the input is a inttoptr, eliminate the pair.  This requires knowing
+    // the width of a pointer, so it can't be done in ConstantExpr::getCast.
+    if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ops[0])) {
+      if (TD && CE->getOpcode() == Instruction::IntToPtr) {
+        Constant *Input = CE->getOperand(0);
+        unsigned InWidth = Input->getType()->getPrimitiveSizeInBits();
+        Constant *Mask = 
+          ConstantInt::get(APInt::getLowBitsSet(InWidth,
+                                                TD->getPointerSizeInBits()));
+        Input = ConstantExpr::getAnd(Input, Mask);
+        // Do a zext or trunc to get to the dest size.
+        return ConstantExpr::getIntegerCast(Input, I->getType(), false);
+      }
+    }
+    // FALL THROUGH.
+  case Instruction::IntToPtr:
   case Instruction::Trunc:
   case Instruction::ZExt:
   case Instruction::SExt:
@@ -225,8 +243,6 @@ Constant *llvm::ConstantFoldInstOperands(const Instruction* I,
   case Instruction::SIToFP:
   case Instruction::FPToUI:
   case Instruction::FPToSI:
-  case Instruction::PtrToInt:
-  case Instruction::IntToPtr:
   case Instruction::BitCast:
     return ConstantExpr::getCast(Opc, Ops[0], DestTy);
   case Instruction::Select:
@@ -283,10 +299,10 @@ Constant *llvm::ConstantFoldLoadThroughGEPConstantExpr(Constant *C,
           C = UndefValue::get(ATy->getElementType());
         else
           return 0;
-      } else if (const PackedType *PTy = dyn_cast<PackedType>(*I)) {
+      } else if (const VectorType *PTy = dyn_cast<VectorType>(*I)) {
         if (CI->getZExtValue() >= PTy->getNumElements())
           return 0;
-        if (ConstantPacked *CP = dyn_cast<ConstantPacked>(C))
+        if (ConstantVector *CP = dyn_cast<ConstantVector>(C))
           C = CP->getOperand(CI->getZExtValue());
         else if (isa<ConstantAggregateZero>(C))
           C = Constant::getNullValue(PTy->getElementType());
@@ -312,56 +328,78 @@ Constant *llvm::ConstantFoldLoadThroughGEPConstantExpr(Constant *C,
 /// the specified function.
 bool
 llvm::canConstantFoldCallTo(Function *F) {
-  const std::string &Name = F->getName();
-
   switch (F->getIntrinsicID()) {
   case Intrinsic::sqrt_f32:
   case Intrinsic::sqrt_f64:
-  case Intrinsic::bswap_i16:
-  case Intrinsic::bswap_i32:
-  case Intrinsic::bswap_i64:
   case Intrinsic::powi_f32:
   case Intrinsic::powi_f64:
-  // FIXME: these should be constant folded as well
-  //case Intrinsic::ctpop_i8:
-  //case Intrinsic::ctpop_i16:
-  //case Intrinsic::ctpop_i32:
-  //case Intrinsic::ctpop_i64:
-  //case Intrinsic::ctlz_i8:
-  //case Intrinsic::ctlz_i16:
-  //case Intrinsic::ctlz_i32:
-  //case Intrinsic::ctlz_i64:
-  //case Intrinsic::cttz_i8:
-  //case Intrinsic::cttz_i16:
-  //case Intrinsic::cttz_i32:
-  //case Intrinsic::cttz_i64:
+  case Intrinsic::bswap:
+  case Intrinsic::ctpop:
+  case Intrinsic::ctlz:
+  case Intrinsic::cttz:
     return true;
   default: break;
   }
 
-  switch (Name[0])
-  {
-    case 'a':
-      return Name == "acos" || Name == "asin" || Name == "atan" ||
-             Name == "atan2";
-    case 'c':
-      return Name == "ceil" || Name == "cos" || Name == "cosf" ||
-             Name == "cosh";
-    case 'e':
-      return Name == "exp";
-    case 'f':
-      return Name == "fabs" || Name == "fmod" || Name == "floor";
-    case 'l':
-      return Name == "log" || Name == "log10";
-    case 'p':
-      return Name == "pow";
-    case 's':
-      return Name == "sin" || Name == "sinh" || 
-             Name == "sqrt" || Name == "sqrtf";
-    case 't':
-      return Name == "tan" || Name == "tanh";
-    default:
-      return false;
+  const ValueName *NameVal = F->getValueName();
+  if (NameVal == 0) return false;
+  const char *Str = NameVal->getKeyData();
+  unsigned Len = NameVal->getKeyLength();
+  
+  // In these cases, the check of the length is required.  We don't want to
+  // return true for a name like "cos\0blah" which strcmp would return equal to
+  // "cos", but has length 8.
+  switch (Str[0]) {
+  default: return false;
+  case 'a':
+    if (Len == 4)
+      return !strcmp(Str, "acos") || !strcmp(Str, "asin") ||
+             !strcmp(Str, "atan");
+    else if (Len == 5)
+      return !strcmp(Str, "atan2");
+    return false;
+  case 'c':
+    if (Len == 3)
+      return !strcmp(Str, "cos");
+    else if (Len == 4)
+      return !strcmp(Str, "ceil") || !strcmp(Str, "cosf") ||
+             !strcmp(Str, "cosh");
+    return false;
+  case 'e':
+    if (Len == 3)
+      return !strcmp(Str, "exp");
+    return false;
+  case 'f':
+    if (Len == 4)
+      return !strcmp(Str, "fabs") || !strcmp(Str, "fmod");
+    else if (Len == 5)
+      return !strcmp(Str, "floor");
+    return false;
+    break;
+  case 'l':
+    if (Len == 3 && !strcmp(Str, "log"))
+      return true;
+    if (Len == 5 && !strcmp(Str, "log10"))
+      return true;
+    return false;
+  case 'p':
+    if (Len == 3 && !strcmp(Str, "pow"))
+      return true;
+    return false;
+  case 's':
+    if (Len == 3)
+      return !strcmp(Str, "sin");
+    if (Len == 4)
+      return !strcmp(Str, "sinh") || !strcmp(Str, "sqrt");
+    if (Len == 5)
+      return !strcmp(Str, "sqrtf");
+    return false;
+  case 't':
+    if (Len == 3 && !strcmp(Str, "tan"))
+      return true;
+    else if (Len == 4 && !strcmp(Str, "tanh"))
+      return true;
+    return false;
   }
 }
 
@@ -369,116 +407,156 @@ static Constant *ConstantFoldFP(double (*NativeFP)(double), double V,
                                 const Type *Ty) {
   errno = 0;
   V = NativeFP(V);
-  if (errno == 0)
-    return ConstantFP::get(Ty, V);
+  if (errno == 0) {
+    if (Ty==Type::FloatTy)
+      return ConstantFP::get(Ty, APFloat((float)V));
+    else if (Ty==Type::DoubleTy)
+      return ConstantFP::get(Ty, APFloat(V));
+    else
+      assert(0);
+  }
+  errno = 0;
+  return 0;
+}
+
+static Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double),
+                                      double V, double W,
+                                      const Type *Ty) {
+  errno = 0;
+  V = NativeFP(V, W);
+  if (errno == 0) {
+    if (Ty==Type::FloatTy)
+      return ConstantFP::get(Ty, APFloat((float)V));
+    else if (Ty==Type::DoubleTy)
+      return ConstantFP::get(Ty, APFloat(V));
+    else
+      assert(0);
+  }
   errno = 0;
   return 0;
 }
 
 /// ConstantFoldCall - Attempt to constant fold a call to the specified function
 /// with the specified arguments, returning null if unsuccessful.
+
 Constant *
 llvm::ConstantFoldCall(Function *F, Constant** Operands, unsigned NumOperands) {
-  const std::string &Name = F->getName();
+  const ValueName *NameVal = F->getValueName();
+  if (NameVal == 0) return 0;
+  const char *Str = NameVal->getKeyData();
+  unsigned Len = NameVal->getKeyLength();
+  
   const Type *Ty = F->getReturnType();
-
   if (NumOperands == 1) {
     if (ConstantFP *Op = dyn_cast<ConstantFP>(Operands[0])) {
-      double V = Op->getValue();
-      switch (Name[0])
-      {
-        case 'a':
-          if (Name == "acos")
-            return ConstantFoldFP(acos, V, Ty);
-          else if (Name == "asin")
-            return ConstantFoldFP(asin, V, Ty);
-          else if (Name == "atan")
-            return ConstantFP::get(Ty, atan(V));
-          break;
-        case 'c':
-          if (Name == "ceil")
-            return ConstantFoldFP(ceil, V, Ty);
-          else if (Name == "cos")
-            return ConstantFP::get(Ty, cos(V));
-          else if (Name == "cosh")
-            return ConstantFP::get(Ty, cosh(V));
-          break;
-        case 'e':
-          if (Name == "exp")
-            return ConstantFP::get(Ty, exp(V));
-          break;
-        case 'f':
-          if (Name == "fabs")
-            return ConstantFP::get(Ty, fabs(V));
-          else if (Name == "floor")
-            return ConstantFoldFP(floor, V, Ty);
-          break;
-        case 'l':
-          if (Name == "log" && V > 0)
-            return ConstantFP::get(Ty, log(V));
-          else if (Name == "log10" && V > 0)
-            return ConstantFoldFP(log10, V, Ty);
-          else if (Name == "llvm.sqrt.f32" || Name == "llvm.sqrt.f64") {
-            if (V >= -0.0)
-              return ConstantFP::get(Ty, sqrt(V));
-            else // Undefined
-              return ConstantFP::get(Ty, 0.0);
-          }
-          break;
-        case 's':
-          if (Name == "sin")
-            return ConstantFP::get(Ty, sin(V));
-          else if (Name == "sinh")
-            return ConstantFP::get(Ty, sinh(V));
-          else if (Name == "sqrt" && V >= 0)
-            return ConstantFP::get(Ty, sqrt(V));
-          else if (Name == "sqrtf" && V >= 0)
-            return ConstantFP::get(Ty, sqrt((float)V));
-          break;
-        case 't':
-          if (Name == "tan")
-            return ConstantFP::get(Ty, tan(V));
-          else if (Name == "tanh")
-            return ConstantFP::get(Ty, tanh(V));
-          break;
-        default:
-          break;
+      if (Ty!=Type::FloatTy && Ty!=Type::DoubleTy)
+        return 0;
+      /// Currently APFloat versions of these functions do not exist, so we use
+      /// the host native double versions.  Float versions are not called
+      /// directly but for all these it is true (float)(f((double)arg)) ==
+      /// f(arg).  Long double not supported yet.
+      double V = Ty==Type::FloatTy ? (double)Op->getValueAPF().convertToFloat():
+                                     Op->getValueAPF().convertToDouble();
+      switch (Str[0]) {
+      case 'a':
+        if (Len == 4 && !strcmp(Str, "acos"))
+          return ConstantFoldFP(acos, V, Ty);
+        else if (Len == 4 && !strcmp(Str, "asin"))
+          return ConstantFoldFP(asin, V, Ty);
+        else if (Len == 4 && !strcmp(Str, "atan"))
+          return ConstantFoldFP(atan, V, Ty);
+        break;
+      case 'c':
+        if (Len == 4 && !strcmp(Str, "ceil"))
+          return ConstantFoldFP(ceil, V, Ty);
+        else if (Len == 3 && !strcmp(Str, "cos"))
+          return ConstantFoldFP(cos, V, Ty);
+        else if (Len == 4 && !strcmp(Str, "cosh"))
+          return ConstantFoldFP(cosh, V, Ty);
+        break;
+      case 'e':
+        if (Len == 3 && !strcmp(Str, "exp"))
+          return ConstantFoldFP(exp, V, Ty);
+        break;
+      case 'f':
+        if (Len == 4 && !strcmp(Str, "fabs"))
+          return ConstantFoldFP(fabs, V, Ty);
+        else if (Len == 5 && !strcmp(Str, "floor"))
+          return ConstantFoldFP(floor, V, Ty);
+        break;
+      case 'l':
+        if (Len == 3 && !strcmp(Str, "log") && V > 0)
+          return ConstantFoldFP(log, V, Ty);
+        else if (Len == 5 && !strcmp(Str, "log10") && V > 0)
+          return ConstantFoldFP(log10, V, Ty);
+        else if (!strcmp(Str, "llvm.sqrt.f32") ||
+                 !strcmp(Str, "llvm.sqrt.f64")) {
+          if (V >= -0.0)
+            return ConstantFoldFP(sqrt, V, Ty);
+          else // Undefined
+            return ConstantFP::get(Ty, Ty==Type::FloatTy ? APFloat(0.0f) :
+                                       APFloat(0.0));
+        }
+        break;
+      case 's':
+        if (Len == 3 && !strcmp(Str, "sin"))
+          return ConstantFoldFP(sin, V, Ty);
+        else if (Len == 4 && !strcmp(Str, "sinh"))
+          return ConstantFoldFP(sinh, V, Ty);
+        else if (Len == 4 && !strcmp(Str, "sqrt") && V >= 0)
+          return ConstantFoldFP(sqrt, V, Ty);
+        else if (Len == 5 && !strcmp(Str, "sqrtf") && V >= 0)
+          return ConstantFoldFP(sqrt, V, Ty);
+        break;
+      case 't':
+        if (Len == 3 && !strcmp(Str, "tan"))
+          return ConstantFoldFP(tan, V, Ty);
+        else if (Len == 4 && !strcmp(Str, "tanh"))
+          return ConstantFoldFP(tanh, V, Ty);
+        break;
+      default:
+        break;
       }
     } else if (ConstantInt *Op = dyn_cast<ConstantInt>(Operands[0])) {
-      uint64_t V = Op->getZExtValue();
-      if (Name == "llvm.bswap.i16")
-        return ConstantInt::get(Ty, ByteSwap_16(V));
-      else if (Name == "llvm.bswap.i32")
-        return ConstantInt::get(Ty, ByteSwap_32(V));
-      else if (Name == "llvm.bswap.i64")
-        return ConstantInt::get(Ty, ByteSwap_64(V));
+      if (Len > 11 && !memcmp(Str, "llvm.bswap", 10)) {
+        return ConstantInt::get(Op->getValue().byteSwap());
+      } else if (Len > 11 && !memcmp(Str, "llvm.ctpop", 10)) {
+        uint64_t ctpop = Op->getValue().countPopulation();
+        return ConstantInt::get(Ty, ctpop);
+      } else if (Len > 10 && !memcmp(Str, "llvm.cttz", 9)) {
+        uint64_t cttz = Op->getValue().countTrailingZeros();
+        return ConstantInt::get(Ty, cttz);
+      } else if (Len > 10 && !memcmp(Str, "llvm.ctlz", 9)) {
+        uint64_t ctlz = Op->getValue().countLeadingZeros();
+        return ConstantInt::get(Ty, ctlz);
+      }
     }
   } else if (NumOperands == 2) {
     if (ConstantFP *Op1 = dyn_cast<ConstantFP>(Operands[0])) {
-      double Op1V = Op1->getValue();
+      double Op1V = Ty==Type::FloatTy ? 
+                      (double)Op1->getValueAPF().convertToFloat():
+                      Op1->getValueAPF().convertToDouble();
       if (ConstantFP *Op2 = dyn_cast<ConstantFP>(Operands[1])) {
-        double Op2V = Op2->getValue();
+        if (Ty!=Type::FloatTy && Ty!=Type::DoubleTy)
+          return 0;
+        double Op2V = Ty==Type::FloatTy ? 
+                      (double)Op2->getValueAPF().convertToFloat():
+                      Op2->getValueAPF().convertToDouble();
 
-        if (Name == "pow") {
-          errno = 0;
-          double V = pow(Op1V, Op2V);
-          if (errno == 0)
-            return ConstantFP::get(Ty, V);
-        } else if (Name == "fmod") {
-          errno = 0;
-          double V = fmod(Op1V, Op2V);
-          if (errno == 0)
-            return ConstantFP::get(Ty, V);
-        } else if (Name == "atan2") {
-          return ConstantFP::get(Ty, atan2(Op1V,Op2V));
+        if (Len == 3 && !strcmp(Str, "pow")) {
+          return ConstantFoldBinaryFP(pow, Op1V, Op2V, Ty);
+        } else if (Len == 4 && !strcmp(Str, "fmod")) {
+          return ConstantFoldBinaryFP(fmod, Op1V, Op2V, Ty);
+        } else if (Len == 5 && !strcmp(Str, "atan2")) {
+          return ConstantFoldBinaryFP(atan2, Op1V, Op2V, Ty);
         }
       } else if (ConstantInt *Op2C = dyn_cast<ConstantInt>(Operands[1])) {
-        if (Name == "llvm.powi.f32") {
-          return ConstantFP::get(Ty, std::pow((float)Op1V,
-                                              (int)Op2C->getZExtValue()));
-        } else if (Name == "llvm.powi.f64") {
-          return ConstantFP::get(Ty, std::pow((double)Op1V,
-                                              (int)Op2C->getZExtValue()));
+        if (!strcmp(Str, "llvm.powi.f32")) {
+          return ConstantFP::get(Ty, APFloat((float)std::pow((float)Op1V,
+                                              (int)Op2C->getZExtValue())));
+        } else if (!strcmp(Str, "llvm.powi.f64")) {
+          return ConstantFP::get(Ty, APFloat((double)std::pow((double)Op1V,
+                                              (int)Op2C->getZExtValue())));
         }
       }
     }