LLParser: Handle BlockAddresses on-the-fly
authorDuncan P. N. Exon Smith <dexonsmith@apple.com>
Tue, 19 Aug 2014 00:13:19 +0000 (00:13 +0000)
committerDuncan P. N. Exon Smith <dexonsmith@apple.com>
Tue, 19 Aug 2014 00:13:19 +0000 (00:13 +0000)
Previously all `blockaddress()` constants were treated as forward
references.  They were resolved twice:  once at the end of the function
in question, and again at the end of the module.  Furthermore, if the
same blockaddress was referenced N times, the parser created N distinct
`GlobalVariable`s (one for each reference).

Instead, resolve all block addresses at the beginning of the function,
creating the standard `BasicBlock` forward references used for all other
basic block references.  After the function, all references can be
resolved immediately.  To check for the condition of parsing block
addresses from within the same function, I created a reference to the
current per-function-state in `BlockAddressPFS`.

Also, create only one forward-reference per basic block.  Because
forward references to block addresses are rare, the data structure here
shouldn't matter.  If somehow it does someday, this can be pretty easily
changed to a `DenseMap<std::pair<ValID, ValID>, GV>`.

This is part of PR20515.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@215952 91177308-0d34-0410-b5e6-96231b3b80d8

lib/AsmParser/LLParser.cpp
lib/AsmParser/LLParser.h

index 0fa53019df1726841f14cf75b8c1f6f942d02cb2..13b648ce9ebcf932cdfee8102ff2157744c96521 100644 (file)
@@ -24,6 +24,7 @@
 #include "llvm/IR/Operator.h"
 #include "llvm/IR/ValueSymbolTable.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/SaveAndRestore.h"
 #include "llvm/Support/raw_ostream.h"
 using namespace llvm;
 
@@ -129,28 +130,11 @@ bool LLParser::ValidateEndOfModule() {
     }
   }
 
-  // If there are entries in ForwardRefBlockAddresses at this point, they are
-  // references after the function was defined.  Resolve those now.
-  while (!ForwardRefBlockAddresses.empty()) {
-    // Okay, we are referencing an already-parsed function, resolve them now.
-    Function *TheFn = nullptr;
-    const ValID &Fn = ForwardRefBlockAddresses.begin()->first;
-    if (Fn.Kind == ValID::t_GlobalName)
-      TheFn = M->getFunction(Fn.StrVal);
-    else if (Fn.UIntVal < NumberedVals.size())
-      TheFn = dyn_cast<Function>(NumberedVals[Fn.UIntVal]);
-
-    if (!TheFn)
-      return Error(Fn.Loc, "unknown function referenced by blockaddress");
-
-    // Resolve all these references.
-    if (ResolveForwardRefBlockAddresses(TheFn,
-                                      ForwardRefBlockAddresses.begin()->second,
-                                        nullptr))
-      return true;
-
-    ForwardRefBlockAddresses.erase(ForwardRefBlockAddresses.begin());
-  }
+  // If there are entries in ForwardRefBlockAddresses at this point, the
+  // function was never defined.
+  if (!ForwardRefBlockAddresses.empty())
+    return Error(ForwardRefBlockAddresses.begin()->first.Loc,
+                 "expected function name in blockaddress");
 
   for (unsigned i = 0, e = NumberedTypes.size(); i != e; ++i)
     if (NumberedTypes[i].second.isValid())
@@ -193,38 +177,6 @@ bool LLParser::ValidateEndOfModule() {
   return false;
 }
 
