[Statepoints] Refactor GCRelocateOperands into an intrinsic wrapper. NFC.
[oota-llvm.git] / include / llvm / IR / Statepoint.h
index 8159cde3425192e6ff72ef1e648e06db387b7ded..51a0951a97986f61567871f13f14e2b9bb0e77c0 100644 (file)
@@ -20,7 +20,9 @@
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CallSite.h"
 #include "llvm/IR/Constants.h"
+#include "llvm/IR/Function.h"
 #include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/Support/Compiler.h"
 
@@ -35,17 +37,16 @@ enum class StatepointFlags {
   MaskAll = GCTransition ///< A bitmask that includes all valid flags.
 };
 
-class GCRelocateOperands;
+class GCRelocateInst;
 class ImmutableStatepoint;
 
 bool isStatepoint(const ImmutableCallSite &CS);
-bool isStatepoint(const Value *inst);
-bool isStatepoint(const Value &inst);
+bool isStatepoint(const Value *V);
+bool isStatepoint(const Value &V);
 
-bool isGCRelocate(const Value *inst);
 bool isGCRelocate(const ImmutableCallSite &CS);
 
-bool isGCResult(const Value *inst);
+bool isGCResult(const Value *V);
 bool isGCResult(const ImmutableCallSite &CS);
 
 /// Analogous to CallSiteBase, this provides most of the actual
@@ -54,20 +55,23 @@ bool isGCResult(const ImmutableCallSite &CS);
 /// concrete subtypes.  This is structured analogous to CallSite
 /// rather than the IntrinsicInst.h helpers since we want to support
 /// invokable statepoints in the near future.
