Merge branch 'master' of /home/git/cds-llvm
[c11llvm.git] / CDSPass.cpp
1 //===-- CDSPass.cpp - xxx -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 // This file is distributed under the University of Illinois Open Source
7 // License. See LICENSE.TXT for details.
8 //
9 //===----------------------------------------------------------------------===//
10 //
11 // This file is a modified version of ThreadSanitizer.cpp, a part of a race detector.
12 //
13 // The tool is under development, for the details about previous versions see
14 // http://code.google.com/p/data-race-test
15 //
16 // The instrumentation phase is quite simple:
17 //   - Insert calls to run-time library before every memory access.
18 //      - Optimizations may apply to avoid instrumenting some of the accesses.
19 //   - Insert calls at function entry/exit.
20 // The rest is handled by the run-time library.
21 //===----------------------------------------------------------------------===//
22
23 #include "llvm/ADT/Statistic.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/Analysis/ValueTracking.h"
27 #include "llvm/Analysis/CaptureTracking.h"
28 #include "llvm/IR/BasicBlock.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/LegacyPassManager.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/IR/PassManager.h"
36 #include "llvm/Pass.h"
37 #include "llvm/ProfileData/InstrProf.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include "llvm/Support/AtomicOrdering.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Transforms/Scalar.h"
42 #include "llvm/Transforms/Utils/Local.h"
43 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
44 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
45 #include "llvm/Transforms/Utils/EscapeEnumerator.h"
46 #include <vector>
47
48 using namespace llvm;
49
50 #define DEBUG_TYPE "CDS"
51 #include <llvm/IR/DebugLoc.h>
52
53 Value *getPosition( Instruction * I, IRBuilder <> IRB, bool print = false)
54 {
55         const DebugLoc & debug_location = I->getDebugLoc ();
56         std::string position_string;
57         {
58                 llvm::raw_string_ostream position_stream (position_string);
59                 debug_location . print (position_stream);
60         }
61
62         if (print) {
63                 errs() << position_string << "\n";
64         }
65
66         return IRB.CreateGlobalStringPtr (position_string);
67 }
68
69 STATISTIC(NumInstrumentedReads, "Number of instrumented reads");
70 STATISTIC(NumInstrumentedWrites, "Number of instrumented writes");
71 STATISTIC(NumAccessesWithBadSize, "Number of accesses with bad size");
72 // STATISTIC(NumInstrumentedVtableWrites, "Number of vtable ptr writes");
73 // STATISTIC(NumInstrumentedVtableReads, "Number of vtable ptr reads");
74
75 STATISTIC(NumOmittedReadsBeforeWrite,
76           "Number of reads ignored due to following writes");
77 STATISTIC(NumOmittedReadsFromConstantGlobals,
78           "Number of reads from constant globals");
79 STATISTIC(NumOmittedReadsFromVtable, "Number of vtable reads");
80 STATISTIC(NumOmittedNonCaptured, "Number of accesses ignored due to capturing");
81
82 Type * OrdTy;
83
84 Type * Int8PtrTy;
85 Type * Int16PtrTy;
86 Type * Int32PtrTy;
87 Type * Int64PtrTy;
88
89 Type * VoidTy;
90
91 static const size_t kNumberOfAccessSizes = 4;
92
93 int getAtomicOrderIndex(AtomicOrdering order){
94         switch (order) {
95                 case AtomicOrdering::Monotonic: 
96                         return (int)AtomicOrderingCABI::relaxed;
97                 //  case AtomicOrdering::Consume:         // not specified yet
98                 //    return AtomicOrderingCABI::consume;
99                 case AtomicOrdering::Acquire: 
100                         return (int)AtomicOrderingCABI::acquire;
101                 case AtomicOrdering::Release: 
102                         return (int)AtomicOrderingCABI::release;
103                 case AtomicOrdering::AcquireRelease: 
104                         return (int)AtomicOrderingCABI::acq_rel;
105                 case AtomicOrdering::SequentiallyConsistent: 
106                         return (int)AtomicOrderingCABI::seq_cst;
107                 default:
108                         // unordered or Not Atomic
109                         return -1;
110         }
111 }
112
113 namespace {
114         struct CDSPass : public FunctionPass {
115                 static char ID;
116                 CDSPass() : FunctionPass(ID) {}
117                 bool runOnFunction(Function &F) override; 
118
119         private:
120                 void initializeCallbacks(Module &M);
121                 bool instrumentLoadOrStore(Instruction *I, const DataLayout &DL);
122                 bool instrumentVolatile(Instruction *I, const DataLayout &DL);
123                 bool isAtomicCall(Instruction *I);
124                 bool instrumentAtomic(Instruction *I, const DataLayout &DL);
125                 bool instrumentAtomicCall(CallInst *CI, const DataLayout &DL);
126                 void chooseInstructionsToInstrument(SmallVectorImpl<Instruction *> &Local,
127                                                                                         SmallVectorImpl<Instruction *> &All,
128                                                                                         const DataLayout &DL);
129                 bool addrPointsToConstantData(Value *Addr);
130                 int getMemoryAccessFuncIndex(Value *Addr, const DataLayout &DL);
131
132                 // Callbacks to run-time library are computed in doInitialization.
133                 Constant * CDSFuncEntry;
134                 Constant * CDSFuncExit;
135
136                 Constant * CDSLoad[kNumberOfAccessSizes];
137                 Constant * CDSStore[kNumberOfAccessSizes];
138                 Constant * CDSVolatileLoad[kNumberOfAccessSizes];
139                 Constant * CDSVolatileStore[kNumberOfAccessSizes];
140                 Constant * CDSAtomicInit[kNumberOfAccessSizes];
141                 Constant * CDSAtomicLoad[kNumberOfAccessSizes];
142                 Constant * CDSAtomicStore[kNumberOfAccessSizes];
143                 Constant * CDSAtomicRMW[AtomicRMWInst::LAST_BINOP + 1][kNumberOfAccessSizes];
144                 Constant * CDSAtomicCAS_V1[kNumberOfAccessSizes];
145                 Constant * CDSAtomicCAS_V2[kNumberOfAccessSizes];
146                 Constant * CDSAtomicThreadFence;
147
148                 std::vector<StringRef> AtomicFuncNames;
149                 std::vector<StringRef> PartialAtomicFuncNames;
150         };
151 }
152
153 static bool isVtableAccess(Instruction *I) {
154         if (MDNode *Tag = I->getMetadata(LLVMContext::MD_tbaa))
155                 return Tag->isTBAAVtableAccess();
156         return false;
157 }
158
159 void CDSPass::initializeCallbacks(Module &M) {
160         LLVMContext &Ctx = M.getContext();
161
162         Type * Int1Ty = Type::getInt1Ty(Ctx);
163         OrdTy = Type::getInt32Ty(Ctx);
164
165         Int8PtrTy  = Type::getInt8PtrTy(Ctx);
166         Int16PtrTy = Type::getInt16PtrTy(Ctx);
167         Int32PtrTy = Type::getInt32PtrTy(Ctx);
168         Int64PtrTy = Type::getInt64PtrTy(Ctx);
169
170         VoidTy = Type::getVoidTy(Ctx);
171
172         CDSFuncEntry = M.getOrInsertFunction("cds_func_entry", 
173                                                                 VoidTy, Int8PtrTy);
174         CDSFuncExit = M.getOrInsertFunction("cds_func_exit", 
175                                                                 VoidTy, Int8PtrTy);
176
177         // Get the function to call from our untime library.
178         for (unsigned i = 0; i < kNumberOfAccessSizes; i++) {
179                 const unsigned ByteSize = 1U << i;
180                 const unsigned BitSize = ByteSize * 8;
181
182                 std::string ByteSizeStr = utostr(ByteSize);
183                 std::string BitSizeStr = utostr(BitSize);
184
185                 Type *Ty = Type::getIntNTy(Ctx, BitSize);
186                 Type *PtrTy = Ty->getPointerTo();
187
188                 // uint8_t cds_atomic_load8 (void * obj, int atomic_index)
189                 // void cds_atomic_store8 (void * obj, int atomic_index, uint8_t val)
190                 SmallString<32> LoadName("cds_load" + BitSizeStr);
191                 SmallString<32> StoreName("cds_store" + BitSizeStr);
192                 SmallString<32> VolatileLoadName("cds_volatile_load" + BitSizeStr);
193                 SmallString<32> VolatileStoreName("cds_volatile_store" + BitSizeStr);
194                 SmallString<32> AtomicInitName("cds_atomic_init" + BitSizeStr);
195                 SmallString<32> AtomicLoadName("cds_atomic_load" + BitSizeStr);
196                 SmallString<32> AtomicStoreName("cds_atomic_store" + BitSizeStr);
197
198                 CDSLoad[i]  = M.getOrInsertFunction(LoadName, VoidTy, PtrTy);
199                 CDSStore[i] = M.getOrInsertFunction(StoreName, VoidTy, PtrTy);
200                 CDSVolatileLoad[i]  = M.getOrInsertFunction(VolatileLoadName,
201                                                                         Ty, PtrTy, Int8PtrTy);
202                 CDSVolatileStore[i] = M.getOrInsertFunction(VolatileStoreName, 
203                                                                         VoidTy, PtrTy, Ty, Int8PtrTy);
204                 CDSAtomicInit[i] = M.getOrInsertFunction(AtomicInitName, 
205                                                                 VoidTy, PtrTy, Ty, Int8PtrTy);
206                 CDSAtomicLoad[i]  = M.getOrInsertFunction(AtomicLoadName, 
207                                                                 Ty, PtrTy, OrdTy, Int8PtrTy);
208                 CDSAtomicStore[i] = M.getOrInsertFunction(AtomicStoreName, 
209                                                                 VoidTy, PtrTy, Ty, OrdTy, Int8PtrTy);
210
211                 for (int op = AtomicRMWInst::FIRST_BINOP; 
212                         op <= AtomicRMWInst::LAST_BINOP; ++op) {
213                         CDSAtomicRMW[op][i] = nullptr;
214                         std::string NamePart;
215
216                         if (op == AtomicRMWInst::Xchg)
217                                 NamePart = "_exchange";
218                         else if (op == AtomicRMWInst::Add) 
219                                 NamePart = "_fetch_add";
220                         else if (op == AtomicRMWInst::Sub)
221                                 NamePart = "_fetch_sub";
222                         else if (op == AtomicRMWInst::And)
223                                 NamePart = "_fetch_and";
224                         else if (op == AtomicRMWInst::Or)
225                                 NamePart = "_fetch_or";
226                         else if (op == AtomicRMWInst::Xor)
227                                 NamePart = "_fetch_xor";
228                         else
229                                 continue;
230
231                         SmallString<32> AtomicRMWName("cds_atomic" + NamePart + BitSizeStr);
232                         CDSAtomicRMW[op][i] = M.getOrInsertFunction(AtomicRMWName, 
233                                                                                 Ty, PtrTy, Ty, OrdTy, Int8PtrTy);
234                 }
235
236                 // only supportes strong version
237                 SmallString<32> AtomicCASName_V1("cds_atomic_compare_exchange" + BitSizeStr + "_v1");
238                 SmallString<32> AtomicCASName_V2("cds_atomic_compare_exchange" + BitSizeStr + "_v2");
239                 CDSAtomicCAS_V1[i] = M.getOrInsertFunction(AtomicCASName_V1, 
240                                                                 Ty, PtrTy, Ty, Ty, OrdTy, OrdTy, Int8PtrTy);
241                 CDSAtomicCAS_V2[i] = M.getOrInsertFunction(AtomicCASName_V2, 
242                                                                 Int1Ty, PtrTy, PtrTy, Ty, OrdTy, OrdTy, Int8PtrTy);
243         }
244
245         CDSAtomicThreadFence = M.getOrInsertFunction("cds_atomic_thread_fence", 
246                                                                                                         VoidTy, OrdTy, Int8PtrTy);
247 }
248
249 static bool shouldInstrumentReadWriteFromAddress(const Module *M, Value *Addr) {
250         // Peel off GEPs and BitCasts.
251         Addr = Addr->stripInBoundsOffsets();
252
253         if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) {
254                 if (GV->hasSection()) {
255                         StringRef SectionName = GV->getSection();
256                         // Check if the global is in the PGO counters section.
257                         auto OF = Triple(M->getTargetTriple()).getObjectFormat();
258                         if (SectionName.endswith(
259                               getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false)))
260                                 return false;
261                 }
262
263                 // Check if the global is private gcov data.
264                 if (GV->getName().startswith("__llvm_gcov") ||
265                 GV->getName().startswith("__llvm_gcda"))
266                 return false;
267         }
268
269         // Do not instrument acesses from different address spaces; we cannot deal
270         // with them.
271         if (Addr) {
272                 Type *PtrTy = cast<PointerType>(Addr->getType()->getScalarType());
273                 if (PtrTy->getPointerAddressSpace() != 0)
274                         return false;
275         }
276
277         return true;
278 }
279
280 bool CDSPass::addrPointsToConstantData(Value *Addr) {
281         // If this is a GEP, just analyze its pointer operand.
282         if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Addr))
283                 Addr = GEP->getPointerOperand();
284
285         if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) {
286                 if (GV->isConstant()) {
287                         // Reads from constant globals can not race with any writes.
288                         NumOmittedReadsFromConstantGlobals++;
289                         return true;
290                 }
291         } else if (LoadInst *L = dyn_cast<LoadInst>(Addr)) {
292                 if (isVtableAccess(L)) {
293                         // Reads from a vtable pointer can not race with any writes.
294                         NumOmittedReadsFromVtable++;
295                         return true;
296                 }
297         }
298         return false;
299 }
300
301 bool CDSPass::runOnFunction(Function &F) {
302         if (F.getName() == "main") {
303                 F.setName("user_main");
304                 errs() << "main replaced by user_main\n";
305         }
306
307         if (true) {
308                 initializeCallbacks( *F.getParent() );
309
310                 AtomicFuncNames = 
311                 {
312                         "atomic_init", "atomic_load", "atomic_store", 
313                         "atomic_fetch_", "atomic_exchange", "atomic_compare_exchange_"
314                 };
315
316                 PartialAtomicFuncNames = 
317                 { 
318                         "load", "store", "fetch", "exchange", "compare_exchange_" 
319                 };
320
321                 SmallVector<Instruction*, 8> AllLoadsAndStores;
322                 SmallVector<Instruction*, 8> LocalLoadsAndStores;
323                 SmallVector<Instruction*, 8> VolatileLoadsAndStores;
324                 SmallVector<Instruction*, 8> AtomicAccesses;
325
326                 std::vector<Instruction *> worklist;
327
328                 bool Res = false;
329                 bool HasAtomic = false;
330                 const DataLayout &DL = F.getParent()->getDataLayout();
331
332                 // errs() << "--- " << F.getName() << "---\n";
333
334                 for (auto &B : F) {
335                         for (auto &I : B) {
336                                 if ( (&I)->isAtomic() || isAtomicCall(&I) ) {
337                                         AtomicAccesses.push_back(&I);
338                                         HasAtomic = true;
339                                 } else if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
340                                         LoadInst *LI = dyn_cast<LoadInst>(&I);
341                                         StoreInst *SI = dyn_cast<StoreInst>(&I);
342                                         bool isVolatile = ( LI ? LI->isVolatile() : SI->isVolatile() );
343
344                                         if (isVolatile)
345                                                 VolatileLoadsAndStores.push_back(&I);
346                                         else
347                                                 LocalLoadsAndStores.push_back(&I);
348                                 } else if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
349                                         // not implemented yet
350                                 }
351                         }
352
353                         chooseInstructionsToInstrument(LocalLoadsAndStores, AllLoadsAndStores, DL);
354                 }
355
356                 for (auto Inst : AllLoadsAndStores) {
357                         Res |= instrumentLoadOrStore(Inst, DL);
358                 }
359
360                 for (auto Inst : VolatileLoadsAndStores) {
361                         Res |= instrumentVolatile(Inst, DL);
362                 }
363
364                 for (auto Inst : AtomicAccesses) {
365                         Res |= instrumentAtomic(Inst, DL);
366                 }
367
368                 // only instrument functions that contain atomics
369                 if (Res && HasAtomic) {
370                         IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI());
371                         /* Unused for now
372                         Value *ReturnAddress = IRB.CreateCall(
373                                 Intrinsic::getDeclaration(F.getParent(), Intrinsic::returnaddress),
374                                 IRB.getInt32(0));
375                         */
376
377                         Value * FuncName = IRB.CreateGlobalStringPtr(F.getName());
378                         IRB.CreateCall(CDSFuncEntry, FuncName);
379
380                         EscapeEnumerator EE(F, "cds_cleanup", true);
381                         while (IRBuilder<> *AtExit = EE.Next()) {
382                           AtExit->CreateCall(CDSFuncExit, FuncName);
383                         }
384
385                         Res = true;
386                 }
387         }
388
389         return false;
390 }
391
392 void CDSPass::chooseInstructionsToInstrument(
393         SmallVectorImpl<Instruction *> &Local, SmallVectorImpl<Instruction *> &All,
394         const DataLayout &DL) {
395         SmallPtrSet<Value*, 8> WriteTargets;
396         // Iterate from the end.
397         for (Instruction *I : reverse(Local)) {
398                 if (StoreInst *Store = dyn_cast<StoreInst>(I)) {
399                         Value *Addr = Store->getPointerOperand();
400                         if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
401                                 continue;
402                         WriteTargets.insert(Addr);
403                 } else {
404                         LoadInst *Load = cast<LoadInst>(I);
405                         Value *Addr = Load->getPointerOperand();
406                         if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
407                                 continue;
408                         if (WriteTargets.count(Addr)) {
409                                 // We will write to this temp, so no reason to analyze the read.
410                                 NumOmittedReadsBeforeWrite++;
411                                 continue;
412                         }
413                         if (addrPointsToConstantData(Addr)) {
414                                 // Addr points to some constant data -- it can not race with any writes.
415                                 continue;
416                         }
417                 }
418                 Value *Addr = isa<StoreInst>(*I)
419                         ? cast<StoreInst>(I)->getPointerOperand()
420                         : cast<LoadInst>(I)->getPointerOperand();
421                 if (isa<AllocaInst>(GetUnderlyingObject(Addr, DL)) &&
422                                 !PointerMayBeCaptured(Addr, true, true)) {
423                         // The variable is addressable but not captured, so it cannot be
424                         // referenced from a different thread and participate in a data race
425                         // (see llvm/Analysis/CaptureTracking.h for details).
426                         NumOmittedNonCaptured++;
427                         continue;
428                 }
429                 All.push_back(I);
430         }
431         Local.clear();
432 }
433
434
435 bool CDSPass::instrumentLoadOrStore(Instruction *I,
436                                                                         const DataLayout &DL) {
437         IRBuilder<> IRB(I);
438         bool IsWrite = isa<StoreInst>(*I);
439         Value *Addr = IsWrite
440                 ? cast<StoreInst>(I)->getPointerOperand()
441                 : cast<LoadInst>(I)->getPointerOperand();
442
443         // swifterror memory addresses are mem2reg promoted by instruction selection.
444         // As such they cannot have regular uses like an instrumentation function and
445         // it makes no sense to track them as memory.
446         if (Addr->isSwiftError())
447         return false;
448
449         int Idx = getMemoryAccessFuncIndex(Addr, DL);
450         if (Idx < 0)
451                 return false;
452
453 //  not supported by CDS yet
454 /*  if (IsWrite && isVtableAccess(I)) {
455     LLVM_DEBUG(dbgs() << "  VPTR : " << *I << "\n");
456     Value *StoredValue = cast<StoreInst>(I)->getValueOperand();
457     // StoredValue may be a vector type if we are storing several vptrs at once.
458     // In this case, just take the first element of the vector since this is
459     // enough to find vptr races.
460     if (isa<VectorType>(StoredValue->getType()))
461       StoredValue = IRB.CreateExtractElement(
462           StoredValue, ConstantInt::get(IRB.getInt32Ty(), 0));
463     if (StoredValue->getType()->isIntegerTy())
464       StoredValue = IRB.CreateIntToPtr(StoredValue, IRB.getInt8PtrTy());
465     // Call TsanVptrUpdate.
466     IRB.CreateCall(TsanVptrUpdate,
467                    {IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()),
468                     IRB.CreatePointerCast(StoredValue, IRB.getInt8PtrTy())});
469     NumInstrumentedVtableWrites++;
470     return true;
471   }
472
473   if (!IsWrite && isVtableAccess(I)) {
474     IRB.CreateCall(TsanVptrLoad,
475                    IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()));
476     NumInstrumentedVtableReads++;
477     return true;
478   }
479 */
480
481         Value *OnAccessFunc = nullptr;
482         OnAccessFunc = IsWrite ? CDSStore[Idx] : CDSLoad[Idx];
483
484         Type *ArgType = IRB.CreatePointerCast(Addr, Addr->getType())->getType();
485
486         if ( ArgType != Int8PtrTy && ArgType != Int16PtrTy && 
487                         ArgType != Int32PtrTy && ArgType != Int64PtrTy ) {
488                 // if other types of load or stores are passed in
489                 return false;   
490         }
491         IRB.CreateCall(OnAccessFunc, IRB.CreatePointerCast(Addr, Addr->getType()));
492         if (IsWrite) NumInstrumentedWrites++;
493         else         NumInstrumentedReads++;
494         return true;
495 }
496
497 bool CDSPass::instrumentVolatile(Instruction * I, const DataLayout &DL) {
498         IRBuilder<> IRB(I);
499         Value *position = getPosition(I, IRB);
500
501         if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
502                 assert( LI->isVolatile() );
503                 Value *Addr = LI->getPointerOperand();
504                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
505                 if (Idx < 0)
506                         return false;
507
508                 Value *args[] = {Addr, position};
509                 Instruction* funcInst=CallInst::Create(CDSVolatileLoad[Idx], args);
510                 ReplaceInstWithInst(LI, funcInst);
511         } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
512                 assert( SI->isVolatile() );
513                 Value *Addr = SI->getPointerOperand();
514                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
515                 if (Idx < 0)
516                         return false;
517
518                 Value *val = SI->getValueOperand();
519                 Value *args[] = {Addr, val, position};
520                 Instruction* funcInst=CallInst::Create(CDSVolatileStore[Idx], args);
521                 ReplaceInstWithInst(SI, funcInst);
522         } else {
523                 return false;
524         }
525
526         return true;
527 }
528
529 bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
530         IRBuilder<> IRB(I);
531
532         if (auto *CI = dyn_cast<CallInst>(I)) {
533                 return instrumentAtomicCall(CI, DL);
534         }
535
536         Value *position = getPosition(I, IRB);
537
538         if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
539                 Value *Addr = LI->getPointerOperand();
540                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
541                 if (Idx < 0)
542                         return false;
543
544                 int atomic_order_index = getAtomicOrderIndex(LI->getOrdering());
545                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
546                 Value *args[] = {Addr, order, position};
547                 Instruction* funcInst=CallInst::Create(CDSAtomicLoad[Idx], args);
548                 ReplaceInstWithInst(LI, funcInst);
549         } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
550                 Value *Addr = SI->getPointerOperand();
551                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
552                 if (Idx < 0)
553                         return false;
554
555                 int atomic_order_index = getAtomicOrderIndex(SI->getOrdering());
556                 Value *val = SI->getValueOperand();
557                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
558                 Value *args[] = {Addr, val, order, position};
559                 Instruction* funcInst=CallInst::Create(CDSAtomicStore[Idx], args);
560                 ReplaceInstWithInst(SI, funcInst);
561         } else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) {
562                 Value *Addr = RMWI->getPointerOperand();
563                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
564                 if (Idx < 0)
565                         return false;
566
567                 int atomic_order_index = getAtomicOrderIndex(RMWI->getOrdering());
568                 Value *val = RMWI->getValOperand();
569                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
570                 Value *args[] = {Addr, val, order, position};
571                 Instruction* funcInst = CallInst::Create(CDSAtomicRMW[RMWI->getOperation()][Idx], args);
572                 ReplaceInstWithInst(RMWI, funcInst);
573         } else if (AtomicCmpXchgInst *CASI = dyn_cast<AtomicCmpXchgInst>(I)) {
574                 IRBuilder<> IRB(CASI);
575
576                 Value *Addr = CASI->getPointerOperand();
577                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
578                 if (Idx < 0)
579                         return false;
580
581                 const unsigned ByteSize = 1U << Idx;
582                 const unsigned BitSize = ByteSize * 8;
583                 Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
584                 Type *PtrTy = Ty->getPointerTo();
585
586                 Value *CmpOperand = IRB.CreateBitOrPointerCast(CASI->getCompareOperand(), Ty);
587                 Value *NewOperand = IRB.CreateBitOrPointerCast(CASI->getNewValOperand(), Ty);
588
589                 int atomic_order_index_succ = getAtomicOrderIndex(CASI->getSuccessOrdering());
590                 int atomic_order_index_fail = getAtomicOrderIndex(CASI->getFailureOrdering());
591                 Value *order_succ = ConstantInt::get(OrdTy, atomic_order_index_succ);
592                 Value *order_fail = ConstantInt::get(OrdTy, atomic_order_index_fail);
593
594                 Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy),
595                                                  CmpOperand, NewOperand,
596                                                  order_succ, order_fail, position};
597
598                 CallInst *funcInst = IRB.CreateCall(CDSAtomicCAS_V1[Idx], Args);
599                 Value *Success = IRB.CreateICmpEQ(funcInst, CmpOperand);
600
601                 Value *OldVal = funcInst;
602                 Type *OrigOldValTy = CASI->getNewValOperand()->getType();
603                 if (Ty != OrigOldValTy) {
604                         // The value is a pointer, so we need to cast the return value.
605                         OldVal = IRB.CreateIntToPtr(funcInst, OrigOldValTy);
606                 }
607
608                 Value *Res =
609                   IRB.CreateInsertValue(UndefValue::get(CASI->getType()), OldVal, 0);
610                 Res = IRB.CreateInsertValue(Res, Success, 1);
611
612                 I->replaceAllUsesWith(Res);
613                 I->eraseFromParent();
614         } else if (FenceInst *FI = dyn_cast<FenceInst>(I)) {
615                 int atomic_order_index = getAtomicOrderIndex(FI->getOrdering());
616                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
617                 Value *Args[] = {order, position};
618
619                 CallInst *funcInst = CallInst::Create(CDSAtomicThreadFence, Args);
620                 ReplaceInstWithInst(FI, funcInst);
621                 // errs() << "Thread Fences replaced\n";
622         }
623         return true;
624 }
625
626 bool CDSPass::isAtomicCall(Instruction *I) {
627         if ( auto *CI = dyn_cast<CallInst>(I) ) {
628                 Function *fun = CI->getCalledFunction();
629                 if (fun == NULL)
630                         return false;
631
632                 StringRef funName = fun->getName();
633
634                 // todo: come up with better rules for function name checking
635                 for (StringRef name : AtomicFuncNames) {
636                         if ( funName.contains(name) ) 
637                                 return true;
638                 }
639                 
640                 for (StringRef PartialName : PartialAtomicFuncNames) {
641                         if (funName.contains(PartialName) && 
642                                         funName.contains("atomic") )
643                                 return true;
644                 }
645         }
646
647         return false;
648 }
649
650 bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
651         IRBuilder<> IRB(CI);
652         Function *fun = CI->getCalledFunction();
653         StringRef funName = fun->getName();
654         std::vector<Value *> parameters;
655
656         User::op_iterator begin = CI->arg_begin();
657         User::op_iterator end = CI->arg_end();
658         for (User::op_iterator it = begin; it != end; ++it) {
659                 Value *param = *it;
660                 parameters.push_back(param);
661         }
662
663         // obtain source line number of the CallInst
664         Value *position = getPosition(CI, IRB);
665
666         // the pointer to the address is always the first argument
667         Value *OrigPtr = parameters[0];
668
669         int Idx = getMemoryAccessFuncIndex(OrigPtr, DL);
670         if (Idx < 0)
671                 return false;
672
673         const unsigned ByteSize = 1U << Idx;
674         const unsigned BitSize = ByteSize * 8;
675         Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
676         Type *PtrTy = Ty->getPointerTo();
677
678         // atomic_init; args = {obj, order}
679         if (funName.contains("atomic_init")) {
680                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
681                 Value *val = IRB.CreateBitOrPointerCast(parameters[1], Ty);
682                 Value *args[] = {ptr, val, position};
683
684                 Instruction* funcInst = CallInst::Create(CDSAtomicInit[Idx], args);
685                 ReplaceInstWithInst(CI, funcInst);
686
687                 return true;
688         }
689
690         // atomic_load; args = {obj, order}
691         if (funName.contains("atomic_load")) {
692                 bool isExplicit = funName.contains("atomic_load_explicit");
693
694                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
695                 Value *order;
696                 if (isExplicit)
697                         order = IRB.CreateBitOrPointerCast(parameters[1], OrdTy);
698                 else 
699                         order = ConstantInt::get(OrdTy, 
700                                                         (int) AtomicOrderingCABI::seq_cst);
701                 Value *args[] = {ptr, order, position};
702                 
703                 Instruction* funcInst = CallInst::Create(CDSAtomicLoad[Idx], args);
704                 ReplaceInstWithInst(CI, funcInst);
705
706                 return true;
707         } else if (funName.contains("atomic") && 
708                                         funName.contains("load") ) {
709                 // does this version of call always have an atomic order as an argument?
710                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
711                 Value *order = IRB.CreateBitOrPointerCast(parameters[1], OrdTy);
712                 Value *args[] = {ptr, order, position};
713
714                 if (!CI->getType()->isPointerTy()) {
715                         return false;   
716                 } 
717
718                 CallInst *funcInst = IRB.CreateCall(CDSAtomicLoad[Idx], args);
719                 Value *RetVal = IRB.CreateIntToPtr(funcInst, CI->getType());
720
721                 CI->replaceAllUsesWith(RetVal);
722                 CI->eraseFromParent();
723
724                 return true;
725         }
726
727         // atomic_store; args = {obj, val, order}
728         if (funName.contains("atomic_store")) {
729                 bool isExplicit = funName.contains("atomic_store_explicit");
730                 Value *OrigVal = parameters[1];
731
732                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
733                 Value *val = IRB.CreatePointerCast(OrigVal, Ty);
734                 Value *order;
735                 if (isExplicit)
736                         order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
737                 else 
738                         order = ConstantInt::get(OrdTy, 
739                                                         (int) AtomicOrderingCABI::seq_cst);
740                 Value *args[] = {ptr, val, order, position};
741                 
742                 Instruction* funcInst = CallInst::Create(CDSAtomicStore[Idx], args);
743                 ReplaceInstWithInst(CI, funcInst);
744
745                 return true;
746         } else if (funName.contains("atomic") && 
747                                         funName.contains("EEEE5store") ) {
748                 // does this version of call always have an atomic order as an argument?
749                 Value *OrigVal = parameters[1];
750
751                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
752                 Value *val = IRB.CreatePointerCast(OrigVal, Ty);
753                 Value *order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
754                 Value *args[] = {ptr, val, order, position};
755
756                 Instruction* funcInst = CallInst::Create(CDSAtomicStore[Idx], args);
757                 ReplaceInstWithInst(CI, funcInst);
758
759                 return true;
760         }
761
762         // atomic_fetch_*; args = {obj, val, order}
763         if (funName.contains("atomic_fetch_") || 
764                         funName.contains("atomic_exchange") ) {
765                 bool isExplicit = funName.contains("_explicit");
766                 Value *OrigVal = parameters[1];
767
768                 int op;
769                 if ( funName.contains("_fetch_add") )
770                         op = AtomicRMWInst::Add;
771                 else if ( funName.contains("_fetch_sub") )
772                         op = AtomicRMWInst::Sub;
773                 else if ( funName.contains("_fetch_and") )
774                         op = AtomicRMWInst::And;
775                 else if ( funName.contains("_fetch_or") )
776                         op = AtomicRMWInst::Or;
777                 else if ( funName.contains("_fetch_xor") )
778                         op = AtomicRMWInst::Xor;
779                 else if ( funName.contains("atomic_exchange") )
780                         op = AtomicRMWInst::Xchg;
781                 else {
782                         errs() << "Unknown atomic read-modify-write operation\n";
783                         return false;
784                 }
785
786                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
787                 Value *val = IRB.CreatePointerCast(OrigVal, Ty);
788                 Value *order;
789                 if (isExplicit)
790                         order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
791                 else 
792                         order = ConstantInt::get(OrdTy, 
793                                                         (int) AtomicOrderingCABI::seq_cst);
794                 Value *args[] = {ptr, val, order, position};
795                 
796                 Instruction* funcInst = CallInst::Create(CDSAtomicRMW[op][Idx], args);
797                 ReplaceInstWithInst(CI, funcInst);
798
799                 return true;
800         } else if (funName.contains("fetch")) {
801                 errs() << "atomic exchange captured. Not implemented yet. ";
802                 errs() << "See source file :";
803                 getPosition(CI, IRB, true);
804         } else if (funName.contains("exchange") &&
805                         !funName.contains("compare_exchange") ) {
806                 errs() << "atomic exchange captured. Not implemented yet. ";
807                 errs() << "See source file :";
808                 getPosition(CI, IRB, true);
809         }
810
811         /* atomic_compare_exchange_*; 
812            args = {obj, expected, new value, order1, order2}
813         */
814         if ( funName.contains("atomic_compare_exchange_") ) {
815                 bool isExplicit = funName.contains("_explicit");
816
817                 Value *Addr = IRB.CreatePointerCast(OrigPtr, PtrTy);
818                 Value *CmpOperand = IRB.CreatePointerCast(parameters[1], PtrTy);
819                 Value *NewOperand = IRB.CreateBitOrPointerCast(parameters[2], Ty);
820
821                 Value *order_succ, *order_fail;
822                 if (isExplicit) {
823                         order_succ = IRB.CreateBitOrPointerCast(parameters[3], OrdTy);
824                         order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
825                 } else  {
826                         order_succ = ConstantInt::get(OrdTy, 
827                                                         (int) AtomicOrderingCABI::seq_cst);
828                         order_fail = ConstantInt::get(OrdTy, 
829                                                         (int) AtomicOrderingCABI::seq_cst);
830                 }
831
832                 Value *args[] = {Addr, CmpOperand, NewOperand, 
833                                                         order_succ, order_fail, position};
834                 
835                 Instruction* funcInst = CallInst::Create(CDSAtomicCAS_V2[Idx], args);
836                 ReplaceInstWithInst(CI, funcInst);
837
838                 return true;
839         } else if ( funName.contains("compare_exchange_strong") ||
840                                 funName.contains("compare_exchange_weak") ) {
841                 Value *Addr = IRB.CreatePointerCast(OrigPtr, PtrTy);
842                 Value *CmpOperand = IRB.CreatePointerCast(parameters[1], PtrTy);
843                 Value *NewOperand = IRB.CreateBitOrPointerCast(parameters[2], Ty);
844
845                 Value *order_succ, *order_fail;
846                 order_succ = IRB.CreateBitOrPointerCast(parameters[3], OrdTy);
847                 order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
848
849                 Value *args[] = {Addr, CmpOperand, NewOperand, 
850                                                         order_succ, order_fail, position};
851                 Instruction* funcInst = CallInst::Create(CDSAtomicCAS_V2[Idx], args);
852                 ReplaceInstWithInst(CI, funcInst);
853
854                 return true;
855         }
856
857         return false;
858 }
859
860 int CDSPass::getMemoryAccessFuncIndex(Value *Addr,
861                                                                                 const DataLayout &DL) {
862         Type *OrigPtrTy = Addr->getType();
863         Type *OrigTy = cast<PointerType>(OrigPtrTy)->getElementType();
864         assert(OrigTy->isSized());
865         uint32_t TypeSize = DL.getTypeStoreSizeInBits(OrigTy);
866         if (TypeSize != 8  && TypeSize != 16 &&
867                 TypeSize != 32 && TypeSize != 64 && TypeSize != 128) {
868                 NumAccessesWithBadSize++;
869                 // Ignore all unusual sizes.
870                 return -1;
871         }
872         size_t Idx = countTrailingZeros(TypeSize / 8);
873         //assert(Idx < kNumberOfAccessSizes);
874         if (Idx >= kNumberOfAccessSizes) {
875                 return -1;
876         }
877         return Idx;
878 }
879
880
881 char CDSPass::ID = 0;
882
883 // Automatically enable the pass.
884 static void registerCDSPass(const PassManagerBuilder &,
885                                                         legacy::PassManagerBase &PM) {
886         PM.add(new CDSPass());
887 }
888
889 /* Enable the pass when opt level is greater than 0 */
890 static RegisterStandardPasses 
891         RegisterMyPass1(PassManagerBuilder::EP_OptimizerLast,
892 registerCDSPass);
893
894 /* Enable the pass when opt level is 0 */
895 static RegisterStandardPasses 
896         RegisterMyPass2(PassManagerBuilder::EP_EnabledOnOptLevel0,
897 registerCDSPass);