[RewriteStatepointsForGC] Handle extractelement fully in the base pointer algorithm
authorPhilip Reames <listmail@philipreames.com>
Wed, 12 Aug 2015 21:00:20 +0000 (21:00 +0000)
committerPhilip Reames <listmail@philipreames.com>
Wed, 12 Aug 2015 21:00:20 +0000 (21:00 +0000)
When rewriting the IR such that base pointers are available for every live pointer, we potentially need to duplicate instructions to propagate the base. The original code had only handled PHI and Select under the belief those were the only instructions which would need duplicated. When I added support for vector instructions, I'd added a collection of hacks for ExtractElement which caught most of the common cases. Of course, I then found the one test case my hacks couldn't cover. :)

This change removes all of the early hacks for extract element. By defining extractelement as a BDV (rather than trying to look through it), we can extend the rewriting algorithm to duplicate the extract as needed.  Note that a couple of peephole optimizations were left in for the moment, because while we now handle extractelement as a first class citizen, we're not yet handling insertelement.  That change will follow in the near future.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@244808 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Transforms/Scalar/RewriteStatepointsForGC.cpp
test/Transforms/RewriteStatepointsForGC/base-vector.ll [new file with mode: 0644]

index 062c0d5..c0fada6 100644 (file)
@@ -377,8 +377,9 @@ findBaseDefiningValueOfVector(Value *I, Value *Index = nullptr) {
 static bool isKnownBaseResult(Value *V);
 
 /// Helper function for findBasePointer - Will return a value which either a)
-/// defines the base pointer for the input or b) blocks the simple search
-/// (i.e. a PHI or Select of two derived pointers)
+/// defines the base pointer for the input, b) blocks the simple search
+/// (i.e. a PHI or Select of two derived pointers), or c) involves a change
+/// from pointer to vector type or back.
 static Value *findBaseDefiningValue(Value *I) {
   if (I->getType()->isVectorTy())
     return findBaseDefiningValueOfVector(I).first;
@@ -386,48 +387,6 @@ static Value *findBaseDefiningValue(Value *I) {
   assert(I->getType()->isPointerTy() &&
          "Illegal to ask for the base pointer of a non-pointer type");
 
-  // This case is a bit of a hack - it only handles extracts from vectors which
-  // trivially contain only base pointers or cases where we can directly match
-  // the index of the original extract element to an insertion into the vector.
-  // See note inside the function for how to improve this.
-  if (auto *EEI = dyn_cast<ExtractElementInst>(I)) {
-    Value *VectorOperand = EEI->getVectorOperand();
-    Value *Index = EEI->getIndexOperand();
-    std::pair<Value *, bool> pair =
-      findBaseDefiningValueOfVector(VectorOperand, Index);
-    Value *VectorBase = pair.first;
-    if (VectorBase->getType()->isPointerTy())
-      // We found a BDV for this specific element with the vector.  This is an
-      // optimization, but in practice it covers most of the useful cases
-      // created via scalarization.
-      return VectorBase;
-    else {
-      assert(VectorBase->getType()->isVectorTy());
-      if (pair.second)
-        // If the entire vector returned is known to be entirely base pointers,
-        // then the extractelement is valid base for this value.
-        return EEI;
-      else {
-        // Otherwise, we have an instruction which potentially produces a
-        // derived pointer and we need findBasePointers to clone code for us
-        // such that we can create an instruction which produces the
-        // accompanying base pointer.
-        // Note: This code is currently rather incomplete.  We don't currently
-        // support the general form of shufflevector of insertelement.
-        // Conceptually, these are just 'base defining values' of the same
-        // variety as phi or select instructions.  We need to update the
-        // findBasePointers algorithm to insert new 'base-only' versions of the
-        // original instructions. This is relative straight forward to do, but
-        // the case which would motivate the work hasn't shown up in real
-        // workloads yet.  
-        assert((isa<PHINode>(VectorBase) || isa<SelectInst>(VectorBase)) &&
-               "need to extend findBasePointers for generic vector"
-               "instruction cases");
-        return VectorBase;
-      }
-    }
-  }
-
   if (isa<Argument>(I))
     // An incoming argument to the function is a base pointer
     // We should have never reached here if this argument isn't an gc value
@@ -532,6 +491,33 @@ static Value *findBaseDefiningValue(Value *I) {
   assert(!isa<InsertValueInst>(I) &&
          "Base pointer for a struct is meaningless");
 
+  // An extractelement produces a base result exactly when it's input does.
+  // We may need to insert a parallel instruction to extract the appropriate
+  // element out of the base vector corresponding to the input. Given this,
+  // it's analogous to the phi and select case even though it's not a merge.
+  if (auto *EEI = dyn_cast<ExtractElementInst>(I)) {
+    Value *VectorOperand = EEI->getVectorOperand();
+    Value *Index = EEI->getIndexOperand();
+    std::pair<Value *, bool> pair =
+      findBaseDefiningValueOfVector(VectorOperand, Index);
+    Value *VectorBase = pair.first;
+    if (VectorBase->getType()->isPointerTy())
+      // We found a BDV for this specific element with the vector.  This is an
+      // optimization, but in practice it covers most of the useful cases
+      // created via scalarization. Note: The peephole optimization here is
+      // currently needed for correctness since the general algorithm doesn't
+      // yet handle insertelements.  That will change shortly.
+      return VectorBase;
+    else {
+      assert(VectorBase->getType()->isVectorTy());
+      // Otherwise, we have an instruction which potentially produces a
+      // derived pointer and we need findBasePointers to clone code for us
+      // such that we can create an instruction which produces the
+      // accompanying base pointer.
+      return EEI;
+    }
+  }
+
   // The last two cases here don't return a base pointer.  Instead, they
   // return a value which dynamically selects from among several base
   // derived pointers (each with it's own base potentially).  It's the job of
@@ -569,7 +555,7 @@ static Value *findBaseOrBDV(Value *I, DefiningValueMapTy &Cache) {
 /// Given the result of a call to findBaseDefiningValue, or findBaseOrBDV,
 /// is it known to be a base pointer?  Or do we need to continue searching.
 static bool isKnownBaseResult(Value *V) {
-  if (!isa<PHINode>(V) && !isa<SelectInst>(V)) {
+  if (!isa<PHINode>(V) && !isa<SelectInst>(V) && !isa<ExtractElementInst>(V)) {
     // no recursion possible
     return true;
   }
@@ -722,7 +708,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
 
 #ifndef NDEBUG
   auto isExpectedBDVType = [](Value *BDV) {
-    return isa<PHINode>(BDV) || isa<SelectInst>(BDV);
+    return isa<PHINode>(BDV) || isa<SelectInst>(BDV) || isa<ExtractElementInst>(BDV);
   };
 #endif
 
@@ -754,10 +740,16 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
       if (PHINode *Phi = dyn_cast<PHINode>(Current)) {
         for (Value *InVal : Phi->incoming_values())
           visitIncomingValue(InVal);
-      } else {
-        SelectInst *Sel = cast<SelectInst>(Current);
+      } else if (SelectInst *Sel = dyn_cast<SelectInst>(Current)) {
         visitIncomingValue(Sel->getTrueValue());
         visitIncomingValue(Sel->getFalseValue());
+      } else if (auto *EE = dyn_cast<ExtractElementInst>(Current)) {
+        visitIncomingValue(EE->getVectorOperand());
+      } else {
+        // There are two classes of instructions we know we don't handle.
+        assert(isa<ShuffleVectorInst>(Current) ||
+               isa<InsertElementInst>(Current));
+        llvm_unreachable("unimplemented instruction case");
       }
     }
     // The frontier of visited instructions are the ones we might need to
@@ -771,7 +763,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
   if (TraceLSP) {
     errs() << "States after initialization:\n";
     for (auto Pair : states)
-      dbgs() << " " << Pair.second << " for " << Pair.first << "\n";
+      dbgs() << " " << Pair.second << " for " << *Pair.first << "\n";
   }
 
   // TODO: come back and revisit the state transitions around inputs which
@@ -809,9 +801,16 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
       if (SelectInst *select = dyn_cast<SelectInst>(v)) {
         calculateMeet.meetWith(getStateForInput(select->getTrueValue()));
         calculateMeet.meetWith(getStateForInput(select->getFalseValue()));
-      } else
-        for (Value *Val : cast<PHINode>(v)->incoming_values())
+      } else if (PHINode *Phi = dyn_cast<PHINode>(v)) {
+        for (Value *Val : Phi->incoming_values())
           calculateMeet.meetWith(getStateForInput(Val));
+      } else {
+        // The 'meet' for an extractelement is slightly trivial, but it's still
+        // useful in that it drives us to conflict if our input is.
+        auto *EE = cast<ExtractElementInst>(v);
+        calculateMeet.meetWith(getStateForInput(EE->getVectorOperand()));
+      }
+
 
       BDVState oldState = states[v];
       BDVState newState = calculateMeet.getResult();
@@ -828,7 +827,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
   if (TraceLSP) {
     errs() << "States after meet iteration:\n";
     for (auto Pair : states)
-      dbgs() << " " << Pair.second << " for " << Pair.first << "\n";
+      dbgs() << " " << Pair.second << " for " << *Pair.first << "\n";
   }
 
   // Insert Phis for all conflicts
@@ -848,6 +847,24 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
     BDVState State = states[I];
     assert(!isKnownBaseResult(I) && "why did it get added?");
     assert(!State.isUnknown() && "Optimistic algorithm didn't complete!");
+
+    // extractelement instructions are a bit special in that we may need to
+    // insert an extract even when we know an exact base for the instruction.
+    // The problem is that we need to convert from a vector base to a scalar
+    // base for the particular indice we're interested in.
+    if (State.isBase() && isa<ExtractElementInst>(I) &&
+        isa<VectorType>(State.getBase()->getType())) {
+      auto *EE = cast<ExtractElementInst>(I);
+      // TODO: In many cases, the new instruction is just EE itself.  We should
+      // exploit this, but can't do it here since it would break the invariant
+      // about the BDV not being known to be a base.
+      auto *BaseInst = ExtractElementInst::Create(State.getBase(),
+                                                  EE->getIndexOperand(),
+                                                  "base_ee", EE);
+      BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {}));
+      states[I] = BDVState(BDVState::Base, BaseInst);
+    }
+    
     if (!State.isConflict())
       continue;
 
@@ -861,14 +878,21 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
         std::string Name = I->hasName() ?
            (I->getName() + ".base").str() : "base_phi";
         return PHINode::Create(I->getType(), NumPreds, Name, I);
+      } else if (SelectInst *Sel = dyn_cast<SelectInst>(I)) {
+        // The undef will be replaced later
+        UndefValue *Undef = UndefValue::get(Sel->getType());
+        std::string Name = I->hasName() ?
+          (I->getName() + ".base").str() : "base_select";
+        return SelectInst::Create(Sel->getCondition(), Undef,
+                                  Undef, Name, Sel);
+      } else {
+        auto *EE = cast<ExtractElementInst>(I);
+        UndefValue *Undef = UndefValue::get(EE->getVectorOperand()->getType());
+        std::string Name = I->hasName() ?
+          (I->getName() + ".base").str() : "base_ee";
+        return ExtractElementInst::Create(Undef, EE->getIndexOperand(), Name,
+                                          EE);
       }
-      SelectInst *Sel = cast<SelectInst>(I);
-      // The undef will be replaced later
-      UndefValue *Undef = UndefValue::get(Sel->getType());
-      std::string Name = I->hasName() ?
-         (I->getName() + ".base").str() : "base_select";
-      return SelectInst::Create(Sel->getCondition(), Undef,
-                                Undef, Name, Sel);
     };
     Instruction *BaseInst = MakeBaseInstPlaceholder(I);
     // Add metadata marking this as a base value
@@ -947,8 +971,7 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
         basephi->addIncoming(base, InBB);
       }
       assert(basephi->getNumIncomingValues() == NumPHIValues);
