IR: Split Metadata from Value
[oota-llvm.git] / lib / IR / Verifier.cpp
index 887631b01716b91c4c8df115346a4f40b4c37019..6545361793e61a0c28b688aa02e331f1d04f0a99 100644 (file)
@@ -68,6 +68,7 @@
 #include "llvm/IR/Metadata.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
+#include "llvm/IR/Statepoint.h"
 #include "llvm/Pass.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
@@ -101,6 +102,13 @@ struct VerifierSupport {
     }
   }
 
+  void WriteMetadata(const Metadata *MD) {
+    if (!MD)
+      return;
+    MD->printAsOperand(OS, true, M);
+    OS << '\n';
+  }
+
   void WriteType(Type *T) {
     if (!T)
       return;
@@ -127,6 +135,24 @@ struct VerifierSupport {
     Broken = true;
   }
 
+  void CheckFailed(const Twine &Message, const Metadata *V1, const Metadata *V2,
+                   const Metadata *V3 = nullptr, const Metadata *V4 = nullptr) {
+    OS << Message.str() << "\n";
+    WriteMetadata(V1);
+    WriteMetadata(V2);
+    WriteMetadata(V3);
+    WriteMetadata(V4);
+    Broken = true;
+  }
+
+  void CheckFailed(const Twine &Message, const Metadata *V1,
+                   const Value *V2 = nullptr) {
+    OS << Message.str() << "\n";
+    WriteMetadata(V1);
+    WriteValue(V2);
+    Broken = true;
+  }
+
   void CheckFailed(const Twine &Message, const Value *V1, Type *T2,
                    const Value *V3 = nullptr) {
     OS << Message.str() << "\n";
@@ -166,7 +192,7 @@ class Verifier : public InstVisitor<Verifier>, VerifierSupport {
   SmallPtrSet<Instruction *, 16> InstsInThisBlock;
 
   /// \brief Keep track of the metadata nodes that have been checked already.
-  SmallPtrSet<MDNode *, 32> MDNodes;
+  SmallPtrSet<Metadata *, 32> MDNodes;
 
   /// \brief The personality function referenced by the LandingPadInsts.
   /// All LandingPadInsts within the same function must use the same
@@ -260,7 +286,9 @@ private:
   void visitAliaseeSubExpr(SmallPtrSetImpl<const GlobalAlias *> &Visited,
                            const GlobalAlias &A, const Constant &C);
   void visitNamedMDNode(const NamedMDNode &NMD);
-  void visitMDNode(MDNode &MD, Function *F);
+  void visitMDNode(MDNode &MD);
+  void visitMetadataAsValue(MetadataAsValue &MD, Function *F);
+  void visitValueAsMetadata(ValueAsMetadata &MD, Function *F);
   void visitComdat(const Comdat &C);
   void visitModuleIdents(const Module &M);
   void visitModuleFlags(const Module &M);
@@ -559,46 +587,77 @@ void Verifier::visitNamedMDNode(const NamedMDNode &NMD) {
     if (!MD)
       continue;
 
-    Assert1(!MD->isFunctionLocal(),
-            "Named metadata operand cannot be function local!", MD);
-    visitMDNode(*MD, nullptr);
+    visitMDNode(*MD);
   }
 }
 
-void Verifier::visitMDNode(MDNode &MD, Function *F) {
+void Verifier::visitMDNode(MDNode &MD) {
   // Only visit each node once.  Metadata can be mutually recursive, so this
   // avoids infinite recursion here, as well as being an optimization.
   if (!MDNodes.insert(&MD).second)
     return;
 
   for (unsigned i = 0, e = MD.getNumOperands(); i != e; ++i) {
-    Value *Op = MD.getOperand(i);
+    Metadata *Op = MD.getOperand(i);
     if (!Op)
       continue;
-    if (isa<Constant>(Op) || isa<MDString>(Op))
+    Assert2(!isa<LocalAsMetadata>(Op), "Invalid operand for global metadata!",
+            &MD, Op);
+    if (auto *N = dyn_cast<MDNode>(Op)) {
+      visitMDNode(*N);
       continue;
-    if (MDNode *N = dyn_cast<MDNode>(Op)) {
-      Assert2(MD.isFunctionLocal() || !N->isFunctionLocal(),
-              "Global metadata operand cannot be function local!", &MD, N);
-      visitMDNode(*N, F);
+    }
+    if (auto *V = dyn_cast<ValueAsMetadata>(Op)) {
+      visitValueAsMetadata(*V, nullptr);
       continue;
     }
-    Assert2(MD.isFunctionLocal(), "Invalid operand for global metadata!", &MD, Op);
-
-    // If this was an instruction, bb, or argument, verify that it is in the
-    // function that we expect.
-    Function *ActualF = nullptr;
-    if (Instruction *I = dyn_cast<Instruction>(Op))
-      ActualF = I->getParent()->getParent();
-    else if (BasicBlock *BB = dyn_cast<BasicBlock>(Op))
-      ActualF = BB->getParent();
-    else if (Argument *A = dyn_cast<Argument>(Op))
-      ActualF = A->getParent();
-    assert(ActualF && "Unimplemented function local metadata case!");
-
-    Assert2(ActualF == F, "function-local metadata used in wrong function",
-            &MD, Op);
   }
+
+  // Check these last, so we diagnose problems in operands first.
+  Assert1(!isa<MDNodeFwdDecl>(MD), "Expected no forward declarations!", &MD);
+  Assert1(MD.isResolved(), "All nodes should be resolved!", &MD);
+}
+
+void Verifier::visitValueAsMetadata(ValueAsMetadata &MD, Function *F) {
+  Assert1(MD.getValue(), "Expected valid value", &MD);
+  Assert2(!MD.getValue()->getType()->isMetadataTy(),
+          "Unexpected metadata round-trip through values", &MD, MD.getValue());
+
+  auto *L = dyn_cast<LocalAsMetadata>(&MD);
+  if (!L)
+    return;
+
+  Assert1(F, "function-local metadata used outside a function", L);
+
+  // If this was an instruction, bb, or argument, verify that it is in the
+  // function that we expect.
+  Function *ActualF = nullptr;
+  if (Instruction *I = dyn_cast<Instruction>(L->getValue())) {
+    Assert2(I->getParent(), "function-local metadata not in basic block", L, I);
+    ActualF = I->getParent()->getParent();
+  } else if (BasicBlock *BB = dyn_cast<BasicBlock>(L->getValue()))
+    ActualF = BB->getParent();
+  else if (Argument *A = dyn_cast<Argument>(L->getValue()))
+    ActualF = A->getParent();
+  assert(ActualF && "Unimplemented function local metadata case!");
+
+  Assert1(ActualF == F, "function-local metadata used in wrong function", L);
+}
+
+void Verifier::visitMetadataAsValue(MetadataAsValue &MDV, Function *F) {
+  Metadata *MD = MDV.getMetadata();
+  if (auto *N = dyn_cast<MDNode>(MD)) {
+    visitMDNode(*N);
+    return;
+  }
+
+  // Only visit each node once.  Metadata can be mutually recursive, so this
+  // avoids infinite recursion here, as well as being an optimization.
+  if (!MDNodes.insert(MD).second)
+    return;
+
+  if (auto *V = dyn_cast<ValueAsMetadata>(MD))
+    visitValueAsMetadata(*V, F);
 }
 
 void Verifier::visitComdat(const Comdat &C) {
@@ -649,7 +708,7 @@ void Verifier::visitModuleFlags(const Module &M) {
   for (unsigned I = 0, E = Requirements.size(); I != E; ++I) {
     const MDNode *Requirement = Requirements[I];
     const MDString *Flag = cast<MDString>(Requirement->getOperand(0));
-    const Value *ReqValue = Requirement->getOperand(1);
+    const Metadata *ReqValue = Requirement->getOperand(1);
 
     const MDNode *Op = SeenIDs.lookup(Flag);
     if (!Op) {
@@ -678,7 +737,7 @@ Verifier::visitModuleFlag(const MDNode *Op,
   Module::ModFlagBehavior MFB;
   if (!Module::isValidModFlagBehavior(Op->getOperand(0), MFB)) {
     Assert1(
-        dyn_cast<ConstantInt>(Op->getOperand(0)),
+        mdconst::dyn_extract<ConstantInt>(Op->getOperand(0)),
         "invalid behavior operand in module flag (expected constant integer)",
         Op->getOperand(0));
     Assert1(false,
@@ -1906,9 +1965,11 @@ void Verifier::visitRangeMetadata(Instruction& I,
   
   ConstantRange LastRange(1); // Dummy initial value
   for (unsigned i = 0; i < NumRanges; ++i) {
-    ConstantInt *Low = dyn_cast<ConstantInt>(Range->getOperand(2*i));
+    ConstantInt *Low =
+        mdconst::dyn_extract<ConstantInt>(Range->getOperand(2 * i));
     Assert1(Low, "The lower limit must be an integer!", Low);
-    ConstantInt *High = dyn_cast<ConstantInt>(Range->getOperand(2*i + 1));
+    ConstantInt *High =
+        mdconst::dyn_extract<ConstantInt>(Range->getOperand(2 * i + 1));
     Assert1(High, "The upper limit must be an integer!", High);
     Assert1(High->getType() == Low->getType() &&
             High->getType() == Ty, "Range types must match instruction type!",
@@ -1931,9 +1992,9 @@ void Verifier::visitRangeMetadata(Instruction& I,
   }
   if (NumRanges > 2) {
     APInt FirstLow =
-      dyn_cast<ConstantInt>(Range->getOperand(0))->getValue();
+        mdconst::dyn_extract<ConstantInt>(Range->getOperand(0))->getValue();
     APInt FirstHigh =
-      dyn_cast<ConstantInt>(Range->getOperand(1))->getValue();
+        mdconst::dyn_extract<ConstantInt>(Range->getOperand(1))->getValue();
     ConstantRange FirstRange(FirstLow, FirstHigh);
     Assert1(FirstRange.intersectWith(LastRange).isEmptySet(),
             "Intervals are overlapping", Range);
@@ -2277,8 +2338,8 @@ void Verifier::visitInstruction(Instruction &I) {
     Assert1(I.getType()->isFPOrFPVectorTy(),
             "fpmath requires a floating point result!", &I);
     Assert1(MD->getNumOperands() == 1, "fpmath takes one operand!", &I);
-    Value *Op0 = MD->getOperand(0);
-    if (ConstantFP *CFP0 = dyn_cast_or_null<ConstantFP>(Op0)) {
+    if (ConstantFP *CFP0 =
+            mdconst::dyn_extract_or_null<ConstantFP>(MD->getOperand(0))) {
       APFloat Accuracy = CFP0->getValueAPF();
       Assert1(Accuracy.isFiniteNonZero() && !Accuracy.isNegative(),
               "fpmath accuracy not a positive number!", &I);
@@ -2405,6 +2466,19 @@ bool Verifier::VerifyIntrinsicType(Type *Ty,
            !isa<VectorType>(ArgTys[D.getArgumentNumber()]) ||
            VectorType::getHalfElementsVectorType(
                          cast<VectorType>(ArgTys[D.getArgumentNumber()])) != Ty;
+  case IITDescriptor::SameVecWidthArgument: {
+    if (D.getArgumentNumber() >= ArgTys.size())
+      return true;
+    VectorType * ReferenceType =
+      dyn_cast<VectorType>(ArgTys[D.getArgumentNumber()]);
+    VectorType *ThisArgType = dyn_cast<VectorType>(Ty);
+    if (!ThisArgType || !ReferenceType || 
+        (ReferenceType->getVectorNumElements() !=
+         ThisArgType->getVectorNumElements()))
+      return true;
+    return VerifyIntrinsicType(ThisArgType->getVectorElementType(),
+                               Infos, ArgTys);
+  }
   }
   llvm_unreachable("unhandled");
 }
@@ -2482,8 +2556,8 @@ void Verifier::visitIntrinsicFunctionCall(Intrinsic::ID ID, CallInst &CI) {
   // If the intrinsic takes MDNode arguments, verify that they are either global
   // or are local to *this* function.
   for (unsigned i = 0, e = CI.getNumArgOperands(); i != e; ++i)
-    if (MDNode *MD = dyn_cast<MDNode>(CI.getArgOperand(i)))
-      visitMDNode(*MD, CI.getParent()->getParent());
+    if (auto *MD = dyn_cast<MetadataAsValue>(CI.getArgOperand(i)))
+      visitMetadataAsValue(*MD, CI.getParent()->getParent());
 
   switch (ID) {
   default:
@@ -2495,11 +2569,8 @@ void Verifier::visitIntrinsicFunctionCall(Intrinsic::ID ID, CallInst &CI) {
             "constant int", &CI);
     break;
   case Intrinsic::dbg_declare: {  // llvm.dbg.declare
-    Assert1(CI.getArgOperand(0) && isa<MDNode>(CI.getArgOperand(0)),
-                "invalid llvm.dbg.declare intrinsic call 1", &CI);
-    MDNode *MD = cast<MDNode>(CI.getArgOperand(0));
-    Assert1(MD->getNumOperands() == 1,
-                "invalid llvm.dbg.declare intrinsic call 2", &CI);
+    Assert1(CI.getArgOperand(0) && isa<MetadataAsValue>(CI.getArgOperand(0)),
+            "invalid llvm.dbg.declare intrinsic call 1", &CI);
   } break;
   case Intrinsic::memcpy:
   case Intrinsic::memmove:
@@ -2561,28 +2632,78 @@ void Verifier::visitIntrinsicFunctionCall(Intrinsic::ID ID, CallInst &CI) {
     break;
  
   case Intrinsic::experimental_gc_statepoint: {
-    // target, # call args = 0, # deopt args = 0, #gc args = 0 -> 4 args
-    assert(CI.getNumArgOperands() >= 4 &&
-           "not enough arguments to statepoint");
-    for (User* U : CI.users()) {
-      const CallInst* GCRelocCall = cast<const CallInst>(U);
-      const Function *GCRelocFn = GCRelocCall->getCalledFunction();
-      Assert1(GCRelocFn && GCRelocFn->isDeclaration() &&
-              (GCRelocFn->getIntrinsicID() == Intrinsic::experimental_gc_result_int ||
-               GCRelocFn->getIntrinsicID() == Intrinsic::experimental_gc_result_float ||
-               GCRelocFn->getIntrinsicID() == Intrinsic::experimental_gc_result_ptr ||
-               GCRelocFn->getIntrinsicID() == Intrinsic::experimental_gc_relocate),
-              "gc.result or gc.relocate are the only value uses of statepoint", &CI);
-      if (GCRelocFn->getIntrinsicID() == Intrinsic::experimental_gc_result_int ||
-          GCRelocFn->getIntrinsicID() == Intrinsic::experimental_gc_result_float ||
-          GCRelocFn->getIntrinsicID() == Intrinsic::experimental_gc_result_ptr ) {
-        Assert1(GCRelocCall->getNumArgOperands() == 1, "wrong number of arguments", &CI);
-        Assert2(GCRelocCall->getArgOperand(0) == &CI, "connected to wrong statepoint", &CI, GCRelocCall);
-      } else if (GCRelocFn->getIntrinsicID() == Intrinsic::experimental_gc_relocate) {
-        Assert1(GCRelocCall->getNumArgOperands() == 3, "wrong number of arguments", &CI);
-        Assert2(GCRelocCall->getArgOperand(0) == &CI, "connected to wrong statepoint", &CI, GCRelocCall);
-      } else {
-        llvm_unreachable("unsupported use type - how'd we get past the assert?");
+    Assert1(!CI.doesNotAccessMemory() &&
+            !CI.onlyReadsMemory(),
+            "gc.statepoint must read and write memory to preserve "
+            "reordering restrictions required by safepoint semantics", &CI);
+    Assert1(!CI.isInlineAsm(),
+            "gc.statepoint support for inline assembly unimplemented", &CI);
+    
+    const Value *Target = CI.getArgOperand(0);
+    const PointerType *PT = dyn_cast<PointerType>(Target->getType());
+    Assert2(PT && PT->getElementType()->isFunctionTy(),
+            "gc.statepoint callee must be of function pointer type",
+            &CI, Target);
+    FunctionType *TargetFuncType = cast<FunctionType>(PT->getElementType());
+    Assert1(!TargetFuncType->isVarArg(),
+            "gc.statepoint support for var arg functions not implemented", &CI);
+
+    const Value *NumCallArgsV = CI.getArgOperand(1);
+    Assert1(isa<ConstantInt>(NumCallArgsV),
+            "gc.statepoint number of arguments to underlying call "
+            "must be constant integer", &CI);
+    const int NumCallArgs = cast<ConstantInt>(NumCallArgsV)->getZExtValue();
+    Assert1(NumCallArgs >= 0,
+            "gc.statepoint number of arguments to underlying call "
+            "must be positive", &CI);
+    Assert1(NumCallArgs == (int)TargetFuncType->getNumParams(),
+            "gc.statepoint mismatch in number of call args", &CI);
+
+    const Value *Unused = CI.getArgOperand(2);
+    Assert1(isa<ConstantInt>(Unused) &&
+            cast<ConstantInt>(Unused)->isNullValue(),
+            "gc.statepoint parameter #3 must be zero", &CI);
+
+    // Verify that the types of the call parameter arguments match
+    // the type of the wrapped callee.
+    for (int i = 0; i < NumCallArgs; i++) {
+      Type *ParamType = TargetFuncType->getParamType(i);
+      Type *ArgType = CI.getArgOperand(3+i)->getType();
+      Assert1(ArgType == ParamType,
+              "gc.statepoint call argument does not match wrapped "
+              "function type", &CI);
+    }
+    const int EndCallArgsInx = 2+NumCallArgs;
+    const Value *NumDeoptArgsV = CI.getArgOperand(EndCallArgsInx+1);
+    Assert1(isa<ConstantInt>(NumDeoptArgsV),
+            "gc.statepoint number of deoptimization arguments "
+            "must be constant integer", &CI);
+    const int NumDeoptArgs = cast<ConstantInt>(NumDeoptArgsV)->getZExtValue();
+    Assert1(NumDeoptArgs >= 0,
+            "gc.statepoint number of deoptimization arguments "
+            "must be positive", &CI);
+
+    Assert1(4 + NumCallArgs + NumDeoptArgs <= (int)CI.getNumArgOperands(),
+            "gc.statepoint too few arguments according to length fields", &CI);
+    
+    // Check that the only uses of this gc.statepoint are gc.result or 
+    // gc.relocate calls which are tied to this statepoint and thus part
+    // of the same statepoint sequence
+    for (User *U : CI.users()) {
+      const CallInst *Call = dyn_cast<const CallInst>(U);
+      Assert2(Call, "illegal use of statepoint token", &CI, U);
+      if (!Call) continue;
+      Assert2(isGCRelocate(Call) || isGCResult(Call),
+              "gc.result or gc.relocate are the only value uses"
+              "of a gc.statepoint", &CI, U);
+      if (isGCResult(Call)) {
+        Assert2(Call->getArgOperand(0) == &CI,
+                "gc.result connected to wrong gc.statepoint",
+                &CI, Call);
+      } else if (isGCRelocate(Call)) {
+        Assert2(Call->getArgOperand(0) == &CI,
+                "gc.relocate connected to wrong gc.statepoint",
+                &CI, Call);
       }
     }
 
@@ -2599,21 +2720,24 @@ void Verifier::visitIntrinsicFunctionCall(Intrinsic::ID ID, CallInst &CI) {
   case Intrinsic::experimental_gc_result_int:
   case Intrinsic::experimental_gc_result_float:
   case Intrinsic::experimental_gc_result_ptr: {
-    Assert1(CI.getNumArgOperands() == 1, "wrong number of arguments", &CI);
-
     // Are we tied to a statepoint properly?
     CallSite StatepointCS(CI.getArgOperand(0));
     const Function *StatepointFn = StatepointCS.getCalledFunction();
     Assert2(StatepointFn && StatepointFn->isDeclaration() &&
             StatepointFn->getIntrinsicID() == Intrinsic::experimental_gc_statepoint,
             "token must be from a statepoint", &CI, CI.getArgOperand(0));
+
+    // Assert that result type matches wrapped callee.
+    const Value *Target = StatepointCS.getArgument(0);
+    const PointerType *PT = cast<PointerType>(Target->getType());
+    const FunctionType *TargetFuncType =
+      cast<FunctionType>(PT->getElementType());
+    Assert1(CI.getType() == TargetFuncType->getReturnType(),
+            "gc.result result type does not match wrapped callee",
+            &CI);
     break;
   }
   case Intrinsic::experimental_gc_relocate: {
-    // Some checks to ensure gc.relocate has the correct set of
-    // parameters.  TODO: we can make these tests much stricter.
-    Assert1(CI.getNumArgOperands() == 3, "wrong number of arguments", &CI);
-
     // Are we tied to a statepoint properly?
     CallSite StatepointCS(CI.getArgOperand(0));
     const Function *StatepointFn =
@@ -2638,6 +2762,12 @@ void Verifier::visitIntrinsicFunctionCall(Intrinsic::ID ID, CallInst &CI) {
     Assert1(0 <= DerivedIndex &&
             DerivedIndex < (int)StatepointCS.arg_size(),
             "index out of bounds", &CI);
+
+    // Assert that the result type matches the type of the relocated pointer
+    GCRelocateOperands Operands(&CI);
+    Assert1(Operands.derivedPtr()->getType() == CI.getType(),
+            "gc.relocate: relocating a pointer shouldn't change it's type",
+            &CI);
     break;
   }
   };