-/// TODO: This does not currently allow the if(Statepoint S = ...)
-///   idiom used with CallSites.  Consider refactoring to support.
-template <typename InstructionTy, typename ValueTy, typename CallSiteTy>
+template <typename FunTy, typename InstructionTy, typename ValueTy,
+          typename CallSiteTy>
 class StatepointBase {
   CallSiteTy StatepointCS;
   void *operator new(size_t, unsigned) = delete;
   void *operator new(size_t s) = delete;
 
 protected:
-  explicit StatepointBase(InstructionTy *I) : StatepointCS(I) {
-    assert(isStatepoint(I));
+  explicit StatepointBase(InstructionTy *I) {
+    if (isStatepoint(I)) {
+      StatepointCS = CallSiteTy(I);
+      assert(StatepointCS && "isStatepoint implies CallSite");
+    }
   }
-  explicit StatepointBase(CallSiteTy CS) : StatepointCS(CS) {
-    assert(isStatepoint(CS));
+  explicit StatepointBase(CallSiteTy CS) {
+    if (isStatepoint(CS))
+      StatepointCS = CS;
   }
 
 public:
@@ -76,29 +80,37 @@ public:
   enum {
     IDPos = 0,
     NumPatchBytesPos = 1,
-    ActualCalleePos = 2,
+    CalledFunctionPos = 2,
     NumCallArgsPos = 3,
     FlagsPos = 4,
     CallArgsBeginPos = 5,
   };
 
+  explicit operator bool() const {
+    // We do not assign non-statepoint CallSites to StatepointCS.
+    return (bool)StatepointCS;
+  }
+
   /// Return the underlying CallSite.
-  CallSiteTy getCallSite() { return StatepointCS; }
+  CallSiteTy getCallSite() const {
+    assert(*this && "check validity first!");
+    return StatepointCS;
+  }
 
   uint64_t getFlags() const {
-    return cast<ConstantInt>(StatepointCS.getArgument(FlagsPos))
+    return cast<ConstantInt>(getCallSite().getArgument(FlagsPos))
         ->getZExtValue();
   }
 
   /// Return the ID associated with this statepoint.
-  uint64_t getID() {
-    const Value *IDVal = StatepointCS.getArgument(IDPos);
+  uint64_t getID() const {
+    const Value *IDVal = getCallSite().getArgument(IDPos);
     return cast<ConstantInt>(IDVal)->getZExtValue();
   }
 
   /// Return the number of patchable bytes associated with this statepoint.
-  uint32_t getNumPatchBytes() {
-    const Value *NumPatchBytesVal = StatepointCS.getArgument(NumPatchBytesPos);
+  uint32_t getNumPatchBytes() const {
+    const Value *NumPatchBytesVal = getCallSite().getArgument(NumPatchBytesPos);
     uint64_t NumPatchBytes =
       cast<ConstantInt>(NumPatchBytesVal)->getZExtValue();
     assert(isInt<32>(NumPatchBytes) && "should fit in 32 bits!");
@@ -106,99 +118,147 @@ public:
   }
 
   /// Return the value actually being called or invoked.
-  ValueTy *getActualCallee() {
-    return StatepointCS.getArgument(ActualCalleePos);
+  ValueTy *getCalledValue() const {
+    return getCallSite().getArgument(CalledFunctionPos);
+  }
+
+  InstructionTy *getInstruction() const {
+    return getCallSite().getInstruction();
+  }
+
+  /// Return the function being called if this is a direct call, otherwise
+  /// return null (if it's an indirect call).
+  FunTy *getCalledFunction() const {
+    return dyn_cast<Function>(getCalledValue());
+  }
+
+  /// Return the caller function for this statepoint.
+  FunTy *getCaller() const { return getCallSite().getCaller(); }
+
+  /// Determine if the statepoint cannot unwind.
+  bool doesNotThrow() const {
+    Function *F = getCalledFunction();
+    return getCallSite().doesNotThrow() || (F ? F->doesNotThrow() : false);
   }
 
   /// Return the type of the value returned by the call underlying the
   /// statepoint.
-  Type *getActualReturnType() {
+  Type *getActualReturnType() const {
     auto *FTy = cast<FunctionType>(
-        cast<PointerType>(getActualCallee()->getType())->getElementType());
+        cast<PointerType>(getCalledValue()->getType())->getElementType());
     return FTy->getReturnType();
   }
 
   /// Number of arguments to be passed to the actual callee.
-  int getNumCallArgs() {
-    const Value *NumCallArgsVal = StatepointCS.getArgument(NumCallArgsPos);
+  int getNumCallArgs() const {
+    const Value *NumCallArgsVal = getCallSite().getArgument(NumCallArgsPos);
     return cast<ConstantInt>(NumCallArgsVal)->getZExtValue();
   }
 
-  typename CallSiteTy::arg_iterator call_args_begin() {
-    assert(CallArgsBeginPos <= (int)StatepointCS.arg_size());
-    return StatepointCS.arg_begin() + CallArgsBeginPos;
+  size_t arg_size() const { return getNumCallArgs(); }
+  typename CallSiteTy::arg_iterator arg_begin() const {
+    assert(CallArgsBeginPos <= (int)getCallSite().arg_size());
+    return getCallSite().arg_begin() + CallArgsBeginPos;
   }
-  typename CallSiteTy::arg_iterator call_args_end() {
-    auto I = call_args_begin() + getNumCallArgs();
-    assert((StatepointCS.arg_end() - I) >= 0);
+  typename CallSiteTy::arg_iterator arg_end() const {
+    auto I = arg_begin() + arg_size();
+    assert((getCallSite().arg_end() - I) >= 0);
     return I;
   }
 
+  ValueTy *getArgument(unsigned Index) {
+    assert(Index < arg_size() && "out of bounds!");
+    return *(arg_begin() + Index);
+  }
+
   /// range adapter for call arguments
-  iterator_range<arg_iterator> call_args() {
-    return iterator_range<arg_iterator>(call_args_begin(), call_args_end());
+  iterator_range<arg_iterator> call_args() const {
+    return make_range(arg_begin(), arg_end());
+  }
+
+  /// \brief Return true if the call or the callee has the given attribute.
+  bool paramHasAttr(unsigned i, Attribute::AttrKind A) const {
+    Function *F = getCalledFunction();
+    return getCallSite().paramHasAttr(i + CallArgsBeginPos, A) ||
+          (F ? F->getAttributes().hasAttribute(i, A) : false);
   }
 
   /// Number of GC transition args.
-  int getNumTotalGCTransitionArgs() {
-    const Value *NumGCTransitionArgs = *call_args_end();
+  int getNumTotalGCTransitionArgs() const {
+    const Value *NumGCTransitionArgs = *arg_end();
     return cast<ConstantInt>(NumGCTransitionArgs)->getZExtValue();
   }
-  typename CallSiteTy::arg_iterator gc_transition_args_begin() {
-    auto I = call_args_end() + 1;
-    assert((StatepointCS.arg_end() - I) >= 0);
+  typename CallSiteTy::arg_iterator gc_transition_args_begin() const {
+    auto I = arg_end() + 1;
+    assert((getCallSite().arg_end() - I) >= 0);
     return I;
   }
-  typename CallSiteTy::arg_iterator gc_transition_args_end() {
+  typename CallSiteTy::arg_iterator gc_transition_args_end() const {
     auto I = gc_transition_args_begin() + getNumTotalGCTransitionArgs();
-    assert((StatepointCS.arg_end() - I) >= 0);
+    assert((getCallSite().arg_end() - I) >= 0);
     return I;
   }
 
   /// range adapter for GC transition arguments
-  iterator_range<arg_iterator> gc_transition_args() {
-    return iterator_range<arg_iterator>(gc_transition_args_begin(),
-                                        gc_transition_args_end());
+  iterator_range<arg_iterator> gc_transition_args() const {
+    return make_range(gc_transition_args_begin(), gc_transition_args_end());
   }
 
   /// Number of additional arguments excluding those intended
   /// for garbage collection.
-  int getNumTotalVMSArgs() {
+  int getNumTotalVMSArgs() const {
     const Value *NumVMSArgs = *gc_transition_args_end();
     return cast<ConstantInt>(NumVMSArgs)->getZExtValue();
   }
 
-  typename CallSiteTy::arg_iterator vm_state_begin() {
+  typename CallSiteTy::arg_iterator vm_state_begin() const {
     auto I = gc_transition_args_end() + 1;
-    assert((StatepointCS.arg_end() - I) >= 0);
+    assert((getCallSite().arg_end() - I) >= 0);
     return I;
   }
-  typename CallSiteTy::arg_iterator vm_state_end() {
+  typename CallSiteTy::arg_iterator vm_state_end() const {
     auto I = vm_state_begin() + getNumTotalVMSArgs();
-    assert((StatepointCS.arg_end() - I) >= 0);
+    assert((getCallSite().arg_end() - I) >= 0);
     return I;
   }
 
   /// range adapter for vm state arguments
-  iterator_range<arg_iterator> vm_state_args() {
-    return iterator_range<arg_iterator>(vm_state_begin(), vm_state_end());
+  iterator_range<arg_iterator> vm_state_args() const {
+    return make_range(vm_state_begin(), vm_state_end());
   }
 
-  typename CallSiteTy::arg_iterator gc_args_begin() { return vm_state_end(); }
-  typename CallSiteTy::arg_iterator gc_args_end() {
-    return StatepointCS.arg_end();
+  typename CallSiteTy::arg_iterator gc_args_begin() const {
+    return vm_state_end();
+  }
+  typename CallSiteTy::arg_iterator gc_args_end() const {
+    return getCallSite().arg_end();
+  }
+
+  unsigned gcArgsStartIdx() const {
+    return gc_args_begin() - getInstruction()->op_begin();
   }
 
   /// range adapter for gc arguments
-  iterator_range<arg_iterator> gc_args() {
-    return iterator_range<arg_iterator>(gc_args_begin(), gc_args_end());
+  iterator_range<arg_iterator> gc_args() const {
+    return make_range(gc_args_begin(), gc_args_end());
   }
 
   /// Get list of all gc reloactes linked to this statepoint
   /// May contain several relocations for the same base/derived pair.
   /// For example this could happen due to relocations on unwinding
   /// path of invoke.
-  std::vector<GCRelocateOperands> getRelocates();
+  std::vector<const GCRelocateInst *> getRelocates() const;
+
+  /// Get the experimental_gc_result call tied to this statepoint.  Can be
+  /// nullptr if there isn't a gc_result tied to this statepoint.  Guaranteed to
+  /// be a CallInst if non-null.
+  InstructionTy *getGCResult() const {
+    for (auto *U : getInstruction()->users())
+      if (isGCResult(U))
+        return cast<CallInst>(U);
+
+    return nullptr;
+  }
 
 #ifndef NDEBUG
   /// Asserts if this statepoint is malformed.  Common cases for failure
@@ -209,8 +269,8 @@ public:
            "number of arguments to actually callee can't be negative");
 
     // The internal asserts in the iterator accessors do the rest.
-    (void)call_args_begin();
-    (void)call_args_end();
+    (void)arg_begin();
+    (void)arg_end();
     (void)gc_transition_args_begin();
     (void)gc_transition_args_end();
     (void)vm_state_begin();
@@ -224,9 +284,10 @@ public:
 /// A specialization of it's base class for read only access
 /// to a gc.statepoint.
 class ImmutableStatepoint
-    : public StatepointBase<const Instruction, const Value, ImmutableCallSite> {
-  typedef StatepointBase<const Instruction, const Value, ImmutableCallSite>
-      Base;
+    : public StatepointBase<const Function, const Instruction, const Value,
+                            ImmutableCallSite> {
+  typedef StatepointBase<const Function, const Instruction, const Value,
+                         ImmutableCallSite> Base;
 
 public:
   explicit ImmutableStatepoint(const Instruction *I) : Base(I) {}
@@ -235,45 +296,40 @@ public:
 
 /// A specialization of it's base class for read-write access
 /// to a gc.statepoint.
-class Statepoint : public StatepointBase<Instruction, Value, CallSite> {
-  typedef StatepointBase<Instruction, Value, CallSite> Base;
+class Statepoint
+    : public StatepointBase<Function, Instruction, Value, CallSite> {
+  typedef StatepointBase<Function, Instruction, Value, CallSite> Base;
 
 public:
   explicit Statepoint(Instruction *I) : Base(I) {}
   explicit Statepoint(CallSite CS) : Base(CS) {}
 };
 
-/// Wraps a call to a gc.relocate and provides access to it's operands.
-/// TODO: This should likely be refactored to resememble the wrappers in
-/// InstrinsicInst.h.
-class GCRelocateOperands {
-  ImmutableCallSite RelocateCS;
-
+/// This represents the gc.relocate intrinsic.
+class GCRelocateInst : public IntrinsicInst {
 public:
-  GCRelocateOperands(const User *U) : RelocateCS(U) { assert(isGCRelocate(U)); }
-  GCRelocateOperands(const Instruction *inst) : RelocateCS(inst) {
-    assert(isGCRelocate(inst));
+  static inline bool classof(const IntrinsicInst *I) {
+    return I->getIntrinsicID() == Intrinsic::experimental_gc_relocate;
+  }
+  static inline bool classof(const Value *V) {
+    return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
   }
-  GCRelocateOperands(CallSite CS) : RelocateCS(CS) { assert(isGCRelocate(CS)); }
 
   /// Return true if this relocate is tied to the invoke statepoint.
   /// This includes relocates which are on the unwinding path.
   bool isTiedToInvoke() const {
-    const Value *Token = RelocateCS.getArgument(0);
+    const Value *Token = getArgOperand(0);
 
-    return isa<ExtractValueInst>(Token) || isa<InvokeInst>(Token);
+    return isa<LandingPadInst>(Token) || isa<InvokeInst>(Token);
   }
 
-  /// Get enclosed relocate intrinsic
-  ImmutableCallSite getUnderlyingCallSite() { return RelocateCS; }
-
   /// The statepoint with which this gc.relocate is associated.
-  const Instruction *getStatepoint() {
-    const Value *Token = RelocateCS.getArgument(0);
+  const Instruction *getStatepoint() const {
+    const Value *Token = getArgOperand(0);
 
     // This takes care both of relocates for call statepoints and relocates
     // on normal path of invoke statepoint.
-    if (!isa<ExtractValueInst>(Token)) {
+    if (!isa<LandingPadInst>(Token)) {
       return cast<Instruction>(Token);
     }
 
@@ -292,62 +348,58 @@ public:
   /// The index into the associate statepoint's argument list
   /// which contains the base pointer of the pointer whose
   /// relocation this gc.relocate describes.
-  unsigned getBasePtrIndex() {
-    return cast<ConstantInt>(RelocateCS.getArgument(1))->getZExtValue();
+  unsigned getBasePtrIndex() const {
+    return cast<ConstantInt>(getArgOperand(1))->getZExtValue();
   }
 
   /// The index into the associate statepoint's argument list which
   /// contains the pointer whose relocation this gc.relocate describes.
-  unsigned getDerivedPtrIndex() {
-    return cast<ConstantInt>(RelocateCS.getArgument(2))->getZExtValue();
+  unsigned getDerivedPtrIndex() const {
+    return cast<ConstantInt>(getArgOperand(2))->getZExtValue();
   }
 
-  Value *getBasePtr() {
+  Value *getBasePtr() const {
     ImmutableCallSite CS(getStatepoint());
     return *(CS.arg_begin() + getBasePtrIndex());
   }
 
-  Value *getDerivedPtr() {
+  Value *getDerivedPtr() const {
     ImmutableCallSite CS(getStatepoint());
     return *(CS.arg_begin() + getDerivedPtrIndex());
   }
 };
 
-template <typename InstructionTy, typename ValueTy, typename CallSiteTy>
-std::vector<GCRelocateOperands>
-StatepointBase<InstructionTy, ValueTy, CallSiteTy>::getRelocates() {
+template <typename FunTy, typename InstructionTy, typename ValueTy,
+          typename CallSiteTy>
+std::vector<const GCRelocateInst *>
+StatepointBase<FunTy, InstructionTy, ValueTy, CallSiteTy>::getRelocates()
+    const {
 
-  std::vector<GCRelocateOperands> Result;
+  std::vector<const GCRelocateInst *> Result;
 
   CallSiteTy StatepointCS = getCallSite();
 
   // Search for relocated pointers.  Note that working backwards from the
   // gc_relocates ensures that we only get pairs which are actually relocated
   // and used after the statepoint.
-  for (const User *U : StatepointCS.getInstruction()->users())
-    if (isGCRelocate(U))
-      Result.push_back(GCRelocateOperands(U));
+  for (const User *U : getInstruction()->users())
+    if (auto *Relocate = dyn_cast<GCRelocateInst>(U))
+      Result.push_back(Relocate);
 
   if (!StatepointCS.isInvoke())
     return Result;
 
   // We need to scan thorough exceptional relocations if it is invoke statepoint
   LandingPadInst *LandingPad =
-      cast<InvokeInst>(StatepointCS.getInstruction())->getLandingPadInst();
+      cast<InvokeInst>(getInstruction())->getLandingPadInst();
 
-  // Search for extract value from landingpad instruction to which
-  // gc relocates will be attached
+  // Search for gc relocates that are attached to this landingpad.
   for (const User *LandingPadUser : LandingPad->users()) {
-    if (!isa<ExtractValueInst>(LandingPadUser))
-      continue;
-
-    // gc relocates should be attached to this extract value
-    for (const User *U : LandingPadUser->users())
-      if (isGCRelocate(U))
-        Result.push_back(GCRelocateOperands(U));
+    if (auto *Relocate = dyn_cast<GCRelocateInst>(LandingPadUser))
+      Result.push_back(Relocate);
   }
   return Result;
 }
-} // namespace llvm
+}
 
 #endif