-    } else {
-      SelectInst *basesel = cast<SelectInst>(state.getBase());
+    } else if (SelectInst *basesel = dyn_cast<SelectInst>(state.getBase())) {
       SelectInst *sel = cast<SelectInst>(v);
       // Operand 1 & 2 are true, false path respectively. TODO: refactor to
       // something more safe and less hacky.
@@ -971,6 +994,18 @@ static Value *findBasePointer(Value *I, DefiningValueMapTy &cache) {
         }
         basesel->setOperand(i, base);
       }
+    } else {
+      auto *BaseEE = cast<ExtractElementInst>(state.getBase());
+      Value *InVal = cast<ExtractElementInst>(v)->getVectorOperand();
+      Value *Base = findBaseOrBDV(InVal, cache);
+      if (!isKnownBaseResult(Base)) {
+        // Either conflict or base.
+        assert(states.count(Base));
+        Base = states[Base].getBase();
+        assert(Base != nullptr && "unknown BDVState!");
+      }
+      assert(Base && "can't be null");
+      BaseEE->setOperand(0, Base);
     }
   }
 
diff --git a/test/Transforms/RewriteStatepointsForGC/base-vector.ll b/test/Transforms/RewriteStatepointsForGC/base-vector.ll
new file mode 100644 (file)
index 0000000..2cba038
--- /dev/null
@@ -0,0 +1,88 @@
+; RUN: opt %s -rewrite-statepoints-for-gc -S | FileCheck  %s
+
+define i64 addrspace(1)* @test(<2 x i64 addrspace(1)*> %vec, i32 %idx) gc "statepoint-example" {
+; CHECK-LABEL: @test
+; CHECK: extractelement
+; CHECK: extractelement
+; CHECK: statepoint
+; CHECK: gc.relocate
+; CHECK-DAG: ; (%base_ee, %base_ee)
+; CHECK: gc.relocate
+; CHECK-DAG: ; (%base_ee, %obj)
+; Note that the second extractelement is actually redundant here.  A correct output would
+; be to reuse the existing obj as a base since it is actually a base pointer.
+entry:
+  %obj = extractelement <2 x i64 addrspace(1)*> %vec, i32 %idx
+  %safepoint_token = call i32 (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 0, i32 0, void ()* @do_safepoint, i32 0, i32 0, i32 0, i32 0)
+
+  ret i64 addrspace(1)* %obj
+}
+
+define i64 addrspace(1)* @test2(<2 x i64 addrspace(1)*>* %ptr, i1 %cnd, i32 %idx1, i32 %idx2) 
+    gc "statepoint-example" {
+; CHECK-LABEL: test2
+entry:
+  br i1 %cnd, label %taken, label %untaken
+taken:
+  %obja = load <2 x i64 addrspace(1)*>, <2 x i64 addrspace(1)*>* %ptr
+  br label %merge
+untaken:
+  %objb = load <2 x i64 addrspace(1)*>, <2 x i64 addrspace(1)*>* %ptr
+  br label %merge
+merge:
+  %vec = phi <2 x i64 addrspace(1)*> [%obja, %taken], [%objb, %untaken]
+  br i1 %cnd, label %taken2, label %untaken2
+taken2:
+  %obj0 = extractelement <2 x i64 addrspace(1)*> %vec, i32 %idx1
+  br label %merge2
+untaken2:
+  %obj1 = extractelement <2 x i64 addrspace(1)*> %vec, i32 %idx2
+  br label %merge2
+merge2:
+; CHECK-LABEL: merge2:
+; CHECK: %obj.base = phi i64 addrspace(1)*
+; CHECK: %obj = phi i64 addrspace(1)*
+; CHECK: statepoint
+; CHECK: gc.relocate
+; CHECK-DAG: ; (%obj.base, %obj)
+; CHECK: gc.relocate
+; CHECK-DAG: ; (%obj.base, %obj.base)
+  %obj = phi i64 addrspace(1)* [%obj0, %taken2], [%obj1, %untaken2]
+  %safepoint_token = call i32 (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 0, i32 0, void ()* @do_safepoint, i32 0, i32 0, i32 0, i32 0)
+  ret i64 addrspace(1)* %obj
+}
+
+define i64 addrspace(1)* @test3(i64 addrspace(1)* %ptr) 
+    gc "statepoint-example" {
+; CHECK-LABEL: test3
+entry:
+  %vec = insertelement <2 x i64 addrspace(1)*> undef, i64 addrspace(1)* %ptr, i32 0
+  %obj = extractelement <2 x i64 addrspace(1)*> %vec, i32 0
+; CHECK: insertelement
+; CHECK: extractelement
+; CHECK: statepoint
+; CHECK: gc.relocate
+; CHECK-DAG: ; (%ptr, %obj)
+   %safepoint_token = call i32 (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 0, i32 0, void ()* @do_safepoint, i32 0, i32 0, i32 0, i32 0)
+  ret i64 addrspace(1)* %obj
+}
+define i64 addrspace(1)* @test4(i64 addrspace(1)* %ptr) 
+    gc "statepoint-example" {
+; CHECK-LABEL: test4
+entry:
+  %derived = getelementptr i64, i64 addrspace(1)* %ptr, i64 16
+  %veca = insertelement <2 x i64 addrspace(1)*> undef, i64 addrspace(1)* %derived, i32 0
+  %vec = insertelement <2 x i64 addrspace(1)*> %veca, i64 addrspace(1)* %ptr, i32 1
+  %obj = extractelement <2 x i64 addrspace(1)*> %vec, i32 0
+; CHECK: statepoint
+; CHECK: gc.relocate
+; CHECK-DAG: ; (%ptr, %obj)
+; CHECK: gc.relocate
+; CHECK-DAG: ; (%ptr, %ptr)
+   %safepoint_token = call i32 (i64, i32, void ()*, i32, i32, ...) @llvm.experimental.gc.statepoint.p0f_isVoidf(i64 0, i32 0, void ()* @do_safepoint, i32 0, i32 0, i32 0, i32 0)
+  ret i64 addrspace(1)* %obj
+}
+
+declare void @do_safepoint()
+
+declare i32 @llvm.experimental.gc.statepoint.p0f_isVoidf(i64, i32, void ()*, i32, i32, ...)