Add comment as follow up to r245712
[oota-llvm.git] / lib / IR / ConstantFold.cpp
index 49ef3024033f98d88c5debedf7d213f8c95860c9..f63ce9bbf038a1fda1da0a987341727618729f24 100644 (file)
@@ -132,7 +132,8 @@ static Constant *FoldBitCast(Constant *V, Type *DestTy) {
 
         if (ElTy == DPTy->getElementType())
           // This GEP is inbounds because all indices are zero.
-          return ConstantExpr::getInBoundsGetElementPtr(V, IdxList);
+          return ConstantExpr::getInBoundsGetElementPtr(PTy->getElementType(),
+                                                        V, IdxList);
       }
 
   // Handle casts from one vector constant to another.  We know that the src 
@@ -169,7 +170,8 @@ static Constant *FoldBitCast(Constant *V, Type *DestTy) {
       // be the same. Consequently, we just fold to V.
       return V;
 
-    if (DestTy->isFloatingPointTy())
+    // See note below regarding the PPC_FP128 restriction.
+    if (DestTy->isFloatingPointTy() && !DestTy->isPPC_FP128Ty())
       return ConstantFP::get(DestTy->getContext(),
                              APFloat(DestTy->getFltSemantics(),
                                      CI->getValue()));
@@ -179,9 +181,19 @@ static Constant *FoldBitCast(Constant *V, Type *DestTy) {
   }
 
   // Handle ConstantFP input: FP -> Integral.
-  if (ConstantFP *FP = dyn_cast<ConstantFP>(V))
+  if (ConstantFP *FP = dyn_cast<ConstantFP>(V)) {
+    // PPC_FP128 is really the sum of two consecutive doubles, where the first
+    // double is always stored first in memory, regardless of the target
+    // endianness. The memory layout of i128, however, depends on the target
+    // endianness, and so we can't fold this without target endianness
+    // information. This should instead be handled by
+    // Analysis/ConstantFolding.cpp
+    if (FP->getType()->isPPC_FP128Ty())
+      return nullptr;
+
     return ConstantInt::get(FP->getContext(),
                             FP->getValueAPF().bitcastToAPInt());
+  }
 
   return nullptr;
 }
