Add comment as follow up to r245712
[oota-llvm.git] / lib / IR / ConstantFold.cpp
index 94a39441a1311fcb7399fc7b4e521e41c133d30a..f63ce9bbf038a1fda1da0a987341727618729f24 100644 (file)
@@ -132,7 +132,8 @@ static Constant *FoldBitCast(Constant *V, Type *DestTy) {
 
         if (ElTy == DPTy->getElementType())
           // This GEP is inbounds because all indices are zero.
-          return ConstantExpr::getInBoundsGetElementPtr(V, IdxList);
+          return ConstantExpr::getInBoundsGetElementPtr(PTy->getElementType(),
+                                                        V, IdxList);
       }
 
   // Handle casts from one vector constant to another.  We know that the src 
@@ -631,8 +632,8 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
     if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
       if (CE->getOpcode() == Instruction::GetElementPtr &&
           CE->getOperand(0)->isNullValue()) {
-        Type *Ty =
-          cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
+        GEPOperator *GEPO = cast<GEPOperator>(CE);
+        Type *Ty = GEPO->getSourceElementType();
         if (CE->getNumOperands() == 2) {
           // Handle a sizeof-like expression.
           Constant *Idx = CE->getOperand(1);
@@ -788,11 +789,10 @@ Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val,
     return UndefValue::get(Val->getType()->getVectorElementType());
 
   if (ConstantInt *CIdx = dyn_cast<ConstantInt>(Idx)) {
-    uint64_t Index = CIdx->getZExtValue();
     // ee({w,x,y,z}, wrong_value) -> undef
-    if (Index >= Val->getType()->getVectorNumElements())
+    if (CIdx->uge(Val->getType()->getVectorNumElements()))
       return UndefValue::get(Val->getType()->getVectorElementType());
-    return Val->getAggregateElement(Index);
+    return Val->getAggregateElement(CIdx->getZExtValue());
   }
   return nullptr;
 }
@@ -800,23 +800,30 @@ Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val,
 Constant *llvm::ConstantFoldInsertElementInstruction(Constant *Val,
                                                      Constant *Elt,
                                                      Constant *Idx) {
+  if (isa<UndefValue>(Idx))
+    return UndefValue::get(Val->getType());
+
   ConstantInt *CIdx = dyn_cast<ConstantInt>(Idx);
   if (!CIdx) return nullptr;
-  const APInt &IdxVal = CIdx->getValue();
-  
+
+  unsigned NumElts = Val->getType()->getVectorNumElements();
+  if (CIdx->uge(NumElts))
+    return UndefValue::get(Val->getType());
+
   SmallVector<Constant*, 16> Result;
-  Type *Ty = IntegerType::get(Val->getContext(), 32);
-  for (unsigned i = 0, e = Val->getType()->getVectorNumElements(); i != e; ++i){
+  Result.reserve(NumElts);
+  auto *Ty = Type::getInt32Ty(Val->getContext());
+  uint64_t IdxVal = CIdx->getZExtValue();
+  for (unsigned i = 0; i != NumElts; ++i) {    
     if (i == IdxVal) {
       Result.push_back(Elt);
       continue;
     }
     
-    Constant *C =
-      ConstantExpr::getExtractElement(Val, ConstantInt::get(Ty, i));
+    Constant *C = ConstantExpr::getExtractElement(Val, ConstantInt::get(Ty, i));
     Result.push_back(C);
   }
-  
+
   return ConstantVector::get(Result);
 }
 
@@ -1378,7 +1385,7 @@ static ICmpInst::Predicate areGlobalsPotentiallyEqual(const GlobalValue *GV1,
     if (GV->hasExternalWeakLinkage() || GV->hasWeakAnyLinkage())
       return true;
     if (const auto *GVar = dyn_cast<GlobalVariable>(GV)) {
-      Type *Ty = GVar->getType()->getPointerElementType();
+      Type *Ty = GVar->getValueType();
       // A global with opaque type might end up being zero sized.
       if (!Ty->isSized())
         return true;
@@ -1990,17 +1997,17 @@ static bool isInBoundsIndices(ArrayRef<IndexTy> Idxs) {
 }
 
 /// \brief Test whether a given ConstantInt is in-range for a SequentialType.
-static bool isIndexInRangeOfSequentialType(const SequentialType *STy,
+static bool isIndexInRangeOfSequentialType(SequentialType *STy,
                                            const ConstantInt *CI) {
-  if (const PointerType *PTy = dyn_cast<PointerType>(STy))
-    // Only handle pointers to sized types, not pointers to functions.
-    return PTy->getElementType()->isSized();
+  // And indicies are valid when indexing along a pointer
+  if (isa<PointerType>(STy))
+    return true;
 
   uint64_t NumElements = 0;
   // Determine the number of elements in our sequential type.
-  if (const ArrayType *ATy = dyn_cast<ArrayType>(STy))
+  if (auto *ATy = dyn_cast<ArrayType>(STy))
     NumElements = ATy->getNumElements();
-  else if (const VectorType *VTy = dyn_cast<VectorType>(STy))
+  else if (auto *VTy = dyn_cast<VectorType>(STy))
     NumElements = VTy->getNumElements();
 
   assert((isa<ArrayType>(STy) || NumElements > 0) &&
@@ -2021,7 +2028,7 @@ static bool isIndexInRangeOfSequentialType(const SequentialType *STy,
 }
 
 template<typename IndexTy>
-static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
+static Constant *ConstantFoldGetElementPtrImpl(Type *PointeeTy, Constant *C,
                                                bool inBounds,
                                                ArrayRef<IndexTy> Idxs) {
   if (Idxs.empty()) return C;
@@ -2120,10 +2127,9 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
 
         NewIndices.push_back(Combined);
         NewIndices.append(Idxs.begin() + 1, Idxs.end());
-        return
-          ConstantExpr::getGetElementPtr(CE->getOperand(0), NewIndices,
-                                         inBounds &&
-                                           cast<GEPOperator>(CE)->isInBounds());
+        return ConstantExpr::getGetElementPtr(
+            cast<GEPOperator>(CE)->getSourceElementType(), CE->getOperand(0),
+            NewIndices, inBounds && cast<GEPOperator>(CE)->isInBounds());
       }
     }
 
@@ -2148,8 +2154,8 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
         if (SrcArrayTy && DstArrayTy
             && SrcArrayTy->getElementType() == DstArrayTy->getElementType()
             && SrcPtrTy->getAddressSpace() == DstPtrTy->getAddressSpace())
-          return ConstantExpr::getGetElementPtr((Constant*)CE->getOperand(0),
-                                                Idxs, inBounds);
+          return ConstantExpr::getGetElementPtr(
+              SrcArrayTy, (Constant *)CE->getOperand(0), Idxs, inBounds);
       }
     }
   }
@@ -2157,11 +2163,11 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
   // Check to see if any array indices are not within the corresponding
   // notional array or vector bounds. If so, try to determine if they can be
   // factored out into preceding dimensions.
-  bool Unknown = false;
   SmallVector<Constant *, 8> NewIdxs;
-  Type *Ty = C->getType();
-  Type *Prev = nullptr;
-  for (unsigned i = 0, e = Idxs.size(); i != e;
+  Type *Ty = PointeeTy;
+  Type *Prev = C->getType();
+  bool Unknown = !isa<ConstantInt>(Idxs[0]);
+  for (unsigned i = 1, e = Idxs.size(); i != e;
        Prev = Ty, Ty = cast<CompositeType>(Ty)->getTypeAtIndex(Idxs[i]), ++i) {
     if (ConstantInt *CI = dyn_cast<ConstantInt>(Idxs[i])) {
       if (isa<ArrayType>(Ty) || isa<VectorType>(Ty))
@@ -2172,7 +2178,7 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
             // dimension.
             NewIdxs.resize(Idxs.size());
             uint64_t NumElements = 0;
-            if (const ArrayType *ATy = dyn_cast<ArrayType>(Ty))
+            if (auto *ATy = dyn_cast<ArrayType>(Ty))
               NumElements = ATy->getNumElements();
             else
               NumElements = cast<VectorType>(Ty)->getNumElements();
@@ -2215,7 +2221,7 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
   if (!NewIdxs.empty()) {
     for (unsigned i = 0, e = Idxs.size(); i != e; ++i)
       if (!NewIdxs[i]) NewIdxs[i] = cast<Constant>(Idxs[i]);
-    return ConstantExpr::getGetElementPtr(C, NewIdxs, inBounds);
+    return ConstantExpr::getGetElementPtr(PointeeTy, C, NewIdxs, inBounds);
   }
 
   // If all indices are known integers and normalized, we can do a simple
@@ -2223,7 +2229,7 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
   if (!Unknown && !inBounds)
     if (auto *GV = dyn_cast<GlobalVariable>(C))
       if (!GV->hasExternalWeakLinkage() && isInBoundsIndices(Idxs))
-        return ConstantExpr::getInBoundsGetElementPtr(C, Idxs);
+        return ConstantExpr::getInBoundsGetElementPtr(PointeeTy, C, Idxs);
 
   return nullptr;
 }
@@ -2231,11 +2237,27 @@ static Constant *ConstantFoldGetElementPtrImpl(Constant *C,
 Constant *llvm::ConstantFoldGetElementPtr(Constant *C,
                                           bool inBounds,
                                           ArrayRef<Constant *> Idxs) {
-  return ConstantFoldGetElementPtrImpl(C, inBounds, Idxs);
+  return ConstantFoldGetElementPtrImpl(
+      cast<PointerType>(C->getType()->getScalarType())->getElementType(), C,
+      inBounds, Idxs);
 }
 
 Constant *llvm::ConstantFoldGetElementPtr(Constant *C,
                                           bool inBounds,
                                           ArrayRef<Value *> Idxs) {
-  return ConstantFoldGetElementPtrImpl(C, inBounds, Idxs);
+  return ConstantFoldGetElementPtrImpl(
+      cast<PointerType>(C->getType()->getScalarType())->getElementType(), C,
+      inBounds, Idxs);
+}
+
+Constant *llvm::ConstantFoldGetElementPtr(Type *Ty, Constant *C,
+                                          bool inBounds,
+                                          ArrayRef<Constant *> Idxs) {
+  return ConstantFoldGetElementPtrImpl(Ty, C, inBounds, Idxs);
+}
+
+Constant *llvm::ConstantFoldGetElementPtr(Type *Ty, Constant *C,
+                                          bool inBounds,
+                                          ArrayRef<Value *> Idxs) {
+  return ConstantFoldGetElementPtrImpl(Ty, C, inBounds, Idxs);
 }