CDSPass completed, able to replace atomic instructions with functional calls and...
[c11llvm.git] / CDSPass.cpp
1 //===-- CdsPass.cpp - xxx -------------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file is a modified version of ThreadSanitizer.cpp, a part of a race detector.
11 //
12 // The tool is under development, for the details about previous versions see
13 // http://code.google.com/p/data-race-test
14 //
15 // The instrumentation phase is quite simple:
16 //   - Insert calls to run-time library before every memory access.
17 //      - Optimizations may apply to avoid instrumenting some of the accesses.
18 //   - Insert calls at function entry/exit.
19 // The rest is handled by the run-time library.
20 //===----------------------------------------------------------------------===//
21
22 #include "llvm/ADT/Statistic.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/SmallString.h"
25 #include "llvm/Analysis/ValueTracking.h"
26 #include "llvm/Analysis/CaptureTracking.h"
27 #include "llvm/IR/BasicBlock.h"
28 #include "llvm/IR/CFG.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/LegacyPassManager.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/IR/PassManager.h"
36 #include "llvm/Pass.h"
37 #include "llvm/ProfileData/InstrProf.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include "llvm/Support/AtomicOrdering.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Transforms/Scalar.h"
42 #include "llvm/Transforms/Utils/Local.h"
43 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
44 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
45 #include <list>
46 #include <vector>
47 // #include "llvm/Support/MathExtras.h"
48
49 #define DEBUG_TYPE "CDS"
50 using namespace llvm;
51
52 #define FUNCARRAYSIZE 4
53
54 STATISTIC(NumInstrumentedReads, "Number of instrumented reads");
55 STATISTIC(NumInstrumentedWrites, "Number of instrumented writes");
56 // STATISTIC(NumInstrumentedVtableWrites, "Number of vtable ptr writes");
57 // STATISTIC(NumInstrumentedVtableReads, "Number of vtable ptr reads");
58
59 STATISTIC(NumOmittedReadsBeforeWrite,
60           "Number of reads ignored due to following writes");
61 STATISTIC(NumOmittedReadsFromConstantGlobals,
62           "Number of reads from constant globals");
63 STATISTIC(NumOmittedReadsFromVtable, "Number of vtable reads");
64 STATISTIC(NumOmittedNonCaptured, "Number of accesses ignored due to capturing");
65
66 Type * Int8Ty;
67 Type * Int16Ty;
68 Type * Int32Ty;
69 Type * Int64Ty;
70 Type * OrdTy;
71
72 Type * Int8PtrTy;
73 Type * Int16PtrTy;
74 Type * Int32PtrTy;
75 Type * Int64PtrTy;
76
77 Type * VoidTy;
78
79 Constant * CdsLoad[FUNCARRAYSIZE];
80 Constant * CdsStore[FUNCARRAYSIZE];
81 Constant * CdsAtomicLoad[FUNCARRAYSIZE];
82 Constant * CdsAtomicStore[FUNCARRAYSIZE];
83 Constant * CdsAtomicRMW[AtomicRMWInst::LAST_BINOP + 1][FUNCARRAYSIZE];
84 Constant * CdsAtomicCAS[FUNCARRAYSIZE];
85 Constant * CdsAtomicThreadFence;
86
87 int getAtomicOrderIndex(AtomicOrdering order){
88   switch (order) {
89     case AtomicOrdering::Monotonic: 
90       return (int)AtomicOrderingCABI::relaxed;
91 //  case AtomicOrdering::Consume:         // not specified yet
92 //    return AtomicOrderingCABI::consume;
93     case AtomicOrdering::Acquire: 
94       return (int)AtomicOrderingCABI::acquire;
95     case AtomicOrdering::Release: 
96       return (int)AtomicOrderingCABI::release;
97     case AtomicOrdering::AcquireRelease: 
98       return (int)AtomicOrderingCABI::acq_rel;
99     case AtomicOrdering::SequentiallyConsistent: 
100       return (int)AtomicOrderingCABI::seq_cst;
101     default:
102       // unordered or Not Atomic
103       return -1;
104   }
105 }
106
107 int getTypeSize(Type* type) {
108   if (type==Int32PtrTy) {
109     return sizeof(int)*8;
110   } else if (type==Int8PtrTy) {
111     return sizeof(char)*8;
112   } else if (type==Int16PtrTy) {
113     return sizeof(short)*8;
114   } else if (type==Int64PtrTy) {
115     return sizeof(long long int)*8;
116   } else {
117     return sizeof(void*)*8;
118   }
119
120   return -1;
121 }
122
123 static int sizetoindex(int size) {
124   switch(size) {
125     case 8:     return 0;
126     case 16:    return 1;
127     case 32:    return 2;
128     case 64:    return 3;
129   }
130   return -1;
131 }
132
133 namespace {
134   struct CdsPass : public FunctionPass {
135     static char ID;
136     CdsPass() : FunctionPass(ID) {}
137     bool runOnFunction(Function &F) override; 
138
139   private:
140     void initializeCallbacks(Module &M);
141     bool instrumentLoadOrStore(Instruction *I, const DataLayout &DL);
142     bool instrumentAtomic(Instruction *I);
143     void chooseInstructionsToInstrument(SmallVectorImpl<Instruction *> &Local,
144                                       SmallVectorImpl<Instruction *> &All,
145                                       const DataLayout &DL);
146     bool addrPointsToConstantData(Value *Addr);
147   };
148 }
149
150 void CdsPass::initializeCallbacks(Module &M) {
151   LLVMContext &Ctx = M.getContext();
152
153   Int8Ty  = Type::getInt8Ty(Ctx);
154   Int16Ty = Type::getInt16Ty(Ctx);
155   Int32Ty = Type::getInt32Ty(Ctx);
156   Int64Ty = Type::getInt64Ty(Ctx);
157   OrdTy = Type::getInt32Ty(Ctx);
158
159   Int8PtrTy  = Type::getInt8PtrTy(Ctx);
160   Int16PtrTy = Type::getInt16PtrTy(Ctx);
161   Int32PtrTy = Type::getInt32PtrTy(Ctx);
162   Int64PtrTy = Type::getInt64PtrTy(Ctx);
163
164   VoidTy = Type::getVoidTy(Ctx);
165   
166
167   // Get the function to call from our untime library.
168   for (unsigned i = 0; i < FUNCARRAYSIZE; i++) {
169     const unsigned ByteSize = 1U << i;
170     const unsigned BitSize = ByteSize * 8;
171 //    errs() << BitSize << "\n";
172     std::string ByteSizeStr = utostr(ByteSize);
173     std::string BitSizeStr = utostr(BitSize);
174
175     Type *Ty = Type::getIntNTy(Ctx, BitSize);
176     Type *PtrTy = Ty->getPointerTo();
177
178     // uint8_t cds_atomic_load8 (void * obj, int atomic_index)
179     // void cds_atomic_store8 (void * obj, int atomic_index, uint8_t val)
180     SmallString<32> LoadName("cds_load" + BitSizeStr);
181     SmallString<32> StoreName("cds_store" + BitSizeStr);
182     SmallString<32> AtomicLoadName("cds_atomic_load" + BitSizeStr);
183     SmallString<32> AtomicStoreName("cds_atomic_store" + BitSizeStr);
184
185 //    CdsLoad[i]  = M.getOrInsertFunction(LoadName, Ty, PtrTy);
186 //    CdsStore[i] = M.getOrInsertFunction(StoreName, VoidTy, PtrTy, Ty);
187     CdsLoad[i]  = M.getOrInsertFunction(LoadName, VoidTy, PtrTy);
188     CdsStore[i] = M.getOrInsertFunction(StoreName, VoidTy, PtrTy);
189     CdsAtomicLoad[i]  = M.getOrInsertFunction(AtomicLoadName, Ty, PtrTy, OrdTy);
190     CdsAtomicStore[i] = M.getOrInsertFunction(AtomicStoreName, VoidTy, PtrTy, OrdTy, Ty);
191
192     for (int op = AtomicRMWInst::FIRST_BINOP; op <= AtomicRMWInst::LAST_BINOP; ++op) {
193       CdsAtomicRMW[op][i] = nullptr;
194       std::string NamePart;
195
196       if (op == AtomicRMWInst::Xchg)
197         NamePart = "_exchange";
198       else if (op == AtomicRMWInst::Add) 
199         NamePart = "_fetch_add";
200       else if (op == AtomicRMWInst::Sub)
201         NamePart = "_fetch_sub";
202       else if (op == AtomicRMWInst::And)
203         NamePart = "_fetch_and";
204       else if (op == AtomicRMWInst::Or)
205         NamePart = "_fetch_or";
206       else if (op == AtomicRMWInst::Xor)
207         NamePart = "_fetch_xor";
208       else
209         continue;
210
211       SmallString<32> AtomicRMWName("cds_atomic" + NamePart + BitSizeStr);
212       CdsAtomicRMW[op][i] = M.getOrInsertFunction(AtomicRMWName, Ty, PtrTy, OrdTy, Ty);
213     }
214
215     // only supportes strong version
216     SmallString<32> AtomicCASName("cds_atomic_compare_exchange" + BitSizeStr);    
217     CdsAtomicCAS[i]   = M.getOrInsertFunction(AtomicCASName, Ty, PtrTy, Ty, Ty, OrdTy, OrdTy);
218   }
219
220   CdsAtomicThreadFence = M.getOrInsertFunction("cds_atomic_thread_fence", VoidTy, OrdTy);
221 }
222
223 static bool isVtableAccess(Instruction *I) {
224   if (MDNode *Tag = I->getMetadata(LLVMContext::MD_tbaa))
225     return Tag->isTBAAVtableAccess();
226   return false;
227 }
228
229 static bool shouldInstrumentReadWriteFromAddress(const Module *M, Value *Addr) {
230   // Peel off GEPs and BitCasts.
231   Addr = Addr->stripInBoundsOffsets();
232
233   if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) {
234     if (GV->hasSection()) {
235       StringRef SectionName = GV->getSection();
236       // Check if the global is in the PGO counters section.
237       auto OF = Triple(M->getTargetTriple()).getObjectFormat();
238       if (SectionName.endswith(
239               getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false)))
240         return false;
241     }
242
243     // Check if the global is private gcov data.
244     if (GV->getName().startswith("__llvm_gcov") ||
245         GV->getName().startswith("__llvm_gcda"))
246       return false;
247   }
248
249   // Do not instrument acesses from different address spaces; we cannot deal
250   // with them.
251   if (Addr) {
252     Type *PtrTy = cast<PointerType>(Addr->getType()->getScalarType());
253     if (PtrTy->getPointerAddressSpace() != 0)
254       return false;
255   }
256
257   return true;
258 }
259
260 bool CdsPass::addrPointsToConstantData(Value *Addr) {
261   // If this is a GEP, just analyze its pointer operand.
262   if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Addr))
263     Addr = GEP->getPointerOperand();
264
265   if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) {
266     if (GV->isConstant()) {
267       // Reads from constant globals can not race with any writes.
268       NumOmittedReadsFromConstantGlobals++;
269       return true;
270     }
271   } else if (LoadInst *L = dyn_cast<LoadInst>(Addr)) {
272     if (isVtableAccess(L)) {
273       // Reads from a vtable pointer can not race with any writes.
274       NumOmittedReadsFromVtable++;
275       return true;
276     }
277   }
278   return false;
279 }
280
281 bool CdsPass::runOnFunction(Function &F) {
282   if (F.getName() == "main") 
283     F.setName("user_main");
284
285   initializeCallbacks( *F.getParent() );
286
287   SmallVector<Instruction*, 8> AllLoadsAndStores;
288   SmallVector<Instruction*, 8> LocalLoadsAndStores;
289   SmallVector<Instruction*, 8> AtomicAccesses;
290
291   std::vector<Instruction *> worklist;
292
293   bool Res = false;
294   const DataLayout &DL = F.getParent()->getDataLayout();
295   
296   errs() << "Before\n";
297   F.dump();
298
299   for (auto &B : F) {
300     for (auto &I : B) {
301       if ( (&I)->isAtomic() ) {
302         AtomicAccesses.push_back(&I);
303       } else if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
304         LocalLoadsAndStores.push_back(&I);
305       }
306     }
307     chooseInstructionsToInstrument(LocalLoadsAndStores, AllLoadsAndStores, DL);
308   }
309
310   for (auto Inst : AllLoadsAndStores) {
311     Res |= instrumentLoadOrStore(Inst, DL);
312   }
313
314   for (auto Inst : AtomicAccesses) {
315     Res |= instrumentAtomic(Inst);
316   } 
317
318   errs() << "After\n";
319   F.dump();
320   
321   return false;
322 }
323
324 void CdsPass::chooseInstructionsToInstrument(
325     SmallVectorImpl<Instruction *> &Local, SmallVectorImpl<Instruction *> &All,
326     const DataLayout &DL) {
327   SmallPtrSet<Value*, 8> WriteTargets;
328   // Iterate from the end.
329   for (Instruction *I : reverse(Local)) {
330     if (StoreInst *Store = dyn_cast<StoreInst>(I)) {
331       Value *Addr = Store->getPointerOperand();
332       if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
333         continue;
334       WriteTargets.insert(Addr);
335     } else {
336       LoadInst *Load = cast<LoadInst>(I);
337       Value *Addr = Load->getPointerOperand();
338       if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
339         continue;
340       if (WriteTargets.count(Addr)) {
341         // We will write to this temp, so no reason to analyze the read.
342         NumOmittedReadsBeforeWrite++;
343         continue;
344       }
345       if (addrPointsToConstantData(Addr)) {
346         // Addr points to some constant data -- it can not race with any writes.
347         continue;
348       }
349     }
350     Value *Addr = isa<StoreInst>(*I)
351         ? cast<StoreInst>(I)->getPointerOperand()
352         : cast<LoadInst>(I)->getPointerOperand();
353     if (isa<AllocaInst>(GetUnderlyingObject(Addr, DL)) &&
354         !PointerMayBeCaptured(Addr, true, true)) {
355       // The variable is addressable but not captured, so it cannot be
356       // referenced from a different thread and participate in a data race
357       // (see llvm/Analysis/CaptureTracking.h for details).
358       NumOmittedNonCaptured++;
359       continue;
360     }
361     All.push_back(I);
362   }
363   Local.clear();
364 }
365
366
367 bool CdsPass::instrumentLoadOrStore(Instruction *I,
368                                             const DataLayout &DL) {
369   IRBuilder<> IRB(I);
370   bool IsWrite = isa<StoreInst>(*I);
371   Value *Addr = IsWrite
372       ? cast<StoreInst>(I)->getPointerOperand()
373       : cast<LoadInst>(I)->getPointerOperand();
374
375   // swifterror memory addresses are mem2reg promoted by instruction selection.
376   // As such they cannot have regular uses like an instrumentation function and
377   // it makes no sense to track them as memory.
378   if (Addr->isSwiftError())
379     return false;
380
381   int size = getTypeSize(Addr->getType());
382   int index = sizetoindex(size);
383
384 //  not supported by Cds yet
385 /*  if (IsWrite && isVtableAccess(I)) {
386     LLVM_DEBUG(dbgs() << "  VPTR : " << *I << "\n");
387     Value *StoredValue = cast<StoreInst>(I)->getValueOperand();
388     // StoredValue may be a vector type if we are storing several vptrs at once.
389     // In this case, just take the first element of the vector since this is
390     // enough to find vptr races.
391     if (isa<VectorType>(StoredValue->getType()))
392       StoredValue = IRB.CreateExtractElement(
393           StoredValue, ConstantInt::get(IRB.getInt32Ty(), 0));
394     if (StoredValue->getType()->isIntegerTy())
395       StoredValue = IRB.CreateIntToPtr(StoredValue, IRB.getInt8PtrTy());
396     // Call TsanVptrUpdate.
397     IRB.CreateCall(TsanVptrUpdate,
398                    {IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()),
399                     IRB.CreatePointerCast(StoredValue, IRB.getInt8PtrTy())});
400     NumInstrumentedVtableWrites++;
401     return true;
402   }
403
404   if (!IsWrite && isVtableAccess(I)) {
405     IRB.CreateCall(TsanVptrLoad,
406                    IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()));
407     NumInstrumentedVtableReads++;
408     return true;
409   }
410 */
411
412   Value *OnAccessFunc = nullptr;
413   OnAccessFunc = IsWrite ? CdsStore[index] : CdsLoad[index];
414
415   IRB.CreateCall(OnAccessFunc, IRB.CreatePointerCast(Addr, Addr->getType()));
416   if (IsWrite) NumInstrumentedWrites++;
417   else         NumInstrumentedReads++;
418   return true;
419 }
420
421
422 bool CdsPass::instrumentAtomic(Instruction * I) {
423   IRBuilder<> IRB(I);
424   // LLVMContext &Ctx = IRB.getContext();
425
426   if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
427     int atomic_order_index = getAtomicOrderIndex(SI->getOrdering());
428
429     Value *val = SI->getValueOperand();
430     Value *ptr = SI->getPointerOperand();
431     Value *order = ConstantInt::get(OrdTy, atomic_order_index);
432     Value *args[] = {ptr, order, val};
433
434     int size=getTypeSize(ptr->getType());
435     int index=sizetoindex(size);
436
437     Instruction* funcInst=CallInst::Create(CdsAtomicStore[index], args,"");
438     ReplaceInstWithInst(SI, funcInst);
439     errs() << "Store replaced\n";
440   } else if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
441     int atomic_order_index = getAtomicOrderIndex(LI->getOrdering());
442
443     Value *ptr = LI->getPointerOperand();
444     Value *order = ConstantInt::get(OrdTy, atomic_order_index);
445     Value *args[] = {ptr, order};
446
447     int size=getTypeSize(ptr->getType());
448     int index=sizetoindex(size);
449
450     Instruction* funcInst=CallInst::Create(CdsAtomicLoad[index], args, "");
451     ReplaceInstWithInst(LI, funcInst);
452     errs() << "Load Replaced\n";
453   } else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) {
454     int atomic_order_index = getAtomicOrderIndex(RMWI->getOrdering());
455
456     Value *val = RMWI->getValOperand();
457     Value *ptr = RMWI->getPointerOperand();
458     Value *order = ConstantInt::get(OrdTy, atomic_order_index);
459     Value *args[] = {ptr, order, val};
460
461     int size = getTypeSize(ptr->getType());
462     int index = sizetoindex(size);
463
464     Instruction* funcInst = CallInst::Create(CdsAtomicRMW[RMWI->getOperation()][index], args, "");
465     ReplaceInstWithInst(RMWI, funcInst);
466     errs() << RMWI->getOperationName(RMWI->getOperation());
467     errs() << " replaced\n";
468   } else if (AtomicCmpXchgInst *CASI = dyn_cast<AtomicCmpXchgInst>(I)) {
469     IRBuilder<> IRB(CASI);
470
471     Value *Addr = CASI->getPointerOperand();
472
473     int size = getTypeSize(Addr->getType());
474     int index = sizetoindex(size);
475     const unsigned ByteSize = 1U << index;
476     const unsigned BitSize = ByteSize * 8;
477     Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
478     Type *PtrTy = Ty->getPointerTo();
479
480     Value *CmpOperand = IRB.CreateBitOrPointerCast(CASI->getCompareOperand(), Ty);
481     Value *NewOperand = IRB.CreateBitOrPointerCast(CASI->getNewValOperand(), Ty);
482
483     int atomic_order_index_succ = getAtomicOrderIndex(CASI->getSuccessOrdering());
484     int atomic_order_index_fail = getAtomicOrderIndex(CASI->getFailureOrdering());
485     Value *order_succ = ConstantInt::get(OrdTy, atomic_order_index_succ);
486     Value *order_fail = ConstantInt::get(OrdTy, atomic_order_index_fail);
487
488     Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy),
489                      CmpOperand, NewOperand,
490                      order_succ, order_fail};
491
492     CallInst *funcInst = IRB.CreateCall(CdsAtomicCAS[index], Args);
493     Value *Success = IRB.CreateICmpEQ(funcInst, CmpOperand);
494
495     Value *OldVal = funcInst;
496     Type *OrigOldValTy = CASI->getNewValOperand()->getType();
497     if (Ty != OrigOldValTy) {
498       // The value is a pointer, so we need to cast the return value.
499       OldVal = IRB.CreateIntToPtr(funcInst, OrigOldValTy);
500     }
501
502     Value *Res =
503       IRB.CreateInsertValue(UndefValue::get(CASI->getType()), OldVal, 0);
504     Res = IRB.CreateInsertValue(Res, Success, 1);
505
506     I->replaceAllUsesWith(Res);
507     I->eraseFromParent();
508   } else if (FenceInst *FI = dyn_cast<FenceInst>(I)) {
509     int atomic_order_index = getAtomicOrderIndex(FI->getOrdering());
510     Value *order = ConstantInt::get(OrdTy, atomic_order_index);
511     Value *Args[] = {order};
512
513     CallInst *funcInst = CallInst::Create(CdsAtomicThreadFence, Args);
514     ReplaceInstWithInst(FI, funcInst);
515     errs() << "Thread Fences replaced\n";
516   }
517   return true;
518 }
519
520
521
522 char CdsPass::ID = 0;
523
524 // Automatically enable the pass.
525 // http://adriansampson.net/blog/clangpass.html
526 static void registerCdsPass(const PassManagerBuilder &,
527                          legacy::PassManagerBase &PM) {
528   PM.add(new CdsPass());
529 }
530 static RegisterStandardPasses 
531         RegisterMyPass(PassManagerBuilder::EP_EarlyAsPossible,
532 registerCdsPass);