[RewriteStatepointsForGC] Extend base pointer inference to handle insertelement
[oota-llvm.git] / lib / Transforms / Scalar / RewriteStatepointsForGC.cpp
index 031f40e2caa6c0a11adc2700080893b2ce22d0d9..c70619cf27c5490522aa1e13dfabdc2a0c8cd27c 100644 (file)
@@ -328,7 +328,7 @@ static BaseDefiningValueResult findBaseDefiningValue(Value *I);
 /// If the later, the return pointer is a BDV (or possibly a base) for the
 /// particular element in 'I'.  
 static BaseDefiningValueResult
-findBaseDefiningValueOfVector(Value *I, Value *Index = nullptr) {
+findBaseDefiningValueOfVector(Value *I) {
   assert(I->getType()->isVectorTy() &&
          cast<VectorType>(I->getType())->getElementType()->isPointerTy() &&
          "Illegal to ask for the base pointer of a non-pointer type");
@@ -362,35 +362,12 @@ findBaseDefiningValueOfVector(Value *I, Value *Index = nullptr) {
   
   if (isa<LoadInst>(I))
     return BaseDefiningValueResult(I, true);
-  
-  // For an insert element, we might be able to look through it if we know
-  // something about the indexes.
-  if (InsertElementInst *IEI = dyn_cast<InsertElementInst>(I)) {
-    if (Index) {
-      Value *InsertIndex = IEI->getOperand(2);
-      // This index is inserting the value, look for its BDV
-      if (InsertIndex == Index)
-        return findBaseDefiningValue(IEI->getOperand(1));
-      // Both constant, and can't be equal per above. This insert is definitely
-      // not relevant, look back at the rest of the vector and keep trying.
-      if (isa<ConstantInt>(Index) && isa<ConstantInt>(InsertIndex))
-        return findBaseDefiningValueOfVector(IEI->getOperand(0), Index);
-    }
 
-    // If both inputs to the insertelement are known bases, then so is the
-    // insertelement itself.  NOTE: This should be handled within the generic
-    // base pointer inference code and after http://reviews.llvm.org/D12583,
-    // will be.  However, when strengthening asserts I needed to add this to
-    // keep an existing test passing which was 'working'. FIXME
-    if (findBaseDefiningValue(IEI->getOperand(0)).IsKnownBase &&
-        findBaseDefiningValue(IEI->getOperand(1)).IsKnownBase)
-      return BaseDefiningValueResult(IEI, true);
-    
+  if (isa<InsertElementInst>(I))
     // We don't know whether this vector contains entirely base pointers or
     // not.  To be conservatively correct, we treat it as a BDV and will
     // duplicate code as needed to construct a parallel vector of bases.
-    return BaseDefiningValueResult(IEI, false);
-  }
+    return BaseDefiningValueResult(I, false);
 
   if (isa<ShuffleVectorInst>(I))
     // We don't know whether this vector contains entirely base pointers or
@@ -528,27 +505,11 @@ static BaseDefiningValueResult findBaseDefiningValue(Value *I) {
   // We may need to insert a parallel instruction to extract the appropriate
   // element out of the base vector corresponding to the input. Given this,
   // it's analogous to the phi and select case even though it's not a merge.
-  if (auto *EEI = dyn_cast<ExtractElementInst>(I)) {
-    Value *VectorOperand = EEI->getVectorOperand();
-    Value *Index = EEI->getIndexOperand();
-    auto VecResult = findBaseDefiningValueOfVector(VectorOperand, Index);
-    Value *VectorBase = VecResult.BDV;
-    if (VectorBase->getType()->isPointerTy())
-      // We found a BDV for this specific element with the vector.  This is an
-      // optimization, but in practice it covers most of the useful cases
-      // created via scalarization. Note: The peephole optimization here is
-      // currently needed for correctness since the general algorithm doesn't
-      // yet handle insertelements.  That will change shortly.
-      return BaseDefiningValueResult(VectorBase, VecResult.IsKnownBase);
-    else {
-      assert(VectorBase->getType()->isVectorTy());
-      // Otherwise, we have an instruction which potentially produces a
-      // derived pointer and we need findBasePointers to clone code for us
-      // such that we can create an instruction which produces the
-      // accompanying base pointer.
-      return BaseDefiningValueResult(I, VecResult.IsKnownBase);
-    }
-  }
+  if (isa<ExtractElementInst>(I))
+    // Note: There a lot of obvious peephole cases here.  This are deliberately
+    // handled after the main base pointer inference algorithm to make writing
+    // test cases to exercise that code easier.
+    return BaseDefiningValueResult(I, false);
 
   // The last two cases here don't return a base pointer.  Instead, they
   // return a value which dynamically selects from among several base
@@ -587,7 +548,9 @@ static Value *findBaseOrBDV(Value *I, DefiningValueMapTy &Cache) {
 /// Given the result of a call to findBaseDefiningValue, or findBaseOrBDV,
 /// is it known to be a base pointer?  Or do we need to continue searching.
 static bool isKnownBaseResult(Value *V) {
-  if (!isa<PHINode>(V) && !isa<SelectInst>(V) && !isa<ExtractElementInst>(V)) {
+  if (!isa<PHINode>(V) && !isa<SelectInst>(V) &&
+      !isa<ExtractElementInst>(V) && !isa<InsertElementInst>(V) &&
+      !isa<ShuffleVectorInst>(V)) {
     // no recursion possible
     return true;
   }
@@ -755,7 +718,8 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
 
 #ifndef NDEBUG
   auto isExpectedBDVType = [](Value *BDV) {
-    return isa<PHINode>(BDV) || isa<SelectInst>(BDV) || isa<ExtractElementInst>(BDV);
+    return isa<PHINode>(BDV) || isa<SelectInst>(BDV) ||
+           isa<ExtractElementInst>(BDV) || isa<InsertElementInst>(BDV);
   };
 #endif
 
@@ -795,10 +759,12 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
         visitIncomingValue(Sel->getFalseValue());
       } else if (auto *EE = dyn_cast<ExtractElementInst>(Current)) {
         visitIncomingValue(EE->getVectorOperand());
+      } else if (auto *IE = dyn_cast<InsertElementInst>(Current)) {
+        visitIncomingValue(IE->getOperand(0)); // vector operand
+        visitIncomingValue(IE->getOperand(1)); // scalar operand
       } else {
-        // There are two classes of instructions we know we don't handle.
-        assert(isa<ShuffleVectorInst>(Current) ||
-               isa<InsertElementInst>(Current));
+        // There is one known class of instructions we know we don't handle.
+        assert(isa<ShuffleVectorInst>(Current));
         llvm_unreachable("unimplemented instruction case");
       }
     }
@@ -849,11 +815,16 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
       } else if (PHINode *Phi = dyn_cast<PHINode>(v)) {
         for (Value *Val : Phi->incoming_values())
           calculateMeet.meetWith(getStateForInput(Val));
-      } else {
+      } else if (auto *EE = dyn_cast<ExtractElementInst>(v)) {
         // The 'meet' for an extractelement is slightly trivial, but it's still
         // useful in that it drives us to conflict if our input is.
-        auto *EE = cast<ExtractElementInst>(v);
         calculateMeet.meetWith(getStateForInput(EE->getVectorOperand()));
+      } else {
+        // Given there's a inherent type mismatch between the operands, will
+        // *always* produce Conflict.
+        auto *IE = cast<InsertElementInst>(v);
+        calculateMeet.meetWith(getStateForInput(IE->getOperand(0)));
+        calculateMeet.meetWith(getStateForInput(IE->getOperand(1)));
       }
 
       BDVState oldState = states[v];
@@ -899,6 +870,13 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
       BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {}));
       states[I] = BDVState(BDVState::Base, BaseInst);
     }
+
+    // Since we're joining a vector and scalar base, they can never be the
+    // same.  As a result, we should always see insert element having reached
+    // the conflict state.
+    if (isa<InsertElementInst>(I)) {
+      assert(State.isConflict());
+    }
     
     if (!State.isConflict())
       continue;
@@ -920,14 +898,22 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
           (I->getName() + ".base").str() : "base_select";
         return SelectInst::Create(Sel->getCondition(), Undef,
                                   Undef, Name, Sel);
-      } else {
-        auto *EE = cast<ExtractElementInst>(I);
+      } else if (auto *EE = dyn_cast<ExtractElementInst>(I)) {
         UndefValue *Undef = UndefValue::get(EE->getVectorOperand()->getType());
         std::string Name = I->hasName() ?
           (I->getName() + ".base").str() : "base_ee";
         return ExtractElementInst::Create(Undef, EE->getIndexOperand(), Name,
                                           EE);
+      } else {
+        auto *IE = cast<InsertElementInst>(I);
+        UndefValue *VecUndef = UndefValue::get(IE->getOperand(0)->getType());
+        UndefValue *ScalarUndef = UndefValue::get(IE->getOperand(1)->getType());
+        std::string Name = I->hasName() ?
+          (I->getName() + ".base").str() : "base_ie";
+        return InsertElementInst::Create(VecUndef, ScalarUndef,
+                                         IE->getOperand(2), Name, IE);
       }
