[RS4GC] Use "deopt" operand bundles
[oota-llvm.git] / lib / Transforms / Scalar / RewriteStatepointsForGC.cpp
index f5797164355d2a481d93059c96c7d2d2ab33bfe4..98ed46fb12ed4e950bb103c5ec3b077f1bc0a321 100644 (file)
@@ -72,6 +72,12 @@ static cl::opt<bool, true> ClobberNonLiveOverride("rs4gc-clobber-non-live",
                                                   cl::location(ClobberNonLive),
                                                   cl::Hidden);
 
+static cl::opt<bool> UseDeoptBundles("rs4gc-use-deopt-bundles", cl::Hidden,
+                                     cl::init(false));
+static cl::opt<bool>
+    AllowStatepointWithNoDeoptInfo("rs4gc-allow-statepoint-with-no-deopt-info",
+                                   cl::Hidden, cl::init(true));
+
 namespace {
 struct RewriteStatepointsForGC : public ModulePass {
   static char ID; // Pass identification, replacement for typeid
@@ -184,6 +190,20 @@ struct PartiallyConstructedSafepointRecord {
 };
 }
 
+static ArrayRef<Use> GetDeoptBundleOperands(ImmutableCallSite CS) {
+  assert(UseDeoptBundles && "Should not be called otherwise!");
+
+  Optional<OperandBundleUse> DeoptBundle = CS.getOperandBundle("deopt");
+
+  if (!DeoptBundle.hasValue()) {
+    assert(AllowStatepointWithNoDeoptInfo &&
+           "Found non-leaf call without deopt info!");
+    return None;
+  }
+
+  return DeoptBundle.getValue().Inputs;
+}
+
 /// Compute the live-in set for every basic block in the function
 static void computeLiveInValues(DominatorTree &DT, Function &F,
                                 GCPtrLivenessData &Data);
@@ -1330,13 +1350,45 @@ static void CreateGCRelocates(ArrayRef<Value *> LiveVariables,
   }
 }
 
