Remove an unnecessary use of pointee types introduced in r194220
[oota-llvm.git] / lib / IR / ConstantFold.cpp
index 612aba01897926e1b3fb0135c1e6d049f173b532..131c5c51790414c6b1c81e27aaa3a1710c22d042 100644 (file)
 #include "llvm/IR/GlobalVariable.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Operator.h"
+#include "llvm/IR/PatternMatch.h"
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/ManagedStatic.h"
 #include "llvm/Support/MathExtras.h"
 #include <limits>
 using namespace llvm;
+using namespace llvm::PatternMatch;
 
 //===----------------------------------------------------------------------===//
 //                ConstantFold*Instruction Implementations
@@ -51,7 +53,7 @@ static Constant *BitCastConstantVector(Constant *CV, VectorType *DstTy) {
   // Analysis/ConstantFolding.cpp
   unsigned NumElts = DstTy->getNumElements();
   if (NumElts != CV->getType()->getVectorNumElements())
-    return 0;
+    return nullptr;
   
   Type *DstEltTy = DstTy->getElementType();
 
@@ -94,7 +96,7 @@ foldConstantCastPair(
 
   // Let CastInst::isEliminableCastPair do the heavy lifting.
   return CastInst::isEliminableCastPair(firstOp, secondOp, SrcTy, MidTy, DstTy,
-                                        0, FakeIntPtrTy, 0);
+                                        nullptr, FakeIntPtrTy, nullptr);
 }
 
 static Constant *FoldBitCast(Constant *V, Type *DestTy) {
@@ -130,7 +132,8 @@ static Constant *FoldBitCast(Constant *V, Type *DestTy) {
 
         if (ElTy == DPTy->getElementType())
           // This GEP is inbounds because all indices are zero.
-          return ConstantExpr::getInBoundsGetElementPtr(V, IdxList);
+          return ConstantExpr::getInBoundsGetElementPtr(PTy->getElementType(),
+                                                        V, IdxList);
       }
 
   // Handle casts from one vector constant to another.  We know that the src 
@@ -139,7 +142,7 @@ static Constant *FoldBitCast(Constant *V, Type *DestTy) {
     if (VectorType *SrcTy = dyn_cast<VectorType>(V->getType())) {
       assert(DestPTy->getBitWidth() == SrcTy->getBitWidth() &&
              "Not cast between same sized vectors!");
-      SrcTy = NULL;
+      SrcTy = nullptr;
       // First, check for null.  Undef is already handled.
       if (isa<ConstantAggregateZero>(V))
         return Constant::getNullValue(DestTy);
@@ -167,21 +170,32 @@ static Constant *FoldBitCast(Constant *V, Type *DestTy) {
       // be the same. Consequently, we just fold to V.
       return V;
 
-    if (DestTy->isFloatingPointTy())
+    // See note below regarding the PPC_FP128 restriction.
+    if (DestTy->isFloatingPointTy() && !DestTy->isPPC_FP128Ty())
       return ConstantFP::get(DestTy->getContext(),
                              APFloat(DestTy->getFltSemantics(),
                                      CI->getValue()));
 
     // Otherwise, can't fold this (vector?)
-    return 0;
+    return nullptr;
   }
 
   // Handle ConstantFP input: FP -> Integral.
-  if (ConstantFP *FP = dyn_cast<ConstantFP>(V))
+  if (ConstantFP *FP = dyn_cast<ConstantFP>(V)) {
+    // PPC_FP128 is really the sum of two consecutive doubles, where the first
+    // double is always stored first in memory, regardless of the target
+    // endianness. The memory layout of i128, however, depends on the target
+    // endianness, and so we can't fold this without target endianness
+    // information. This should instead be handled by
+    // Analysis/ConstantFolding.cpp
+    if (FP->getType()->isPPC_FP128Ty())
+      return nullptr;
+
     return ConstantInt::get(FP->getContext(),
                             FP->getValueAPF().bitcastToAPInt());
+  }
 
-  return 0;
+  return nullptr;
 }
 
 
@@ -216,14 +230,14 @@ static Constant *ExtractConstantBytes(Constant *C, unsigned ByteStart,
   // In the input is a constant expr, we might be able to recursively simplify.
   // If not, we definitely can't do anything.
   ConstantExpr *CE = dyn_cast<ConstantExpr>(C);
-  if (CE == 0) return 0;
-  
+  if (!CE) return nullptr;
+
   switch (CE->getOpcode()) {
-  default: return 0;
+  default: return nullptr;
   case Instruction::Or: {
     Constant *RHS = ExtractConstantBytes(CE->getOperand(1), ByteStart,ByteSize);
-    if (RHS == 0)
-      return 0;
+    if (!RHS)
+      return nullptr;
     
     // X | -1 -> -1.
     if (ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS))
@@ -231,32 +245,32 @@ static Constant *ExtractConstantBytes(Constant *C, unsigned ByteStart,
         return RHSC;
     
     Constant *LHS = ExtractConstantBytes(CE->getOperand(0), ByteStart,ByteSize);
-    if (LHS == 0)
-      return 0;
+    if (!LHS)
+      return nullptr;
     return ConstantExpr::getOr(LHS, RHS);
   }
   case Instruction::And: {
     Constant *RHS = ExtractConstantBytes(CE->getOperand(1), ByteStart,ByteSize);
-    if (RHS == 0)
-      return 0;
+    if (!RHS)
+      return nullptr;
     
     // X & 0 -> 0.
     if (RHS->isNullValue())
       return RHS;
     
     Constant *LHS = ExtractConstantBytes(CE->getOperand(0), ByteStart,ByteSize);
-    if (LHS == 0)
-      return 0;
+    if (!LHS)
+      return nullptr;
     return ConstantExpr::getAnd(LHS, RHS);
   }
   case Instruction::LShr: {
     ConstantInt *Amt = dyn_cast<ConstantInt>(CE->getOperand(1));
-    if (Amt == 0)
-      return 0;
+    if (!Amt)
+      return nullptr;
     unsigned ShAmt = Amt->getZExtValue();
     // Cannot analyze non-byte shifts.
     if ((ShAmt & 7) != 0)
-      return 0;
+      return nullptr;
     ShAmt >>= 3;
     
     // If the extract is known to be all zeros, return zero.
@@ -268,17 +282,17 @@ static Constant *ExtractConstantBytes(Constant *C, unsigned ByteStart,
       return ExtractConstantBytes(CE->getOperand(0), ByteStart+ShAmt, ByteSize);
     
     // TODO: Handle the 'partially zero' case.
-    return 0;
+    return nullptr;
   }
     
   case Instruction::Shl: {
     ConstantInt *Amt = dyn_cast<ConstantInt>(CE->getOperand(1));
-    if (Amt == 0)
-      return 0;
+    if (!Amt)
+      return nullptr;
     unsigned ShAmt = Amt->getZExtValue();
     // Cannot analyze non-byte shifts.
     if ((ShAmt & 7) != 0)
-      return 0;
+      return nullptr;
     ShAmt >>= 3;
     
     // If the extract is known to be all zeros, return zero.
@@ -290,7 +304,7 @@ static Constant *ExtractConstantBytes(Constant *C, unsigned ByteStart,
       return ExtractConstantBytes(CE->getOperand(0), ByteStart-ShAmt, ByteSize);
     
     // TODO: Handle the 'partially zero' case.
-    return 0;
+    return nullptr;
   }
       
   case Instruction::ZExt: {
@@ -324,7 +338,7 @@ static Constant *ExtractConstantBytes(Constant *C, unsigned ByteStart,
     }
     
     // TODO: Handle the 'partially zero' case.
-    return 0;
+    return nullptr;
   }
   }
 }
@@ -376,7 +390,7 @@ static Constant *getFoldedSizeOf(Type *Ty, Type *DestTy,
   // If there's no interesting folding happening, bail so that we don't create
   // a constant that looks like it needs folding but really doesn't.
   if (!Folded)
-    return 0;
+    return nullptr;
 
   // Base case: Get a regular sizeof expression.
   Constant *C = ConstantExpr::getSizeOf(Ty);
@@ -442,7 +456,7 @@ static Constant *getFoldedAlignOf(Type *Ty, Type *DestTy,
   // If there's no interesting folding happening, bail so that we don't create
   // a constant that looks like it needs folding but really doesn't.
   if (!Folded)
-    return 0;
+    return nullptr;
 
   // Base case: Get a regular alignof expression.
   Constant *C = ConstantExpr::getAlignOf(Ty);
@@ -473,7 +487,7 @@ static Constant *getFoldedOffsetOf(Type *Ty, Constant *FieldNo,
       unsigned NumElems = STy->getNumElements();
       // An empty struct has no members.
       if (NumElems == 0)
-        return 0;
+        return nullptr;
       // Check for a struct with all members having the same size.
       Constant *MemberSize =
         getFoldedSizeOf(STy->getElementType(0), DestTy, true);
@@ -497,7 +511,7 @@ static Constant *getFoldedOffsetOf(Type *Ty, Constant *FieldNo,
   // If there's no interesting folding happening, bail so that we don't create
   // a constant that looks like it needs folding but really doesn't.
   if (!Folded)
-    return 0;
+    return nullptr;
 
   // Base case: Get a regular offsetof expression.
   Constant *C = ConstantExpr::getOffsetOf(Ty, FieldNo);
@@ -529,7 +543,10 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
       // Try hard to fold cast of cast because they are often eliminable.
       if (unsigned newOpc = foldConstantCastPair(opc, CE, DestTy))
         return ConstantExpr::getCast(newOpc, CE->getOperand(0), DestTy);
-    } else if (CE->getOpcode() == Instruction::GetElementPtr) {
+    } else if (CE->getOpcode() == Instruction::GetElementPtr &&
+               // Do not fold addrspacecast (gep 0, .., 0). It might make the
+               // addrspacecast uncanonicalized.
+               opc != Instruction::AddrSpaceCast) {
       // If all of the indexes in the GEP are null values, there is no pointer
       // adjustment going on.  We might as well cast the source pointer.
       bool isAllNull = true;
@@ -582,7 +599,7 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
                   APFloat::rmNearestTiesToEven, &ignored);
       return ConstantFP::get(V->getContext(), Val);
     }
-    return 0; // Can't fold.
+    return nullptr; // Can't fold.
   case Instruction::FPToUI: 
   case Instruction::FPToSI:
     if (ConstantFP *FPC = dyn_cast<ConstantFP>(V)) {
@@ -590,16 +607,21 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
       bool ignored;
       uint64_t x[2]; 
       uint32_t DestBitWidth = cast<IntegerType>(DestTy)->getBitWidth();
-      (void) V.convertToInteger(x, DestBitWidth, opc==Instruction::FPToSI,
-                                APFloat::rmTowardZero, &ignored);
+      if (APFloat::opInvalidOp ==
+          V.convertToInteger(x, DestBitWidth, opc==Instruction::FPToSI,
+                             APFloat::rmTowardZero, &ignored)) {
+        // Undefined behavior invoked - the destination type can't represent
+        // the input constant.
+        return UndefValue::get(DestTy);
+      }
       APInt Val(DestBitWidth, x);
       return ConstantInt::get(FPC->getContext(), Val);
     }
-    return 0; // Can't fold.
+    return nullptr; // Can't fold.
   case Instruction::IntToPtr:   //always treated as unsigned
     if (V->isNullValue())       // Is it an integral null value?
       return ConstantPointerNull::get(cast<PointerType>(DestTy));
-    return 0;                   // Other pointer types cannot be casted
+    return nullptr;                   // Other pointer types cannot be casted
   case Instruction::PtrToInt:   // always treated as unsigned
     // Is it a null pointer value?
     if (V->isNullValue())
@@ -610,8 +632,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
     if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
       if (CE->getOpcode() == Instruction::GetElementPtr &&
           CE->getOperand(0)->isNullValue()) {
-        Type *Ty =
-          cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
+        GEPOperator *GEPO = cast<GEPOperator>(CE);
+        Type *Ty = GEPO->getSourceElementType();
         if (CE->getNumOperands() == 2) {
           // Handle a sizeof-like expression.
           Constant *Idx = CE->getOperand(1);
@@ -643,34 +665,41 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
         }
       }
     // Other pointer types cannot be casted
-    return 0;
+    return nullptr;
   case Instruction::UIToFP:
   case Instruction::SIToFP:
     if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
       APInt api = CI->getValue();
       APFloat apf(DestTy->getFltSemantics(),
                   APInt::getNullValue(DestTy->getPrimitiveSizeInBits()));
-      (void)apf.convertFromAPInt(api, 
-                                 opc==Instruction::SIToFP,
-                                 APFloat::rmNearestTiesToEven);
+      if (APFloat::opOverflow &
+          apf.convertFromAPInt(api, opc==Instruction::SIToFP,
+                              APFloat::rmNearestTiesToEven)) {
+        // Undefined behavior invoked - the destination type can't represent
+        // the input constant.
+        return UndefValue::get(DestTy);
+      }
       return ConstantFP::get(V->getContext(), apf);
     }
-    return 0;
+    return nullptr;
   case Instruction::ZExt:
     if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
       uint32_t BitWidth = cast<IntegerType>(DestTy)->getBitWidth();
       return ConstantInt::get(V->getContext(),
                               CI->getValue().zext(BitWidth));
     }
-    return 0;
+    return nullptr;
   case Instruction::SExt:
     if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
       uint32_t BitWidth = cast<IntegerType>(DestTy)->getBitWidth();
       return ConstantInt::get(V->getContext(),
                               CI->getValue().sext(BitWidth));
     }
-    return 0;
+    return nullptr;
   case Instruction::Trunc: {
+    if (V->getType()->isVectorTy())
+      return nullptr;
+
     uint32_t DestBitWidth = cast<IntegerType>(DestTy)->getBitWidth();
     if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
       return ConstantInt::get(V->getContext(),
@@ -685,12 +714,12 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
       if (Constant *Res = ExtractConstantBytes(V, 0, DestBitWidth / 8))
         return Res;
       
-    return 0;
+    return nullptr;
   }
   case Instruction::BitCast:
     return FoldBitCast(V, DestTy);
   case Instruction::AddrSpaceCast:
-    return 0;
+    return nullptr;
   }
 }
 
@@ -746,7 +775,7 @@ Constant *llvm::ConstantFoldSelectInstruction(Constant *Cond,
         return ConstantExpr::getSelect(Cond, V1, FalseVal->getOperand(2));
   }
 
-  return 0;
+  return nullptr;
 }
 
 Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val,
@@ -760,35 +789,41 @@ Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val,
     return UndefValue::get(Val->getType()->getVectorElementType());
 
   if (ConstantInt *CIdx = dyn_cast<ConstantInt>(Idx)) {
-    uint64_t Index = CIdx->getZExtValue();
     // ee({w,x,y,z}, wrong_value) -> undef
-    if (Index >= Val->getType()->getVectorNumElements())
+    if (CIdx->uge(Val->getType()->getVectorNumElements()))
       return UndefValue::get(Val->getType()->getVectorElementType());
-    return Val->getAggregateElement(Index);
+    return Val->getAggregateElement(CIdx->getZExtValue());
   }
-  return 0;
+  return nullptr;
 }
 
 Constant *llvm::ConstantFoldInsertElementInstruction(Constant *Val,
                                                      Constant *Elt,
                                                      Constant *Idx) {
+  if (isa<UndefValue>(Idx))
+    return UndefValue::get(Val->getType());
+
   ConstantInt *CIdx = dyn_cast<ConstantInt>(Idx);
-  if (!CIdx) return 0;
-  const APInt &IdxVal = CIdx->getValue();
-  
+  if (!CIdx) return nullptr;
+
+  unsigned NumElts = Val->getType()->getVectorNumElements();
+  if (CIdx->uge(NumElts))
+    return UndefValue::get(Val->getType());
+
   SmallVector<Constant*, 16> Result;
-  Type *Ty = IntegerType::get(Val->getContext(), 32);
-  for (unsigned i = 0, e = Val->getType()->getVectorNumElements(); i != e; ++i){
+  Result.reserve(NumElts);
+  auto *Ty = Type::getInt32Ty(Val->getContext());
+  uint64_t IdxVal = CIdx->getZExtValue();
+  for (unsigned i = 0; i != NumElts; ++i) {    
     if (i == IdxVal) {
       Result.push_back(Elt);
       continue;
     }
     
-    Constant *C =
-      ConstantExpr::getExtractElement(Val, ConstantInt::get(Ty, i));
+    Constant *C = ConstantExpr::getExtractElement(Val, ConstantInt::get(Ty, i));
     Result.push_back(C);
   }
-  
+
   return ConstantVector::get(Result);
 }
 
@@ -803,7 +838,7 @@ Constant *llvm::ConstantFoldShuffleVectorInstruction(Constant *V1,
     return UndefValue::get(VectorType::get(EltTy, MaskNumElts));
 
   // Don't break the bitcode reader hack.
-  if (isa<ConstantExpr>(Mask)) return 0;
+  if (isa<ConstantExpr>(Mask)) return nullptr;
   
   unsigned SrcNumElts = V1->getType()->getVectorNumElements();
 
@@ -842,7 +877,7 @@ Constant *llvm::ConstantFoldExtractValueInstruction(Constant *Agg,
   if (Constant *C = Agg->getAggregateElement(Idxs[0]))
     return ConstantFoldExtractValueInstruction(C, Idxs.slice(1));
 
-  return 0;
+  return nullptr;
 }
 
 Constant *llvm::ConstantFoldInsertValueInstruction(Constant *Agg,
@@ -863,8 +898,8 @@ Constant *llvm::ConstantFoldInsertValueInstruction(Constant *Agg,
   SmallVector<Constant*, 32> Result;
   for (unsigned i = 0; i != NumElts; ++i) {
     Constant *C = Agg->getAggregateElement(i);
-    if (C == 0) return 0;
-    
+    if (!C) return nullptr;
+
     if (Idxs[0] == i)
       C = ConstantFoldInsertValueInstruction(C, Val, Idxs.slice(1));
     
@@ -898,49 +933,70 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
         return C1;
       return Constant::getNullValue(C1->getType());   // undef & X -> 0
     case Instruction::Mul: {
-      ConstantInt *CI;
-      // X * undef -> undef   if X is odd or undef
-      if (((CI = dyn_cast<ConstantInt>(C1)) && CI->getValue()[0]) ||
-          ((CI = dyn_cast<ConstantInt>(C2)) && CI->getValue()[0]) ||
-          (isa<UndefValue>(C1) && isa<UndefValue>(C2)))
-        return UndefValue::get(C1->getType());
+      // undef * undef -> undef
+      if (isa<UndefValue>(C1) && isa<UndefValue>(C2))
+        return C1;
+      const APInt *CV;
+      // X * undef -> undef   if X is odd
+      if (match(C1, m_APInt(CV)) || match(C2, m_APInt(CV)))
+        if ((*CV)[0])
+          return UndefValue::get(C1->getType());
 
       // X * undef -> 0       otherwise
       return Constant::getNullValue(C1->getType());
     }
-    case Instruction::UDiv:
     case Instruction::SDiv:
+    case Instruction::UDiv:
+      // X / undef -> undef
+      if (match(C1, m_Zero()))
+        return C2;
+      // undef / 0 -> undef
       // undef / 1 -> undef
-      if (Opcode == Instruction::UDiv || Opcode == Instruction::SDiv)
-        if (ConstantInt *CI2 = dyn_cast<ConstantInt>(C2))
-          if (CI2->isOne())
-            return C1;
-      // FALL THROUGH
+      if (match(C2, m_Zero()) || match(C2, m_One()))
+        return C1;
+      // undef / X -> 0       otherwise
+      return Constant::getNullValue(C1->getType());
     case Instruction::URem:
     case Instruction::SRem:
-      if (!isa<UndefValue>(C2))                    // undef / X -> 0
-        return Constant::getNullValue(C1->getType());
-      return C2;                                   // X / undef -> undef
+      // X % undef -> undef
+      if (match(C2, m_Undef()))
+        return C2;
+      // undef % 0 -> undef
+      if (match(C2, m_Zero()))
+        return C1;
+      // undef % X -> 0       otherwise
+      return Constant::getNullValue(C1->getType());
     case Instruction::Or:                          // X | undef -> -1
       if (isa<UndefValue>(C1) && isa<UndefValue>(C2)) // undef | undef -> undef
         return C1;
       return Constant::getAllOnesValue(C1->getType()); // undef | X -> ~0
     case Instruction::LShr:
-      if (isa<UndefValue>(C2) && isa<UndefValue>(C1))
-        return C1;                                  // undef lshr undef -> undef
-      return Constant::getNullValue(C1->getType()); // X lshr undef -> 0
-                                                    // undef lshr X -> 0
+      // X >>l undef -> undef
+      if (isa<UndefValue>(C2))
+        return C2;
+      // undef >>l 0 -> undef
+      if (match(C2, m_Zero()))
+        return C1;
+      // undef >>l X -> 0
+      return Constant::getNullValue(C1->getType());
     case Instruction::AShr:
-      if (!isa<UndefValue>(C2))                     // undef ashr X --> all ones
-        return Constant::getAllOnesValue(C1->getType());
-      else if (isa<UndefValue>(C1)) 
-        return C1;                                  // undef ashr undef -> undef
-      else
-        return C1;                                  // X ashr undef --> X
+      // X >>a undef -> undef
+      if (isa<UndefValue>(C2))
+        return C2;
+      // undef >>a 0 -> undef
+      if (match(C2, m_Zero()))
+        return C1;
+      // TODO: undef >>a X -> undef if the shift is exact
+      // undef >>a X -> 0
+      return Constant::getNullValue(C1->getType());
     case Instruction::Shl:
-      if (isa<UndefValue>(C2) && isa<UndefValue>(C1))
-        return C1;                                  // undef shl undef -> undef
-      // undef << X -> 0   or   X << undef -> 0
+      // X << undef -> undef
+      if (isa<UndefValue>(C2))
+        return C2;
+      // undef << 0 -> undef
+      if (match(C2, m_Zero()))
+        return C1;
+      // undef << X -> 0
       return Constant::getNullValue(C1->getType());
     }
   }
@@ -1082,27 +1138,18 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
         return ConstantInt::get(CI1->getContext(), C1V | C2V);
       case Instruction::Xor:
         return ConstantInt::get(CI1->getContext(), C1V ^ C2V);
-      case Instruction::Shl: {
-        uint32_t shiftAmt = C2V.getZExtValue();
-        if (shiftAmt < C1V.getBitWidth())
-          return ConstantInt::get(CI1->getContext(), C1V.shl(shiftAmt));
-        else
-          return UndefValue::get(C1->getType()); // too big shift is undef
-      }
-      case Instruction::LShr: {
-        uint32_t shiftAmt = C2V.getZExtValue();
-        if (shiftAmt < C1V.getBitWidth())
-          return ConstantInt::get(CI1->getContext(), C1V.lshr(shiftAmt));
-        else
-          return UndefValue::get(C1->getType()); // too big shift is undef
-      }
-      case Instruction::AShr: {
-        uint32_t shiftAmt = C2V.getZExtValue();
-        if (shiftAmt < C1V.getBitWidth())
-          return ConstantInt::get(CI1->getContext(), C1V.ashr(shiftAmt));
-        else
-          return UndefValue::get(C1->getType()); // too big shift is undef
-      }
+      case Instruction::Shl:
+        if (C2V.ult(C1V.getBitWidth()))
+          return ConstantInt::get(CI1->getContext(), C1V.shl(C2V));
+        return UndefValue::get(C1->getType()); // too big shift is undef
+      case Instruction::LShr:
+        if (C2V.ult(C1V.getBitWidth()))
+          return ConstantInt::get(CI1->getContext(), C1V.lshr(C2V));
+        return UndefValue::get(C1->getType()); // too big shift is undef
+      case Instruction::AShr:
+        if (C2V.ult(C1V.getBitWidth()))
+          return ConstantInt::get(CI1->getContext(), C1V.ashr(C2V));
+        return UndefValue::get(C1->getType()); // too big shift is undef
       }
     }
 
@@ -1209,7 +1256,7 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
   }
 
   // We don't know how to fold this.