-bool LLParser::ResolveForwardRefBlockAddresses(Function *TheFn,
-                             std::vector<std::pair<ValID, GlobalValue*> > &Refs,
-                                               PerFunctionState *PFS) {
-  // Loop over all the references, resolving them.
-  for (unsigned i = 0, e = Refs.size(); i != e; ++i) {
-    BasicBlock *Res;
-    if (PFS) {
-      if (Refs[i].first.Kind == ValID::t_LocalName)
-        Res = PFS->GetBB(Refs[i].first.StrVal, Refs[i].first.Loc);
-      else
-        Res = PFS->GetBB(Refs[i].first.UIntVal, Refs[i].first.Loc);
-    } else if (Refs[i].first.Kind == ValID::t_LocalID) {
-      return Error(Refs[i].first.Loc,
-       "cannot take address of numeric label after the function is defined");
-    } else {
-      Res = dyn_cast_or_null<BasicBlock>(
-                     TheFn->getValueSymbolTable().lookup(Refs[i].first.StrVal));
-    }
-
-    if (!Res)
-      return Error(Refs[i].first.Loc,
-                   "referenced value is not a basic block");
-
-    // Get the BlockAddress for this and update references to use it.
-    BlockAddress *BA = BlockAddress::get(TheFn, Res);
-    Refs[i].second->replaceAllUsesWith(BA);
-    Refs[i].second->eraseFromParent();
-  }
-  return false;
-}
-
-
 //===----------------------------------------------------------------------===//
 // Top-Level Entities
 //===----------------------------------------------------------------------===//
