Revert r247080, "LowerBitSets: Extend pass to support functions as bitset
[oota-llvm.git] / lib / Transforms / IPO / LowerBitSets.cpp
index 9c1fc33c31034a1c6dccaae1b1acb4ad08b347f5..bf386a6c618636cd4043304a0334e71b199d0c78 100644 (file)
@@ -26,6 +26,8 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Operator.h"
 #include "llvm/Pass.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 
 using namespace llvm;
@@ -38,6 +40,11 @@ STATISTIC(NumByteArraysCreated, "Number of byte arrays created");
 STATISTIC(NumBitSetCallsLowered, "Number of bitset calls lowered");
 STATISTIC(NumBitSetDisjointSets, "Number of disjoint sets of bitsets");
 
+static cl::opt<bool> AvoidReuse(
+    "lowerbitsets-avoid-reuse",
+    cl::desc("Try to avoid reuse of byte array addresses using aliases"),
+    cl::Hidden, cl::init(true));
+
 bool BitSetInfo::containsGlobalOffset(uint64_t Offset) const {
   if (Offset < ByteOffset)
     return false;
@@ -85,6 +92,22 @@ bool BitSetInfo::containsValue(
   return false;
 }
 
+void BitSetInfo::print(raw_ostream &OS) const {
+  OS << "offset " << ByteOffset << " size " << BitSize << " align "
+     << (1 << AlignLog2);
+
+  if (isAllOnes()) {
+    OS << " all-ones\n";
+    return;
+  }
+
+  OS << " { ";
+  for (uint64_t B : Bits)
+    OS << B << ' ';
+  OS << "}\n";
+  return;
+}
+
 BitSetInfo BitSetBuilder::build() {
   if (Min > Max)
     Min = 0;
@@ -266,8 +289,10 @@ BitSetInfo LowerBitSets::buildBitSet(
     for (MDNode *Op : BitSetNM->operands()) {
       if (Op->getOperand(0) != BitSet || !Op->getOperand(1))
         continue;
-      auto OpGlobal = cast<GlobalVariable>(
+      auto OpGlobal = dyn_cast<GlobalVariable>(
           cast<ConstantAsMetadata>(Op->getOperand(1))->getValue());
+      if (!OpGlobal)
+        continue;
       uint64_t Offset =
           cast<ConstantInt>(cast<ConstantAsMetadata>(Op->getOperand(2))
                                 ->getValue())->getZExtValue();
@@ -344,7 +369,8 @@ void LowerBitSets::allocateByteArrays() {
 
     Constant *Idxs[] = {ConstantInt::get(IntPtrTy, 0),
                         ConstantInt::get(IntPtrTy, ByteArrayOffsets[I])};
-    Constant *GEP = ConstantExpr::getInBoundsGetElementPtr(ByteArray, Idxs);
+    Constant *GEP = ConstantExpr::getInBoundsGetElementPtr(
+        ByteArrayConst->getType(), ByteArray, Idxs);
 
     // Create an alias instead of RAUW'ing the gep directly. On x86 this ensures
     // that the pc-relative displacement is folded into the lea instead of the
@@ -352,8 +378,9 @@ void LowerBitSets::allocateByteArrays() {
     if (LinkerSubsectionsViaSymbols) {
       BAI->ByteArray->replaceAllUsesWith(GEP);
     } else {
-      GlobalAlias *Alias = GlobalAlias::create(
-          Int8Ty, 0, GlobalValue::PrivateLinkage, "bits", GEP, M);
+      GlobalAlias *Alias =
+          GlobalAlias::create(PointerType::getUnqual(Int8Ty),
+                              GlobalValue::PrivateLinkage, "bits", GEP, M);
       BAI->ByteArray->replaceAllUsesWith(Alias);
     }
     BAI->ByteArray->eraseFromParent();
@@ -389,7 +416,18 @@ Value *LowerBitSets::createBitSetTest(IRBuilder<> &B, BitSetInfo &BSI,
       BAI = createByteArray(BSI);
     }
 
-    Value *ByteAddr = B.CreateGEP(BAI->ByteArray, BitOffset);
+    Constant *ByteArray = BAI->ByteArray;
+    Type *Ty = BAI->ByteArray->getValueType();
+    if (!LinkerSubsectionsViaSymbols && AvoidReuse) {
+      // Each use of the byte array uses a different alias. This makes the
+      // backend less likely to reuse previously computed byte array addresses,
+      // improving the security of the CFI mechanism based on this pass.
+      ByteArray = GlobalAlias::create(BAI->ByteArray->getType(),
+                                      GlobalValue::PrivateLinkage, "bits_use",
+                                      ByteArray, M);
+    }
+
+    Value *ByteAddr = B.CreateGEP(Ty, ByteArray, BitOffset);
     Value *Byte = B.CreateLoad(ByteAddr);
 
     Value *ByteAndMask = B.CreateAnd(Byte, BAI->Mask);
@@ -512,6 +550,10 @@ void LowerBitSets::buildBitSetsFromGlobals(
   for (MDString *BS : BitSets) {
     // Build the bitset.
     BitSetInfo BSI = buildBitSet(BS, GlobalLayout);
+    DEBUG({
+      dbgs() << BS->getString() << ": ";
+      BSI.print(dbgs());
+    });
 
     ByteArrayInfo *BAI = 0;
 
@@ -531,15 +573,14 @@ void LowerBitSets::buildBitSetsFromGlobals(
     // Multiply by 2 to account for padding elements.
     Constant *CombinedGlobalIdxs[] = {ConstantInt::get(Int32Ty, 0),
                                       ConstantInt::get(Int32Ty, I * 2)};
-    Constant *CombinedGlobalElemPtr =
-        ConstantExpr::getGetElementPtr(CombinedGlobal, CombinedGlobalIdxs);
+    Constant *CombinedGlobalElemPtr = ConstantExpr::getGetElementPtr(
+        NewInit->getType(), CombinedGlobal, CombinedGlobalIdxs);
     if (LinkerSubsectionsViaSymbols) {
       Globals[I]->replaceAllUsesWith(CombinedGlobalElemPtr);
     } else {
-      GlobalAlias *GAlias = GlobalAlias::create(
-          Globals[I]->getType()->getElementType(),
-          Globals[I]->getType()->getAddressSpace(), Globals[I]->getLinkage(),
-          "", CombinedGlobalElemPtr, M);
+      GlobalAlias *GAlias =
+          GlobalAlias::create(Globals[I]->getType(), Globals[I]->getLinkage(),
+                              "", CombinedGlobalElemPtr, M);
       GAlias->takeName(Globals[I]);
       Globals[I]->replaceAllUsesWith(GAlias);
     }
@@ -604,7 +645,7 @@ bool LowerBitSets::buildBitSets() {
         report_fatal_error("Bit set element must be a constant");
       auto OpGlobal = dyn_cast<GlobalVariable>(OpConstMD->getValue());
       if (!OpGlobal)
-        report_fatal_error("Bit set element must refer to global");
+        continue;
 
       auto OffsetConstMD = dyn_cast<ConstantAsMetadata>(Op->getOperand(2));
       if (!OffsetConstMD)
@@ -658,8 +699,10 @@ bool LowerBitSets::buildBitSets() {
         if (I == BitSetIndices.end())
           continue;
 
-        auto OpGlobal = cast<GlobalVariable>(
+        auto OpGlobal = dyn_cast<GlobalVariable>(
             cast<ConstantAsMetadata>(Op->getOperand(1))->getValue());
+        if (!OpGlobal)
+          continue;
         BitSetMembers[I->second].insert(GlobalIndices[OpGlobal]);
       }
     }