-  return 0;
+  return nullptr;
 }
 
 /// isZeroSizedType - This type is zero sized if its an array or structure of
@@ -1244,15 +1291,17 @@ static int IdxCompare(Constant *C1, Constant *C2, Type *ElTy) {
   if (!isa<ConstantInt>(C1) || !isa<ConstantInt>(C2))
     return -2; // don't know!
 
-  // Ok, we have two differing integer indices.  Sign extend them to be the same
-  // type.  Long is always big enough, so we use it.
-  if (!C1->getType()->isIntegerTy(64))
-    C1 = ConstantExpr::getSExt(C1, Type::getInt64Ty(C1->getContext()));
+  // We cannot compare the indices if they don't fit in an int64_t.
+  if (cast<ConstantInt>(C1)->getValue().getActiveBits() > 64 ||
+      cast<ConstantInt>(C2)->getValue().getActiveBits() > 64)
+    return -2; // don't know!
 
-  if (!C2->getType()->isIntegerTy(64))
-    C2 = ConstantExpr::getSExt(C2, Type::getInt64Ty(C1->getContext()));
+  // Ok, we have two differing integer indices.  Sign extend them to be the same
+  // type.
+  int64_t C1Val = cast<ConstantInt>(C1)->getSExtValue();
+  int64_t C2Val = cast<ConstantInt>(C2)->getSExtValue();
 
-  if (C1 == C2) return 0;  // They are equal
+  if (C1Val == C2Val) return 0;  // They are equal
 
   // If the type being indexed over is really just a zero sized type, there is
   // no pointer difference being made here.
@@ -1261,8 +1310,7 @@ static int IdxCompare(Constant *C1, Constant *C2, Type *ElTy) {
 
   // If they are really different, now that they are the same type, then we
   // found a difference!
-  if (cast<ConstantInt>(C1)->getSExtValue() < 
-      cast<ConstantInt>(C2)->getSExtValue())
+  if (C1Val < C2Val)
     return -1;
   else
     return 1;
@@ -1288,8 +1336,8 @@ static FCmpInst::Predicate evaluateFCmpRelation(Constant *V1, Constant *V2) {
 
   if (!isa<ConstantExpr>(V1)) {
     if (!isa<ConstantExpr>(V2)) {
-      // We distilled thisUse the standard constant folder for a few cases
-      ConstantInt *R = 0;
+      // Simple case, use the standard constant folder.
+      ConstantInt *R = nullptr;
       R = dyn_cast<ConstantInt>(
                       ConstantExpr::getFCmp(FCmpInst::FCMP_OEQ, V1, V2));
       if (R && !R->isZero()) 
@@ -1331,6 +1379,30 @@ static FCmpInst::Predicate evaluateFCmpRelation(Constant *V1, Constant *V2) {
   return FCmpInst::BAD_FCMP_PREDICATE;
 }
 
+static ICmpInst::Predicate areGlobalsPotentiallyEqual(const GlobalValue *GV1,
+                                                      const GlobalValue *GV2) {
+  auto isGlobalUnsafeForEquality = [](const GlobalValue *GV) {
+    if (GV->hasExternalWeakLinkage() || GV->hasWeakAnyLinkage())
+      return true;
+    if (const auto *GVar = dyn_cast<GlobalVariable>(GV)) {
+      Type *Ty = GVar->getValueType();
+      // A global with opaque type might end up being zero sized.
+      if (!Ty->isSized())
+        return true;
+      // A global with an empty type might lie at the address of any other
+      // global.
+      if (Ty->isEmptyTy())
+        return true;
+    }
+    return false;
+  };
+  // Don't try to decide equality of aliases.
+  if (!isa<GlobalAlias>(GV1) && !isa<GlobalAlias>(GV2))
+    if (!isGlobalUnsafeForEquality(GV1) && !isGlobalUnsafeForEquality(GV2))
+      return ICmpInst::ICMP_NE;
+  return ICmpInst::BAD_ICMP_PREDICATE;
+}
+
 /// evaluateICmpRelation - This function determines if there is anything we can
 /// decide about the two constants provided.  This doesn't need to handle simple
 /// things like integer comparisons, but should instead handle ConstantExprs
@@ -1355,7 +1427,7 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2,
         !isa<BlockAddress>(V2)) {
       // We distilled this down to a simple case, use the standard constant
       // folder.
-      ConstantInt *R = 0;
+      ConstantInt *R = nullptr;
       ICmpInst::Predicate pred = ICmpInst::ICMP_EQ;
       R = dyn_cast<ConstantInt>(ConstantExpr::getICmp(pred, V1, V2));
       if (R && !R->isZero()) 
@@ -1392,10 +1464,7 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2,
     // constant (which, since the types must match, means that it's a
     // ConstantPointerNull).
     if (const GlobalValue *GV2 = dyn_cast<GlobalValue>(V2)) {
-      // Don't try to decide equality of aliases.
-      if (!isa<GlobalAlias>(GV) && !isa<GlobalAlias>(GV2))
-        if (!GV->hasExternalWeakLinkage() || !GV2->hasExternalWeakLinkage())
-          return ICmpInst::ICMP_NE;
+      return areGlobalsPotentiallyEqual(GV, GV2);
     } else if (isa<BlockAddress>(V2)) {
       return ICmpInst::ICMP_NE; // Globals never equal labels.
     } else {
@@ -1460,7 +1529,8 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2,
       }
       break;
 
-    case Instruction::GetElementPtr:
+    case Instruction::GetElementPtr: {
+      GEPOperator *CE1GEP = cast<GEPOperator>(CE1);
       // Ok, since this is a getelementptr, we know that the constant has a
       // pointer type.  Check the various cases.
       if (isa<ConstantPointerNull>(V2)) {
@@ -1507,7 +1577,8 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2,
                    "Surprising getelementptr!");
             return isSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
           } else {
-            // If they are different globals, we don't know what the value is.
+            if (CE1GEP->hasAllZeroIndices())
+              return areGlobalsPotentiallyEqual(GV, GV2);
             return ICmpInst::BAD_ICMP_PREDICATE;
           }
         }
@@ -1523,8 +1594,14 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2,
           // By far the most common case to handle is when the base pointers are
           // obviously to the same global.
           if (isa<GlobalValue>(CE1Op0) && isa<GlobalValue>(CE2Op0)) {
-            if (CE1Op0 != CE2Op0) // Don't know relative ordering.
+            // Don't know relative ordering, but check for inequality.
+            if (CE1Op0 != CE2Op0) {
+              GEPOperator *CE2GEP = cast<GEPOperator>(CE2);
+              if (CE1GEP->hasAllZeroIndices() && CE2GEP->hasAllZeroIndices())
+                return areGlobalsPotentiallyEqual(cast<GlobalValue>(CE1Op0),
+                                                  cast<GlobalValue>(CE2Op0));
               return ICmpInst::BAD_ICMP_PREDICATE;
+            }
             // Ok, we know that both getelementptr instructions are based on the
             // same global.  From this, we can precisely determine the relative
             // ordering of the resultant pointers.
@@ -1570,6 +1647,7 @@ static ICmpInst::Predicate evaluateICmpRelation(Constant *V1, Constant *V2,
           }
         }
       }
+    }
     default:
       break;
     }
@@ -1596,15 +1674,22 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
 
   // Handle some degenerate cases first
   if (isa<UndefValue>(C1) || isa<UndefValue>(C2)) {
+    CmpInst::Predicate Predicate = CmpInst::Predicate(pred);
+    bool isIntegerPredicate = ICmpInst::isIntPredicate(Predicate);
     // For EQ and NE, we can always pick a value for the undef to make the
     // predicate pass or fail, so we can return undef.
-    // Also, if both operands are undef, we can return undef.
-    if (ICmpInst::isEquality(ICmpInst::Predicate(pred)) ||
-        (isa<UndefValue>(C1) && isa<UndefValue>(C2)))
+    // Also, if both operands are undef, we can return undef for int comparison.
+    if (ICmpInst::isEquality(Predicate) || (isIntegerPredicate && C1 == C2))
       return UndefValue::get(ResultTy);
-    // Otherwise, pick the same value as the non-undef operand, and fold
-    // it to true or false.
-    return ConstantInt::get(ResultTy, CmpInst::isTrueWhenEqual(pred));
+
+    // Otherwise, for integer compare, pick the same value as the non-undef
+    // operand, and fold it to true or false.
+    if (isIntegerPredicate)
+      return ConstantInt::get(ResultTy, CmpInst::isTrueWhenEqual(pred));
+
+    // Choosing NaN for the undef will always make unordered comparison succeed
+    // and ordered comparison fails.
+    return ConstantInt::get(ResultTy, CmpInst::isUnordered(Predicate));
   }
 
   // icmp eq/ne(null,GV) -> false/true
@@ -1720,7 +1805,10 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
     return ConstantVector::get(ResElts);
   }
 
