Revert r247080, "LowerBitSets: Extend pass to support functions as bitset
[oota-llvm.git] / lib / Transforms / IPO / LowerBitSets.cpp
index f397c38a996714af5b28e8e2e845a1dc9b7c84d0..bf386a6c618636cd4043304a0334e71b199d0c78 100644 (file)
@@ -16,6 +16,7 @@
 #include "llvm/Transforms/IPO.h"
 #include "llvm/ADT/EquivalenceClasses.h"
 #include "llvm/ADT/Statistic.h"
+#include "llvm/ADT/Triple.h"
 #include "llvm/IR/Constant.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/GlobalVariable.h"
@@ -25,6 +26,8 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/Pass.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 
 using namespace llvm;
@@ -37,6 +40,11 @@ STATISTIC(NumByteArraysCreated, "Number of byte arrays created");
 STATISTIC(NumBitSetCallsLowered, "Number of bitset calls lowered");
 STATISTIC(NumBitSetDisjointSets, "Number of disjoint sets of bitsets");
 
+static cl::opt<bool> AvoidReuse(
+    "lowerbitsets-avoid-reuse",
+    cl::desc("Try to avoid reuse of byte array addresses using aliases"),
+    cl::Hidden, cl::init(true));
+
 bool BitSetInfo::containsGlobalOffset(uint64_t Offset) const {
   if (Offset < ByteOffset)
     return false;
@@ -52,7 +60,7 @@ bool BitSetInfo::containsGlobalOffset(uint64_t Offset) const {
 }
 
 bool BitSetInfo::containsValue(
-    const DataLayout *DL,
+    const DataLayout &DL,
     const DenseMap<GlobalVariable *, uint64_t> &GlobalLayout, Value *V,
     uint64_t COffset) const {
   if (auto GV = dyn_cast<GlobalVariable>(V)) {
@@ -63,8 +71,8 @@ bool BitSetInfo::containsValue(
   }
 
   if (auto GEP = dyn_cast<GEPOperator>(V)) {
-    APInt APOffset(DL->getPointerSizeInBits(0), 0);
-    bool Result = GEP->accumulateConstantOffset(*DL, APOffset);
+    APInt APOffset(DL.getPointerSizeInBits(0), 0);
+    bool Result = GEP->accumulateConstantOffset(DL, APOffset);
     if (!Result)
       return false;
     COffset += APOffset.getZExtValue();
@@ -84,6 +92,22 @@ bool BitSetInfo::containsValue(
   return false;
 }
 
+void BitSetInfo::print(raw_ostream &OS) const {
+  OS << "offset " << ByteOffset << " size " << BitSize << " align "
+     << (1 << AlignLog2);
+
+  if (isAllOnes()) {
+    OS << " all-ones\n";
+    return;
+  }
+
+  OS << " { ";
+  for (uint64_t B : Bits)
+    OS << B << ' ';
+  OS << "}\n";
+  return;
+}
+
 BitSetInfo BitSetBuilder::build() {
   if (Min > Max)
     Min = 0;
@@ -186,7 +210,7 @@ struct LowerBitSets : public ModulePass {
 
   Module *M;
 
-  const DataLayout *DL;
+  bool LinkerSubsectionsViaSymbols;
   IntegerType *Int1Ty;
   IntegerType *Int8Ty;
   IntegerType *Int32Ty;
@@ -234,17 +258,17 @@ ModulePass *llvm::createLowerBitSetsPass() { return new LowerBitSets; }
 
 bool LowerBitSets::doInitialization(Module &Mod) {
   M = &Mod;
+  const DataLayout &DL = Mod.getDataLayout();
 
-  DL = M->getDataLayout();
-  if (!DL)
-    report_fatal_error("Data layout required");
+  Triple TargetTriple(M->getTargetTriple());
+  LinkerSubsectionsViaSymbols = TargetTriple.isMacOSX();
 
   Int1Ty = Type::getInt1Ty(M->getContext());
   Int8Ty = Type::getInt8Ty(M->getContext());
   Int32Ty = Type::getInt32Ty(M->getContext());
   Int32PtrTy = PointerType::getUnqual(Int32Ty);
   Int64Ty = Type::getInt64Ty(M->getContext());
-  IntPtrTy = DL->getIntPtrType(M->getContext(), 0);
+  IntPtrTy = DL.getIntPtrType(M->getContext(), 0);
 
   BitSetNM = M->getNamedMetadata("llvm.bitsets");
 
@@ -265,8 +289,10 @@ BitSetInfo LowerBitSets::buildBitSet(
     for (MDNode *Op : BitSetNM->operands()) {
       if (Op->getOperand(0) != BitSet || !Op->getOperand(1))
         continue;
-      auto OpGlobal = cast<GlobalVariable>(
+      auto OpGlobal = dyn_cast<GlobalVariable>(
           cast<ConstantAsMetadata>(Op->getOperand(1))->getValue());
+      if (!OpGlobal)
+        continue;
       uint64_t Offset =
           cast<ConstantInt>(cast<ConstantAsMetadata>(Op->getOperand(2))
                                 ->getValue())->getZExtValue();
@@ -343,14 +369,20 @@ void LowerBitSets::allocateByteArrays() {
 
     Constant *Idxs[] = {ConstantInt::get(IntPtrTy, 0),
                         ConstantInt::get(IntPtrTy, ByteArrayOffsets[I])};
-    Constant *GEP = ConstantExpr::getInBoundsGetElementPtr(ByteArray, Idxs);
+    Constant *GEP = ConstantExpr::getInBoundsGetElementPtr(
+        ByteArrayConst->getType(), ByteArray, Idxs);
 
     // Create an alias instead of RAUW'ing the gep directly. On x86 this ensures
     // that the pc-relative displacement is folded into the lea instead of the
     // test instruction getting another displacement.
-    GlobalAlias *Alias = GlobalAlias::create(
-        Int8Ty, 0, GlobalValue::PrivateLinkage, "bits", GEP, M);
-    BAI->ByteArray->replaceAllUsesWith(Alias);
+    if (LinkerSubsectionsViaSymbols) {
+      BAI->ByteArray->replaceAllUsesWith(GEP);
+    } else {
+      GlobalAlias *Alias =
+          GlobalAlias::create(PointerType::getUnqual(Int8Ty),
+                              GlobalValue::PrivateLinkage, "bits", GEP, M);
+      BAI->ByteArray->replaceAllUsesWith(Alias);
+    }
     BAI->ByteArray->eraseFromParent();
   }
 
@@ -384,7 +416,18 @@ Value *LowerBitSets::createBitSetTest(IRBuilder<> &B, BitSetInfo &BSI,
       BAI = createByteArray(BSI);
     }
 
-    Value *ByteAddr = B.CreateGEP(BAI->ByteArray, BitOffset);
+    Constant *ByteArray = BAI->ByteArray;
+    Type *Ty = BAI->ByteArray->getValueType();
+    if (!LinkerSubsectionsViaSymbols && AvoidReuse) {
+      // Each use of the byte array uses a different alias. This makes the
+      // backend less likely to reuse previously computed byte array addresses,
+      // improving the security of the CFI mechanism based on this pass.
+      ByteArray = GlobalAlias::create(BAI->ByteArray->getType(),
+                                      GlobalValue::PrivateLinkage, "bits_use",
+                                      ByteArray, M);
+    }
+
+    Value *ByteAddr = B.CreateGEP(Ty, ByteArray, BitOffset);
     Value *Byte = B.CreateLoad(ByteAddr);
 
     Value *ByteAndMask = B.CreateAnd(Byte, BAI->Mask);
@@ -399,6 +442,7 @@ Value *LowerBitSets::lowerBitSetCall(
     GlobalVariable *CombinedGlobal,
     const DenseMap<GlobalVariable *, uint64_t> &GlobalLayout) {
   Value *Ptr = CI->getArgOperand(0);
+  const DataLayout &DL = M->getDataLayout();
 
   if (BSI.containsValue(DL, GlobalLayout, Ptr))
     return ConstantInt::getTrue(CombinedGlobal->getParent()->getContext());
@@ -433,8 +477,8 @@ Value *LowerBitSets::lowerBitSetCall(
     Value *OffsetSHR =
         B.CreateLShr(PtrOffset, ConstantInt::get(IntPtrTy, BSI.AlignLog2));
     Value *OffsetSHL = B.CreateShl(
-        PtrOffset, ConstantInt::get(IntPtrTy, DL->getPointerSizeInBits(0) -
-                                                  BSI.AlignLog2));
+        PtrOffset,
+        ConstantInt::get(IntPtrTy, DL.getPointerSizeInBits(0) - BSI.AlignLog2));
     BitOffset = B.CreateOr(OffsetSHR, OffsetSHL);
   }
 
@@ -469,9 +513,10 @@ void LowerBitSets::buildBitSetsFromGlobals(
     const std::vector<GlobalVariable *> &Globals) {
   // Build a new global with the combined contents of the referenced globals.
   std::vector<Constant *> GlobalInits;
+  const DataLayout &DL = M->getDataLayout();
   for (GlobalVariable *G : Globals) {
     GlobalInits.push_back(G->getInitializer());
-    uint64_t InitSize = DL->getTypeAllocSize(G->getInitializer()->getType());
+    uint64_t InitSize = DL.getTypeAllocSize(G->getInitializer()->getType());
 
     // Compute the amount of padding required to align the next element to the
     // next power of 2.
@@ -493,7 +538,7 @@ void LowerBitSets::buildBitSetsFromGlobals(
                          GlobalValue::PrivateLinkage, NewInit);
 
   const StructLayout *CombinedGlobalLayout =
-      DL->getStructLayout(cast<StructType>(NewInit->getType()));
+      DL.getStructLayout(cast<StructType>(NewInit->getType()));
 
   // Compute the offsets of the original globals within the new global.
   DenseMap<GlobalVariable *, uint64_t> GlobalLayout;
@@ -505,6 +550,10 @@ void LowerBitSets::buildBitSetsFromGlobals(
   for (MDString *BS : BitSets) {
     // Build the bitset.
     BitSetInfo BSI = buildBitSet(BS, GlobalLayout);
+    DEBUG({
+      dbgs() << BS->getString() << ": ";
+      BSI.print(dbgs());
+    });
 
     ByteArrayInfo *BAI = 0;
 
@@ -524,14 +573,17 @@ void LowerBitSets::buildBitSetsFromGlobals(
     // Multiply by 2 to account for padding elements.
     Constant *CombinedGlobalIdxs[] = {ConstantInt::get(Int32Ty, 0),
                                       ConstantInt::get(Int32Ty, I * 2)};
-    Constant *CombinedGlobalElemPtr =
-        ConstantExpr::getGetElementPtr(CombinedGlobal, CombinedGlobalIdxs);
-    GlobalAlias *GAlias = GlobalAlias::create(
-        Globals[I]->getType()->getElementType(),
-        Globals[I]->getType()->getAddressSpace(), Globals[I]->getLinkage(),
-        "", CombinedGlobalElemPtr, M);
-    GAlias->takeName(Globals[I]);
-    Globals[I]->replaceAllUsesWith(GAlias);
+    Constant *CombinedGlobalElemPtr = ConstantExpr::getGetElementPtr(
+        NewInit->getType(), CombinedGlobal, CombinedGlobalIdxs);
+    if (LinkerSubsectionsViaSymbols) {
+      Globals[I]->replaceAllUsesWith(CombinedGlobalElemPtr);
+    } else {
+      GlobalAlias *GAlias =
+          GlobalAlias::create(Globals[I]->getType(), Globals[I]->getLinkage(),
+                              "", CombinedGlobalElemPtr, M);
+      GAlias->takeName(Globals[I]);
+      Globals[I]->replaceAllUsesWith(GAlias);
+    }
     Globals[I]->eraseFromParent();
   }
 }
@@ -593,7 +645,7 @@ bool LowerBitSets::buildBitSets() {
         report_fatal_error("Bit set element must be a constant");
       auto OpGlobal = dyn_cast<GlobalVariable>(OpConstMD->getValue());
       if (!OpGlobal)
-        report_fatal_error("Bit set element must refer to global");
+        continue;
 
       auto OffsetConstMD = dyn_cast<ConstantAsMetadata>(Op->getOperand(2));
       if (!OffsetConstMD)
@@ -647,8 +699,10 @@ bool LowerBitSets::buildBitSets() {
         if (I == BitSetIndices.end())
           continue;
 
-        auto OpGlobal = cast<GlobalVariable>(
+        auto OpGlobal = dyn_cast<GlobalVariable>(
             cast<ConstantAsMetadata>(Op->getOperand(1))->getValue());
+        if (!OpGlobal)
+          continue;
         BitSetMembers[I->second].insert(GlobalIndices[OpGlobal]);
       }
     }