+
     };
     Instruction *BaseInst = MakeBaseInstPlaceholder(I);
     // Add metadata marking this as a base value
@@ -1029,14 +1015,31 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
         Value *Base = getBaseForInput(InVal, BaseSel);
         BaseSel->setOperand(i, Base);
       }
-    } else {
-      auto *BaseEE = cast<ExtractElementInst>(state.getBase());
+    } else if (auto *BaseEE = dyn_cast<ExtractElementInst>(state.getBase())) {
       Value *InVal = cast<ExtractElementInst>(v)->getVectorOperand();
       // Find the instruction which produces the base for each input.  We may
       // need to insert a bitcast.
       Value *Base = getBaseForInput(InVal, BaseEE);
       BaseEE->setOperand(0, Base);
+    } else {
+      auto *BaseIE = cast<InsertElementInst>(state.getBase());
+      auto *BdvIE = cast<InsertElementInst>(v);
+      auto UpdateOperand = [&](int OperandIdx) {
+        Value *InVal = BdvIE->getOperand(OperandIdx);
+        Value *Base = findBaseOrBDV(InVal, cache);
+        if (!isKnownBaseResult(Base)) {
+          // Either conflict or base.
+          assert(states.count(Base));
+          Base = states[Base].getBase();
+          assert(Base != nullptr && "unknown BDVState!");
+        }
+        assert(Base && "can't be null");
+        BaseIE->setOperand(OperandIdx, Base);
+      };
+      UpdateOperand(0); // vector operand
+      UpdateOperand(1); // scalar operand
     }
+
   }
 
   // Now that we're done with the algorithm, see if we can optimize the