[NaryReassociate] allow candidate to have a different type
[oota-llvm.git] / lib / Transforms / Scalar / NaryReassociate.cpp
index 930552d2f9043855e2625f8a482d31953d553644..c8f885e7eec53d095e1f802b42eb33b3f90364b2 100644 (file)
@@ -421,19 +421,20 @@ GetElementPtrInst *NaryReassociate::tryReassociateGEPAtIndex(
       GEP->getSourceElementType(), SE->getSCEV(GEP->getPointerOperand()),
       IndexExprs, GEP->isInBounds());
 
-  auto *Candidate = findClosestMatchingDominator(CandidateExpr, GEP);
+  Value *Candidate = findClosestMatchingDominator(CandidateExpr, GEP);
   if (Candidate == nullptr)
     return nullptr;
 
-  PointerType *TypeOfCandidate = dyn_cast<PointerType>(Candidate->getType());
-  // Pretty rare but theoretically possible when a numeric value happens to
-  // share CandidateExpr.
-  if (TypeOfCandidate == nullptr)
-    return nullptr;
+  IRBuilder<> Builder(GEP);
+  // Candidate does not necessarily have the same pointer type as GEP. Use
+  // bitcast or pointer cast to make sure they have the same type, so that the
+  // later RAUW doesn't complain.
+  Candidate = Builder.CreateBitOrPointerCast(Candidate, GEP->getType());
+  assert(Candidate->getType() == GEP->getType());
 
   // NewGEP = (char *)Candidate + RHS * sizeof(IndexedType)
   uint64_t IndexedSize = DL->getTypeAllocSize(IndexedType);
-  Type *ElementType = TypeOfCandidate->getElementType();
+  Type *ElementType = GEP->getType()->getElementType();
   uint64_t ElementSize = DL->getTypeAllocSize(ElementType);
   // Another less rare case: because I is not necessarily the last index of the
   // GEP, the size of the type at the I-th index (IndexedSize) is not
@@ -453,8 +454,7 @@ GetElementPtrInst *NaryReassociate::tryReassociateGEPAtIndex(
     return nullptr;
 
   // NewGEP = &Candidate[RHS * (sizeof(IndexedType) / sizeof(Candidate[0])));
-  IRBuilder<> Builder(GEP);
-  Type *IntPtrTy = DL->getIntPtrType(TypeOfCandidate);
+  Type *IntPtrTy = DL->getIntPtrType(GEP->getType());
   if (RHS->getType() != IntPtrTy)
     RHS = Builder.CreateSExtOrTrunc(RHS, IntPtrTy);
   if (IndexedSize != ElementSize) {