Taints the non-acquire RMW's store address with the load part
[oota-llvm.git] / lib / CodeGen / ImplicitNullChecks.cpp
index d7644a6676c6a1cd4af37fb2be3322fa584f69d2..39c1b9fb9a6677aeaddff871bd9413b098c5e6e3 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 
+#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Statistic.h"
 #include "llvm/CodeGen/Passes.h"
 #include "llvm/CodeGen/MachineFunction.h"
+#include "llvm/CodeGen/MachineMemOperand.h"
 #include "llvm/CodeGen/MachineOperand.h"
 #include "llvm/CodeGen/MachineFunctionPass.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
@@ -35,6 +38,7 @@
 #include "llvm/CodeGen/MachineModuleInfo.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Instruction.h"
+#include "llvm/IR/LLVMContext.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Target/TargetSubtargetInfo.h"
@@ -47,6 +51,11 @@ static cl::opt<unsigned> PageSize("imp-null-check-page-size",
                                            "bytes"),
                                   cl::init(4096));
 
+#define DEBUG_TYPE "implicit-null-checks"
+
+STATISTIC(NumImplicitNullChecks,
+          "Number of explicit null checks made implicit");
+
 namespace {
 
 class ImplicitNullChecks : public MachineFunctionPass {
@@ -99,6 +108,98 @@ public:
 
   bool runOnMachineFunction(MachineFunction &MF) override;
 };
+
+/// \brief Detect re-ordering hazards and dependencies.
+///
+/// This class keeps track of defs and uses, and can be queried if a given
+/// machine instruction can be re-ordered from after the machine instructions
+/// seen so far to before them.
+class HazardDetector {
+  DenseSet<unsigned> RegDefs;
+  DenseSet<unsigned> RegUses;
+  const TargetRegisterInfo &TRI;
+  bool hasSeenClobber;
+
+public:
+  explicit HazardDetector(const TargetRegisterInfo &TRI) :
+    TRI(TRI), hasSeenClobber(false) {}
+
+  /// \brief Make a note of \p MI for later queries to isSafeToHoist.
+  ///
+  /// May clobber this HazardDetector instance.  \see isClobbered.
+  void rememberInstruction(MachineInstr *MI);
+
+  /// \brief Return true if it is safe to hoist \p MI from after all the
+  /// instructions seen so far (via rememberInstruction) to before it.
+  bool isSafeToHoist(MachineInstr *MI);
+
+  /// \brief Return true if this instance of HazardDetector has been clobbered
+  /// (i.e. has no more useful information).
+  ///
+  /// A HazardDetecter is clobbered when it sees a construct it cannot
+  /// understand, and it would have to return a conservative answer for all
+  /// future queries.  Having a separate clobbered state lets the client code
+  /// bail early, without making queries about all of the future instructions
+  /// (which would have returned the most conservative answer anyway).
+  ///
+  /// Calling rememberInstruction or isSafeToHoist on a clobbered HazardDetector
+  /// is an error.
+  bool isClobbered() { return hasSeenClobber; }
+};
+}
+
+
+void HazardDetector::rememberInstruction(MachineInstr *MI) {
+  assert(!isClobbered() &&
+         "Don't add instructions to a clobbered hazard detector");
+
+  if (MI->mayStore() || MI->hasUnmodeledSideEffects()) {
+    hasSeenClobber = true;
+    return;
+  }
+
+  for (auto *MMO : MI->memoperands()) {
+    // Right now we don't want to worry about LLVM's memory model.
+    if (!MMO->isUnordered()) {
+      hasSeenClobber = true;
+      return;
+    }
+  }
+
+  for (auto &MO : MI->operands()) {
+    if (!MO.isReg() || !MO.getReg())
+      continue;
+
+    if (MO.isDef())
+      RegDefs.insert(MO.getReg());
+    else
+      RegUses.insert(MO.getReg());
+  }
+}
+
+bool HazardDetector::isSafeToHoist(MachineInstr *MI) {
+  assert(!isClobbered() && "isSafeToHoist cannot do anything useful!");
+
+  // Right now we don't want to worry about LLVM's memory model.  This can be
+  // made more precise later.
+  for (auto *MMO : MI->memoperands())
+    if (!MMO->isUnordered())
+      return false;
+
+  for (auto &MO : MI->operands()) {
+    if (MO.isReg() && MO.getReg()) {
+      for (unsigned Reg : RegDefs)
+        if (TRI.regsOverlap(Reg, MO.getReg()))
+          return false;  // We found a write-after-write or read-after-write
+
+      if (MO.isDef())
+        for (unsigned Reg : RegUses)
+          if (TRI.regsOverlap(Reg, MO.getReg()))
+            return false;  // We found a write-after-read
+    }
+  }
+
+  return true;
 }
 
 bool ImplicitNullChecks::runOnMachineFunction(MachineFunction &MF) {
@@ -124,6 +225,13 @@ bool ImplicitNullChecks::analyzeBlockForNullChecks(
     MachineBasicBlock &MBB, SmallVectorImpl<NullCheck> &NullCheckList) {
   typedef TargetInstrInfo::MachineBranchPredicate MachineBranchPredicate;
 
+  MDNode *BranchMD = nullptr;
+  if (auto *BB = MBB.getBasicBlock())
+    BranchMD = BB->getTerminator()->getMetadata(LLVMContext::MD_make_implicit);
+
+  if (!BranchMD)
+    return false;
+
   MachineBranchPredicate MBP;
 
   if (TII->AnalyzeBranchPredicate(MBB, MBP, true))
@@ -164,32 +272,72 @@ bool ImplicitNullChecks::analyzeBlockForNullChecks(
   //   callq throw_NullPointerException
   //
   //  LblNotNull:
+  //   Inst0
+  //   Inst1
+  //   ...
   //   Def = Load (%RAX + <offset>)
   //   ...
   //
   //
   // we want to end up with
   //
-  //   Def = TrappingLoad (%RAX + <offset>), LblNull
+  //   Def = FaultingLoad (%RAX + <offset>), LblNull
   //   jmp LblNotNull ;; explicit or fallthrough
   //
   //  LblNotNull:
+  //   Inst0
+  //   Inst1
   //   ...
   //
   //  LblNull:
   //   callq throw_NullPointerException
   //
+  //
+  // To see why this is legal, consider the two possibilities:
+  //
+  //  1. %RAX is null: since we constrain <offset> to be less than PageSize, the
+  //     load instruction dereferences the null page, causing a segmentation
+  //     fault.
+  //
+  //  2. %RAX is not null: in this case we know that the load cannot fault, as
+  //     otherwise the load would've faulted in the original program too and the
+  //     original program would've been undefined.
+  //
+  // This reasoning cannot be extended to justify hoisting through arbitrary
+  // control flow.  For instance, in the example below (in pseudo-C)
+  //
+  //    if (ptr == null) { throw_npe(); unreachable; }
+  //    if (some_cond) { return 42; }
+  //    v = ptr->field;  // LD
+  //    ...
+  //
+  // we cannot (without code duplication) use the load marked "LD" to null check
+  // ptr -- clause (2) above does not apply in this case.  In the above program
+  // the safety of ptr->field can be dependent on some_cond; and, for instance,
+  // ptr could be some non-null invalid reference that never gets loaded from
+  // because some_cond is always true.
 
   unsigned PointerReg = MBP.LHS.getReg();
-  MachineInstr *MemOp = &*NotNullSucc->begin();
-  unsigned BaseReg, Offset;
-  if (TII->getMemOpBaseRegImmOfs(MemOp, BaseReg, Offset, TRI))
-    if (MemOp->mayLoad() && !MemOp->isPredicable() && BaseReg == PointerReg &&
-        Offset < PageSize && MemOp->getDesc().getNumDefs() == 1) {
-      NullCheckList.emplace_back(MemOp, MBP.ConditionDef, &MBB, NotNullSucc,
-                                 NullSucc);
-      return true;
-    }
+
+  HazardDetector HD(*TRI);
+
+  for (auto MII = NotNullSucc->begin(), MIE = NotNullSucc->end(); MII != MIE;
+       ++MII) {
+    MachineInstr *MI = &*MII;
+    unsigned BaseReg, Offset;
+    if (TII->getMemOpBaseRegImmOfs(MI, BaseReg, Offset, TRI))
+      if (MI->mayLoad() && !MI->isPredicable() && BaseReg == PointerReg &&
+          Offset < PageSize && MI->getDesc().getNumDefs() <= 1 &&
+          HD.isSafeToHoist(MI)) {
+        NullCheckList.emplace_back(MI, MBP.ConditionDef, &MBB, NotNullSucc,
+                                   NullSucc);
+        return true;
+      }
+
+    HD.rememberInstruction(MI);
+    if (HD.isClobbered())
+      return false;
+  }
 
   return false;
 }
@@ -201,14 +349,19 @@ bool ImplicitNullChecks::analyzeBlockForNullChecks(
 MachineInstr *ImplicitNullChecks::insertFaultingLoad(MachineInstr *LoadMI,
                                                      MachineBasicBlock *MBB,
                                                      MCSymbol *HandlerLabel) {
+  const unsigned NoRegister = 0; // Guaranteed to be the NoRegister value for
+                                 // all targets.
+
   DebugLoc DL;
   unsigned NumDefs = LoadMI->getDesc().getNumDefs();
-  assert(NumDefs == 1 && "other cases unhandled!");
-  (void)NumDefs;
+  assert(NumDefs <= 1 && "other cases unhandled!");
 
-  unsigned DefReg = LoadMI->defs().begin()->getReg();
-  assert(std::distance(LoadMI->defs().begin(), LoadMI->defs().end()) == 1 &&
-         "expected exactly one def!");
+  unsigned DefReg = NoRegister;
+  if (NumDefs != 0) {
+    DefReg = LoadMI->defs().begin()->getReg();
+    assert(std::distance(LoadMI->defs().begin(), LoadMI->defs().end()) == 1 &&
+           "expected exactly one def!");
+  }
 
   auto MIB = BuildMI(MBB, DL, TII->get(TargetOpcode::FAULTING_LOAD_OP), DefReg)
                  .addSym(HandlerLabel)
@@ -240,7 +393,7 @@ void ImplicitNullChecks::rewriteNullChecks(
     // touch the successors list for any basic block since we haven't changed
     // control flow, we've just made it implicit.
     insertFaultingLoad(NC.MemOperation, NC.CheckBlock, HandlerLabel);
-    NC.MemOperation->removeFromParent();
+    NC.MemOperation->eraseFromParent();
     NC.CheckOperation->eraseFromParent();
 
     // Insert an *unconditional* branch to not-null successor.
@@ -250,6 +403,8 @@ void ImplicitNullChecks::rewriteNullChecks(
     // Emit the HandlerLabel as an EH_LABEL.
     BuildMI(*NC.NullSucc, NC.NullSucc->begin(), DL,
             TII->get(TargetOpcode::EH_LABEL)).addSym(HandlerLabel);
+
+    NumImplicitNullChecks++;
   }
 }