f397c38a996714af5b28e8e2e845a1dc9b7c84d0
[oota-llvm.git] / lib / Transforms / IPO / LowerBitSets.cpp
1 //===-- LowerBitSets.cpp - Bitset lowering pass ---------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass lowers bitset metadata and calls to the llvm.bitset.test intrinsic.
11 // See http://llvm.org/docs/LangRef.html#bitsets for more information.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "llvm/Transforms/IPO/LowerBitSets.h"
16 #include "llvm/Transforms/IPO.h"
17 #include "llvm/ADT/EquivalenceClasses.h"
18 #include "llvm/ADT/Statistic.h"
19 #include "llvm/IR/Constant.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/GlobalVariable.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/Intrinsics.h"
25 #include "llvm/IR/Module.h"
26 #include "llvm/IR/Operator.h"
27 #include "llvm/Pass.h"
28 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
29
30 using namespace llvm;
31
32 #define DEBUG_TYPE "lowerbitsets"
33
34 STATISTIC(ByteArraySizeBits, "Byte array size in bits");
35 STATISTIC(ByteArraySizeBytes, "Byte array size in bytes");
36 STATISTIC(NumByteArraysCreated, "Number of byte arrays created");
37 STATISTIC(NumBitSetCallsLowered, "Number of bitset calls lowered");
38 STATISTIC(NumBitSetDisjointSets, "Number of disjoint sets of bitsets");
39
40 bool BitSetInfo::containsGlobalOffset(uint64_t Offset) const {
41   if (Offset < ByteOffset)
42     return false;
43
44   if ((Offset - ByteOffset) % (uint64_t(1) << AlignLog2) != 0)
45     return false;
46
47   uint64_t BitOffset = (Offset - ByteOffset) >> AlignLog2;
48   if (BitOffset >= BitSize)
49     return false;
50
51   return Bits.count(BitOffset);
52 }
53
54 bool BitSetInfo::containsValue(
55     const DataLayout *DL,
56     const DenseMap<GlobalVariable *, uint64_t> &GlobalLayout, Value *V,
57     uint64_t COffset) const {
58   if (auto GV = dyn_cast<GlobalVariable>(V)) {
59     auto I = GlobalLayout.find(GV);
60     if (I == GlobalLayout.end())
61       return false;
62     return containsGlobalOffset(I->second + COffset);
63   }
64
65   if (auto GEP = dyn_cast<GEPOperator>(V)) {
66     APInt APOffset(DL->getPointerSizeInBits(0), 0);
67     bool Result = GEP->accumulateConstantOffset(*DL, APOffset);
68     if (!Result)
69       return false;
70     COffset += APOffset.getZExtValue();
71     return containsValue(DL, GlobalLayout, GEP->getPointerOperand(),
72                          COffset);
73   }
74
75   if (auto Op = dyn_cast<Operator>(V)) {
76     if (Op->getOpcode() == Instruction::BitCast)
77       return containsValue(DL, GlobalLayout, Op->getOperand(0), COffset);
78
79     if (Op->getOpcode() == Instruction::Select)
80       return containsValue(DL, GlobalLayout, Op->getOperand(1), COffset) &&
81              containsValue(DL, GlobalLayout, Op->getOperand(2), COffset);
82   }
83
84   return false;
85 }
86
87 BitSetInfo BitSetBuilder::build() {
88   if (Min > Max)
89     Min = 0;
90
91   // Normalize each offset against the minimum observed offset, and compute
92   // the bitwise OR of each of the offsets. The number of trailing zeros
93   // in the mask gives us the log2 of the alignment of all offsets, which
94   // allows us to compress the bitset by only storing one bit per aligned
95   // address.
96   uint64_t Mask = 0;
97   for (uint64_t &Offset : Offsets) {
98     Offset -= Min;
99     Mask |= Offset;
100   }
101
102   BitSetInfo BSI;
103   BSI.ByteOffset = Min;
104
105   BSI.AlignLog2 = 0;
106   if (Mask != 0)
107     BSI.AlignLog2 = countTrailingZeros(Mask, ZB_Undefined);
108
109   // Build the compressed bitset while normalizing the offsets against the
110   // computed alignment.
111   BSI.BitSize = ((Max - Min) >> BSI.AlignLog2) + 1;
112   for (uint64_t Offset : Offsets) {
113     Offset >>= BSI.AlignLog2;
114     BSI.Bits.insert(Offset);
115   }
116
117   return BSI;
118 }
119
120 void GlobalLayoutBuilder::addFragment(const std::set<uint64_t> &F) {
121   // Create a new fragment to hold the layout for F.
122   Fragments.emplace_back();
123   std::vector<uint64_t> &Fragment = Fragments.back();
124   uint64_t FragmentIndex = Fragments.size() - 1;
125
126   for (auto ObjIndex : F) {
127     uint64_t OldFragmentIndex = FragmentMap[ObjIndex];
128     if (OldFragmentIndex == 0) {
129       // We haven't seen this object index before, so just add it to the current
130       // fragment.
131       Fragment.push_back(ObjIndex);
132     } else {
133       // This index belongs to an existing fragment. Copy the elements of the
134       // old fragment into this one and clear the old fragment. We don't update
135       // the fragment map just yet, this ensures that any further references to
136       // indices from the old fragment in this fragment do not insert any more
137       // indices.
138       std::vector<uint64_t> &OldFragment = Fragments[OldFragmentIndex];
139       Fragment.insert(Fragment.end(), OldFragment.begin(), OldFragment.end());
140       OldFragment.clear();
141     }
142   }
143
144   // Update the fragment map to point our object indices to this fragment.
145   for (uint64_t ObjIndex : Fragment)
146     FragmentMap[ObjIndex] = FragmentIndex;
147 }
148
149 void ByteArrayBuilder::allocate(const std::set<uint64_t> &Bits,
150                                 uint64_t BitSize, uint64_t &AllocByteOffset,
151                                 uint8_t &AllocMask) {
152   // Find the smallest current allocation.
153   unsigned Bit = 0;
154   for (unsigned I = 1; I != BitsPerByte; ++I)
155     if (BitAllocs[I] < BitAllocs[Bit])
156       Bit = I;
157
158   AllocByteOffset = BitAllocs[Bit];
159
160   // Add our size to it.
161   unsigned ReqSize = AllocByteOffset + BitSize;
162   BitAllocs[Bit] = ReqSize;
163   if (Bytes.size() < ReqSize)
164     Bytes.resize(ReqSize);
165
166   // Set our bits.
167   AllocMask = 1 << Bit;
168   for (uint64_t B : Bits)
169     Bytes[AllocByteOffset + B] |= AllocMask;
170 }
171
172 namespace {
173
174 struct ByteArrayInfo {
175   std::set<uint64_t> Bits;
176   uint64_t BitSize;
177   GlobalVariable *ByteArray;
178   Constant *Mask;
179 };
180
181 struct LowerBitSets : public ModulePass {
182   static char ID;
183   LowerBitSets() : ModulePass(ID) {
184     initializeLowerBitSetsPass(*PassRegistry::getPassRegistry());
185   }
186
187   Module *M;
188
189   const DataLayout *DL;
190   IntegerType *Int1Ty;
191   IntegerType *Int8Ty;
192   IntegerType *Int32Ty;
193   Type *Int32PtrTy;
194   IntegerType *Int64Ty;
195   Type *IntPtrTy;
196
197   // The llvm.bitsets named metadata.
198   NamedMDNode *BitSetNM;
199
200   // Mapping from bitset mdstrings to the call sites that test them.
201   DenseMap<MDString *, std::vector<CallInst *>> BitSetTestCallSites;
202
203   std::vector<ByteArrayInfo> ByteArrayInfos;
204
205   BitSetInfo
206   buildBitSet(MDString *BitSet,
207               const DenseMap<GlobalVariable *, uint64_t> &GlobalLayout);
208   ByteArrayInfo *createByteArray(BitSetInfo &BSI);
209   void allocateByteArrays();
210   Value *createBitSetTest(IRBuilder<> &B, BitSetInfo &BSI, ByteArrayInfo *&BAI,
211                           Value *BitOffset);
212   Value *
213   lowerBitSetCall(CallInst *CI, BitSetInfo &BSI, ByteArrayInfo *&BAI,
214                   GlobalVariable *CombinedGlobal,
215                   const DenseMap<GlobalVariable *, uint64_t> &GlobalLayout);
216   void buildBitSetsFromGlobals(const std::vector<MDString *> &BitSets,
217                                const std::vector<GlobalVariable *> &Globals);
218   bool buildBitSets();
219   bool eraseBitSetMetadata();
220
221   bool doInitialization(Module &M) override;
222   bool runOnModule(Module &M) override;
223 };
224
225 } // namespace
226
227 INITIALIZE_PASS_BEGIN(LowerBitSets, "lowerbitsets",
228                 "Lower bitset metadata", false, false)
229 INITIALIZE_PASS_END(LowerBitSets, "lowerbitsets",
230                 "Lower bitset metadata", false, false)
231 char LowerBitSets::ID = 0;
232
233 ModulePass *llvm::createLowerBitSetsPass() { return new LowerBitSets; }
234
235 bool LowerBitSets::doInitialization(Module &Mod) {
236   M = &Mod;
237
238   DL = M->getDataLayout();
239   if (!DL)
240     report_fatal_error("Data layout required");
241
242   Int1Ty = Type::getInt1Ty(M->getContext());
243   Int8Ty = Type::getInt8Ty(M->getContext());
244   Int32Ty = Type::getInt32Ty(M->getContext());
245   Int32PtrTy = PointerType::getUnqual(Int32Ty);
246   Int64Ty = Type::getInt64Ty(M->getContext());
247   IntPtrTy = DL->getIntPtrType(M->getContext(), 0);
248
249   BitSetNM = M->getNamedMetadata("llvm.bitsets");
250
251   BitSetTestCallSites.clear();
252
253   return false;
254 }
255
256 /// Build a bit set for BitSet using the object layouts in
257 /// GlobalLayout.
258 BitSetInfo LowerBitSets::buildBitSet(
259     MDString *BitSet,
260     const DenseMap<GlobalVariable *, uint64_t> &GlobalLayout) {
261   BitSetBuilder BSB;
262
263   // Compute the byte offset of each element of this bitset.
264   if (BitSetNM) {
265     for (MDNode *Op : BitSetNM->operands()) {
266       if (Op->getOperand(0) != BitSet || !Op->getOperand(1))
267         continue;
268       auto OpGlobal = cast<GlobalVariable>(
269           cast<ConstantAsMetadata>(Op->getOperand(1))->getValue());
270       uint64_t Offset =
271           cast<ConstantInt>(cast<ConstantAsMetadata>(Op->getOperand(2))
272                                 ->getValue())->getZExtValue();
273
274       Offset += GlobalLayout.find(OpGlobal)->second;
275
276       BSB.addOffset(Offset);
277     }
278   }
279
280   return BSB.build();
281 }
282
283 /// Build a test that bit BitOffset mod sizeof(Bits)*8 is set in
284 /// Bits. This pattern matches to the bt instruction on x86.
285 static Value *createMaskedBitTest(IRBuilder<> &B, Value *Bits,
286                                   Value *BitOffset) {
287   auto BitsType = cast<IntegerType>(Bits->getType());
288   unsigned BitWidth = BitsType->getBitWidth();
289
290   BitOffset = B.CreateZExtOrTrunc(BitOffset, BitsType);
291   Value *BitIndex =
292       B.CreateAnd(BitOffset, ConstantInt::get(BitsType, BitWidth - 1));
293   Value *BitMask = B.CreateShl(ConstantInt::get(BitsType, 1), BitIndex);
294   Value *MaskedBits = B.CreateAnd(Bits, BitMask);
295   return B.CreateICmpNE(MaskedBits, ConstantInt::get(BitsType, 0));
296 }
297
298 ByteArrayInfo *LowerBitSets::createByteArray(BitSetInfo &BSI) {
299   // Create globals to stand in for byte arrays and masks. These never actually
300   // get initialized, we RAUW and erase them later in allocateByteArrays() once
301   // we know the offset and mask to use.
302   auto ByteArrayGlobal = new GlobalVariable(
303       *M, Int8Ty, /*isConstant=*/true, GlobalValue::PrivateLinkage, nullptr);
304   auto MaskGlobal = new GlobalVariable(
305       *M, Int8Ty, /*isConstant=*/true, GlobalValue::PrivateLinkage, nullptr);
306
307   ByteArrayInfos.emplace_back();
308   ByteArrayInfo *BAI = &ByteArrayInfos.back();
309
310   BAI->Bits = BSI.Bits;
311   BAI->BitSize = BSI.BitSize;
312   BAI->ByteArray = ByteArrayGlobal;
313   BAI->Mask = ConstantExpr::getPtrToInt(MaskGlobal, Int8Ty);
314   return BAI;
315 }
316
317 void LowerBitSets::allocateByteArrays() {
318   std::stable_sort(ByteArrayInfos.begin(), ByteArrayInfos.end(),
319                    [](const ByteArrayInfo &BAI1, const ByteArrayInfo &BAI2) {
320                      return BAI1.BitSize > BAI2.BitSize;
321                    });
322
323   std::vector<uint64_t> ByteArrayOffsets(ByteArrayInfos.size());
324
325   ByteArrayBuilder BAB;
326   for (unsigned I = 0; I != ByteArrayInfos.size(); ++I) {
327     ByteArrayInfo *BAI = &ByteArrayInfos[I];
328
329     uint8_t Mask;
330     BAB.allocate(BAI->Bits, BAI->BitSize, ByteArrayOffsets[I], Mask);
331
332     BAI->Mask->replaceAllUsesWith(ConstantInt::get(Int8Ty, Mask));
333     cast<GlobalVariable>(BAI->Mask->getOperand(0))->eraseFromParent();
334   }
335
336   Constant *ByteArrayConst = ConstantDataArray::get(M->getContext(), BAB.Bytes);
337   auto ByteArray =
338       new GlobalVariable(*M, ByteArrayConst->getType(), /*isConstant=*/true,
339                          GlobalValue::PrivateLinkage, ByteArrayConst);
340
341   for (unsigned I = 0; I != ByteArrayInfos.size(); ++I) {
342     ByteArrayInfo *BAI = &ByteArrayInfos[I];
343
344     Constant *Idxs[] = {ConstantInt::get(IntPtrTy, 0),
345                         ConstantInt::get(IntPtrTy, ByteArrayOffsets[I])};
346     Constant *GEP = ConstantExpr::getInBoundsGetElementPtr(ByteArray, Idxs);
347
348     // Create an alias instead of RAUW'ing the gep directly. On x86 this ensures
349     // that the pc-relative displacement is folded into the lea instead of the
350     // test instruction getting another displacement.
351     GlobalAlias *Alias = GlobalAlias::create(
352         Int8Ty, 0, GlobalValue::PrivateLinkage, "bits", GEP, M);
353     BAI->ByteArray->replaceAllUsesWith(Alias);
354     BAI->ByteArray->eraseFromParent();
355   }
356
357   ByteArraySizeBits = BAB.BitAllocs[0] + BAB.BitAllocs[1] + BAB.BitAllocs[2] +
358                       BAB.BitAllocs[3] + BAB.BitAllocs[4] + BAB.BitAllocs[5] +
359                       BAB.BitAllocs[6] + BAB.BitAllocs[7];
360   ByteArraySizeBytes = BAB.Bytes.size();
361 }
362
363 /// Build a test that bit BitOffset is set in BSI, where
364 /// BitSetGlobal is a global containing the bits in BSI.
365 Value *LowerBitSets::createBitSetTest(IRBuilder<> &B, BitSetInfo &BSI,
366                                       ByteArrayInfo *&BAI, Value *BitOffset) {
367   if (BSI.BitSize <= 64) {
368     // If the bit set is sufficiently small, we can avoid a load by bit testing
369     // a constant.
370     IntegerType *BitsTy;
371     if (BSI.BitSize <= 32)
372       BitsTy = Int32Ty;
373     else
374       BitsTy = Int64Ty;
375
376     uint64_t Bits = 0;
377     for (auto Bit : BSI.Bits)
378       Bits |= uint64_t(1) << Bit;
379     Constant *BitsConst = ConstantInt::get(BitsTy, Bits);
380     return createMaskedBitTest(B, BitsConst, BitOffset);
381   } else {
382     if (!BAI) {
383       ++NumByteArraysCreated;
384       BAI = createByteArray(BSI);
385     }
386
387     Value *ByteAddr = B.CreateGEP(BAI->ByteArray, BitOffset);
388     Value *Byte = B.CreateLoad(ByteAddr);
389
390     Value *ByteAndMask = B.CreateAnd(Byte, BAI->Mask);
391     return B.CreateICmpNE(ByteAndMask, ConstantInt::get(Int8Ty, 0));
392   }
393 }
394
395 /// Lower a llvm.bitset.test call to its implementation. Returns the value to
396 /// replace the call with.
397 Value *LowerBitSets::lowerBitSetCall(
398     CallInst *CI, BitSetInfo &BSI, ByteArrayInfo *&BAI,
399     GlobalVariable *CombinedGlobal,
400     const DenseMap<GlobalVariable *, uint64_t> &GlobalLayout) {
401   Value *Ptr = CI->getArgOperand(0);
402
403   if (BSI.containsValue(DL, GlobalLayout, Ptr))
404     return ConstantInt::getTrue(CombinedGlobal->getParent()->getContext());
405
406   Constant *GlobalAsInt = ConstantExpr::getPtrToInt(CombinedGlobal, IntPtrTy);
407   Constant *OffsetedGlobalAsInt = ConstantExpr::getAdd(
408       GlobalAsInt, ConstantInt::get(IntPtrTy, BSI.ByteOffset));
409
410   BasicBlock *InitialBB = CI->getParent();
411
412   IRBuilder<> B(CI);
413
414   Value *PtrAsInt = B.CreatePtrToInt(Ptr, IntPtrTy);
415
416   if (BSI.isSingleOffset())
417     return B.CreateICmpEQ(PtrAsInt, OffsetedGlobalAsInt);
418
419   Value *PtrOffset = B.CreateSub(PtrAsInt, OffsetedGlobalAsInt);
420
421   Value *BitOffset;
422   if (BSI.AlignLog2 == 0) {
423     BitOffset = PtrOffset;
424   } else {
425     // We need to check that the offset both falls within our range and is
426     // suitably aligned. We can check both properties at the same time by
427     // performing a right rotate by log2(alignment) followed by an integer
428     // comparison against the bitset size. The rotate will move the lower
429     // order bits that need to be zero into the higher order bits of the
430     // result, causing the comparison to fail if they are nonzero. The rotate
431     // also conveniently gives us a bit offset to use during the load from
432     // the bitset.
433     Value *OffsetSHR =
434         B.CreateLShr(PtrOffset, ConstantInt::get(IntPtrTy, BSI.AlignLog2));
435     Value *OffsetSHL = B.CreateShl(
436         PtrOffset, ConstantInt::get(IntPtrTy, DL->getPointerSizeInBits(0) -
437                                                   BSI.AlignLog2));
438     BitOffset = B.CreateOr(OffsetSHR, OffsetSHL);
439   }
440
441   Constant *BitSizeConst = ConstantInt::get(IntPtrTy, BSI.BitSize);
442   Value *OffsetInRange = B.CreateICmpULT(BitOffset, BitSizeConst);
443
444   // If the bit set is all ones, testing against it is unnecessary.
445   if (BSI.isAllOnes())
446     return OffsetInRange;
447
448   TerminatorInst *Term = SplitBlockAndInsertIfThen(OffsetInRange, CI, false);
449   IRBuilder<> ThenB(Term);
450
451   // Now that we know that the offset is in range and aligned, load the
452   // appropriate bit from the bitset.
453   Value *Bit = createBitSetTest(ThenB, BSI, BAI, BitOffset);
454
455   // The value we want is 0 if we came directly from the initial block
456   // (having failed the range or alignment checks), or the loaded bit if
457   // we came from the block in which we loaded it.
458   B.SetInsertPoint(CI);
459   PHINode *P = B.CreatePHI(Int1Ty, 2);
460   P->addIncoming(ConstantInt::get(Int1Ty, 0), InitialBB);
461   P->addIncoming(Bit, ThenB.GetInsertBlock());
462   return P;
463 }
464
465 /// Given a disjoint set of bitsets and globals, layout the globals, build the
466 /// bit sets and lower the llvm.bitset.test calls.
467 void LowerBitSets::buildBitSetsFromGlobals(
468     const std::vector<MDString *> &BitSets,
469     const std::vector<GlobalVariable *> &Globals) {
470   // Build a new global with the combined contents of the referenced globals.
471   std::vector<Constant *> GlobalInits;
472   for (GlobalVariable *G : Globals) {
473     GlobalInits.push_back(G->getInitializer());
474     uint64_t InitSize = DL->getTypeAllocSize(G->getInitializer()->getType());
475
476     // Compute the amount of padding required to align the next element to the
477     // next power of 2.
478     uint64_t Padding = NextPowerOf2(InitSize - 1) - InitSize;
479
480     // Cap at 128 was found experimentally to have a good data/instruction
481     // overhead tradeoff.
482     if (Padding > 128)
483       Padding = RoundUpToAlignment(InitSize, 128) - InitSize;
484
485     GlobalInits.push_back(
486         ConstantAggregateZero::get(ArrayType::get(Int8Ty, Padding)));
487   }
488   if (!GlobalInits.empty())
489     GlobalInits.pop_back();
490   Constant *NewInit = ConstantStruct::getAnon(M->getContext(), GlobalInits);
491   auto CombinedGlobal =
492       new GlobalVariable(*M, NewInit->getType(), /*isConstant=*/true,
493                          GlobalValue::PrivateLinkage, NewInit);
494
495   const StructLayout *CombinedGlobalLayout =
496       DL->getStructLayout(cast<StructType>(NewInit->getType()));
497
498   // Compute the offsets of the original globals within the new global.
499   DenseMap<GlobalVariable *, uint64_t> GlobalLayout;
500   for (unsigned I = 0; I != Globals.size(); ++I)
501     // Multiply by 2 to account for padding elements.
502     GlobalLayout[Globals[I]] = CombinedGlobalLayout->getElementOffset(I * 2);
503
504   // For each bitset in this disjoint set...
505   for (MDString *BS : BitSets) {
506     // Build the bitset.
507     BitSetInfo BSI = buildBitSet(BS, GlobalLayout);
508
509     ByteArrayInfo *BAI = 0;
510
511     // Lower each call to llvm.bitset.test for this bitset.
512     for (CallInst *CI : BitSetTestCallSites[BS]) {
513       ++NumBitSetCallsLowered;
514       Value *Lowered = lowerBitSetCall(CI, BSI, BAI, CombinedGlobal, GlobalLayout);
515       CI->replaceAllUsesWith(Lowered);
516       CI->eraseFromParent();
517     }
518   }
519
520   // Build aliases pointing to offsets into the combined global for each
521   // global from which we built the combined global, and replace references
522   // to the original globals with references to the aliases.
523   for (unsigned I = 0; I != Globals.size(); ++I) {
524     // Multiply by 2 to account for padding elements.
525     Constant *CombinedGlobalIdxs[] = {ConstantInt::get(Int32Ty, 0),
526                                       ConstantInt::get(Int32Ty, I * 2)};
527     Constant *CombinedGlobalElemPtr =
528         ConstantExpr::getGetElementPtr(CombinedGlobal, CombinedGlobalIdxs);
529     GlobalAlias *GAlias = GlobalAlias::create(
530         Globals[I]->getType()->getElementType(),
531         Globals[I]->getType()->getAddressSpace(), Globals[I]->getLinkage(),
532         "", CombinedGlobalElemPtr, M);
533     GAlias->takeName(Globals[I]);
534     Globals[I]->replaceAllUsesWith(GAlias);
535     Globals[I]->eraseFromParent();
536   }
537 }
538
539 /// Lower all bit sets in this module.
540 bool LowerBitSets::buildBitSets() {
541   Function *BitSetTestFunc =
542       M->getFunction(Intrinsic::getName(Intrinsic::bitset_test));
543   if (!BitSetTestFunc)
544     return false;
545
546   // Equivalence class set containing bitsets and the globals they reference.
547   // This is used to partition the set of bitsets in the module into disjoint
548   // sets.
549   typedef EquivalenceClasses<PointerUnion<GlobalVariable *, MDString *>>
550       GlobalClassesTy;
551   GlobalClassesTy GlobalClasses;
552
553   for (const Use &U : BitSetTestFunc->uses()) {
554     auto CI = cast<CallInst>(U.getUser());
555
556     auto BitSetMDVal = dyn_cast<MetadataAsValue>(CI->getArgOperand(1));
557     if (!BitSetMDVal || !isa<MDString>(BitSetMDVal->getMetadata()))
558       report_fatal_error(
559           "Second argument of llvm.bitset.test must be metadata string");
560     auto BitSet = cast<MDString>(BitSetMDVal->getMetadata());
561
562     // Add the call site to the list of call sites for this bit set. We also use
563     // BitSetTestCallSites to keep track of whether we have seen this bit set
564     // before. If we have, we don't need to re-add the referenced globals to the
565     // equivalence class.
566     std::pair<DenseMap<MDString *, std::vector<CallInst *>>::iterator,
567               bool> Ins =
568         BitSetTestCallSites.insert(
569             std::make_pair(BitSet, std::vector<CallInst *>()));
570     Ins.first->second.push_back(CI);
571     if (!Ins.second)
572       continue;
573
574     // Add the bitset to the equivalence class.
575     GlobalClassesTy::iterator GCI = GlobalClasses.insert(BitSet);
576     GlobalClassesTy::member_iterator CurSet = GlobalClasses.findLeader(GCI);
577
578     if (!BitSetNM)
579       continue;
580
581     // Verify the bitset metadata and add the referenced globals to the bitset's
582     // equivalence class.
583     for (MDNode *Op : BitSetNM->operands()) {
584       if (Op->getNumOperands() != 3)
585         report_fatal_error(
586             "All operands of llvm.bitsets metadata must have 3 elements");
587
588       if (Op->getOperand(0) != BitSet || !Op->getOperand(1))
589         continue;
590
591       auto OpConstMD = dyn_cast<ConstantAsMetadata>(Op->getOperand(1));
592       if (!OpConstMD)
593         report_fatal_error("Bit set element must be a constant");
594       auto OpGlobal = dyn_cast<GlobalVariable>(OpConstMD->getValue());
595       if (!OpGlobal)
596         report_fatal_error("Bit set element must refer to global");
597
598       auto OffsetConstMD = dyn_cast<ConstantAsMetadata>(Op->getOperand(2));
599       if (!OffsetConstMD)
600         report_fatal_error("Bit set element offset must be a constant");
601       auto OffsetInt = dyn_cast<ConstantInt>(OffsetConstMD->getValue());
602       if (!OffsetInt)
603         report_fatal_error(
604             "Bit set element offset must be an integer constant");
605
606       CurSet = GlobalClasses.unionSets(
607           CurSet, GlobalClasses.findLeader(GlobalClasses.insert(OpGlobal)));
608     }
609   }
610
611   if (GlobalClasses.empty())
612     return false;
613
614   // For each disjoint set we found...
615   for (GlobalClassesTy::iterator I = GlobalClasses.begin(),
616                                  E = GlobalClasses.end();
617        I != E; ++I) {
618     if (!I->isLeader()) continue;
619
620     ++NumBitSetDisjointSets;
621
622     // Build the list of bitsets and referenced globals in this disjoint set.
623     std::vector<MDString *> BitSets;
624     std::vector<GlobalVariable *> Globals;
625     llvm::DenseMap<MDString *, uint64_t> BitSetIndices;
626     llvm::DenseMap<GlobalVariable *, uint64_t> GlobalIndices;
627     for (GlobalClassesTy::member_iterator MI = GlobalClasses.member_begin(I);
628          MI != GlobalClasses.member_end(); ++MI) {
629       if ((*MI).is<MDString *>()) {
630         BitSetIndices[MI->get<MDString *>()] = BitSets.size();
631         BitSets.push_back(MI->get<MDString *>());
632       } else {
633         GlobalIndices[MI->get<GlobalVariable *>()] = Globals.size();
634         Globals.push_back(MI->get<GlobalVariable *>());
635       }
636     }
637
638     // For each bitset, build a set of indices that refer to globals referenced
639     // by the bitset.
640     std::vector<std::set<uint64_t>> BitSetMembers(BitSets.size());
641     if (BitSetNM) {
642       for (MDNode *Op : BitSetNM->operands()) {
643         // Op = { bitset name, global, offset }
644         if (!Op->getOperand(1))
645           continue;
646         auto I = BitSetIndices.find(cast<MDString>(Op->getOperand(0)));
647         if (I == BitSetIndices.end())
648           continue;
649
650         auto OpGlobal = cast<GlobalVariable>(
651             cast<ConstantAsMetadata>(Op->getOperand(1))->getValue());
652         BitSetMembers[I->second].insert(GlobalIndices[OpGlobal]);
653       }
654     }
655
656     // Order the sets of indices by size. The GlobalLayoutBuilder works best
657     // when given small index sets first.
658     std::stable_sort(
659         BitSetMembers.begin(), BitSetMembers.end(),
660         [](const std::set<uint64_t> &O1, const std::set<uint64_t> &O2) {
661           return O1.size() < O2.size();
662         });
663
664     // Create a GlobalLayoutBuilder and provide it with index sets as layout
665     // fragments. The GlobalLayoutBuilder tries to lay out members of fragments
666     // as close together as possible.
667     GlobalLayoutBuilder GLB(Globals.size());
668     for (auto &&MemSet : BitSetMembers)
669       GLB.addFragment(MemSet);
670
671     // Build a vector of globals with the computed layout.
672     std::vector<GlobalVariable *> OrderedGlobals(Globals.size());
673     auto OGI = OrderedGlobals.begin();
674     for (auto &&F : GLB.Fragments)
675       for (auto &&Offset : F)
676         *OGI++ = Globals[Offset];
677
678     // Order bitsets by name for determinism.
679     std::sort(BitSets.begin(), BitSets.end(), [](MDString *S1, MDString *S2) {
680       return S1->getString() < S2->getString();
681     });
682
683     // Build the bitsets from this disjoint set.
684     buildBitSetsFromGlobals(BitSets, OrderedGlobals);
685   }
686
687   allocateByteArrays();
688
689   return true;
690 }
691
692 bool LowerBitSets::eraseBitSetMetadata() {
693   if (!BitSetNM)
694     return false;
695
696   M->eraseNamedMetadata(BitSetNM);
697   return true;
698 }
699
700 bool LowerBitSets::runOnModule(Module &M) {
701   bool Changed = buildBitSets();
702   Changed |= eraseBitSetMetadata();
703   return Changed;
704 }