@@ -2186,28 +2138,6 @@ LLParser::PerFunctionState::~PerFunctionState() {
 }
 
 bool LLParser::PerFunctionState::FinishFunction() {
-  // Check to see if someone took the address of labels in this block.
-  if (!P.ForwardRefBlockAddresses.empty()) {
-    ValID FunctionID;
-    if (!F.getName().empty()) {
-      FunctionID.Kind = ValID::t_GlobalName;
-      FunctionID.StrVal = F.getName();
-    } else {
-      FunctionID.Kind = ValID::t_GlobalID;
-      FunctionID.UIntVal = FunctionNumber;
-    }
-
-    std::map<ValID, std::vector<std::pair<ValID, GlobalValue*> > >::iterator
-      FRBAI = P.ForwardRefBlockAddresses.find(FunctionID);
-    if (FRBAI != P.ForwardRefBlockAddresses.end()) {
-      // Resolve all these references.
-      if (P.ResolveForwardRefBlockAddresses(&F, FRBAI->second, this))
-        return true;
-
-      P.ForwardRefBlockAddresses.erase(FRBAI);
-    }
-  }
-
   if (!ForwardRefVals.empty())
     return P.Error(ForwardRefVals.begin()->second.second,
                    "use of undefined value '%" + ForwardRefVals.begin()->first +
@@ -2593,12 +2523,56 @@ bool LLParser::ParseValID(ValID &ID, PerFunctionState *PFS) {
     if (Label.Kind != ValID::t_LocalID && Label.Kind != ValID::t_LocalName)
       return Error(Label.Loc, "expected basic block name in blockaddress");
 
-    // Make a global variable as a placeholder for this reference.
-    GlobalVariable *FwdRef = new GlobalVariable(*M, Type::getInt8Ty(Context),
-                                           false, GlobalValue::InternalLinkage,
-                                                nullptr, "");
-    ForwardRefBlockAddresses[Fn].push_back(std::make_pair(Label, FwdRef));
-    ID.ConstantVal = FwdRef;
+    // Try to find the function (but skip it if it's forward-referenced).
+    GlobalValue *GV = nullptr;
+    if (Fn.Kind == ValID::t_GlobalID) {
+      if (Fn.UIntVal < NumberedVals.size())
+        GV = NumberedVals[Fn.UIntVal];
+    } else if (!ForwardRefVals.count(Fn.StrVal)) {
+      GV = M->getNamedValue(Fn.StrVal);
+    }
+    Function *F = nullptr;
+    if (GV) {
+      // Confirm that it's actually a function with a definition.
+      if (!isa<Function>(GV))
+        return Error(Fn.Loc, "expected function name in blockaddress");
+      F = cast<Function>(GV);
+      if (F->isDeclaration())
+        return Error(Fn.Loc, "cannot take blockaddress inside a declaration");
+    }
+
+    if (!F) {
+      // Make a global variable as a placeholder for this reference.
+      GlobalValue *&FwdRef = ForwardRefBlockAddresses[Fn][Label];
+      if (!FwdRef)
+        FwdRef = new GlobalVariable(*M, Type::getInt8Ty(Context), false,
+                                    GlobalValue::InternalLinkage, nullptr, "");
+      ID.ConstantVal = FwdRef;
+      ID.Kind = ValID::t_Constant;
+      return false;
+    }
+
+    // We found the function; now find the basic block.  Don't use PFS, since we
+    // might be inside a constant expression.
+    BasicBlock *BB;
+    if (BlockAddressPFS && F == &BlockAddressPFS->getFunction()) {
+      if (Label.Kind == ValID::t_LocalID)
+        BB = BlockAddressPFS->GetBB(Label.UIntVal, Label.Loc);
+      else
+        BB = BlockAddressPFS->GetBB(Label.StrVal, Label.Loc);
+      if (!BB)
+        return Error(Label.Loc, "referenced value is not a basic block");
+    } else {
+      if (Label.Kind == ValID::t_LocalID)
+        return Error(Label.Loc, "cannot take address of numeric label after "
+                                "the function is defined");
+      BB = dyn_cast_or_null<BasicBlock>(
+          F->getValueSymbolTable().lookup(Label.StrVal));
+      if (!BB)
+        return Error(Label.Loc, "referenced value is not a basic block");
+    }
+
+    ID.ConstantVal = BlockAddress::get(F, BB);
     ID.Kind = ValID::t_Constant;
     return false;
   }
@@ -2913,7 +2887,7 @@ bool LLParser::parseOptionalComdat(Comdat *&C) {
 /// ParseGlobalValueVector
 ///   ::= /*empty*/
 ///   ::= TypeAndValue (',' TypeAndValue)*
-bool LLParser::ParseGlobalValueVector(SmallVectorImpl<Constant*> &Elts) {
+bool LLParser::ParseGlobalValueVector(SmallVectorImpl<Constant *> &Elts) {
   // Empty list.
   if (Lex.getKind() == lltok::rbrace ||
       Lex.getKind() == lltok::rsquare ||
@@ -3341,9 +3315,60 @@ bool LLParser::ParseFunctionHeader(Function *&Fn, bool isDefine) {
                    ArgList[i].Name + "'");
   }
 
+  if (isDefine)
+    return false;
+
+  // Check the a declaration has no block address forward references.
+  ValID ID;
+  if (FunctionName.empty()) {
+    ID.Kind = ValID::t_GlobalID;
+    ID.UIntVal = NumberedVals.size() - 1;
+  } else {
+    ID.Kind = ValID::t_GlobalName;
+    ID.StrVal = FunctionName;
+  }
+  auto Blocks = ForwardRefBlockAddresses.find(ID);
+  if (Blocks != ForwardRefBlockAddresses.end())
+    return Error(Blocks->first.Loc,
+                 "cannot take blockaddress inside a declaration");
   return false;
 }
 
+bool LLParser::PerFunctionState::resolveForwardRefBlockAddresses() {
+  ValID ID;
+  if (FunctionNumber == -1) {
+    ID.Kind = ValID::t_GlobalName;
+    ID.StrVal = F.getName();
+  } else {
+    ID.Kind = ValID::t_GlobalID;
+    ID.UIntVal = FunctionNumber;
+  }
+
+  auto Blocks = P.ForwardRefBlockAddresses.find(ID);
+  if (Blocks == P.ForwardRefBlockAddresses.end())
+    return false;
+
+  for (const auto &I : Blocks->second) {
+    const ValID &BBID = I.first;
+    GlobalValue *GV = I.second;
+
+    assert((BBID.Kind == ValID::t_LocalID || BBID.Kind == ValID::t_LocalName) &&
+           "Expected local id or name");
+    BasicBlock *BB;
+    if (BBID.Kind == ValID::t_LocalName)
+      BB = GetBB(BBID.StrVal, BBID.Loc);
+    else
+      BB = GetBB(BBID.UIntVal, BBID.Loc);
+    if (!BB)
+      return P.Error(BBID.Loc, "referenced value is not a basic block");
+
+    GV->replaceAllUsesWith(BlockAddress::get(&F, BB));
+    GV->eraseFromParent();
+  }
+
+  P.ForwardRefBlockAddresses.erase(Blocks);
+  return false;
+}
 
 /// ParseFunctionBody
 ///   ::= '{' BasicBlock+ '}'
@@ -3358,6 +3383,12 @@ bool LLParser::ParseFunctionBody(Function &Fn) {
 
   PerFunctionState PFS(*this, Fn, FunctionNumber);
 
+  // Resolve block addresses and allow basic blocks to be forward-declared
+  // within this function.
+  if (PFS.resolveForwardRefBlockAddresses())
+    return true;
+  SaveAndRestore<PerFunctionState *> ScopeExit(BlockAddressPFS, &PFS);
+
   // We need at least one basic block.
   if (Lex.getKind() == lltok::rbrace)
     return TokError("function body requires at least one basic block");
index a3eb8d3ed3869797cb4e8d8c6cb5d57a5e05ac55..556b385c3c4201be4f8b2abe648fd3ca71207283 100644 (file)
@@ -128,17 +128,21 @@ namespace llvm {
 
     // References to blockaddress.  The key is the function ValID, the value is
     // a list of references to blocks in that function.
-    std::map<ValID, std::vector<std::pair<ValID, GlobalValue*> > >
-      ForwardRefBlockAddresses;
+    std::map<ValID, std::map<ValID, GlobalValue *>> ForwardRefBlockAddresses;
+    class PerFunctionState;
+    /// Reference to per-function state to allow basic blocks to be
+    /// forward-referenced by blockaddress instructions within the same
+    /// function.
+    PerFunctionState *BlockAddressPFS;
 
     // Attribute builder reference information.
     std::map<Value*, std::vector<unsigned> > ForwardRefAttrGroups;
     std::map<unsigned, AttrBuilder> NumberedAttrBuilders;
 
   public:
-    LLParser(StringRef F, SourceMgr &SM, SMDiagnostic &Err, Module *m) :
-      Context(m->getContext()), Lex(F, SM, Err, m->getContext()),
-      M(m) {}
+    LLParser(StringRef F, SourceMgr &SM, SMDiagnostic &Err, Module *m)
+        : Context(m->getContext()), Lex(F, SM, Err, m->getContext()), M(m),
+          BlockAddressPFS(nullptr) {}
     bool Run();
 
     LLVMContext &getContext() { return Context; }
@@ -327,6 +331,8 @@ namespace llvm {
       /// unnamed.  If there is an error, this returns null otherwise it returns
       /// the block being defined.
       BasicBlock *DefineBB(const std::string &Name, LocTy Loc);
+
+      bool resolveForwardRefBlockAddresses();
     };
 
     bool ConvertValIDToValue(Type *Ty, ValID &ID, Value *&V,
@@ -372,7 +378,7 @@ namespace llvm {
     bool ParseValID(ValID &ID, PerFunctionState *PFS = nullptr);
     bool ParseGlobalValue(Type *Ty, Constant *&V);
     bool ParseGlobalTypeAndValue(Constant *&V);
-    bool ParseGlobalValueVector(SmallVectorImpl<Constant*> &Elts);
+    bool ParseGlobalValueVector(SmallVectorImpl<Constant *> &Elts);
     bool parseOptionalComdat(Comdat *&C);
     bool ParseMetadataListValue(ValID &ID, PerFunctionState *PFS);
     bool ParseMetadataValue(ValID &ID, PerFunctionState *PFS);
@@ -432,10 +438,6 @@ namespace llvm {
     int ParseGetElementPtr(Instruction *&I, PerFunctionState &PFS);
     int ParseExtractValue(Instruction *&I, PerFunctionState &PFS);
     int ParseInsertValue(Instruction *&I, PerFunctionState &PFS);
-
-    bool ResolveForwardRefBlockAddresses(Function *TheFn,
-                             std::vector<std::pair<ValID, GlobalValue*> > &Refs,
-                                         PerFunctionState *PFS);
   };
 } // End llvm namespace