[RewriteStatepointsForGC] Handle extractelement fully in the base pointer algorithm
[oota-llvm.git] / lib / Transforms / Scalar / RewriteStatepointsForGC.cpp
index 062c0d5612be63c6884f8e0ddc228d67a03e3055..c0fada6864bdf0731766f752d54c2debe06b7293 100644 (file)
@@ -377,8 +377,9 @@ findBaseDefiningValueOfVector(Value *I, Value *Index = nullptr) {
 static bool isKnownBaseResult(Value *V);
 
 /// Helper function for findBasePointer - Will return a value which either a)
-/// defines the base pointer for the input or b) blocks the simple search
-/// (i.e. a PHI or Select of two derived pointers)
+/// defines the base pointer for the input, b) blocks the simple search
+/// (i.e. a PHI or Select of two derived pointers), or c) involves a change
+/// from pointer to vector type or back.
 static Value *findBaseDefiningValue(Value *I) {
   if (I->getType()->isVectorTy())
     return findBaseDefiningValueOfVector(I).first;
@@ -386,48 +387,6 @@ static Value *findBaseDefiningValue(Value *I) {
   assert(I->getType()->isPointerTy() &&
          "Illegal to ask for the base pointer of a non-pointer type");
 
-  // This case is a bit of a hack - it only handles extracts from vectors which
-  // trivially contain only base pointers or cases where we can directly match
-  // the index of the original extract element to an insertion into the vector.
-  // See note inside the function for how to improve this.
-  if (auto *EEI = dyn_cast<ExtractElementInst>(I)) {
-    Value *VectorOperand = EEI->getVectorOperand();
-    Value *Index = EEI->getIndexOperand();
-    std::pair<Value *, bool> pair =
-      findBaseDefiningValueOfVector(VectorOperand, Index);
-    Value *VectorBase = pair.first;
-    if (VectorBase->getType()->isPointerTy())
-      // We found a BDV for this specific element with the vector.  This is an
-      // optimization, but in practice it covers most of the useful cases
-      // created via scalarization.
-      return VectorBase;
-    else {
-      assert(VectorBase->getType()->isVectorTy());
-      if (pair.second)
-        // If the entire vector returned is known to be entirely base pointers,
-        // then the extractelement is valid base for this value.
-        return EEI;
-      else {
-        // Otherwise, we have an instruction which potentially produces a
-        // derived pointer and we need findBasePointers to clone code for us
-        // such that we can create an instruction which produces the
-        // accompanying base pointer.
-        // Note: This code is currently rather incomplete.  We don't currently
-        // support the general form of shufflevector of insertelement.
-        // Conceptually, these are just 'base defining values' of the same
-        // variety as phi or select instructions.  We need to update the
-        // findBasePointers algorithm to insert new 'base-only' versions of the
-        // original instructions. This is relative straight forward to do, but
-        // the case which would motivate the work hasn't shown up in real
-        // workloads yet.  
-        assert((isa<PHINode>(VectorBase) || isa<SelectInst>(VectorBase)) &&
-               "need to extend findBasePointers for generic vector"
-               "instruction cases");
-        return VectorBase;
-      }
-    }
-  }
-
   if (isa<Argument>(I))
     // An incoming argument to the function is a base pointer
     // We should have never reached here if this argument isn't an gc value
@@ -532,6 +491,33 @@ static Value *findBaseDefiningValue(Value *I) {
   assert(!isa<InsertValueInst>(I) &&
          "Base pointer for a struct is meaningless");
 
+  // An extractelement produces a base result exactly when it's input does.
+  // We may need to insert a parallel instruction to extract the appropriate
+  // element out of the base vector corresponding to the input. Given this,
+  // it's analogous to the phi and select case even though it's not a merge.
+  if (auto *EEI = dyn_cast<ExtractElementInst>(I)) {
+    Value *VectorOperand = EEI->getVectorOperand();
+    Value *Index = EEI->getIndexOperand();
+    std::pair<Value *, bool> pair =
+      findBaseDefiningValueOfVector(VectorOperand, Index);
+    Value *VectorBase = pair.first;
+    if (VectorBase->getType()->isPointerTy())
+      // We found a BDV for this specific element with the vector.  This is an
+      // optimization, but in practice it covers most of the useful cases
+      // created via scalarization. Note: The peephole optimization here is
+      // currently needed for correctness since the general algorithm doesn't
+      // yet handle insertelements.  That will change shortly.
+      return VectorBase;
+    else {
+      assert(VectorBase->getType()->isVectorTy());
+      // Otherwise, we have an instruction which potentially produces a
+      // derived pointer and we need findBasePointers to clone code for us
+      // such that we can create an instruction which produces the
+      // accompanying base pointer.
+      return EEI;
+    }
+  }
+
   // The last two cases here don't return a base pointer.  Instead, they
   // return a value which dynamically selects from among several base
   // derived pointers (each with it's own base potentially).  It's the job of
@@ -569,7 +555,7 @@ static Value *findBaseOrBDV(Value *I, DefiningValueMapTy &Cache) {
 /// Given the result of a call to findBaseDefiningValue, or findBaseOrBDV,
 /// is it known to be a base pointer?  Or do we need to continue searching.
 static bool isKnownBaseResult(Value *V) {
-  if (!isa<PHINode>(V) && !isa<SelectInst>(V)) {
+  if (!isa<PHINode>(V) && !isa<SelectInst>(V) && !isa<ExtractElementInst>(V)) {
     // no recursion possible
     return true;
   }
@@ -722,7 +708,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
 
 #ifndef NDEBUG
   auto isExpectedBDVType = [](Value *BDV) {
-    return isa<PHINode>(BDV) || isa<SelectInst>(BDV);
+    return isa<PHINode>(BDV) || isa<SelectInst>(BDV) || isa<ExtractElementInst>(BDV);
   };
 #endif
 
@@ -754,10 +740,16 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
       if (PHINode *Phi = dyn_cast<PHINode>(Current)) {
         for (Value *InVal : Phi->incoming_values())
           visitIncomingValue(InVal);
-      } else {
-        SelectInst *Sel = cast<SelectInst>(Current);
+      } else if (SelectInst *Sel = dyn_cast<SelectInst>(Current)) {
         visitIncomingValue(Sel->getTrueValue());
         visitIncomingValue(Sel->getFalseValue());
+      } else if (auto *EE = dyn_cast<ExtractElementInst>(Current)) {
+        visitIncomingValue(EE->getVectorOperand());
+      } else {
+        // There are two classes of instructions we know we don't handle.
+        assert(isa<ShuffleVectorInst>(Current) ||
+               isa<InsertElementInst>(Current));
+        llvm_unreachable("unimplemented instruction case");
       }
     }
     // The frontier of visited instructions are the ones we might need to
@@ -771,7 +763,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
   if (TraceLSP) {
     errs() << "States after initialization:\n";
     for (auto Pair : states)
-      dbgs() << " " << Pair.second << " for " << Pair.first << "\n";
+      dbgs() << " " << Pair.second << " for " << *Pair.first << "\n";
   }
 
   // TODO: come back and revisit the state transitions around inputs which
@@ -809,9 +801,16 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
       if (SelectInst *select = dyn_cast<SelectInst>(v)) {
         calculateMeet.meetWith(getStateForInput(select->getTrueValue()));
         calculateMeet.meetWith(getStateForInput(select->getFalseValue()));
-      } else
-        for (Value *Val : cast<PHINode>(v)->incoming_values())
+      } else if (PHINode *Phi = dyn_cast<PHINode>(v)) {
+        for (Value *Val : Phi->incoming_values())
           calculateMeet.meetWith(getStateForInput(Val));
+      } else {
+        // The 'meet' for an extractelement is slightly trivial, but it's still
+        // useful in that it drives us to conflict if our input is.
+        auto *EE = cast<ExtractElementInst>(v);
+        calculateMeet.meetWith(getStateForInput(EE->getVectorOperand()));
+      }
+
 
       BDVState oldState = states[v];
       BDVState newState = calculateMeet.getResult();
@@ -828,7 +827,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
   if (TraceLSP) {
     errs() << "States after meet iteration:\n";
     for (auto Pair : states)
-      dbgs() << " " << Pair.second << " for " << Pair.first << "\n";
+      dbgs() << " " << Pair.second << " for " << *Pair.first << "\n";
   }
 
   // Insert Phis for all conflicts
@@ -848,6 +847,24 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
     BDVState State = states[I];
     assert(!isKnownBaseResult(I) && "why did it get added?");
     assert(!State.isUnknown() && "Optimistic algorithm didn't complete!");
+
+    // extractelement instructions are a bit special in that we may need to
+    // insert an extract even when we know an exact base for the instruction.
+    // The problem is that we need to convert from a vector base to a scalar
+    // base for the particular indice we're interested in.
+    if (State.isBase() && isa<ExtractElementInst>(I) &&
+        isa<VectorType>(State.getBase()->getType())) {
+      auto *EE = cast<ExtractElementInst>(I);
+      // TODO: In many cases, the new instruction is just EE itself.  We should
+      // exploit this, but can't do it here since it would break the invariant
+      // about the BDV not being known to be a base.
+      auto *BaseInst = ExtractElementInst::Create(State.getBase(),
+                                                  EE->getIndexOperand(),
+                                                  "base_ee", EE);
+      BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {}));
+      states[I] = BDVState(BDVState::Base, BaseInst);
+    }
+    
     if (!State.isConflict())
       continue;
 
@@ -861,14 +878,21 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
         std::string Name = I->hasName() ?
            (I->getName() + ".base").str() : "base_phi";
         return PHINode::Create(I->getType(), NumPreds, Name, I);
+      } else if (SelectInst *Sel = dyn_cast<SelectInst>(I)) {
+        // The undef will be replaced later
+        UndefValue *Undef = UndefValue::get(Sel->getType());
+        std::string Name = I->hasName() ?
+          (I->getName() + ".base").str() : "base_select";
+        return SelectInst::Create(Sel->getCondition(), Undef,
+                                  Undef, Name, Sel);
+      } else {
+        auto *EE = cast<ExtractElementInst>(I);
+        UndefValue *Undef = UndefValue::get(EE->getVectorOperand()->getType());
+        std::string Name = I->hasName() ?
+          (I->getName() + ".base").str() : "base_ee";
+        return ExtractElementInst::Create(Undef, EE->getIndexOperand(), Name,
+                                          EE);
       }
-      SelectInst *Sel = cast<SelectInst>(I);
-      // The undef will be replaced later
-      UndefValue *Undef = UndefValue::get(Sel->getType());
-      std::string Name = I->hasName() ?
-         (I->getName() + ".base").str() : "base_select";
-      return SelectInst::Create(Sel->getCondition(), Undef,
-                                Undef, Name, Sel);
     };
     Instruction *BaseInst = MakeBaseInstPlaceholder(I);
     // Add metadata marking this as a base value
@@ -947,8 +971,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
         basephi->addIncoming(base, InBB);
       }
       assert(basephi->getNumIncomingValues() == NumPHIValues);
-    } else {
-      SelectInst *basesel = cast<SelectInst>(state.getBase());
+    } else if (SelectInst *basesel = dyn_cast<SelectInst>(state.getBase())) {
       SelectInst *sel = cast<SelectInst>(v);
       // Operand 1 & 2 are true, false path respectively. TODO: refactor to
       // something more safe and less hacky.
@@ -971,6 +994,18 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
         }
         basesel->setOperand(i, base);
       }
+    } else {
+      auto *BaseEE = cast<ExtractElementInst>(state.getBase());
+      Value *InVal = cast<ExtractElementInst>(v)->getVectorOperand();
+      Value *Base = findBaseOrBDV(InVal, cache);
+      if (!isKnownBaseResult(Base)) {
+        // Either conflict or base.
+        assert(states.count(Base));
+        Base = states[Base].getBase();
+        assert(Base != nullptr && "unknown BDVState!");
+      }
+      assert(Base && "can't be null");
+      BaseEE->setOperand(0, Base);
     }
   }