354e5bd49b7c00e1845f646c9e98860bbc730486
[c11llvm.git] / CDSPass.cpp
1 //===-- CDSPass.cpp - xxx -------------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file is a modified version of ThreadSanitizer.cpp, a part of a race detector.
11 //
12 // The tool is under development, for the details about previous versions see
13 // http://code.google.com/p/data-race-test
14 //
15 // The instrumentation phase is quite simple:
16 //   - Insert calls to run-time library before every memory access.
17 //      - Optimizations may apply to avoid instrumenting some of the accesses.
18 //   - Insert calls at function entry/exit.
19 // The rest is handled by the run-time library.
20 //===----------------------------------------------------------------------===//
21
22 #include "llvm/ADT/Statistic.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/SmallString.h"
25 #include "llvm/Analysis/ValueTracking.h"
26 #include "llvm/Analysis/CaptureTracking.h"
27 #include "llvm/IR/BasicBlock.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/LLVMContext.h"
32 #include "llvm/IR/LegacyPassManager.h"
33 #include "llvm/IR/Module.h"
34 #include "llvm/IR/PassManager.h"
35 #include "llvm/Pass.h"
36 #include "llvm/ProfileData/InstrProf.h"
37 #include "llvm/Support/raw_ostream.h"
38 #include "llvm/Support/AtomicOrdering.h"
39 #include "llvm/Support/Debug.h"
40 #include "llvm/Transforms/Scalar.h"
41 #include "llvm/Transforms/Utils/Local.h"
42 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
43 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
44 #include <vector>
45
46 #define DEBUG_TYPE "CDS"
47 using namespace llvm;
48
49 #include <llvm/IR/DebugLoc.h>
50
51 Value *getPosition( Instruction * I, IRBuilder <> IRB)
52 {
53         const DebugLoc & debug_location = I->getDebugLoc ();
54         std::string position_string;
55         {
56                 llvm::raw_string_ostream position_stream (position_string);
57                 debug_location . print (position_stream);
58         }
59
60         return IRB . CreateGlobalStringPtr (position_string);
61 }
62
63 #define FUNCARRAYSIZE 4
64
65 STATISTIC(NumInstrumentedReads, "Number of instrumented reads");
66 STATISTIC(NumInstrumentedWrites, "Number of instrumented writes");
67 // STATISTIC(NumInstrumentedVtableWrites, "Number of vtable ptr writes");
68 // STATISTIC(NumInstrumentedVtableReads, "Number of vtable ptr reads");
69
70 STATISTIC(NumOmittedReadsBeforeWrite,
71           "Number of reads ignored due to following writes");
72 STATISTIC(NumOmittedReadsFromConstantGlobals,
73           "Number of reads from constant globals");
74 STATISTIC(NumOmittedReadsFromVtable, "Number of vtable reads");
75 STATISTIC(NumOmittedNonCaptured, "Number of accesses ignored due to capturing");
76
77 Type * Int8Ty;
78 Type * Int16Ty;
79 Type * Int32Ty;
80 Type * Int64Ty;
81 Type * OrdTy;
82
83 Type * Int8PtrTy;
84 Type * Int16PtrTy;
85 Type * Int32PtrTy;
86 Type * Int64PtrTy;
87
88 Type * VoidTy;
89
90 Constant * CDSLoad[FUNCARRAYSIZE];
91 Constant * CDSStore[FUNCARRAYSIZE];
92 Constant * CDSAtomicInit[FUNCARRAYSIZE];
93 Constant * CDSAtomicLoad[FUNCARRAYSIZE];
94 Constant * CDSAtomicStore[FUNCARRAYSIZE];
95 Constant * CDSAtomicRMW[AtomicRMWInst::LAST_BINOP + 1][FUNCARRAYSIZE];
96 Constant * CDSAtomicCAS_V1[FUNCARRAYSIZE];
97 Constant * CDSAtomicCAS_V2[FUNCARRAYSIZE];
98 Constant * CDSAtomicThreadFence;
99
100 bool start = false;
101
102 int getAtomicOrderIndex(AtomicOrdering order){
103   switch (order) {
104     case AtomicOrdering::Monotonic: 
105       return (int)AtomicOrderingCABI::relaxed;
106 //  case AtomicOrdering::Consume:         // not specified yet
107 //    return AtomicOrderingCABI::consume;
108     case AtomicOrdering::Acquire: 
109       return (int)AtomicOrderingCABI::acquire;
110     case AtomicOrdering::Release: 
111       return (int)AtomicOrderingCABI::release;
112     case AtomicOrdering::AcquireRelease: 
113       return (int)AtomicOrderingCABI::acq_rel;
114     case AtomicOrdering::SequentiallyConsistent: 
115       return (int)AtomicOrderingCABI::seq_cst;
116     default:
117       // unordered or Not Atomic
118       return -1;
119   }
120 }
121
122 int getTypeSize(Type* type) {
123   if (type == Int8PtrTy) {
124     return sizeof(char)*8;
125   } else if (type == Int16PtrTy) {
126     return sizeof(short)*8;
127   } else if (type == Int32PtrTy) {
128     return sizeof(int)*8;
129   } else if (type == Int64PtrTy) {
130     return sizeof(long long int)*8;
131   } else {
132     return sizeof(void*)*8;
133   }
134
135   return -1;
136 }
137
138 static int sizetoindex(int size) {
139   switch(size) {
140     case 8:     return 0;
141     case 16:    return 1;
142     case 32:    return 2;
143     case 64:    return 3;
144   }
145   return -1;
146 }
147
148 namespace {
149   struct CDSPass : public FunctionPass {
150     static char ID;
151     CDSPass() : FunctionPass(ID) {}
152     bool runOnFunction(Function &F) override; 
153
154   private:
155     void initializeCallbacks(Module &M);
156     bool instrumentLoadOrStore(Instruction *I, const DataLayout &DL);
157     bool instrumentAtomic(Instruction *I, const DataLayout &DL);
158     bool instrumentAtomicCall(CallInst *CI, const DataLayout &DL);
159     void chooseInstructionsToInstrument(SmallVectorImpl<Instruction *> &Local,
160                                       SmallVectorImpl<Instruction *> &All,
161                                       const DataLayout &DL);
162     bool addrPointsToConstantData(Value *Addr);
163     int getMemoryAccessFuncIndex(Value *Addr, const DataLayout &DL);
164   };
165 }
166
167 static bool isVtableAccess(Instruction *I) {
168   if (MDNode *Tag = I->getMetadata(LLVMContext::MD_tbaa))
169     return Tag->isTBAAVtableAccess();
170   return false;
171 }
172
173 void CDSPass::initializeCallbacks(Module &M) {
174         LLVMContext &Ctx = M.getContext();
175
176         Type * Int1Ty = Type::getInt1Ty(Ctx);
177         Int8Ty  = Type::getInt8Ty(Ctx);
178         Int16Ty = Type::getInt16Ty(Ctx);
179         Int32Ty = Type::getInt32Ty(Ctx);
180         Int64Ty = Type::getInt64Ty(Ctx);
181         OrdTy = Type::getInt32Ty(Ctx);
182
183         Int8PtrTy  = Type::getInt8PtrTy(Ctx);
184         Int16PtrTy = Type::getInt16PtrTy(Ctx);
185         Int32PtrTy = Type::getInt32PtrTy(Ctx);
186         Int64PtrTy = Type::getInt64PtrTy(Ctx);
187
188         VoidTy = Type::getVoidTy(Ctx);
189   
190         // Get the function to call from our untime library.
191         for (unsigned i = 0; i < FUNCARRAYSIZE; i++) {
192                 const unsigned ByteSize = 1U << i;
193                 const unsigned BitSize = ByteSize * 8;
194
195                 std::string ByteSizeStr = utostr(ByteSize);
196                 std::string BitSizeStr = utostr(BitSize);
197
198                 Type *Ty = Type::getIntNTy(Ctx, BitSize);
199                 Type *PtrTy = Ty->getPointerTo();
200
201                 // uint8_t cds_atomic_load8 (void * obj, int atomic_index)
202                 // void cds_atomic_store8 (void * obj, int atomic_index, uint8_t val)
203                 SmallString<32> LoadName("cds_load" + BitSizeStr);
204                 SmallString<32> StoreName("cds_store" + BitSizeStr);
205                 SmallString<32> AtomicInitName("cds_atomic_init" + BitSizeStr);
206                 SmallString<32> AtomicLoadName("cds_atomic_load" + BitSizeStr);
207                 SmallString<32> AtomicStoreName("cds_atomic_store" + BitSizeStr);
208
209                 CDSLoad[i]  = M.getOrInsertFunction(LoadName, VoidTy, PtrTy);
210                 CDSStore[i] = M.getOrInsertFunction(StoreName, VoidTy, PtrTy);
211                 CDSAtomicInit[i] = M.getOrInsertFunction(AtomicInitName, 
212                                                                 VoidTy, PtrTy, Ty, Int8PtrTy);
213                 CDSAtomicLoad[i]  = M.getOrInsertFunction(AtomicLoadName, 
214                                                                 Ty, PtrTy, OrdTy, Int8PtrTy);
215                 CDSAtomicStore[i] = M.getOrInsertFunction(AtomicStoreName, 
216                                                                 VoidTy, PtrTy, Ty, OrdTy, Int8PtrTy);
217
218                 for (int op = AtomicRMWInst::FIRST_BINOP; 
219                         op <= AtomicRMWInst::LAST_BINOP; ++op) {
220                         CDSAtomicRMW[op][i] = nullptr;
221                         std::string NamePart;
222
223                         if (op == AtomicRMWInst::Xchg)
224                                 NamePart = "_exchange";
225                         else if (op == AtomicRMWInst::Add) 
226                                 NamePart = "_fetch_add";
227                         else if (op == AtomicRMWInst::Sub)
228                                 NamePart = "_fetch_sub";
229                         else if (op == AtomicRMWInst::And)
230                                 NamePart = "_fetch_and";
231                         else if (op == AtomicRMWInst::Or)
232                                 NamePart = "_fetch_or";
233                         else if (op == AtomicRMWInst::Xor)
234                                 NamePart = "_fetch_xor";
235                         else
236                                 continue;
237
238                         SmallString<32> AtomicRMWName("cds_atomic" + NamePart + BitSizeStr);
239                         CDSAtomicRMW[op][i] = M.getOrInsertFunction(AtomicRMWName, 
240                                                                                 Ty, PtrTy, Ty, OrdTy, Int8PtrTy);
241                 }
242
243                 // only supportes strong version
244                 SmallString<32> AtomicCASName_V1("cds_atomic_compare_exchange" + BitSizeStr + "_v1");
245                 SmallString<32> AtomicCASName_V2("cds_atomic_compare_exchange" + BitSizeStr + "_v2");
246                 CDSAtomicCAS_V1[i] = M.getOrInsertFunction(AtomicCASName_V1, 
247                                                                 Ty, PtrTy, Ty, Ty, OrdTy, OrdTy, Int8PtrTy);
248                 CDSAtomicCAS_V2[i] = M.getOrInsertFunction(AtomicCASName_V2, 
249                                                                 Int1Ty, PtrTy, PtrTy, Ty, OrdTy, OrdTy, Int8PtrTy);
250         }
251
252         CDSAtomicThreadFence = M.getOrInsertFunction("cds_atomic_thread_fence", 
253                                                                                                         VoidTy, OrdTy, Int8PtrTy);
254 }
255
256 void printArgs(CallInst *);
257
258 bool isAtomicCall(Instruction *I)
259 {
260         if ( auto *CI = dyn_cast<CallInst>(I) ) {
261                 Function *fun = CI->getCalledFunction();
262                 if (fun == NULL)
263                         return false;
264
265                 StringRef funName = fun->getName();
266
267                 if ( (CI->isTailCall() && funName.contains("atomic_")) ||
268                         funName.contains("atomic_compare_exchange_") ) {
269                         // printArgs(CI);
270                         return true;
271                 }
272         }
273
274         return false;
275 }
276
277 void printArgs (CallInst *CI)
278 {
279         Function *fun = CI->getCalledFunction();
280         StringRef funName = fun->getName();
281
282         User::op_iterator begin = CI->arg_begin();
283         User::op_iterator end = CI->arg_end();
284
285         if ( funName.contains("atomic_") ) {
286                 std::vector<Value *> parameters;
287
288                 for (User::op_iterator it = begin; it != end; ++it) {
289                         Value *param = *it;
290                         parameters.push_back(param);
291                         errs() << *param << " type: " << *param->getType()  << "\n";
292                 }
293         }
294
295 }
296
297 bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
298         IRBuilder<> IRB(CI);
299         Function *fun = CI->getCalledFunction();
300         StringRef funName = fun->getName();
301         std::vector<Value *> parameters;
302
303         User::op_iterator begin = CI->arg_begin();
304         User::op_iterator end = CI->arg_end();
305         for (User::op_iterator it = begin; it != end; ++it) {
306                 Value *param = *it;
307                 parameters.push_back(param);
308         }
309
310         // obtain source line number of the CallInst
311         Value *position = getPosition(CI, IRB);
312
313         // the pointer to the address is always the first argument
314         Value *OrigPtr = parameters[0];
315         int Idx = getMemoryAccessFuncIndex(OrigPtr, DL);
316         if (Idx < 0)
317                 return false;
318
319         const unsigned ByteSize = 1U << Idx;
320         const unsigned BitSize = ByteSize * 8;
321         Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
322         Type *PtrTy = Ty->getPointerTo();
323
324         // atomic_init; args = {obj, order}
325         if (funName.contains("atomic_init")) {
326                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
327                 Value *val = IRB.CreateBitOrPointerCast(parameters[1], Ty);
328                 Value *args[] = {ptr, val, position};
329
330                 Instruction* funcInst=CallInst::Create(CDSAtomicInit[Idx], args);
331                 ReplaceInstWithInst(CI, funcInst);
332
333                 return true;
334         }
335
336         // atomic_load; args = {obj, order}
337         if (funName.contains("atomic_load")) {
338                 bool isExplicit = funName.contains("atomic_load_explicit");
339
340                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
341                 Value *order;
342                 if (isExplicit)
343                         order = IRB.CreateBitOrPointerCast(parameters[1], OrdTy);
344                 else 
345                         order = ConstantInt::get(OrdTy, 
346                                                         (int) AtomicOrderingCABI::seq_cst);
347                 Value *args[] = {ptr, order, position};
348                 
349                 Instruction* funcInst=CallInst::Create(CDSAtomicLoad[Idx], args);
350                 ReplaceInstWithInst(CI, funcInst);
351
352                 return true;
353         }
354
355         // atomic_store; args = {obj, val, order}
356         if (funName.contains("atomic_store")) {
357                 bool isExplicit = funName.contains("atomic_store_explicit");
358                 Value *OrigVal = parameters[1];
359
360                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
361                 Value *val = IRB.CreatePointerCast(OrigVal, Ty);
362                 Value *order;
363                 if (isExplicit)
364                         order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
365                 else 
366                         order = ConstantInt::get(OrdTy, 
367                                                         (int) AtomicOrderingCABI::seq_cst);
368                 Value *args[] = {ptr, val, order, position};
369                 
370                 Instruction* funcInst=CallInst::Create(CDSAtomicStore[Idx], args);
371                 ReplaceInstWithInst(CI, funcInst);
372
373                 return true;
374         }
375
376         // atomic_fetch_*; args = {obj, val, order}
377         if (funName.contains("atomic_fetch_") || 
378                         funName.contains("atomic_exchange") ) {
379                 bool isExplicit = funName.contains("_explicit");
380                 Value *OrigVal = parameters[1];
381
382                 int op;
383                 if ( funName.contains("_fetch_add") )
384                         op = AtomicRMWInst::Add;
385                 else if ( funName.contains("_fetch_sub") )
386                         op = AtomicRMWInst::Sub;
387                 else if ( funName.contains("_fetch_and") )
388                         op = AtomicRMWInst::And;
389                 else if ( funName.contains("_fetch_or") )
390                         op = AtomicRMWInst::Or;
391                 else if ( funName.contains("_fetch_xor") )
392                         op = AtomicRMWInst::Xor;
393                 else if ( funName.contains("atomic_exchange") )
394                         op = AtomicRMWInst::Xchg;
395                 else {
396                         errs() << "Unknown atomic read modify write operation\n";
397                         return false;
398                 }
399
400                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
401                 Value *val = IRB.CreatePointerCast(OrigVal, Ty);
402                 Value *order;
403                 if (isExplicit)
404                         order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
405                 else 
406                         order = ConstantInt::get(OrdTy, 
407                                                         (int) AtomicOrderingCABI::seq_cst);
408                 Value *args[] = {ptr, val, order, position};
409                 
410                 Instruction* funcInst=CallInst::Create(CDSAtomicRMW[op][Idx], args);
411                 ReplaceInstWithInst(CI, funcInst);
412
413                 return true;
414         }
415
416         /* atomic_compare_exchange_*; 
417            args = {obj, expected, new value, order1, order2}
418         */
419         if ( funName.contains("atomic_compare_exchange_") ) {
420                 bool isExplicit = funName.contains("_explicit");
421
422                 Value *Addr = IRB.CreatePointerCast(OrigPtr, PtrTy);
423                 Value *CmpOperand = IRB.CreatePointerCast(parameters[1], PtrTy);
424                 Value *NewOperand = IRB.CreateBitOrPointerCast(parameters[2], Ty);
425
426                 Value *order_succ, *order_fail;
427                 if (isExplicit) {
428                         order_succ = IRB.CreateBitOrPointerCast(parameters[3], OrdTy);
429                         order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
430                 } else  {
431                         order_succ = ConstantInt::get(OrdTy, 
432                                                         (int) AtomicOrderingCABI::seq_cst);
433                         order_fail = ConstantInt::get(OrdTy, 
434                                                         (int) AtomicOrderingCABI::seq_cst);
435                 }
436
437                 Value *args[] = {Addr, CmpOperand, NewOperand, 
438                                                         order_succ, order_fail, position};
439                 
440                 Instruction* funcInst=CallInst::Create(CDSAtomicCAS_V2[Idx], args);
441                 ReplaceInstWithInst(CI, funcInst);
442
443                 return true;
444         }
445
446         return false;
447 }
448
449 static bool shouldInstrumentReadWriteFromAddress(const Module *M, Value *Addr) {
450   // Peel off GEPs and BitCasts.
451   Addr = Addr->stripInBoundsOffsets();
452
453   if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) {
454     if (GV->hasSection()) {
455       StringRef SectionName = GV->getSection();
456       // Check if the global is in the PGO counters section.
457       auto OF = Triple(M->getTargetTriple()).getObjectFormat();
458       if (SectionName.endswith(
459               getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false)))
460         return false;
461     }
462
463     // Check if the global is private gcov data.
464     if (GV->getName().startswith("__llvm_gcov") ||
465         GV->getName().startswith("__llvm_gcda"))
466       return false;
467   }
468
469   // Do not instrument acesses from different address spaces; we cannot deal
470   // with them.
471   if (Addr) {
472     Type *PtrTy = cast<PointerType>(Addr->getType()->getScalarType());
473     if (PtrTy->getPointerAddressSpace() != 0)
474       return false;
475   }
476
477   return true;
478 }
479
480 bool CDSPass::addrPointsToConstantData(Value *Addr) {
481   // If this is a GEP, just analyze its pointer operand.
482   if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Addr))
483     Addr = GEP->getPointerOperand();
484
485   if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) {
486     if (GV->isConstant()) {
487       // Reads from constant globals can not race with any writes.
488       NumOmittedReadsFromConstantGlobals++;
489       return true;
490     }
491   } else if (LoadInst *L = dyn_cast<LoadInst>(Addr)) {
492     if (isVtableAccess(L)) {
493       // Reads from a vtable pointer can not race with any writes.
494       NumOmittedReadsFromVtable++;
495       return true;
496     }
497   }
498   return false;
499 }
500
501 bool CDSPass::runOnFunction(Function &F) {
502   if (F.getName() == "main") {
503     F.setName("user_main");
504     errs() << "main replaced by user_main\n";
505   }
506
507   if (true) {
508     initializeCallbacks( *F.getParent() );
509
510     SmallVector<Instruction*, 8> AllLoadsAndStores;
511     SmallVector<Instruction*, 8> LocalLoadsAndStores;
512     SmallVector<Instruction*, 8> AtomicAccesses;
513
514     std::vector<Instruction *> worklist;
515
516     bool Res = false;
517     const DataLayout &DL = F.getParent()->getDataLayout();
518
519     errs() << "--- " << F.getName() << "---\n";
520
521     for (auto &B : F) {
522       for (auto &I : B) {
523         if ( (&I)->isAtomic() || isAtomicCall(&I) ) {
524           AtomicAccesses.push_back(&I);
525         } else if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
526           LocalLoadsAndStores.push_back(&I);
527         } else if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
528           // not implemented yet
529         }
530       }
531
532       chooseInstructionsToInstrument(LocalLoadsAndStores, AllLoadsAndStores, DL);
533     }
534
535     for (auto Inst : AllLoadsAndStores) {
536 //      Res |= instrumentLoadOrStore(Inst, DL);
537 //      errs() << "load and store are replaced\n";
538     }
539
540     for (auto Inst : AtomicAccesses) {
541       Res |= instrumentAtomic(Inst, DL);
542     }
543
544     if (F.getName() == "user_main") {
545       // F.dump();
546     }
547
548   }
549
550   return false;
551 }
552
553 void CDSPass::chooseInstructionsToInstrument(
554     SmallVectorImpl<Instruction *> &Local, SmallVectorImpl<Instruction *> &All,
555     const DataLayout &DL) {
556   SmallPtrSet<Value*, 8> WriteTargets;
557   // Iterate from the end.
558   for (Instruction *I : reverse(Local)) {
559     if (StoreInst *Store = dyn_cast<StoreInst>(I)) {
560       Value *Addr = Store->getPointerOperand();
561       if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
562         continue;
563       WriteTargets.insert(Addr);
564     } else {
565       LoadInst *Load = cast<LoadInst>(I);
566       Value *Addr = Load->getPointerOperand();
567       if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
568         continue;
569       if (WriteTargets.count(Addr)) {
570         // We will write to this temp, so no reason to analyze the read.
571         NumOmittedReadsBeforeWrite++;
572         continue;
573       }
574       if (addrPointsToConstantData(Addr)) {
575         // Addr points to some constant data -- it can not race with any writes.
576         continue;
577       }
578     }
579     Value *Addr = isa<StoreInst>(*I)
580         ? cast<StoreInst>(I)->getPointerOperand()
581         : cast<LoadInst>(I)->getPointerOperand();
582     if (isa<AllocaInst>(GetUnderlyingObject(Addr, DL)) &&
583         !PointerMayBeCaptured(Addr, true, true)) {
584       // The variable is addressable but not captured, so it cannot be
585       // referenced from a different thread and participate in a data race
586       // (see llvm/Analysis/CaptureTracking.h for details).
587       NumOmittedNonCaptured++;
588       continue;
589     }
590     All.push_back(I);
591   }
592   Local.clear();
593 }
594
595
596 bool CDSPass::instrumentLoadOrStore(Instruction *I,
597                                             const DataLayout &DL) {
598   IRBuilder<> IRB(I);
599   bool IsWrite = isa<StoreInst>(*I);
600   Value *Addr = IsWrite
601       ? cast<StoreInst>(I)->getPointerOperand()
602       : cast<LoadInst>(I)->getPointerOperand();
603
604   // swifterror memory addresses are mem2reg promoted by instruction selection.
605   // As such they cannot have regular uses like an instrumentation function and
606   // it makes no sense to track them as memory.
607   if (Addr->isSwiftError())
608     return false;
609
610   int size = getTypeSize(Addr->getType());
611   int index = sizetoindex(size);
612
613 //  not supported by CDS yet
614 /*  if (IsWrite && isVtableAccess(I)) {
615     LLVM_DEBUG(dbgs() << "  VPTR : " << *I << "\n");
616     Value *StoredValue = cast<StoreInst>(I)->getValueOperand();
617     // StoredValue may be a vector type if we are storing several vptrs at once.
618     // In this case, just take the first element of the vector since this is
619     // enough to find vptr races.
620     if (isa<VectorType>(StoredValue->getType()))
621       StoredValue = IRB.CreateExtractElement(
622           StoredValue, ConstantInt::get(IRB.getInt32Ty(), 0));
623     if (StoredValue->getType()->isIntegerTy())
624       StoredValue = IRB.CreateIntToPtr(StoredValue, IRB.getInt8PtrTy());
625     // Call TsanVptrUpdate.
626     IRB.CreateCall(TsanVptrUpdate,
627                    {IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()),
628                     IRB.CreatePointerCast(StoredValue, IRB.getInt8PtrTy())});
629     NumInstrumentedVtableWrites++;
630     return true;
631   }
632
633   if (!IsWrite && isVtableAccess(I)) {
634     IRB.CreateCall(TsanVptrLoad,
635                    IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()));
636     NumInstrumentedVtableReads++;
637     return true;
638   }
639 */
640
641   Value *OnAccessFunc = nullptr;
642   OnAccessFunc = IsWrite ? CDSStore[index] : CDSLoad[index];
643   
644   Type *ArgType = IRB.CreatePointerCast(Addr, Addr->getType())->getType();
645
646   if ( ArgType != Int8PtrTy && ArgType != Int16PtrTy && 
647                 ArgType != Int32PtrTy && ArgType != Int64PtrTy ) {
648         //errs() << "A load or store of type ";
649         //errs() << *ArgType;
650         //errs() << " is passed in\n";
651         return false;   // if other types of load or stores are passed in
652   }
653   IRB.CreateCall(OnAccessFunc, IRB.CreatePointerCast(Addr, Addr->getType()));
654   if (IsWrite) NumInstrumentedWrites++;
655   else         NumInstrumentedReads++;
656   return true;
657 }
658
659 // todo: replace getTypeSize with the getMemoryAccessFuncIndex
660 bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
661   IRBuilder<> IRB(I);
662   // LLVMContext &Ctx = IRB.getContext();
663
664   if (auto *CI = dyn_cast<CallInst>(I)) {
665     return instrumentAtomicCall(CI, DL);
666   }
667
668   Value *position = getPosition(I, IRB);
669
670   if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
671     int atomic_order_index = getAtomicOrderIndex(SI->getOrdering());
672
673     Value *val = SI->getValueOperand();
674     Value *ptr = SI->getPointerOperand();
675     Value *order = ConstantInt::get(OrdTy, atomic_order_index);
676     Value *args[] = {ptr, val, order, position};
677
678     int size=getTypeSize(ptr->getType());
679     int index=sizetoindex(size);
680
681     Instruction* funcInst=CallInst::Create(CDSAtomicStore[index], args);
682     ReplaceInstWithInst(SI, funcInst);
683 //    errs() << "Store replaced\n";
684   } else if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
685     int atomic_order_index = getAtomicOrderIndex(LI->getOrdering());
686
687     Value *ptr = LI->getPointerOperand();
688     Value *order = ConstantInt::get(OrdTy, atomic_order_index);
689     Value *args[] = {ptr, order, position};
690
691     int size=getTypeSize(ptr->getType());
692     int index=sizetoindex(size);
693
694     Instruction* funcInst=CallInst::Create(CDSAtomicLoad[index], args);
695     ReplaceInstWithInst(LI, funcInst);
696 //    errs() << "Load Replaced\n";
697   } else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) {
698     int atomic_order_index = getAtomicOrderIndex(RMWI->getOrdering());
699
700     Value *val = RMWI->getValOperand();
701     Value *ptr = RMWI->getPointerOperand();
702     Value *order = ConstantInt::get(OrdTy, atomic_order_index);
703     Value *args[] = {ptr, val, order, position};
704
705     int size = getTypeSize(ptr->getType());
706     int index = sizetoindex(size);
707
708     Instruction* funcInst = CallInst::Create(CDSAtomicRMW[RMWI->getOperation()][index], args);
709     ReplaceInstWithInst(RMWI, funcInst);
710 //    errs() << RMWI->getOperationName(RMWI->getOperation());
711 //    errs() << " replaced\n";
712   } else if (AtomicCmpXchgInst *CASI = dyn_cast<AtomicCmpXchgInst>(I)) {
713     IRBuilder<> IRB(CASI);
714
715     Value *Addr = CASI->getPointerOperand();
716
717     int size = getTypeSize(Addr->getType());
718     int index = sizetoindex(size);
719     const unsigned ByteSize = 1U << index;
720     const unsigned BitSize = ByteSize * 8;
721     Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
722     Type *PtrTy = Ty->getPointerTo();
723
724     Value *CmpOperand = IRB.CreateBitOrPointerCast(CASI->getCompareOperand(), Ty);
725     Value *NewOperand = IRB.CreateBitOrPointerCast(CASI->getNewValOperand(), Ty);
726
727     int atomic_order_index_succ = getAtomicOrderIndex(CASI->getSuccessOrdering());
728     int atomic_order_index_fail = getAtomicOrderIndex(CASI->getFailureOrdering());
729     Value *order_succ = ConstantInt::get(OrdTy, atomic_order_index_succ);
730     Value *order_fail = ConstantInt::get(OrdTy, atomic_order_index_fail);
731
732     Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy),
733                      CmpOperand, NewOperand,
734                      order_succ, order_fail, position};
735
736     CallInst *funcInst = IRB.CreateCall(CDSAtomicCAS_V1[index], Args);
737     Value *Success = IRB.CreateICmpEQ(funcInst, CmpOperand);
738
739     Value *OldVal = funcInst;
740     Type *OrigOldValTy = CASI->getNewValOperand()->getType();
741     if (Ty != OrigOldValTy) {
742       // The value is a pointer, so we need to cast the return value.
743       OldVal = IRB.CreateIntToPtr(funcInst, OrigOldValTy);
744     }
745
746     Value *Res =
747       IRB.CreateInsertValue(UndefValue::get(CASI->getType()), OldVal, 0);
748     Res = IRB.CreateInsertValue(Res, Success, 1);
749
750     I->replaceAllUsesWith(Res);
751     I->eraseFromParent();
752   } else if (FenceInst *FI = dyn_cast<FenceInst>(I)) {
753     int atomic_order_index = getAtomicOrderIndex(FI->getOrdering());
754     Value *order = ConstantInt::get(OrdTy, atomic_order_index);
755     Value *Args[] = {order, position};
756
757     CallInst *funcInst = CallInst::Create(CDSAtomicThreadFence, Args);
758     ReplaceInstWithInst(FI, funcInst);
759 //    errs() << "Thread Fences replaced\n";
760   }
761   return true;
762 }
763
764 int CDSPass::getMemoryAccessFuncIndex(Value *Addr,
765                                               const DataLayout &DL) {
766   Type *OrigPtrTy = Addr->getType();
767   Type *OrigTy = cast<PointerType>(OrigPtrTy)->getElementType();
768   assert(OrigTy->isSized());
769   uint32_t TypeSize = DL.getTypeStoreSizeInBits(OrigTy);
770   if (TypeSize != 8  && TypeSize != 16 &&
771       TypeSize != 32 && TypeSize != 64 && TypeSize != 128) {
772     // NumAccessesWithBadSize++;
773     // Ignore all unusual sizes.
774     return -1;
775   }
776   size_t Idx = countTrailingZeros(TypeSize / 8);
777   // assert(Idx < FUNCARRAYSIZE);
778   return Idx;
779 }
780
781
782 char CDSPass::ID = 0;
783
784 // Automatically enable the pass.
785 static void registerCDSPass(const PassManagerBuilder &,
786                          legacy::PassManagerBase &PM) {
787   PM.add(new CDSPass());
788 }
789 static RegisterStandardPasses 
790         RegisterMyPass(PassManagerBuilder::EP_OptimizerLast,
791 registerCDSPass);