refactor codes in CDSPass.cpp and also add support for function-like atomic operation...
[c11llvm.git] / CDSPass.cpp
1 //===-- CDSPass.cpp - xxx -------------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file is a modified version of ThreadSanitizer.cpp, a part of a race detector.
11 //
12 // The tool is under development, for the details about previous versions see
13 // http://code.google.com/p/data-race-test
14 //
15 // The instrumentation phase is quite simple:
16 //   - Insert calls to run-time library before every memory access.
17 //      - Optimizations may apply to avoid instrumenting some of the accesses.
18 //   - Insert calls at function entry/exit.
19 // The rest is handled by the run-time library.
20 //===----------------------------------------------------------------------===//
21
22 #include "llvm/ADT/Statistic.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/SmallString.h"
25 #include "llvm/Analysis/ValueTracking.h"
26 #include "llvm/Analysis/CaptureTracking.h"
27 #include "llvm/IR/BasicBlock.h"
28 #include "llvm/IR/CFG.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/LegacyPassManager.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/IR/PassManager.h"
36 #include "llvm/IR/DebugLoc.h"
37 #include "llvm/Pass.h"
38 #include "llvm/ProfileData/InstrProf.h"
39 #include "llvm/Support/raw_ostream.h"
40 #include "llvm/Support/AtomicOrdering.h"
41 #include "llvm/Support/Debug.h"
42 #include "llvm/Transforms/Scalar.h"
43 #include "llvm/Transforms/Utils/Local.h"
44 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
45 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
46 #include <vector>
47
48 #define DEBUG_TYPE "CDS"
49 using namespace llvm;
50
51 #define FUNCARRAYSIZE 4
52
53 STATISTIC(NumInstrumentedReads, "Number of instrumented reads");
54 STATISTIC(NumInstrumentedWrites, "Number of instrumented writes");
55 // STATISTIC(NumInstrumentedVtableWrites, "Number of vtable ptr writes");
56 // STATISTIC(NumInstrumentedVtableReads, "Number of vtable ptr reads");
57
58 STATISTIC(NumOmittedReadsBeforeWrite,
59           "Number of reads ignored due to following writes");
60 STATISTIC(NumOmittedReadsFromConstantGlobals,
61           "Number of reads from constant globals");
62 STATISTIC(NumOmittedReadsFromVtable, "Number of vtable reads");
63 STATISTIC(NumOmittedNonCaptured, "Number of accesses ignored due to capturing");
64
65 Type * Int8Ty;
66 Type * Int16Ty;
67 Type * Int32Ty;
68 Type * Int64Ty;
69 Type * OrdTy;
70
71 Type * Int8PtrTy;
72 Type * Int16PtrTy;
73 Type * Int32PtrTy;
74 Type * Int64PtrTy;
75
76 Type * VoidTy;
77
78 Constant * CDSLoad[FUNCARRAYSIZE];
79 Constant * CDSStore[FUNCARRAYSIZE];
80 Constant * CDSAtomicInit[FUNCARRAYSIZE];
81 Constant * CDSAtomicLoad[FUNCARRAYSIZE];
82 Constant * CDSAtomicStore[FUNCARRAYSIZE];
83 Constant * CDSAtomicRMW[AtomicRMWInst::LAST_BINOP + 1][FUNCARRAYSIZE];
84 Constant * CDSAtomicCAS_V1[FUNCARRAYSIZE];
85 Constant * CDSAtomicCAS_V2[FUNCARRAYSIZE];
86 Constant * CDSAtomicThreadFence;
87
88 bool start = false;
89
90 int getAtomicOrderIndex(AtomicOrdering order){
91   switch (order) {
92     case AtomicOrdering::Monotonic: 
93       return (int)AtomicOrderingCABI::relaxed;
94 //  case AtomicOrdering::Consume:         // not specified yet
95 //    return AtomicOrderingCABI::consume;
96     case AtomicOrdering::Acquire: 
97       return (int)AtomicOrderingCABI::acquire;
98     case AtomicOrdering::Release: 
99       return (int)AtomicOrderingCABI::release;
100     case AtomicOrdering::AcquireRelease: 
101       return (int)AtomicOrderingCABI::acq_rel;
102     case AtomicOrdering::SequentiallyConsistent: 
103       return (int)AtomicOrderingCABI::seq_cst;
104     default:
105       // unordered or Not Atomic
106       return -1;
107   }
108 }
109
110 int getTypeSize(Type* type) {
111   if (type == Int8PtrTy) {
112     return sizeof(char)*8;
113   } else if (type == Int16PtrTy) {
114     return sizeof(short)*8;
115   } else if (type == Int32PtrTy) {
116     return sizeof(int)*8;
117   } else if (type == Int64PtrTy) {
118     return sizeof(long long int)*8;
119   } else {
120     return sizeof(void*)*8;
121   }
122
123   return -1;
124 }
125
126 static int sizetoindex(int size) {
127   switch(size) {
128     case 8:     return 0;
129     case 16:    return 1;
130     case 32:    return 2;
131     case 64:    return 3;
132   }
133   return -1;
134 }
135
136 namespace {
137   struct CDSPass : public FunctionPass {
138     static char ID;
139     CDSPass() : FunctionPass(ID) {}
140     bool runOnFunction(Function &F) override; 
141
142   private:
143     void initializeCallbacks(Module &M);
144     bool instrumentLoadOrStore(Instruction *I, const DataLayout &DL);
145     bool instrumentAtomic(Instruction *I, const DataLayout &DL);
146     bool instrumentAtomicCall(CallInst *CI, const DataLayout &DL);
147     void chooseInstructionsToInstrument(SmallVectorImpl<Instruction *> &Local,
148                                       SmallVectorImpl<Instruction *> &All,
149                                       const DataLayout &DL);
150     bool addrPointsToConstantData(Value *Addr);
151     int getMemoryAccessFuncIndex(Value *Addr, const DataLayout &DL);
152   };
153 }
154
155 static bool isVtableAccess(Instruction *I) {
156   if (MDNode *Tag = I->getMetadata(LLVMContext::MD_tbaa))
157     return Tag->isTBAAVtableAccess();
158   return false;
159 }
160
161 #include "initializeCallbacks.hpp"
162 #include "isAtomicCall.hpp"
163 #include "instrumentAtomicCall.hpp"
164
165 static bool shouldInstrumentReadWriteFromAddress(const Module *M, Value *Addr) {
166   // Peel off GEPs and BitCasts.
167   Addr = Addr->stripInBoundsOffsets();
168
169   if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) {
170     if (GV->hasSection()) {
171       StringRef SectionName = GV->getSection();
172       // Check if the global is in the PGO counters section.
173       auto OF = Triple(M->getTargetTriple()).getObjectFormat();
174       if (SectionName.endswith(
175               getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false)))
176         return false;
177     }
178
179     // Check if the global is private gcov data.
180     if (GV->getName().startswith("__llvm_gcov") ||
181         GV->getName().startswith("__llvm_gcda"))
182       return false;
183   }
184
185   // Do not instrument acesses from different address spaces; we cannot deal
186   // with them.
187   if (Addr) {
188     Type *PtrTy = cast<PointerType>(Addr->getType()->getScalarType());
189     if (PtrTy->getPointerAddressSpace() != 0)
190       return false;
191   }
192
193   return true;
194 }
195
196 bool CDSPass::addrPointsToConstantData(Value *Addr) {
197   // If this is a GEP, just analyze its pointer operand.
198   if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Addr))
199     Addr = GEP->getPointerOperand();
200
201   if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) {
202     if (GV->isConstant()) {
203       // Reads from constant globals can not race with any writes.
204       NumOmittedReadsFromConstantGlobals++;
205       return true;
206     }
207   } else if (LoadInst *L = dyn_cast<LoadInst>(Addr)) {
208     if (isVtableAccess(L)) {
209       // Reads from a vtable pointer can not race with any writes.
210       NumOmittedReadsFromVtable++;
211       return true;
212     }
213   }
214   return false;
215 }
216
217 bool CDSPass::runOnFunction(Function &F) {
218   if (F.getName() == "main") {
219     F.setName("user_main");
220     errs() << "main replaced by user_main\n";
221   }
222
223   if (true) {
224     initializeCallbacks( *F.getParent() );
225
226     SmallVector<Instruction*, 8> AllLoadsAndStores;
227     SmallVector<Instruction*, 8> LocalLoadsAndStores;
228     SmallVector<Instruction*, 8> AtomicAccesses;
229
230     std::vector<Instruction *> worklist;
231
232     bool Res = false;
233     const DataLayout &DL = F.getParent()->getDataLayout();
234
235     errs() << "--- " << F.getName() << "---\n";
236
237     for (auto &B : F) {
238       for (auto &I : B) {
239         if ( (&I)->isAtomic() || isAtomicCall(&I) ) {
240           AtomicAccesses.push_back(&I);
241
242           const llvm::DebugLoc & debug_location = I.getDebugLoc();
243           std::string position_string;
244           {
245             llvm::raw_string_ostream position_stream (position_string);
246             debug_location . print (position_stream);
247           }
248
249           errs() << I << "\n" << (position_string) << "\n\n";
250         } else if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
251           LocalLoadsAndStores.push_back(&I);
252         } else if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
253           // not implemented yet
254         }
255       }
256
257       chooseInstructionsToInstrument(LocalLoadsAndStores, AllLoadsAndStores, DL);
258     }
259
260     for (auto Inst : AllLoadsAndStores) {
261 //      Res |= instrumentLoadOrStore(Inst, DL);
262 //      errs() << "load and store are replaced\n";
263     }
264
265     for (auto Inst : AtomicAccesses) {
266       Res |= instrumentAtomic(Inst, DL);
267     }
268
269     if (F.getName() == "user_main") {
270       // F.dump();
271     }
272
273   }
274
275   return false;
276 }
277
278 void CDSPass::chooseInstructionsToInstrument(
279     SmallVectorImpl<Instruction *> &Local, SmallVectorImpl<Instruction *> &All,
280     const DataLayout &DL) {
281   SmallPtrSet<Value*, 8> WriteTargets;
282   // Iterate from the end.
283   for (Instruction *I : reverse(Local)) {
284     if (StoreInst *Store = dyn_cast<StoreInst>(I)) {
285       Value *Addr = Store->getPointerOperand();
286       if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
287         continue;
288       WriteTargets.insert(Addr);
289     } else {
290       LoadInst *Load = cast<LoadInst>(I);
291       Value *Addr = Load->getPointerOperand();
292       if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
293         continue;
294       if (WriteTargets.count(Addr)) {
295         // We will write to this temp, so no reason to analyze the read.
296         NumOmittedReadsBeforeWrite++;
297         continue;
298       }
299       if (addrPointsToConstantData(Addr)) {
300         // Addr points to some constant data -- it can not race with any writes.
301         continue;
302       }
303     }
304     Value *Addr = isa<StoreInst>(*I)
305         ? cast<StoreInst>(I)->getPointerOperand()
306         : cast<LoadInst>(I)->getPointerOperand();
307     if (isa<AllocaInst>(GetUnderlyingObject(Addr, DL)) &&
308         !PointerMayBeCaptured(Addr, true, true)) {
309       // The variable is addressable but not captured, so it cannot be
310       // referenced from a different thread and participate in a data race
311       // (see llvm/Analysis/CaptureTracking.h for details).
312       NumOmittedNonCaptured++;
313       continue;
314     }
315     All.push_back(I);
316   }
317   Local.clear();
318 }
319
320
321 bool CDSPass::instrumentLoadOrStore(Instruction *I,
322                                             const DataLayout &DL) {
323   IRBuilder<> IRB(I);
324   bool IsWrite = isa<StoreInst>(*I);
325   Value *Addr = IsWrite
326       ? cast<StoreInst>(I)->getPointerOperand()
327       : cast<LoadInst>(I)->getPointerOperand();
328
329   // swifterror memory addresses are mem2reg promoted by instruction selection.
330   // As such they cannot have regular uses like an instrumentation function and
331   // it makes no sense to track them as memory.
332   if (Addr->isSwiftError())
333     return false;
334
335   int size = getTypeSize(Addr->getType());
336   int index = sizetoindex(size);
337
338 //  not supported by CDS yet
339 /*  if (IsWrite && isVtableAccess(I)) {
340     LLVM_DEBUG(dbgs() << "  VPTR : " << *I << "\n");
341     Value *StoredValue = cast<StoreInst>(I)->getValueOperand();
342     // StoredValue may be a vector type if we are storing several vptrs at once.
343     // In this case, just take the first element of the vector since this is
344     // enough to find vptr races.
345     if (isa<VectorType>(StoredValue->getType()))
346       StoredValue = IRB.CreateExtractElement(
347           StoredValue, ConstantInt::get(IRB.getInt32Ty(), 0));
348     if (StoredValue->getType()->isIntegerTy())
349       StoredValue = IRB.CreateIntToPtr(StoredValue, IRB.getInt8PtrTy());
350     // Call TsanVptrUpdate.
351     IRB.CreateCall(TsanVptrUpdate,
352                    {IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()),
353                     IRB.CreatePointerCast(StoredValue, IRB.getInt8PtrTy())});
354     NumInstrumentedVtableWrites++;
355     return true;
356   }
357
358   if (!IsWrite && isVtableAccess(I)) {
359     IRB.CreateCall(TsanVptrLoad,
360                    IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()));
361     NumInstrumentedVtableReads++;
362     return true;
363   }
364 */
365
366   Value *OnAccessFunc = nullptr;
367   OnAccessFunc = IsWrite ? CDSStore[index] : CDSLoad[index];
368   
369   Type *ArgType = IRB.CreatePointerCast(Addr, Addr->getType())->getType();
370
371   if ( ArgType != Int8PtrTy && ArgType != Int16PtrTy && 
372                 ArgType != Int32PtrTy && ArgType != Int64PtrTy ) {
373         //errs() << "A load or store of type ";
374         //errs() << *ArgType;
375         //errs() << " is passed in\n";
376         return false;   // if other types of load or stores are passed in
377   }
378   IRB.CreateCall(OnAccessFunc, IRB.CreatePointerCast(Addr, Addr->getType()));
379   if (IsWrite) NumInstrumentedWrites++;
380   else         NumInstrumentedReads++;
381   return true;
382 }
383
384 // todo: replace getTypeSize with the getMemoryAccessFuncIndex
385 bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
386   IRBuilder<> IRB(I);
387   // LLVMContext &Ctx = IRB.getContext();
388
389   if (auto *CI = dyn_cast<CallInst>(I)) {
390     return instrumentAtomicCall(CI, DL);
391   }
392
393   if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
394     int atomic_order_index = getAtomicOrderIndex(SI->getOrdering());
395
396     Value *val = SI->getValueOperand();
397     Value *ptr = SI->getPointerOperand();
398     Value *order = ConstantInt::get(OrdTy, atomic_order_index);
399     Value *args[] = {ptr, val, order};
400
401     int size=getTypeSize(ptr->getType());
402     int index=sizetoindex(size);
403
404     Instruction* funcInst=CallInst::Create(CDSAtomicStore[index], args,"");
405     ReplaceInstWithInst(SI, funcInst);
406 //    errs() << "Store replaced\n";
407   } else if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
408     int atomic_order_index = getAtomicOrderIndex(LI->getOrdering());
409
410     Value *ptr = LI->getPointerOperand();
411     Value *order = ConstantInt::get(OrdTy, atomic_order_index);
412     Value *args[] = {ptr, order};
413
414     int size=getTypeSize(ptr->getType());
415     int index=sizetoindex(size);
416
417     Instruction* funcInst=CallInst::Create(CDSAtomicLoad[index], args, "");
418     ReplaceInstWithInst(LI, funcInst);
419 //    errs() << "Load Replaced\n";
420   } else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) {
421     int atomic_order_index = getAtomicOrderIndex(RMWI->getOrdering());
422
423     Value *val = RMWI->getValOperand();
424     Value *ptr = RMWI->getPointerOperand();
425     Value *order = ConstantInt::get(OrdTy, atomic_order_index);
426     Value *args[] = {ptr, val, order};
427
428     int size = getTypeSize(ptr->getType());
429     int index = sizetoindex(size);
430
431     Instruction* funcInst = CallInst::Create(CDSAtomicRMW[RMWI->getOperation()][index], args, "");
432     ReplaceInstWithInst(RMWI, funcInst);
433 //    errs() << RMWI->getOperationName(RMWI->getOperation());
434 //    errs() << " replaced\n";
435   } else if (AtomicCmpXchgInst *CASI = dyn_cast<AtomicCmpXchgInst>(I)) {
436     IRBuilder<> IRB(CASI);
437
438     Value *Addr = CASI->getPointerOperand();
439
440     int size = getTypeSize(Addr->getType());
441     int index = sizetoindex(size);
442     const unsigned ByteSize = 1U << index;
443     const unsigned BitSize = ByteSize * 8;
444     Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
445     Type *PtrTy = Ty->getPointerTo();
446
447     Value *CmpOperand = IRB.CreateBitOrPointerCast(CASI->getCompareOperand(), Ty);
448     Value *NewOperand = IRB.CreateBitOrPointerCast(CASI->getNewValOperand(), Ty);
449
450     int atomic_order_index_succ = getAtomicOrderIndex(CASI->getSuccessOrdering());
451     int atomic_order_index_fail = getAtomicOrderIndex(CASI->getFailureOrdering());
452     Value *order_succ = ConstantInt::get(OrdTy, atomic_order_index_succ);
453     Value *order_fail = ConstantInt::get(OrdTy, atomic_order_index_fail);
454
455     Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy),
456                      CmpOperand, NewOperand,
457                      order_succ, order_fail};
458
459     CallInst *funcInst = IRB.CreateCall(CDSAtomicCAS_V1[index], Args);
460     Value *Success = IRB.CreateICmpEQ(funcInst, CmpOperand);
461
462     Value *OldVal = funcInst;
463     Type *OrigOldValTy = CASI->getNewValOperand()->getType();
464     if (Ty != OrigOldValTy) {
465       // The value is a pointer, so we need to cast the return value.
466       OldVal = IRB.CreateIntToPtr(funcInst, OrigOldValTy);
467     }
468
469     Value *Res =
470       IRB.CreateInsertValue(UndefValue::get(CASI->getType()), OldVal, 0);
471     Res = IRB.CreateInsertValue(Res, Success, 1);
472
473     I->replaceAllUsesWith(Res);
474     I->eraseFromParent();
475   } else if (FenceInst *FI = dyn_cast<FenceInst>(I)) {
476     int atomic_order_index = getAtomicOrderIndex(FI->getOrdering());
477     Value *order = ConstantInt::get(OrdTy, atomic_order_index);
478     Value *Args[] = {order};
479
480     CallInst *funcInst = CallInst::Create(CDSAtomicThreadFence, Args);
481     ReplaceInstWithInst(FI, funcInst);
482 //    errs() << "Thread Fences replaced\n";
483   }
484   return true;
485 }
486
487 int CDSPass::getMemoryAccessFuncIndex(Value *Addr,
488                                               const DataLayout &DL) {
489   Type *OrigPtrTy = Addr->getType();
490   Type *OrigTy = cast<PointerType>(OrigPtrTy)->getElementType();
491   assert(OrigTy->isSized());
492   uint32_t TypeSize = DL.getTypeStoreSizeInBits(OrigTy);
493   if (TypeSize != 8  && TypeSize != 16 &&
494       TypeSize != 32 && TypeSize != 64 && TypeSize != 128) {
495     // NumAccessesWithBadSize++;
496     // Ignore all unusual sizes.
497     return -1;
498   }
499   size_t Idx = countTrailingZeros(TypeSize / 8);
500   // assert(Idx < FUNCARRAYSIZE);
501   return Idx;
502 }
503
504
505 char CDSPass::ID = 0;
506
507 // Automatically enable the pass.
508 static void registerCDSPass(const PassManagerBuilder &,
509                          legacy::PassManagerBase &PM) {
510   PM.add(new CDSPass());
511 }
512 static RegisterStandardPasses 
513         RegisterMyPass(PassManagerBuilder::EP_OptimizerLast,
514 registerCDSPass);