Address Sanjoy's review comments to r256326
[oota-llvm.git] / include / llvm / IR / Statepoint.h
index e3c4243e9d81b8e0a8c36b58de185b5a4ee1011e..21b98a97a83c5598d9ec00e96285583fc954ef96 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 
-#ifndef __LLVM_IR_STATEPOINT_H
-#define __LLVM_IR_STATEPOINT_H
+#ifndef LLVM_IR_STATEPOINT_H
+#define LLVM_IR_STATEPOINT_H
 
 #include "llvm/ADT/iterator_range.h"
+#include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/CallSite.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/Function.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/Support/Compiler.h"
 
 namespace llvm {
+/// The statepoint intrinsic accepts a set of flags as its third argument.
+/// Valid values come out of this set.
+enum class StatepointFlags {
+  None = 0,
+  GCTransition = 1, ///< Indicates that this statepoint is a transition from
+                    ///< GC-aware code to code that is not GC-aware.
+
+  MaskAll = GCTransition ///< A bitmask that includes all valid flags.
+};
+
+class GCRelocateOperands;
+class ImmutableStatepoint;
 
 bool isStatepoint(const ImmutableCallSite &CS);
-bool isStatepoint(const Instruction *inst);
-bool isStatepoint(const Instruction &inst);
+bool isStatepoint(const Value *V);
+bool isStatepoint(const Value &V);
 
-bool isGCRelocate(const Instruction *inst);
+bool isGCRelocate(const Value *V);
 bool isGCRelocate(const ImmutableCallSite &CS);
 
-bool isGCResult(const Instruction *inst);
+bool isGCResult(const Value *V);
 bool isGCResult(const ImmutableCallSite &CS);
 
 /// Analogous to CallSiteBase, this provides most of the actual
@@ -40,105 +55,224 @@ bool isGCResult(const ImmutableCallSite &CS);
 /// concrete subtypes.  This is structured analogous to CallSite
 /// rather than the IntrinsicInst.h helpers since we want to support
 /// invokable statepoints in the near future.
-/// TODO: This does not currently allow the if(Statepoint S = ...)
-///   idiom used with CallSites.  Consider refactoring to support.
-template <typename InstructionTy, typename ValueTy, typename CallSiteTy>
+template <typename FunTy, typename InstructionTy, typename ValueTy,
+          typename CallSiteTy>
 class StatepointBase {
   CallSiteTy StatepointCS;
-  void *operator new(size_t, unsigned) LLVM_DELETED_FUNCTION;
-  void *operator new(size_t s) LLVM_DELETED_FUNCTION;
+  void *operator new(size_t, unsigned) = delete;
+  void *operator new(size_t s) = delete;
 
- protected:
-  explicit StatepointBase(InstructionTy *I) : StatepointCS(I) {
-    assert(isStatepoint(I));
+protected:
+  explicit StatepointBase(InstructionTy *I) {
+    if (isStatepoint(I)) {
+      StatepointCS = CallSiteTy(I);
+      assert(StatepointCS && "isStatepoint implies CallSite");
+    }
   }
-  explicit StatepointBase(CallSiteTy CS) : StatepointCS(CS) {
-    assert(isStatepoint(CS));
+  explicit StatepointBase(CallSiteTy CS) {
+    if (isStatepoint(CS))
+      StatepointCS = CS;
   }
 
- public:
+public:
   typedef typename CallSiteTy::arg_iterator arg_iterator;
 
+  enum {
+    IDPos = 0,
+    NumPatchBytesPos = 1,
+    CalledFunctionPos = 2,
+    NumCallArgsPos = 3,
+    FlagsPos = 4,
+    CallArgsBeginPos = 5,
+  };
+
+  explicit operator bool() const {
+    // We do not assign non-statepoint CallSites to StatepointCS.
+    return (bool)StatepointCS;
+  }
+
   /// Return the underlying CallSite.
-  CallSiteTy getCallSite() {
+  CallSiteTy getCallSite() const {
+    assert(*this && "check validity first!");
     return StatepointCS;
   }
 
+  uint64_t getFlags() const {
+    return cast<ConstantInt>(getCallSite().getArgument(FlagsPos))
+        ->getZExtValue();
+  }
+
+  /// Return the ID associated with this statepoint.
+  uint64_t getID() const {
+    const Value *IDVal = getCallSite().getArgument(IDPos);
+    return cast<ConstantInt>(IDVal)->getZExtValue();
+  }
+
+  /// Return the number of patchable bytes associated with this statepoint.
+  uint32_t getNumPatchBytes() const {
+    const Value *NumPatchBytesVal = getCallSite().getArgument(NumPatchBytesPos);
+    uint64_t NumPatchBytes =
+      cast<ConstantInt>(NumPatchBytesVal)->getZExtValue();
+    assert(isInt<32>(NumPatchBytes) && "should fit in 32 bits!");
+    return NumPatchBytes;
+  }
+
   /// Return the value actually being called or invoked.
-  ValueTy *actualCallee() {
-    return StatepointCS.getArgument(0);
+  ValueTy *getCalledValue() const {
+    return getCallSite().getArgument(CalledFunctionPos);
   }
-  /// Number of arguments to be passed to the actual callee.
-  int numCallArgs() {
-    return cast<ConstantInt>(StatepointCS.getArgument(1))->getZExtValue();
+
+  InstructionTy *getInstruction() const {
+    return getCallSite().getInstruction();
   }
-  /// Number of additional arguments excluding those intended
-  /// for garbage collection.
-  int numTotalVMSArgs() {
-    return cast<ConstantInt>(StatepointCS.getArgument(3 + numCallArgs()))->getZExtValue();
+
+  /// Return the function being called if this is a direct call, otherwise
+  /// return null (if it's an indirect call).
+  FunTy *getCalledFunction() const {
+    return dyn_cast<Function>(getCalledValue());
   }
 
-  typename CallSiteTy::arg_iterator call_args_begin() {
-    // 3 = callTarget, #callArgs, flag
-    int Offset = 3;
-    assert(Offset <= (int)StatepointCS.arg_size());
-    return StatepointCS.arg_begin() + Offset;
+  /// Return the caller function for this statepoint.
+  FunTy *getCaller() const { return getCallSite().getCaller(); }
+
+  /// Determine if the statepoint cannot unwind.
+  bool doesNotThrow() const {
+    Function *F = getCalledFunction();
+    return getCallSite().doesNotThrow() || (F ? F->doesNotThrow() : false);
+  }
+
+  /// Return the type of the value returned by the call underlying the
+  /// statepoint.
+  Type *getActualReturnType() const {
+    auto *FTy = cast<FunctionType>(
+        cast<PointerType>(getCalledValue()->getType())->getElementType());
+    return FTy->getReturnType();
   }
-  typename CallSiteTy::arg_iterator call_args_end() {
-    int Offset = 3 + numCallArgs();
-    assert(Offset <= (int)StatepointCS.arg_size());
-    return StatepointCS.arg_begin() + Offset;
+
+  /// Number of arguments to be passed to the actual callee.
+  int getNumCallArgs() const {
+    const Value *NumCallArgsVal = getCallSite().getArgument(NumCallArgsPos);
+    return cast<ConstantInt>(NumCallArgsVal)->getZExtValue();
+  }
+
+  size_t arg_size() const { return getNumCallArgs(); }
+  typename CallSiteTy::arg_iterator arg_begin() const {
+    assert(CallArgsBeginPos <= (int)getCallSite().arg_size());
+    return getCallSite().arg_begin() + CallArgsBeginPos;
+  }
+  typename CallSiteTy::arg_iterator arg_end() const {
+    auto I = arg_begin() + arg_size();
+    assert((getCallSite().arg_end() - I) >= 0);
+    return I;
+  }
+
+  ValueTy *getArgument(unsigned Index) {
+    assert(Index < arg_size() && "out of bounds!");
+    return *(arg_begin() + Index);
   }
 
   /// range adapter for call arguments
-  iterator_range<arg_iterator> call_args() {
-    return iterator_range<arg_iterator>(call_args_begin(), call_args_end());
+  iterator_range<arg_iterator> call_args() const {
+    return make_range(arg_begin(), arg_end());
+  }
+
+  /// \brief Return true if the call or the callee has the given attribute.
+  bool paramHasAttr(unsigned i, Attribute::AttrKind A) const {
+    Function *F = getCalledFunction();
+    return getCallSite().paramHasAttr(i + CallArgsBeginPos, A) ||
+          (F ? F->getAttributes().hasAttribute(i, A) : false);
   }
 
-  typename CallSiteTy::arg_iterator vm_state_begin() {
-    return call_args_end();
+  /// Number of GC transition args.
+  int getNumTotalGCTransitionArgs() const {
+    const Value *NumGCTransitionArgs = *arg_end();
+    return cast<ConstantInt>(NumGCTransitionArgs)->getZExtValue();
+  }
+  typename CallSiteTy::arg_iterator gc_transition_args_begin() const {
+    auto I = arg_end() + 1;
+    assert((getCallSite().arg_end() - I) >= 0);
+    return I;
   }
-  typename CallSiteTy::arg_iterator vm_state_end() {
-    int Offset = 3 + numCallArgs() + 1 + numTotalVMSArgs();
-    assert(Offset <= (int)StatepointCS.arg_size());
-    return StatepointCS.arg_begin() + Offset;
+  typename CallSiteTy::arg_iterator gc_transition_args_end() const {
+    auto I = gc_transition_args_begin() + getNumTotalGCTransitionArgs();
+    assert((getCallSite().arg_end() - I) >= 0);
+    return I;
   }
 
-  /// range adapter for vm state arguments
-  iterator_range<arg_iterator> vm_state_args() {
-    return iterator_range<arg_iterator>(vm_state_begin(), vm_state_end());
+  /// range adapter for GC transition arguments
+  iterator_range<arg_iterator> gc_transition_args() const {
+    return make_range(gc_transition_args_begin(), gc_transition_args_end());
   }
 
-  typename CallSiteTy::arg_iterator first_vm_state_stack_begin() {
-    // 6 = numTotalVMSArgs, 1st_objectID, 1st_bci,
-    //     1st_#stack, 1st_#local, 1st_#monitor
-    return vm_state_begin() + 6;
+  /// Number of additional arguments excluding those intended
+  /// for garbage collection.
+  int getNumTotalVMSArgs() const {
+    const Value *NumVMSArgs = *gc_transition_args_end();
+    return cast<ConstantInt>(NumVMSArgs)->getZExtValue();
   }
 
-  typename CallSiteTy::arg_iterator gc_args_begin() {
+  typename CallSiteTy::arg_iterator vm_state_begin() const {
+    auto I = gc_transition_args_end() + 1;
+    assert((getCallSite().arg_end() - I) >= 0);
+    return I;
+  }
+  typename CallSiteTy::arg_iterator vm_state_end() const {
+    auto I = vm_state_begin() + getNumTotalVMSArgs();
+    assert((getCallSite().arg_end() - I) >= 0);
+    return I;
+  }
+
+  /// range adapter for vm state arguments
+  iterator_range<arg_iterator> vm_state_args() const {
+    return make_range(vm_state_begin(), vm_state_end());
+  }
+
+  typename CallSiteTy::arg_iterator gc_args_begin() const {
     return vm_state_end();
   }
-  typename CallSiteTy::arg_iterator gc_args_end() {
-    return StatepointCS.arg_end();
+  typename CallSiteTy::arg_iterator gc_args_end() const {
+    return getCallSite().arg_end();
+  }
+
+  unsigned gcArgsStartIdx() const {
+    return gc_args_begin() - getInstruction()->op_begin();
   }
 
   /// range adapter for gc arguments
-  iterator_range<arg_iterator> gc_args() {
-    return iterator_range<arg_iterator>(gc_args_begin(), gc_args_end());
+  iterator_range<arg_iterator> gc_args() const {
+    return make_range(gc_args_begin(), gc_args_end());
   }
 
+  /// Get list of all gc reloactes linked to this statepoint
+  /// May contain several relocations for the same base/derived pair.
+  /// For example this could happen due to relocations on unwinding
+  /// path of invoke.
+  std::vector<GCRelocateOperands> getRelocates() const;
+
+  /// Get the experimental_gc_result call tied to this statepoint.  Can be
+  /// nullptr if there isn't a gc_result tied to this statepoint.  Guaranteed to
+  /// be a CallInst if non-null.
+  InstructionTy *getGCResult() const {
+    for (auto *U : getInstruction()->users())
+      if (isGCResult(U))
+        return cast<CallInst>(U);
+
+    return nullptr;
+  }
 
 #ifndef NDEBUG
   /// Asserts if this statepoint is malformed.  Common cases for failure
   /// include incorrect length prefixes for variable length sections or
   /// illegal values for parameters.
   void verify() {
-    assert(numCallArgs() >= 0 &&
+    assert(getNumCallArgs() >= 0 &&
            "number of arguments to actually callee can't be negative");
 
     // The internal asserts in the iterator accessors do the rest.
-    (void)call_args_begin();
-    (void)call_args_end();
+    (void)arg_begin();
+    (void)arg_end();
+    (void)gc_transition_args_begin();
+    (void)gc_transition_args_end();
     (void)vm_state_begin();
     (void)vm_state_end();
     (void)gc_args_begin();
@@ -150,10 +284,10 @@ class StatepointBase {
 /// A specialization of it's base class for read only access
 /// to a gc.statepoint.
 class ImmutableStatepoint
-    : public StatepointBase<const Instruction, const Value,
+    : public StatepointBase<const Function, const Instruction, const Value,
                             ImmutableCallSite> {
-  typedef StatepointBase<const Instruction, const Value, ImmutableCallSite>
-      Base;
+  typedef StatepointBase<const Function, const Instruction, const Value,
+                         ImmutableCallSite> Base;
 
 public:
   explicit ImmutableStatepoint(const Instruction *I) : Base(I) {}
@@ -162,8 +296,9 @@ public:
 
 /// A specialization of it's base class for read-write access
 /// to a gc.statepoint.
-class Statepoint : public StatepointBase<Instruction, Value, CallSite> {
-  typedef StatepointBase<Instruction, Value, CallSite> Base;
+class Statepoint
+    : public StatepointBase<Function, Instruction, Value, CallSite> {
+  typedef StatepointBase<Function, Instruction, Value, CallSite> Base;
 
 public:
   explicit Statepoint(Instruction *I) : Base(I) {}
@@ -176,40 +311,107 @@ public:
 class GCRelocateOperands {
   ImmutableCallSite RelocateCS;
 
- public:
-  GCRelocateOperands(const User* U) : RelocateCS(U) {
-    assert(isGCRelocate(U));
-  }
+public:
+  GCRelocateOperands(const User *U) : RelocateCS(U) { assert(isGCRelocate(U)); }
   GCRelocateOperands(const Instruction *inst) : RelocateCS(inst) {
     assert(isGCRelocate(inst));
   }
-  GCRelocateOperands(CallSite CS) : RelocateCS(CS) {
-    assert(isGCRelocate(CS));
+  GCRelocateOperands(CallSite CS) : RelocateCS(CS) { assert(isGCRelocate(CS)); }
+
+  /// Return true if this relocate is tied to the invoke statepoint.
+  /// This includes relocates which are on the unwinding path.
+  bool isTiedToInvoke() const {
+    const Value *Token = RelocateCS.getArgument(0);
+
+    return isa<ExtractValueInst>(Token) || isa<InvokeInst>(Token);
   }
 
+  /// Get enclosed relocate intrinsic
+  ImmutableCallSite getUnderlyingCallSite() { return RelocateCS; }
+
   /// The statepoint with which this gc.relocate is associated.
-  const Instruction *statepoint() {
-    return cast<Instruction>(RelocateCS.getArgument(0));
+  const Instruction *getStatepoint() {
+    const Value *Token = RelocateCS.getArgument(0);
+
+    // This takes care both of relocates for call statepoints and relocates
+    // on normal path of invoke statepoint.
+    if (!isa<ExtractValueInst>(Token)) {
+      return cast<Instruction>(Token);
+    }
+
+    // This relocate is on exceptional path of an invoke statepoint
+    const BasicBlock *InvokeBB =
+        cast<Instruction>(Token)->getParent()->getUniquePredecessor();
+
+    assert(InvokeBB && "safepoints should have unique landingpads");
+    assert(InvokeBB->getTerminator() &&
+           "safepoint block should be well formed");
+    assert(isStatepoint(InvokeBB->getTerminator()));
+
+    return InvokeBB->getTerminator();
   }
+
   /// The index into the associate statepoint's argument list
   /// which contains the base pointer of the pointer whose
   /// relocation this gc.relocate describes.
-  int basePtrIndex() {
+  unsigned getBasePtrIndex() {
     return cast<ConstantInt>(RelocateCS.getArgument(1))->getZExtValue();
   }
+
   /// The index into the associate statepoint's argument list which
   /// contains the pointer whose relocation this gc.relocate describes.
-  int derivedPtrIndex() {
+  unsigned getDerivedPtrIndex() {
     return cast<ConstantInt>(RelocateCS.getArgument(2))->getZExtValue();
   }
-  Value *basePtr() {
-    ImmutableCallSite CS(statepoint());
-    return *(CS.arg_begin() + basePtrIndex());
+
+  Value *getBasePtr() {
+    ImmutableCallSite CS(getStatepoint());
+    return *(CS.arg_begin() + getBasePtrIndex());
   }
-  Value *derivedPtr() {
-    ImmutableCallSite CS(statepoint());
-    return *(CS.arg_begin() + derivedPtrIndex());
+
+  Value *getDerivedPtr() {
+    ImmutableCallSite CS(getStatepoint());
+    return *(CS.arg_begin() + getDerivedPtrIndex());
   }
 };
+
+template <typename FunTy, typename InstructionTy, typename ValueTy,
+          typename CallSiteTy>
+std::vector<GCRelocateOperands>
+StatepointBase<FunTy, InstructionTy, ValueTy, CallSiteTy>::getRelocates()
+    const {
+
+  std::vector<GCRelocateOperands> Result;
+
+  CallSiteTy StatepointCS = getCallSite();
+
+  // Search for relocated pointers.  Note that working backwards from the
+  // gc_relocates ensures that we only get pairs which are actually relocated
+  // and used after the statepoint.
+  for (const User *U : getInstruction()->users())
+    if (isGCRelocate(U))
+      Result.push_back(GCRelocateOperands(U));
+
+  if (!StatepointCS.isInvoke())
+    return Result;
+
+  // We need to scan thorough exceptional relocations if it is invoke statepoint
+  LandingPadInst *LandingPad =
+      cast<InvokeInst>(getInstruction())->getLandingPadInst();
+
+  // Search for extract value from landingpad instruction to which
+  // gc relocates will be attached
+  for (const User *LandingPadUser : LandingPad->users()) {
+    if (!isa<ExtractValueInst>(LandingPadUser))
+      continue;
+
+    // gc relocates should be attached to this extract value
+    for (const User *U : LandingPadUser->users())
+      if (isGCRelocate(U))
+        Result.push_back(GCRelocateOperands(U));
+  }
+  return Result;
 }
+}
+
 #endif