[Statepoints][NFC] Add Statepoint::operator bool()
authorSanjoy Das <sanjoy@playingwithpointers.com>
Thu, 2 Jul 2015 02:53:36 +0000 (02:53 +0000)
committerSanjoy Das <sanjoy@playingwithpointers.com>
Thu, 2 Jul 2015 02:53:36 +0000 (02:53 +0000)
Summary:
This allows the "if (Statepoint SP = Statepoint(I))" idiom.

(I don't think this change needs review, this was uploaded to
phabricator to provide context for later dependent changes.)

Differential Revision: http://reviews.llvm.org/D10629

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@241232 91177308-0d34-0410-b5e6-96231b3b80d8

include/llvm/IR/Statepoint.h

index 4e6fb4171a3827acbc649714269ad5f466041c94..6cbffc39e706e994f15dcba770d4a8ea9bb1e294 100644 (file)
@@ -54,8 +54,6 @@ bool isGCResult(const ImmutableCallSite &CS);
 /// concrete subtypes.  This is structured analogous to CallSite
 /// rather than the IntrinsicInst.h helpers since we want to support
 /// invokable statepoints in the near future.
-/// TODO: This does not currently allow the if(Statepoint S = ...)
-///   idiom used with CallSites.  Consider refactoring to support.
 template <typename InstructionTy, typename ValueTy, typename CallSiteTy>
 class StatepointBase {
   CallSiteTy StatepointCS;
@@ -63,11 +61,15 @@ class StatepointBase {
   void *operator new(size_t s) = delete;
 
 protected:
-  explicit StatepointBase(InstructionTy *I) : StatepointCS(I) {
-    assert(isStatepoint(I));
+  explicit StatepointBase(InstructionTy *I) {
+    if (isStatepoint(I)) {
+      StatepointCS = CallSiteTy(I);
+      assert(StatepointCS && "isStatepoint implies CallSite");
+    }
   }
-  explicit StatepointBase(CallSiteTy CS) : StatepointCS(CS) {
-    assert(isStatepoint(CS));
+  explicit StatepointBase(CallSiteTy CS) {
+    if (isStatepoint(CS))
+      StatepointCS = CS;
   }
 
 public:
@@ -82,23 +84,31 @@ public:
     CallArgsBeginPos = 5,
   };
 
+  operator bool() const {
+    // We do not assign non-statepoint CallSites to StatepointCS.
+    return (bool)StatepointCS;
+  }
+
   /// Return the underlying CallSite.
-  CallSiteTy getCallSite() { return StatepointCS; }
+  CallSiteTy getCallSite() const {
+    assert(*this && "check validity first!");
+    return StatepointCS;
+  }
 
   uint64_t getFlags() const {
-    return cast<ConstantInt>(StatepointCS.getArgument(FlagsPos))
+    return cast<ConstantInt>(getCallSite().getArgument(FlagsPos))
         ->getZExtValue();
   }
 
   /// Return the ID associated with this statepoint.
   uint64_t getID() {
-    const Value *IDVal = StatepointCS.getArgument(IDPos);
+    const Value *IDVal = getCallSite().getArgument(IDPos);
     return cast<ConstantInt>(IDVal)->getZExtValue();
   }
 
   /// Return the number of patchable bytes associated with this statepoint.
   uint32_t getNumPatchBytes() {
-    const Value *NumPatchBytesVal = StatepointCS.getArgument(NumPatchBytesPos);
+    const Value *NumPatchBytesVal = getCallSite().getArgument(NumPatchBytesPos);
     uint64_t NumPatchBytes =
       cast<ConstantInt>(NumPatchBytesVal)->getZExtValue();
     assert(isInt<32>(NumPatchBytes) && "should fit in 32 bits!");
@@ -107,7 +117,7 @@ public:
 
   /// Return the value actually being called or invoked.
   ValueTy *getActualCallee() {
-    return StatepointCS.getArgument(ActualCalleePos);
+    return getCallSite().getArgument(ActualCalleePos);
   }
 
   /// Return the type of the value returned by the call underlying the
@@ -120,17 +130,17 @@ public:
 
   /// Number of arguments to be passed to the actual callee.
   int getNumCallArgs() {
-    const Value *NumCallArgsVal = StatepointCS.getArgument(NumCallArgsPos);
+    const Value *NumCallArgsVal = getCallSite().getArgument(NumCallArgsPos);
     return cast<ConstantInt>(NumCallArgsVal)->getZExtValue();
   }
 
   typename CallSiteTy::arg_iterator call_args_begin() {
-    assert(CallArgsBeginPos <= (int)StatepointCS.arg_size());
-    return StatepointCS.arg_begin() + CallArgsBeginPos;
+    assert(CallArgsBeginPos <= (int)getCallSite().arg_size());
+    return getCallSite().arg_begin() + CallArgsBeginPos;
   }
   typename CallSiteTy::arg_iterator call_args_end() {
     auto I = call_args_begin() + getNumCallArgs();
-    assert((StatepointCS.arg_end() - I) >= 0);
+    assert((getCallSite().arg_end() - I) >= 0);
     return I;
   }
 
@@ -146,12 +156,12 @@ public:
   }
   typename CallSiteTy::arg_iterator gc_transition_args_begin() {
     auto I = call_args_end() + 1;
-    assert((StatepointCS.arg_end() - I) >= 0);
+    assert((getCallSite().arg_end() - I) >= 0);
     return I;
   }
   typename CallSiteTy::arg_iterator gc_transition_args_end() {
     auto I = gc_transition_args_begin() + getNumTotalGCTransitionArgs();
-    assert((StatepointCS.arg_end() - I) >= 0);
+    assert((getCallSite().arg_end() - I) >= 0);
     return I;
   }
 
@@ -170,12 +180,12 @@ public:
 
   typename CallSiteTy::arg_iterator vm_state_begin() {
     auto I = gc_transition_args_end() + 1;
-    assert((StatepointCS.arg_end() - I) >= 0);
+    assert((getCallSite().arg_end() - I) >= 0);
     return I;
   }
   typename CallSiteTy::arg_iterator vm_state_end() {
     auto I = vm_state_begin() + getNumTotalVMSArgs();
-    assert((StatepointCS.arg_end() - I) >= 0);
+    assert((getCallSite().arg_end() - I) >= 0);
     return I;
   }
 
@@ -186,7 +196,7 @@ public:
 
   typename CallSiteTy::arg_iterator gc_args_begin() { return vm_state_end(); }
   typename CallSiteTy::arg_iterator gc_args_end() {
-    return StatepointCS.arg_end();
+    return getCallSite().arg_end();
   }
 
   /// range adapter for gc arguments