@@ -620,8 +632,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
     if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
       if (CE->getOpcode() == Instruction::GetElementPtr &&
           CE->getOperand(0)->isNullValue()) {
-        Type *Ty =
-          cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
+        GEPOperator *GEPO = cast<GEPOperator>(CE);
+        Type *Ty = GEPO->getSourceElementType();
         if (CE->getNumOperands() == 2) {
           // Handle a sizeof-like expression.
           Constant *Idx = CE->getOperand(1);
@@ -777,11 +789,10 @@ Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val,
     return UndefValue::get(Val->getType()->getVectorElementType());
 
   if (ConstantInt *CIdx = dyn_cast<ConstantInt>(Idx)) {
-    uint64_t Index = CIdx->getZExtValue();
     // ee({w,x,y,z}, wrong_value) -> undef
-    if (Index >= Val->getType()->getVectorNumElements())
+    if (CIdx->uge(Val->getType()->getVectorNumElements()))
       return UndefValue::get(Val->getType()->getVectorElementType());
-    return Val->getAggregateElement(Index);
+    return Val->getAggregateElement(CIdx->getZExtValue());
   }
   return nullptr;
 }
@@ -789,23 +800,30 @@ Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val,
 Constant *llvm::ConstantFoldInsertElementInstruction(Constant *Val,
                                                      Constant *Elt,
                                                      Constant *Idx) {
+  if (isa<UndefValue>(Idx))
+    return UndefValue::get(Val->getType());
+
   ConstantInt *CIdx = dyn_cast<ConstantInt>(Idx);
   if (!CIdx) return nullptr;
-  const APInt &IdxVal = CIdx->getValue();
-  
+
+  unsigned NumElts = Val->getType()->getVectorNumElements();
+  if (CIdx->uge(NumElts))
+    return UndefValue::get(Val->getType());
+
   SmallVector<Constant*, 16> Result;
-  Type *Ty = IntegerType::get(Val->getContext(), 32);
-  for (unsigned i = 0, e = Val->getType()->getVectorNumElements(); i != e; ++i){
+  Result.reserve(NumElts);
+  auto *Ty = Type::getInt32Ty(Val->getContext());
+  uint64_t IdxVal = CIdx->getZExtValue();
+  for (unsigned i = 0; i != NumElts; ++i) {    
     if (i == IdxVal) {
       Result.push_back(Elt);
       continue;
     }
     
-    Constant *C =
-      ConstantExpr::getExtractElement(Val, ConstantInt::get(Ty, i));
+    Constant *C = ConstantExpr::getExtractElement(Val, ConstantInt::get(Ty, i));
     Result.push_back(C);
   }
-  
+
   return ConstantVector::get(Result);
 }
 
@@ -915,12 +933,14 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
         return C1;
       return Constant::getNullValue(C1->getType());   // undef & X -> 0
     case Instruction::Mul: {
-      ConstantInt *CI;
-      // X * undef -> undef   if X is odd or undef
-      if (((CI = dyn_cast<ConstantInt>(C1)) && CI->getValue()[0]) ||
-          ((CI = dyn_cast<ConstantInt>(C2)) && CI->getValue()[0]) ||
-          (isa<UndefValue>(C1) && isa<UndefValue>(C2)))
-        return UndefValue::get(C1->getType());
+      // undef * undef -> undef
+      if (isa<UndefValue>(C1) && isa<UndefValue>(C2))
+        return C1;
+      const APInt *CV;
+      // X * undef -> undef   if X is odd
+      if (match(C1, m_APInt(CV)) || match(C2, m_APInt(CV)))
+        if ((*CV)[0])
+          return UndefValue::get(C1->getType());
 
       // X * undef -> 0       otherwise
       return Constant::getNullValue(C1->getType());
@@ -954,18 +974,28 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
       // X >>l undef -> undef
       if (isa<UndefValue>(C2))
         return C2;
+      // undef >>l 0 -> undef
+      if (match(C2, m_Zero()))
+        return C1;
       // undef >>l X -> 0
       return Constant::getNullValue(C1->getType());
     case Instruction::AShr:
       // X >>a undef -> undef
       if (isa<UndefValue>(C2))
         return C2;
-      // undef >>a X -> all ones
-      return Constant::getAllOnesValue(C1->getType());
+      // undef >>a 0 -> undef
+      if (match(C2, m_Zero()))
+        return C1;
+      // TODO: undef >>a X -> undef if the shift is exact
+      // undef >>a X -> 0
+      return Constant::getNullValue(C1->getType());
     case Instruction::Shl:
       // X << undef -> undef
       if (isa<UndefValue>(C2))
         return C2;
+      // undef << 0 -> undef
+      if (match(C2, m_Zero()))
+        return C1;
       // undef << X -> 0
       return Constant::getNullValue(C1->getType());
     }
@@ -1108,27 +1138,18 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
         return ConstantInt::get(CI1->getContext(), C1V | C2V);
       case Instruction::Xor:
         return ConstantInt::get(CI1->getContext(), C1V ^ C2V);
-      case Instruction::Shl: {
-        uint32_t shiftAmt = C2V.getZExtValue();
-        if (shiftAmt < C1V.getBitWidth())
-          return ConstantInt::get(CI1->getContext(), C1V.shl(shiftAmt));
-        else
-          return UndefValue::get(C1->getType()); // too big shift is undef
-      }
-      case Instruction::LShr: {
-        uint32_t shiftAmt = C2V.getZExtValue();
-        if (shiftAmt < C1V.getBitWidth())
-          return ConstantInt::get(CI1->getContext(), C1V.lshr(shiftAmt));
-        else
-          return UndefValue::get(C1->getType()); // too big shift is undef
-      }
-      case Instruction::AShr: {
-        uint32_t shiftAmt = C2V.getZExtValue();
-        if (shiftAmt < C1V.getBitWidth())
-          return ConstantInt::get(CI1->getContext(), C1V.ashr(shiftAmt));
-        else
-          return UndefValue::get(C1->getType()); // too big shift is undef
-      }
+      case Instruction::Shl:
+        if (C2V.ult(C1V.getBitWidth()))
+          return ConstantInt::get(CI1->getContext(), C1V.shl(C2V));
+        return UndefValue::get(C1->getType()); // too big shift is undef
+      case Instruction::LShr:
+        if (C2V.ult(C1V.getBitWidth()))
+          return ConstantInt::get(CI1->getContext(), C1V.lshr(C2V));
+        return UndefValue::get(C1->getType()); // too big shift is undef
+      case Instruction::AShr:
+        if (C2V.ult(C1V.getBitWidth()))
+          return ConstantInt::get(CI1->getContext(), C1V.ashr(C2V));
+        return UndefValue::get(C1->getType()); // too big shift is undef
       }
     }
 
