[opaque pointer type] Pass explicit pointer type through GEP constant folding
[oota-llvm.git] / lib / IR / ConstantFold.cpp
index 2a524937391c5cddfafc5258685840ae6edd8fe8..8afb3e489de266927bfbc355eb817dc6787574e6 100644 (file)
@@ -2028,7 +2028,7 @@ static bool isIndexInRangeOfSequentialType(const SequentialType *STy,
 }
 
 template<typename IndexTy>
-static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
+static Constant *ConstantFoldGetElementPtrImpl(Type *PointeeTy, Constant *C,
                                                bool inBounds,
                                                ArrayRef<IndexTy> Idxs) {
   if (Idxs.empty()) return C;
@@ -2165,9 +2165,9 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
   // factored out into preceding dimensions.
   bool Unknown = false;
   SmallVector<Constant *, 8> NewIdxs;
-  Type *Ty = C->getType();
-  Type *Prev = nullptr;
-  for (unsigned i = 0, e = Idxs.size(); i != e;
+  Type *Ty = PointeeTy;
+  Type *Prev = C->getType();
+  for (unsigned i = 1, e = Idxs.size(); i != e;
        Prev = Ty, Ty = cast<CompositeType>(Ty)->getTypeAtIndex(Idxs[i]), ++i) {
     if (ConstantInt *CI = dyn_cast<ConstantInt>(Idxs[i])) {
       if (isa<ArrayType>(Ty) || isa<VectorType>(Ty))
@@ -2229,7 +2229,7 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
   if (!Unknown && !inBounds)
     if (auto *GV = dyn_cast<GlobalVariable>(C))
       if (!GV->hasExternalWeakLinkage() && isInBoundsIndices(Idxs))
-        return ConstantExpr::getInBoundsGetElementPtr(nullptr, C, Idxs);
+        return ConstantExpr::getInBoundsGetElementPtr(PointeeTy, C, Idxs);
 
   return nullptr;
 }
@@ -2237,11 +2237,27 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
 Constant *llvm::ConstantFoldGetElementPtr(Constant *C,
                                           bool inBounds,
                                           ArrayRef<Constant *> Idxs) {
-  return ConstantFoldGetElementPtrImpl(C, inBounds, Idxs);
+  return ConstantFoldGetElementPtrImpl(
+      cast<PointerType>(C->getType()->getScalarType())->getElementType(), C,
+      inBounds, Idxs);
 }
 
 Constant *llvm::ConstantFoldGetElementPtr(Constant *C,
                                           bool inBounds,
                                           ArrayRef<Value *> Idxs) {
-  return ConstantFoldGetElementPtrImpl(C, inBounds, Idxs);
+  return ConstantFoldGetElementPtrImpl(
+      cast<PointerType>(C->getType()->getScalarType())->getElementType(), C,
+      inBounds, Idxs);
+}
+
+Constant *llvm::ConstantFoldGetElementPtr(Type *Ty, Constant *C,
+                                          bool inBounds,
+                                          ArrayRef<Constant *> Idxs) {
+  return ConstantFoldGetElementPtrImpl(Ty, C, inBounds, Idxs);
+}
+
+Constant *llvm::ConstantFoldGetElementPtr(Type *Ty, Constant *C,
+                                          bool inBounds,
+                                          ArrayRef<Value *> Idxs) {
+  return ConstantFoldGetElementPtrImpl(Ty, C, inBounds, Idxs);
 }