+namespace {
+
+/// This struct is used to defer RAUWs and `eraseFromParent` s.  Using this
+/// avoids having to worry about keeping around dangling pointers to Values.
+class DeferredReplacement {
+  AssertingVH<Instruction> Old;
+  AssertingVH<Instruction> New;
+
+public:
+  explicit DeferredReplacement(Instruction *Old, Instruction *New) :
+    Old(Old), New(New) {
+    assert(Old != New && "Not allowed!");
+  }
+
+  /// Does the task represented by this instance.
+  void doReplacement() {
+    Instruction *OldI = Old;
+    Instruction *NewI = New;
+
+    assert(OldI != NewI && "Disallowed at construction?!");
+
+    Old = nullptr;
+    New = nullptr;
+
+    if (NewI)
+      OldI->replaceAllUsesWith(NewI);
+    OldI->eraseFromParent();
+  }
+};
+}
+
 static void
 makeStatepointExplicitImpl(const CallSite CS, /* to replace */
                            const SmallVectorImpl<Value *> &BasePtrs,
                            const SmallVectorImpl<Value *> &LiveVariables,
-                           PartiallyConstructedSafepointRecord &Result) {
+                           PartiallyConstructedSafepointRecord &Result,
+                           std::vector<DeferredReplacement> &Replacements) {
   assert(BasePtrs.size() == LiveVariables.size());
-  assert(isStatepoint(CS) &&
+  assert((UseDeoptBundles || isStatepoint(CS)) &&
          "This method expects to be rewriting a statepoint");
 
   // Then go ahead and use the builder do actually do the inserts.  We insert
@@ -1346,18 +1398,49 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */
   Instruction *InsertBefore = CS.getInstruction();
   IRBuilder<> Builder(InsertBefore);
 
-  Statepoint OldSP(CS);
-
   ArrayRef<Value *> GCArgs(LiveVariables);
-  uint64_t StatepointID = OldSP.getID();
-  uint32_t NumPatchBytes = OldSP.getNumPatchBytes();
-  uint32_t Flags = OldSP.getFlags();
+  uint64_t StatepointID = 0xABCDEF00;
+  uint32_t NumPatchBytes = 0;
+  uint32_t Flags = uint32_t(StatepointFlags::None);
+
+  ArrayRef<Use> CallArgs;
+  ArrayRef<Use> DeoptArgs;
+  ArrayRef<Use> TransitionArgs;
+
+  Value *CallTarget = nullptr;
+
+  if (UseDeoptBundles) {
+    CallArgs = {CS.arg_begin(), CS.arg_end()};
+    DeoptArgs = GetDeoptBundleOperands(CS);
+    // TODO: we don't fill in TransitionArgs or Flags in this branch, but we
+    // could have an operand bundle for that too.
+    AttributeSet OriginalAttrs = CS.getAttributes();
+
+    Attribute AttrID = OriginalAttrs.getAttribute(AttributeSet::FunctionIndex,
+                                                  "statepoint-id");
+    if (AttrID.isStringAttribute())
+      AttrID.getValueAsString().getAsInteger(10, StatepointID);
+
+    Attribute AttrNumPatchBytes = OriginalAttrs.getAttribute(
+        AttributeSet::FunctionIndex, "statepoint-num-patch-bytes");
+    if (AttrNumPatchBytes.isStringAttribute())
+      AttrNumPatchBytes.getValueAsString().getAsInteger(10, NumPatchBytes);
+
+    CallTarget = CS.getCalledValue();
+  } else {
+    // This branch will be gone soon, and we will soon only support the
+    // UseDeoptBundles == true configuration.
+    Statepoint OldSP(CS);
+    StatepointID = OldSP.getID();
+    NumPatchBytes = OldSP.getNumPatchBytes();
+    Flags = OldSP.getFlags();
 
-  ArrayRef<Use> CallArgs(OldSP.arg_begin(), OldSP.arg_end());
-  ArrayRef<Use> DeoptArgs(OldSP.vm_state_begin(), OldSP.vm_state_end());
-  ArrayRef<Use> TransitionArgs(OldSP.gc_transition_args_begin(),
-                               OldSP.gc_transition_args_end());
-  Value *CallTarget = OldSP.getCalledValue();
+    CallArgs = {OldSP.arg_begin(), OldSP.arg_end()};
+    DeoptArgs = {OldSP.vm_state_begin(), OldSP.vm_state_end()};
+    TransitionArgs = {OldSP.gc_transition_args_begin(),
+                      OldSP.gc_transition_args_end()};
+    CallTarget = OldSP.getCalledValue();
+  }
 
   // Create the statepoint given all the arguments
   Instruction *Token = nullptr;
@@ -1442,22 +1525,39 @@ makeStatepointExplicitImpl(const CallSite CS, /* to replace */
   }
   assert(Token && "Should be set in one of the above branches!");
 
-  // Take the name of the original value call if it had one.
-  Token->takeName(CS.getInstruction());
+  if (UseDeoptBundles) {
+    Token->setName("statepoint_token");
+    if (!CS.getType()->isVoidTy() && !CS.getInstruction()->use_empty()) {
+      StringRef Name =
+          CS.getInstruction()->hasName() ? CS.getInstruction()->getName() : "";
+      CallInst *GCResult = Builder.CreateGCResult(Token, CS.getType(), Name);
+      GCResult->setAttributes(CS.getAttributes().getRetAttributes());
+
+      // We cannot RAUW or delete CS.getInstruction() because it could be in the
+      // live set of some other safepoint, in which case that safepoint's
+      // PartiallyConstructedSafepointRecord will hold a raw pointer to this
+      // llvm::Instruction.  Instead, we defer the replacement and deletion to
+      // after the live sets have been made explicit in the IR, and we no longer
+      // have raw pointers to worry about.
+      Replacements.emplace_back(CS.getInstruction(), GCResult);
+    } else {
+      Replacements.emplace_back(CS.getInstruction(), nullptr);
+    }
+  } else {
+    assert(!CS.getInstruction()->hasNUsesOrMore(2) &&
+           "only valid use before rewrite is gc.result");
+    assert(!CS.getInstruction()->hasOneUse() ||
+           isGCResult(cast<Instruction>(*CS.getInstruction()->user_begin())));
 
-// The GCResult is already inserted, we just need to find it
-#ifndef NDEBUG
-  Instruction *ToReplace = CS.getInstruction();
-  assert(!ToReplace->hasNUsesOrMore(2) &&
-         "only valid use before rewrite is gc.result");
-  assert(!ToReplace->hasOneUse() ||
-         isGCResult(cast<Instruction>(*ToReplace->user_begin())));
-#endif
+    // Take the name of the original statepoint token if there was one.
+    Token->takeName(CS.getInstruction());
 
-  // Update the gc.result of the original statepoint (if any) to use the newly
-  // inserted statepoint.  This is safe to do here since the token can't be
-  // considered a live reference.
-  CS.getInstruction()->replaceAllUsesWith(Token);
+    // Update the gc.result of the original statepoint (if any) to use the newly
+    // inserted statepoint.  This is safe to do here since the token can't be
+    // considered a live reference.
+    CS.getInstruction()->replaceAllUsesWith(Token);
+    CS.getInstruction()->eraseFromParent();
+  }
 
   Result.StatepointToken = Token;
 
@@ -1503,7 +1603,8 @@ static void StabilizeOrder(SmallVectorImpl<Value *> &BaseVec,
 // values.  That's the callers responsibility.
 static void
 makeStatepointExplicit(DominatorTree &DT, const CallSite &CS,
-                       PartiallyConstructedSafepointRecord &Result) {
+                       PartiallyConstructedSafepointRecord &Result,
+                       std::vector<DeferredReplacement> &Replacements) {
   const auto &LiveSet = Result.LiveSet;
   const auto &PointerToBase = Result.PointerToBase;
 
@@ -1525,8 +1626,7 @@ makeStatepointExplicit(DominatorTree &DT, const CallSite &CS,
   StabilizeOrder(BaseVec, LiveVec);
 
   // Do the actual rewriting and delete the old statepoint
-  makeStatepointExplicitImpl(CS, BaseVec, LiveVec, Result);
-  CS.getInstruction()->eraseFromParent();
+  makeStatepointExplicitImpl(CS, BaseVec, LiveVec, Result, Replacements);
 }
 
 // Helper function for the relocationViaAlloca.
@@ -2182,7 +2282,8 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, Pass *P,
 
   for (CallSite CS : ToUpdate) {
     assert(CS.getInstruction()->getParent()->getParent() == &F);
-    assert(isStatepoint(CS) && "expected to already be a deopt statepoint");
+    assert((UseDeoptBundles || isStatepoint(CS)) &&
+           "expected to already be a deopt statepoint");
   }
 #endif
 
@@ -2207,16 +2308,20 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, Pass *P,
   // the deopt argument list are considered live through the safepoint (and
   // thus makes sure they get relocated.)
   for (CallSite CS : ToUpdate) {
-    Statepoint StatepointCS(CS);
-
     SmallVector<Value *, 64> DeoptValues;
-    for (Use &U : StatepointCS.vm_state_args()) {
-      Value *Arg = cast<Value>(&U);
+
+    iterator_range<const Use *> DeoptStateRange =
+        UseDeoptBundles
+            ? iterator_range<const Use *>(GetDeoptBundleOperands(CS))
+            : iterator_range<const Use *>(Statepoint(CS).vm_state_args());
+
+    for (Value *Arg : DeoptStateRange) {
       assert(!isUnhandledGCPointerType(Arg->getType()) &&
              "support for FCA unimplemented");
       if (isHandledGCPointerType(Arg->getType()))
         DeoptValues.push_back(Arg);
     }
+
     insertUseHolderAfter(CS, DeoptValues, Holders);
   }
 
@@ -2303,6 +2408,11 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, Pass *P,
   for (size_t i = 0; i < Records.size(); i++)
     rematerializeLiveValues(ToUpdate[i], Records[i], TTI);
 
+  // We need this to safely RAUW and delete call or invoke return values that
+  // may themselves be live over a statepoint.  For details, please see usage in
+  // makeStatepointExplicitImpl.
+  std::vector<DeferredReplacement> Replacements;
+
   // Now run through and replace the existing statepoints with new ones with
   // the live variables listed.  We do not yet update uses of the values being
   // relocated. We have references to live variables that need to
@@ -2310,14 +2420,33 @@ static bool insertParsePoints(Function &F, DominatorTree &DT, Pass *P,
   // previous statepoint can not be a live variable, thus we can and remove
   // the old statepoint calls as we go.)
   for (size_t i = 0; i < Records.size(); i++)
-    makeStatepointExplicit(DT, ToUpdate[i], Records[i]);
+    makeStatepointExplicit(DT, ToUpdate[i], Records[i], Replacements);
 
   ToUpdate.clear(); // prevent accident use of invalid CallSites
 
+  for (auto &PR : Replacements)
+    PR.doReplacement();
+
+  Replacements.clear();
+
+  for (auto &Info : Records) {
+    // These live sets may contain state Value pointers, since we replaced calls
+    // with operand bundles with calls wrapped in gc.statepoint, and some of
+    // those calls may have been def'ing live gc pointers.  Clear these out to
+    // avoid accidentally using them.
+    //
+    // TODO: We should create a separate data structure that does not contain
+    // these live sets, and migrate to using that data structure from this point
+    // onward.
+    Info.LiveSet.clear();
+    Info.PointerToBase.clear();
+  }
+
   // Do all the fixups of the original live variables to their relocated selves
   SmallVector<Value *, 128> Live;
   for (size_t i = 0; i < Records.size(); i++) {
     PartiallyConstructedSafepointRecord &Info = Records[i];
+
     // We can't simply save the live set from the original insertion.  One of
     // the live values might be the result of a call which needs a safepoint.
     // That Value* no longer exists and we need to use the new gc_result.
@@ -2462,6 +2591,16 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F) {
 
   DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>(F).getDomTree();
 
+  auto NeedsRewrite = [](Instruction &I) {
+    if (UseDeoptBundles) {
+      if (ImmutableCallSite CS = ImmutableCallSite(&I))
+        return !callsGCLeafFunction(CS);
+      return false;
+    }
+
+    return isStatepoint(I);
+  };
+
   // Gather all the statepoints which need rewritten.  Be careful to only
   // consider those in reachable code since we need to ask dominance queries
   // when rewriting.  We'll delete the unreachable ones in a moment.
@@ -2469,7 +2608,7 @@ bool RewriteStatepointsForGC::runOnFunction(Function &F) {
   bool HasUnreachableStatepoint = false;
   for (Instruction &I : instructions(F)) {
     // TODO: only the ones with the flag set!
-    if (isStatepoint(I)) {
+    if (NeedsRewrite(I)) {
       if (DT.isReachableFromEntry(I.getParent()))
         ParsePointNeeded.push_back(CallSite(&I));
       else