@@ -1270,15 +1291,17 @@ static int IdxCompare(Constant *C1, Constant *C2, Type *ElTy) {
   if (!isa<ConstantInt>(C1) || !isa<ConstantInt>(C2))
     return -2; // don't know!
 
-  // Ok, we have two differing integer indices.  Sign extend them to be the same
-  // type.  Long is always big enough, so we use it.
-  if (!C1->getType()->isIntegerTy(64))
-    C1 = ConstantExpr::getSExt(C1, Type::getInt64Ty(C1->getContext()));
+  // We cannot compare the indices if they don't fit in an int64_t.
+  if (cast<ConstantInt>(C1)->getValue().getActiveBits() > 64 ||
+      cast<ConstantInt>(C2)->getValue().getActiveBits() > 64)
+    return -2; // don't know!
 
-  if (!C2->getType()->isIntegerTy(64))
-    C2 = ConstantExpr::getSExt(C2, Type::getInt64Ty(C1->getContext()));
+  // Ok, we have two differing integer indices.  Sign extend them to be the same
+  // type.
+  int64_t C1Val = cast<ConstantInt>(C1)->getSExtValue();
+  int64_t C2Val = cast<ConstantInt>(C2)->getSExtValue();
 
-  if (C1 == C2) return 0;  // They are equal
+  if (C1Val == C2Val) return 0;  // They are equal
 
   // If the type being indexed over is really just a zero sized type, there is
   // no pointer difference being made here.
@@ -1287,8 +1310,7 @@ static int IdxCompare(Constant *C1, Constant *C2, Type *ElTy) {
 
   // If they are really different, now that they are the same type, then we
   // found a difference!
-  if (cast<ConstantInt>(C1)->getSExtValue() < 
-      cast<ConstantInt>(C2)->getSExtValue())
+  if (C1Val < C2Val)
     return -1;
   else
     return 1;
@@ -1314,7 +1336,7 @@ static FCmpInst::Predicate evaluateFCmpRelation(Constant *V1, Constant *V2) {
 
   if (!isa<ConstantExpr>(V1)) {
     if (!isa<ConstantExpr>(V2)) {
-      // We distilled thisUse the standard constant folder for a few cases
+      // Simple case, use the standard constant folder.
       ConstantInt *R = nullptr;
       R = dyn_cast<ConstantInt>(
                       ConstantExpr::getFCmp(FCmpInst::FCMP_OEQ, V1, V2));
@@ -1363,7 +1385,7 @@ static ICmpInst::Predicate areGlobalsPotentiallyEqual(const GlobalValue *GV1,
     if (GV->hasExternalWeakLinkage() || GV->hasWeakAnyLinkage())
       return true;
     if (const auto *GVar = dyn_cast<GlobalVariable>(GV)) {
-      Type *Ty = GVar->getType()->getPointerElementType();
+      Type *Ty = GVar->getValueType();
       // A global with opaque type might end up being zero sized.
       if (!Ty->isSized())
         return true;
@@ -1652,15 +1674,22 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
 
   // Handle some degenerate cases first
   if (isa<UndefValue>(C1) || isa<UndefValue>(C2)) {
+    CmpInst::Predicate Predicate = CmpInst::Predicate(pred);
+    bool isIntegerPredicate = ICmpInst::isIntPredicate(Predicate);
     // For EQ and NE, we can always pick a value for the undef to make the
     // predicate pass or fail, so we can return undef.
-    // Also, if both operands are undef, we can return undef.
-    if (ICmpInst::isEquality(ICmpInst::Predicate(pred)) ||
-        (isa<UndefValue>(C1) && isa<UndefValue>(C2)))
+    // Also, if both operands are undef, we can return undef for int comparison.
+    if (ICmpInst::isEquality(Predicate) || (isIntegerPredicate && C1 == C2))
       return UndefValue::get(ResultTy);
-    // Otherwise, pick the same value as the non-undef operand, and fold
-    // it to true or false.
-    return ConstantInt::get(ResultTy, CmpInst::isTrueWhenEqual(pred));
+
+    // Otherwise, for integer compare, pick the same value as the non-undef
+    // operand, and fold it to true or false.
+    if (isIntegerPredicate)
+      return ConstantInt::get(ResultTy, CmpInst::isTrueWhenEqual(pred));
+
+    // Choosing NaN for the undef will always make unordered comparison succeed
+    // and ordered comparison fails.
+    return ConstantInt::get(ResultTy, CmpInst::isUnordered(Predicate));
   }
 
   // icmp eq/ne(null,GV) -> false/true
@@ -1776,7 +1805,10 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
     return ConstantVector::get(ResElts);
   }
 
-  if (C1->getType()->isFloatingPointTy()) {
+  if (C1->getType()->isFloatingPointTy() &&
+      // Only call evaluateFCmpRelation if we have a constant expr to avoid
+      // infinite recursive loop
+      (isa<ConstantExpr>(C1) || isa<ConstantExpr>(C2))) {
     int Result = -1;  // -1 = unknown, 0 = known false, 1 = known true.
     switch (evaluateFCmpRelation(C1, C2)) {
     default: llvm_unreachable("Unknown relation!");
@@ -1965,17 +1997,17 @@ static bool isInBoundsIndices(ArrayRef<IndexTy> Idxs) {
 }
 
 /// \brief Test whether a given ConstantInt is in-range for a SequentialType.
-static bool isIndexInRangeOfSequentialType(const SequentialType *STy,
+static bool isIndexInRangeOfSequentialType(SequentialType *STy,
                                            const ConstantInt *CI) {
-  if (const PointerType *PTy = dyn_cast<PointerType>(STy))
-    // Only handle pointers to sized types, not pointers to functions.
-    return PTy->getElementType()->isSized();
+  // And indicies are valid when indexing along a pointer
+  if (isa<PointerType>(STy))
+    return true;
 
   uint64_t NumElements = 0;
   // Determine the number of elements in our sequential type.
-  if (const ArrayType *ATy = dyn_cast<ArrayType>(STy))
+  if (auto *ATy = dyn_cast<ArrayType>(STy))
     NumElements = ATy->getNumElements();
-  else if (const VectorType *VTy = dyn_cast<VectorType>(STy))
+  else if (auto *VTy = dyn_cast<VectorType>(STy))
     NumElements = VTy->getNumElements();
 
   assert((isa<ArrayType>(STy) || NumElements > 0) &&
@@ -1996,7 +2028,7 @@ static bool isIndexInRangeOfSequentialType(const SequentialType *STy,
 }
 
 template<typename IndexTy>
-static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
+static Constant *ConstantFoldGetElementPtrImpl(Type *PointeeTy, Constant *C,
                                                bool inBounds,
                                                ArrayRef<IndexTy> Idxs) {
   if (Idxs.empty()) return C;
@@ -2006,7 +2038,8 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
 
   if (isa<UndefValue>(C)) {
     PointerType *Ptr = cast<PointerType>(C->getType());
-    Type *Ty = GetElementPtrInst::getIndexedType(Ptr, Idxs);
+    Type *Ty = GetElementPtrInst::getIndexedType(
+        cast<PointerType>(Ptr->getScalarType())->getElementType(), Idxs);
     assert(Ty && "Invalid indices for GEP!");
     return UndefValue::get(PointerType::get(Ty, Ptr->getAddressSpace()));
   }
@@ -2020,7 +2053,8 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
       }
     if (isNull) {
       PointerType *Ptr = cast<PointerType>(C->getType());
-      Type *Ty = GetElementPtrInst::getIndexedType(Ptr, Idxs);
+      Type *Ty = GetElementPtrInst::getIndexedType(
+          cast<PointerType>(Ptr->getScalarType())->getElementType(), Idxs);
       assert(Ty && "Invalid indices for GEP!");
       return ConstantPointerNull::get(PointerType::get(Ty,
                                                        Ptr->getAddressSpace()));
@@ -2066,8 +2100,7 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
       if (PerformFold) {
         SmallVector<Value*, 16> NewIndices;
         NewIndices.reserve(Idxs.size() + CE->getNumOperands());
-        for (unsigned i = 1, e = CE->getNumOperands()-1; i != e; ++i)
-          NewIndices.push_back(CE->getOperand(i));
+        NewIndices.append(CE->op_begin() + 1, CE->op_end() - 1);
 
         // Add the last index of the source with the first index of the new GEP.
         // Make sure to handle the case when they are actually different types.
@@ -2076,9 +2109,15 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
         if (!Idx0->isNullValue()) {
           Type *IdxTy = Combined->getType();
           if (IdxTy != Idx0->getType()) {
-            Type *Int64Ty = Type::getInt64Ty(IdxTy->getContext());
-            Constant *C1 = ConstantExpr::getSExtOrBitCast(Idx0, Int64Ty);
-            Constant *C2 = ConstantExpr::getSExtOrBitCast(Combined, Int64Ty);
+            unsigned CommonExtendedWidth =
+                std::max(IdxTy->getIntegerBitWidth(),
+                         Idx0->getType()->getIntegerBitWidth());
+            CommonExtendedWidth = std::max(CommonExtendedWidth, 64U);
+
+            Type *CommonTy =
+                Type::getIntNTy(IdxTy->getContext(), CommonExtendedWidth);
+            Constant *C1 = ConstantExpr::getSExtOrBitCast(Idx0, CommonTy);
+            Constant *C2 = ConstantExpr::getSExtOrBitCast(Combined, CommonTy);
             Combined = ConstantExpr::get(Instruction::Add, C1, C2);
           } else {
             Combined =
@@ -2088,10 +2127,9 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
 
         NewIndices.push_back(Combined);
         NewIndices.append(Idxs.begin() + 1, Idxs.end());
-        return
-          ConstantExpr::getGetElementPtr(CE->getOperand(0), NewIndices,
-                                         inBounds &&
-                                           cast<GEPOperator>(CE)->isInBounds());
+        return ConstantExpr::getGetElementPtr(
+            cast<GEPOperator>(CE)->getSourceElementType(), CE->getOperand(0),
+            NewIndices, inBounds && cast<GEPOperator>(CE)->isInBounds());
       }
     }
 
@@ -2116,8 +2154,8 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
         if (SrcArrayTy && DstArrayTy
             && SrcArrayTy->getElementType() == DstArrayTy->getElementType()
             && SrcPtrTy->getAddressSpace() == DstPtrTy->getAddressSpace())
-          return ConstantExpr::getGetElementPtr((Constant*)CE->getOperand(0),
-                                                Idxs, inBounds);
+          return ConstantExpr::getGetElementPtr(
+              SrcArrayTy, (Constant *)CE->getOperand(0), Idxs, inBounds);
       }
     }
   }
@@ -2125,11 +2163,11 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
   // Check to see if any array indices are not within the corresponding
   // notional array or vector bounds. If so, try to determine if they can be
   // factored out into preceding dimensions.
-  bool Unknown = false;
   SmallVector<Constant *, 8> NewIdxs;
-  Type *Ty = C->getType();
-  Type *Prev = nullptr;
-  for (unsigned i = 0, e = Idxs.size(); i != e;
+  Type *Ty = PointeeTy;
+  Type *Prev = C->getType();
+  bool Unknown = !isa<ConstantInt>(Idxs[0]);
+  for (unsigned i = 1, e = Idxs.size(); i != e;
        Prev = Ty, Ty = cast<CompositeType>(Ty)->getTypeAtIndex(Idxs[i]), ++i) {
     if (ConstantInt *CI = dyn_cast<ConstantInt>(Idxs[i])) {
       if (isa<ArrayType>(Ty) || isa<VectorType>(Ty))
@@ -2140,7 +2178,7 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
             // dimension.
             NewIdxs.resize(Idxs.size());
             uint64_t NumElements = 0;
-            if (const ArrayType *ATy = dyn_cast<ArrayType>(Ty))
+            if (auto *ATy = dyn_cast<ArrayType>(Ty))
               NumElements = ATy->getNumElements();
             else
               NumElements = cast<VectorType>(Ty)->getNumElements();
@@ -2151,14 +2189,20 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
             Constant *PrevIdx = cast<Constant>(Idxs[i-1]);
             Constant *Div = ConstantExpr::getSDiv(CI, Factor);
 
+            unsigned CommonExtendedWidth =
+                std::max(PrevIdx->getType()->getIntegerBitWidth(),
+                         Div->getType()->getIntegerBitWidth());
+            CommonExtendedWidth = std::max(CommonExtendedWidth, 64U);
+
             // Before adding, extend both operands to i64 to avoid
             // overflow trouble.
-            if (!PrevIdx->getType()->isIntegerTy(64))
-              PrevIdx = ConstantExpr::getSExt(PrevIdx,
-                                           Type::getInt64Ty(Div->getContext()));
-            if (!Div->getType()->isIntegerTy(64))
-              Div = ConstantExpr::getSExt(Div,
-                                          Type::getInt64Ty(Div->getContext()));
+            if (!PrevIdx->getType()->isIntegerTy(CommonExtendedWidth))
+              PrevIdx = ConstantExpr::getSExt(
+                  PrevIdx,
+                  Type::getIntNTy(Div->getContext(), CommonExtendedWidth));
+            if (!Div->getType()->isIntegerTy(CommonExtendedWidth))
+              Div = ConstantExpr::getSExt(
+                  Div, Type::getIntNTy(Div->getContext(), CommonExtendedWidth));
 
             NewIdxs[i-1] = ConstantExpr::getAdd(PrevIdx, Div);
           } else {
@@ -2177,7 +2221,7 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
   if (!NewIdxs.empty()) {
     for (unsigned i = 0, e = Idxs.size(); i != e; ++i)
       if (!NewIdxs[i]) NewIdxs[i] = cast<Constant>(Idxs[i]);
-    return ConstantExpr::getGetElementPtr(C, NewIdxs, inBounds);
+    return ConstantExpr::getGetElementPtr(PointeeTy, C, NewIdxs, inBounds);
   }
 
   // If all indices are known integers and normalized, we can do a simple
@@ -2185,7 +2229,7 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
   if (!Unknown && !inBounds)
     if (auto *GV = dyn_cast<GlobalVariable>(C))
       if (!GV->hasExternalWeakLinkage() && isInBoundsIndices(Idxs))
-        return ConstantExpr::getInBoundsGetElementPtr(C, Idxs);
+        return ConstantExpr::getInBoundsGetElementPtr(PointeeTy, C, Idxs);
 
   return nullptr;
 }
@@ -2193,11 +2237,27 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
 Constant *llvm::ConstantFoldGetElementPtr(Constant *C,
                                           bool inBounds,
                                           ArrayRef<Constant *> Idxs) {
-  return ConstantFoldGetElementPtrImpl(C, inBounds, Idxs);
+  return ConstantFoldGetElementPtrImpl(
+      cast<PointerType>(C->getType()->getScalarType())->getElementType(), C,
+      inBounds, Idxs);
 }
 
 Constant *llvm::ConstantFoldGetElementPtr(Constant *C,
                                           bool inBounds,
                                           ArrayRef<Value *> Idxs) {
-  return ConstantFoldGetElementPtrImpl(C, inBounds, Idxs);
+  return ConstantFoldGetElementPtrImpl(
+      cast<PointerType>(C->getType()->getScalarType())->getElementType(), C,
+      inBounds, Idxs);
+}
+
+Constant *llvm::ConstantFoldGetElementPtr(Type *Ty, Constant *C,
+                                          bool inBounds,
+                                          ArrayRef<Constant *> Idxs) {
+  return ConstantFoldGetElementPtrImpl(Ty, C, inBounds, Idxs);
+}
+
+Constant *llvm::ConstantFoldGetElementPtr(Type *Ty, Constant *C,
+                                          bool inBounds,
+                                          ArrayRef<Value *> Idxs) {
+  return ConstantFoldGetElementPtrImpl(Ty, C, inBounds, Idxs);
 }