-  if (C1->getType()->isFloatingPointTy()) {
+  if (C1->getType()->isFloatingPointTy() &&
+      // Only call evaluateFCmpRelation if we have a constant expr to avoid
+      // infinite recursive loop
+      (isa<ConstantExpr>(C1) || isa<ConstantExpr>(C2))) {
     int Result = -1;  // -1 = unknown, 0 = known false, 1 = known true.
     switch (evaluateFCmpRelation(C1, C2)) {
     default: llvm_unreachable("Unknown relation!");
@@ -1885,7 +1973,7 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
       return ConstantExpr::getICmp(pred, C2, C1);
     }
   }
-  return 0;
+  return nullptr;
 }
 
 /// isInBoundsIndices - Test whether the given sequence of *normalized* indices
@@ -1909,17 +1997,16 @@ static bool isInBoundsIndices(ArrayRef<IndexTy> Idxs) {
 }
 
 /// \brief Test whether a given ConstantInt is in-range for a SequentialType.
-static bool isIndexInRangeOfSequentialType(const SequentialType *STy,
+static bool isIndexInRangeOfSequentialType(SequentialType *STy,
                                            const ConstantInt *CI) {
-  if (const PointerType *PTy = dyn_cast<PointerType>(STy))
-    // Only handle pointers to sized types, not pointers to functions.
-    return PTy->getElementType()->isSized();
+  if (isa<PointerType>(STy))
+    return true;
 
   uint64_t NumElements = 0;
   // Determine the number of elements in our sequential type.
-  if (const ArrayType *ATy = dyn_cast<ArrayType>(STy))
+  if (auto *ATy = dyn_cast<ArrayType>(STy))
     NumElements = ATy->getNumElements();
-  else if (const VectorType *VTy = dyn_cast<VectorType>(STy))
+  else if (auto *VTy = dyn_cast<VectorType>(STy))
     NumElements = VTy->getNumElements();
 
   assert((isa<ArrayType>(STy) || NumElements > 0) &&
@@ -1940,7 +2027,7 @@ static bool isIndexInRangeOfSequentialType(const SequentialType *STy,
 }
 
 template<typename IndexTy>
-static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
+static Constant *ConstantFoldGetElementPtrImpl(Type *PointeeTy, Constant *C,
                                                bool inBounds,
                                                ArrayRef<IndexTy> Idxs) {
   if (Idxs.empty()) return C;
@@ -1950,8 +2037,9 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
 
   if (isa<UndefValue>(C)) {
     PointerType *Ptr = cast<PointerType>(C->getType());
-    Type *Ty = GetElementPtrInst::getIndexedType(Ptr, Idxs);
-    assert(Ty != 0 && "Invalid indices for GEP!");
+    Type *Ty = GetElementPtrInst::getIndexedType(
+        cast<PointerType>(Ptr->getScalarType())->getElementType(), Idxs);
+    assert(Ty && "Invalid indices for GEP!");
     return UndefValue::get(PointerType::get(Ty, Ptr->getAddressSpace()));
   }
 
@@ -1964,8 +2052,9 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
       }
     if (isNull) {
       PointerType *Ptr = cast<PointerType>(C->getType());
-      Type *Ty = GetElementPtrInst::getIndexedType(Ptr, Idxs);
-      assert(Ty != 0 && "Invalid indices for GEP!");
+      Type *Ty = GetElementPtrInst::getIndexedType(
+          cast<PointerType>(Ptr->getScalarType())->getElementType(), Idxs);
+      assert(Ty && "Invalid indices for GEP!");
       return ConstantPointerNull::get(PointerType::get(Ty,
                                                        Ptr->getAddressSpace()));
     }
@@ -1977,7 +2066,7 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
     // getelementptr instructions into a single instruction.
     //
     if (CE->getOpcode() == Instruction::GetElementPtr) {
-      Type *LastTy = 0;
+      Type *LastTy = nullptr;
       for (gep_type_iterator I = gep_type_begin(CE), E = gep_type_end(CE);
            I != E; ++I)
         LastTy = *I;
@@ -2010,8 +2099,7 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
       if (PerformFold) {
         SmallVector<Value*, 16> NewIndices;
         NewIndices.reserve(Idxs.size() + CE->getNumOperands());
-        for (unsigned i = 1, e = CE->getNumOperands()-1; i != e; ++i)
-          NewIndices.push_back(CE->getOperand(i));
+        NewIndices.append(CE->op_begin() + 1, CE->op_end() - 1);
 
         // Add the last index of the source with the first index of the new GEP.
         // Make sure to handle the case when they are actually different types.
@@ -2020,9 +2108,15 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
         if (!Idx0->isNullValue()) {
           Type *IdxTy = Combined->getType();
           if (IdxTy != Idx0->getType()) {
-            Type *Int64Ty = Type::getInt64Ty(IdxTy->getContext());
-            Constant *C1 = ConstantExpr::getSExtOrBitCast(Idx0, Int64Ty);
-            Constant *C2 = ConstantExpr::getSExtOrBitCast(Combined, Int64Ty);
+            unsigned CommonExtendedWidth =
+                std::max(IdxTy->getIntegerBitWidth(),
+                         Idx0->getType()->getIntegerBitWidth());
+            CommonExtendedWidth = std::max(CommonExtendedWidth, 64U);
+
+            Type *CommonTy =
+                Type::getIntNTy(IdxTy->getContext(), CommonExtendedWidth);
+            Constant *C1 = ConstantExpr::getSExtOrBitCast(Idx0, CommonTy);
+            Constant *C2 = ConstantExpr::getSExtOrBitCast(Combined, CommonTy);
             Combined = ConstantExpr::get(Instruction::Add, C1, C2);
           } else {
             Combined =
@@ -2032,10 +2126,9 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
 
         NewIndices.push_back(Combined);
         NewIndices.append(Idxs.begin() + 1, Idxs.end());
-        return
-          ConstantExpr::getGetElementPtr(CE->getOperand(0), NewIndices,
-                                         inBounds &&
-                                           cast<GEPOperator>(CE)->isInBounds());
+        return ConstantExpr::getGetElementPtr(
+            cast<GEPOperator>(CE)->getSourceElementType(), CE->getOperand(0),
+            NewIndices, inBounds && cast<GEPOperator>(CE)->isInBounds());
       }
     }
 
@@ -2060,8 +2153,8 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
         if (SrcArrayTy && DstArrayTy
             && SrcArrayTy->getElementType() == DstArrayTy->getElementType()
             && SrcPtrTy->getAddressSpace() == DstPtrTy->getAddressSpace())
-          return ConstantExpr::getGetElementPtr((Constant*)CE->getOperand(0),
-                                                Idxs, inBounds);
+          return ConstantExpr::getGetElementPtr(
+              SrcArrayTy, (Constant *)CE->getOperand(0), Idxs, inBounds);
       }
     }
   }
@@ -2069,11 +2162,11 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
   // Check to see if any array indices are not within the corresponding
   // notional array or vector bounds. If so, try to determine if they can be
   // factored out into preceding dimensions.
-  bool Unknown = false;
   SmallVector<Constant *, 8> NewIdxs;
-  Type *Ty = C->getType();
-  Type *Prev = 0;
-  for (unsigned i = 0, e = Idxs.size(); i != e;
+  Type *Ty = PointeeTy;
+  Type *Prev = C->getType();
+  bool Unknown = !isa<ConstantInt>(Idxs[0]);
+  for (unsigned i = 1, e = Idxs.size(); i != e;
        Prev = Ty, Ty = cast<CompositeType>(Ty)->getTypeAtIndex(Idxs[i]), ++i) {
     if (ConstantInt *CI = dyn_cast<ConstantInt>(Idxs[i])) {
       if (isa<ArrayType>(Ty) || isa<VectorType>(Ty))
@@ -2084,7 +2177,7 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
             // dimension.
             NewIdxs.resize(Idxs.size());
             uint64_t NumElements = 0;
-            if (const ArrayType *ATy = dyn_cast<ArrayType>(Ty))
+            if (auto *ATy = dyn_cast<ArrayType>(Ty))
               NumElements = ATy->getNumElements();
             else
               NumElements = cast<VectorType>(Ty)->getNumElements();
@@ -2095,14 +2188,20 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
             Constant *PrevIdx = cast<Constant>(Idxs[i-1]);
             Constant *Div = ConstantExpr::getSDiv(CI, Factor);
 
+            unsigned CommonExtendedWidth =
+                std::max(PrevIdx->getType()->getIntegerBitWidth(),
+                         Div->getType()->getIntegerBitWidth());
+            CommonExtendedWidth = std::max(CommonExtendedWidth, 64U);
+
             // Before adding, extend both operands to i64 to avoid
             // overflow trouble.
-            if (!PrevIdx->getType()->isIntegerTy(64))
-              PrevIdx = ConstantExpr::getSExt(PrevIdx,
-                                           Type::getInt64Ty(Div->getContext()));
-            if (!Div->getType()->isIntegerTy(64))
-              Div = ConstantExpr::getSExt(Div,
-                                          Type::getInt64Ty(Div->getContext()));
+            if (!PrevIdx->getType()->isIntegerTy(CommonExtendedWidth))
+              PrevIdx = ConstantExpr::getSExt(
+                  PrevIdx,
+                  Type::getIntNTy(Div->getContext(), CommonExtendedWidth));
+            if (!Div->getType()->isIntegerTy(CommonExtendedWidth))
+              Div = ConstantExpr::getSExt(
+                  Div, Type::getIntNTy(Div->getContext(), CommonExtendedWidth));
 
             NewIdxs[i-1] = ConstantExpr::getAdd(PrevIdx, Div);
           } else {
@@ -2121,26 +2220,43 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
   if (!NewIdxs.empty()) {
     for (unsigned i = 0, e = Idxs.size(); i != e; ++i)
       if (!NewIdxs[i]) NewIdxs[i] = cast<Constant>(Idxs[i]);
-    return ConstantExpr::getGetElementPtr(C, NewIdxs, inBounds);
+    return ConstantExpr::getGetElementPtr(PointeeTy, C, NewIdxs, inBounds);
   }
 
   // If all indices are known integers and normalized, we can do a simple
   // check for the "inbounds" property.
-  if (!Unknown && !inBounds &&
-      isa<GlobalVariable>(C) && isInBoundsIndices(Idxs))
-    return ConstantExpr::getInBoundsGetElementPtr(C, Idxs);
+  if (!Unknown && !inBounds)
+    if (auto *GV = dyn_cast<GlobalVariable>(C))
+      if (!GV->hasExternalWeakLinkage() && isInBoundsIndices(Idxs))
+        return ConstantExpr::getInBoundsGetElementPtr(PointeeTy, C, Idxs);
 
-  return 0;
+  return nullptr;
 }
 
 Constant *llvm::ConstantFoldGetElementPtr(Constant *C,
                                           bool inBounds,
                                           ArrayRef<Constant *> Idxs) {
-  return ConstantFoldGetElementPtrImpl(C, inBounds, Idxs);
+  return ConstantFoldGetElementPtrImpl(
+      cast<PointerType>(C->getType()->getScalarType())->getElementType(), C,
+      inBounds, Idxs);
 }
 
 Constant *llvm::ConstantFoldGetElementPtr(Constant *C,
                                           bool inBounds,
                                           ArrayRef<Value *> Idxs) {
-  return ConstantFoldGetElementPtrImpl(C, inBounds, Idxs);
+  return ConstantFoldGetElementPtrImpl(
+      cast<PointerType>(C->getType()->getScalarType())->getElementType(), C,
+      inBounds, Idxs);
+}
+
+Constant *llvm::ConstantFoldGetElementPtr(Type *Ty, Constant *C,
+                                          bool inBounds,
+                                          ArrayRef<Constant *> Idxs) {
+  return ConstantFoldGetElementPtrImpl(Ty, C, inBounds, Idxs);
+}
+
+Constant *llvm::ConstantFoldGetElementPtr(Type *Ty, Constant *C,
+                                          bool inBounds,
+                                          ArrayRef<Value *> Idxs) {
+  return ConstantFoldGetElementPtrImpl(Ty, C, inBounds, Idxs);
 }