instrument volatile loads and stores
[c11llvm.git] / CDSPass.cpp
1 //===-- CDSPass.cpp - xxx -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 // This file is distributed under the University of Illinois Open Source
7 // License. See LICENSE.TXT for details.
8 //
9 //===----------------------------------------------------------------------===//
10 //
11 // This file is a modified version of ThreadSanitizer.cpp, a part of a race detector.
12 //
13 // The tool is under development, for the details about previous versions see
14 // http://code.google.com/p/data-race-test
15 //
16 // The instrumentation phase is quite simple:
17 //   - Insert calls to run-time library before every memory access.
18 //      - Optimizations may apply to avoid instrumenting some of the accesses.
19 //   - Insert calls at function entry/exit.
20 // The rest is handled by the run-time library.
21 //===----------------------------------------------------------------------===//
22
23 #include "llvm/ADT/Statistic.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/Analysis/ValueTracking.h"
27 #include "llvm/Analysis/CaptureTracking.h"
28 #include "llvm/IR/BasicBlock.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/LegacyPassManager.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/IR/PassManager.h"
36 #include "llvm/Pass.h"
37 #include "llvm/ProfileData/InstrProf.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include "llvm/Support/AtomicOrdering.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Transforms/Scalar.h"
42 #include "llvm/Transforms/Utils/Local.h"
43 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
44 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
45 #include "llvm/Transforms/Utils/EscapeEnumerator.h"
46 #include <vector>
47
48 using namespace llvm;
49
50 #define DEBUG_TYPE "CDS"
51 #include <llvm/IR/DebugLoc.h>
52
53 Value *getPosition( Instruction * I, IRBuilder <> IRB, bool print = false)
54 {
55         const DebugLoc & debug_location = I->getDebugLoc ();
56         std::string position_string;
57         {
58                 llvm::raw_string_ostream position_stream (position_string);
59                 debug_location . print (position_stream);
60         }
61
62         if (print) {
63                 errs() << position_string << "\n";
64         }
65
66         return IRB.CreateGlobalStringPtr (position_string);
67 }
68
69 STATISTIC(NumInstrumentedReads, "Number of instrumented reads");
70 STATISTIC(NumInstrumentedWrites, "Number of instrumented writes");
71 STATISTIC(NumAccessesWithBadSize, "Number of accesses with bad size");
72 // STATISTIC(NumInstrumentedVtableWrites, "Number of vtable ptr writes");
73 // STATISTIC(NumInstrumentedVtableReads, "Number of vtable ptr reads");
74
75 STATISTIC(NumOmittedReadsBeforeWrite,
76           "Number of reads ignored due to following writes");
77 STATISTIC(NumOmittedReadsFromConstantGlobals,
78           "Number of reads from constant globals");
79 STATISTIC(NumOmittedReadsFromVtable, "Number of vtable reads");
80 STATISTIC(NumOmittedNonCaptured, "Number of accesses ignored due to capturing");
81
82 Type * OrdTy;
83
84 Type * Int8PtrTy;
85 Type * Int16PtrTy;
86 Type * Int32PtrTy;
87 Type * Int64PtrTy;
88
89 Type * VoidTy;
90
91 static const size_t kNumberOfAccessSizes = 4;
92 static const int volatile_order = 6;
93
94 int getAtomicOrderIndex(AtomicOrdering order){
95         switch (order) {
96                 case AtomicOrdering::Monotonic: 
97                         return (int)AtomicOrderingCABI::relaxed;
98                 //  case AtomicOrdering::Consume:         // not specified yet
99                 //    return AtomicOrderingCABI::consume;
100                 case AtomicOrdering::Acquire: 
101                         return (int)AtomicOrderingCABI::acquire;
102                 case AtomicOrdering::Release: 
103                         return (int)AtomicOrderingCABI::release;
104                 case AtomicOrdering::AcquireRelease: 
105                         return (int)AtomicOrderingCABI::acq_rel;
106                 case AtomicOrdering::SequentiallyConsistent: 
107                         return (int)AtomicOrderingCABI::seq_cst;
108                 default:
109                         // unordered or Not Atomic
110                         return -1;
111         }
112 }
113
114 namespace {
115         struct CDSPass : public FunctionPass {
116                 static char ID;
117                 CDSPass() : FunctionPass(ID) {}
118                 bool runOnFunction(Function &F) override; 
119
120         private:
121                 void initializeCallbacks(Module &M);
122                 bool instrumentLoadOrStore(Instruction *I, const DataLayout &DL);
123                 bool instrumentVolatile(Instruction *I, const DataLayout &DL);
124                 bool isAtomicCall(Instruction *I);
125                 bool instrumentAtomic(Instruction *I, const DataLayout &DL);
126                 bool instrumentAtomicCall(CallInst *CI, const DataLayout &DL);
127                 void chooseInstructionsToInstrument(SmallVectorImpl<Instruction *> &Local,
128                                                                                         SmallVectorImpl<Instruction *> &All,
129                                                                                         const DataLayout &DL);
130                 bool addrPointsToConstantData(Value *Addr);
131                 int getMemoryAccessFuncIndex(Value *Addr, const DataLayout &DL);
132
133                 // Callbacks to run-time library are computed in doInitialization.
134                 Constant * CDSFuncEntry;
135                 Constant * CDSFuncExit;
136
137                 Constant * CDSLoad[kNumberOfAccessSizes];
138                 Constant * CDSStore[kNumberOfAccessSizes];
139                 Constant * CDSVolatileLoad[kNumberOfAccessSizes];
140                 Constant * CDSVolatileStore[kNumberOfAccessSizes];
141                 Constant * CDSAtomicInit[kNumberOfAccessSizes];
142                 Constant * CDSAtomicLoad[kNumberOfAccessSizes];
143                 Constant * CDSAtomicStore[kNumberOfAccessSizes];
144                 Constant * CDSAtomicRMW[AtomicRMWInst::LAST_BINOP + 1][kNumberOfAccessSizes];
145                 Constant * CDSAtomicCAS_V1[kNumberOfAccessSizes];
146                 Constant * CDSAtomicCAS_V2[kNumberOfAccessSizes];
147                 Constant * CDSAtomicThreadFence;
148
149                 std::vector<StringRef> AtomicFuncNames;
150                 std::vector<StringRef> PartialAtomicFuncNames;
151         };
152 }
153
154 static bool isVtableAccess(Instruction *I) {
155         if (MDNode *Tag = I->getMetadata(LLVMContext::MD_tbaa))
156                 return Tag->isTBAAVtableAccess();
157         return false;
158 }
159
160 void CDSPass::initializeCallbacks(Module &M) {
161         LLVMContext &Ctx = M.getContext();
162
163         Type * Int1Ty = Type::getInt1Ty(Ctx);
164         OrdTy = Type::getInt32Ty(Ctx);
165
166         Int8PtrTy  = Type::getInt8PtrTy(Ctx);
167         Int16PtrTy = Type::getInt16PtrTy(Ctx);
168         Int32PtrTy = Type::getInt32PtrTy(Ctx);
169         Int64PtrTy = Type::getInt64PtrTy(Ctx);
170
171         VoidTy = Type::getVoidTy(Ctx);
172
173         CDSFuncEntry = M.getOrInsertFunction("cds_func_entry", 
174                                                                 VoidTy, Int8PtrTy);
175         CDSFuncExit = M.getOrInsertFunction("cds_func_exit", 
176                                                                 VoidTy, Int8PtrTy);
177
178         // Get the function to call from our untime library.
179         for (unsigned i = 0; i < kNumberOfAccessSizes; i++) {
180                 const unsigned ByteSize = 1U << i;
181                 const unsigned BitSize = ByteSize * 8;
182
183                 std::string ByteSizeStr = utostr(ByteSize);
184                 std::string BitSizeStr = utostr(BitSize);
185
186                 Type *Ty = Type::getIntNTy(Ctx, BitSize);
187                 Type *PtrTy = Ty->getPointerTo();
188
189                 // uint8_t cds_atomic_load8 (void * obj, int atomic_index)
190                 // void cds_atomic_store8 (void * obj, int atomic_index, uint8_t val)
191                 SmallString<32> LoadName("cds_load" + BitSizeStr);
192                 SmallString<32> StoreName("cds_store" + BitSizeStr);
193                 SmallString<32> VolatileLoadName("cds_volatile_load" + BitSizeStr);
194                 SmallString<32> VolatileStoreName("cds_volatile_store" + BitSizeStr);
195                 SmallString<32> AtomicInitName("cds_atomic_init" + BitSizeStr);
196                 SmallString<32> AtomicLoadName("cds_atomic_load" + BitSizeStr);
197                 SmallString<32> AtomicStoreName("cds_atomic_store" + BitSizeStr);
198
199                 CDSLoad[i]  = M.getOrInsertFunction(LoadName, VoidTy, PtrTy);
200                 CDSStore[i] = M.getOrInsertFunction(StoreName, VoidTy, PtrTy);
201                 CDSVolatileLoad[i]  = M.getOrInsertFunction(VolatileLoadName,
202                                                                         Ty, PtrTy, OrdTy, Int8PtrTy);
203                 CDSVolatileStore[i] = M.getOrInsertFunction(VolatileStoreName, 
204                                                                         VoidTy, PtrTy, Ty, OrdTy, Int8PtrTy);
205                 CDSAtomicInit[i] = M.getOrInsertFunction(AtomicInitName, 
206                                                                 VoidTy, PtrTy, Ty, Int8PtrTy);
207                 CDSAtomicLoad[i]  = M.getOrInsertFunction(AtomicLoadName, 
208                                                                 Ty, PtrTy, OrdTy, Int8PtrTy);
209                 CDSAtomicStore[i] = M.getOrInsertFunction(AtomicStoreName, 
210                                                                 VoidTy, PtrTy, Ty, OrdTy, Int8PtrTy);
211
212                 for (int op = AtomicRMWInst::FIRST_BINOP; 
213                         op <= AtomicRMWInst::LAST_BINOP; ++op) {
214                         CDSAtomicRMW[op][i] = nullptr;
215                         std::string NamePart;
216
217                         if (op == AtomicRMWInst::Xchg)
218                                 NamePart = "_exchange";
219                         else if (op == AtomicRMWInst::Add) 
220                                 NamePart = "_fetch_add";
221                         else if (op == AtomicRMWInst::Sub)
222                                 NamePart = "_fetch_sub";
223                         else if (op == AtomicRMWInst::And)
224                                 NamePart = "_fetch_and";
225                         else if (op == AtomicRMWInst::Or)
226                                 NamePart = "_fetch_or";
227                         else if (op == AtomicRMWInst::Xor)
228                                 NamePart = "_fetch_xor";
229                         else
230                                 continue;
231
232                         SmallString<32> AtomicRMWName("cds_atomic" + NamePart + BitSizeStr);
233                         CDSAtomicRMW[op][i] = M.getOrInsertFunction(AtomicRMWName, 
234                                                                                 Ty, PtrTy, Ty, OrdTy, Int8PtrTy);
235                 }
236
237                 // only supportes strong version
238                 SmallString<32> AtomicCASName_V1("cds_atomic_compare_exchange" + BitSizeStr + "_v1");
239                 SmallString<32> AtomicCASName_V2("cds_atomic_compare_exchange" + BitSizeStr + "_v2");
240                 CDSAtomicCAS_V1[i] = M.getOrInsertFunction(AtomicCASName_V1, 
241                                                                 Ty, PtrTy, Ty, Ty, OrdTy, OrdTy, Int8PtrTy);
242                 CDSAtomicCAS_V2[i] = M.getOrInsertFunction(AtomicCASName_V2, 
243                                                                 Int1Ty, PtrTy, PtrTy, Ty, OrdTy, OrdTy, Int8PtrTy);
244         }
245
246         CDSAtomicThreadFence = M.getOrInsertFunction("cds_atomic_thread_fence", 
247                                                                                                         VoidTy, OrdTy, Int8PtrTy);
248 }
249
250 static bool shouldInstrumentReadWriteFromAddress(const Module *M, Value *Addr) {
251         // Peel off GEPs and BitCasts.
252         Addr = Addr->stripInBoundsOffsets();
253
254         if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) {
255                 if (GV->hasSection()) {
256                         StringRef SectionName = GV->getSection();
257                         // Check if the global is in the PGO counters section.
258                         auto OF = Triple(M->getTargetTriple()).getObjectFormat();
259                         if (SectionName.endswith(
260                               getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false)))
261                                 return false;
262                 }
263
264                 // Check if the global is private gcov data.
265                 if (GV->getName().startswith("__llvm_gcov") ||
266                 GV->getName().startswith("__llvm_gcda"))
267                 return false;
268         }
269
270         // Do not instrument acesses from different address spaces; we cannot deal
271         // with them.
272         if (Addr) {
273                 Type *PtrTy = cast<PointerType>(Addr->getType()->getScalarType());
274                 if (PtrTy->getPointerAddressSpace() != 0)
275                         return false;
276         }
277
278         return true;
279 }
280
281 bool CDSPass::addrPointsToConstantData(Value *Addr) {
282         // If this is a GEP, just analyze its pointer operand.
283         if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Addr))
284                 Addr = GEP->getPointerOperand();
285
286         if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) {
287                 if (GV->isConstant()) {
288                         // Reads from constant globals can not race with any writes.
289                         NumOmittedReadsFromConstantGlobals++;
290                         return true;
291                 }
292         } else if (LoadInst *L = dyn_cast<LoadInst>(Addr)) {
293                 if (isVtableAccess(L)) {
294                         // Reads from a vtable pointer can not race with any writes.
295                         NumOmittedReadsFromVtable++;
296                         return true;
297                 }
298         }
299         return false;
300 }
301
302 bool CDSPass::runOnFunction(Function &F) {
303         if (F.getName() == "main") {
304                 F.setName("user_main");
305                 errs() << "main replaced by user_main\n";
306         }
307
308         if (true) {
309                 initializeCallbacks( *F.getParent() );
310
311                 AtomicFuncNames = 
312                 {
313                         "atomic_init", "atomic_load", "atomic_store", 
314                         "atomic_fetch_", "atomic_exchange", "atomic_compare_exchange_"
315                 };
316
317                 PartialAtomicFuncNames = 
318                 { 
319                         "load", "store", "fetch", "exchange", "compare_exchange_" 
320                 };
321
322                 SmallVector<Instruction*, 8> AllLoadsAndStores;
323                 SmallVector<Instruction*, 8> LocalLoadsAndStores;
324                 SmallVector<Instruction*, 8> VolatileLoadsAndStores;
325                 SmallVector<Instruction*, 8> AtomicAccesses;
326
327                 std::vector<Instruction *> worklist;
328
329                 bool Res = false;
330                 bool HasAtomic = false;
331                 const DataLayout &DL = F.getParent()->getDataLayout();
332
333                 // errs() << "--- " << F.getName() << "---\n";
334
335                 for (auto &B : F) {
336                         for (auto &I : B) {
337                                 if ( (&I)->isAtomic() || isAtomicCall(&I) ) {
338                                         AtomicAccesses.push_back(&I);
339                                         HasAtomic = true;
340                                 } else if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
341                                         LoadInst *LI = dyn_cast<LoadInst>(&I);
342                                         StoreInst *SI = dyn_cast<StoreInst>(&I);
343                                         bool isVolatile = ( LI ? LI->isVolatile() : SI->isVolatile() );
344
345                                         if (isVolatile)
346                                                 VolatileLoadsAndStores.push_back(&I);
347                                         else
348                                                 LocalLoadsAndStores.push_back(&I);
349                                 } else if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
350                                         // not implemented yet
351                                 }
352                         }
353
354                         chooseInstructionsToInstrument(LocalLoadsAndStores, AllLoadsAndStores, DL);
355                 }
356
357                 for (auto Inst : AllLoadsAndStores) {
358                         Res |= instrumentLoadOrStore(Inst, DL);
359                 }
360
361                 for (auto Inst : VolatileLoadsAndStores) {
362                         Res |= instrumentVolatile(Inst, DL);
363                 }
364
365                 for (auto Inst : AtomicAccesses) {
366                         Res |= instrumentAtomic(Inst, DL);
367                 }
368
369                 // only instrument functions that contain atomics
370                 if (Res && HasAtomic) {
371                         IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI());
372                         /* Unused for now
373                         Value *ReturnAddress = IRB.CreateCall(
374                                 Intrinsic::getDeclaration(F.getParent(), Intrinsic::returnaddress),
375                                 IRB.getInt32(0));
376                         */
377
378                         Value * FuncName = IRB.CreateGlobalStringPtr(F.getName());
379                         IRB.CreateCall(CDSFuncEntry, FuncName);
380
381                         EscapeEnumerator EE(F, "cds_cleanup", true);
382                         while (IRBuilder<> *AtExit = EE.Next()) {
383                           AtExit->CreateCall(CDSFuncExit, FuncName);
384                         }
385
386                         Res = true;
387                 }
388
389                 F.dump();
390         }
391
392         return false;
393 }
394
395 void CDSPass::chooseInstructionsToInstrument(
396         SmallVectorImpl<Instruction *> &Local, SmallVectorImpl<Instruction *> &All,
397         const DataLayout &DL) {
398         SmallPtrSet<Value*, 8> WriteTargets;
399         // Iterate from the end.
400         for (Instruction *I : reverse(Local)) {
401                 if (StoreInst *Store = dyn_cast<StoreInst>(I)) {
402                         Value *Addr = Store->getPointerOperand();
403                         if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
404                                 continue;
405                         WriteTargets.insert(Addr);
406                 } else {
407                         LoadInst *Load = cast<LoadInst>(I);
408                         Value *Addr = Load->getPointerOperand();
409                         if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
410                                 continue;
411                         if (WriteTargets.count(Addr)) {
412                                 // We will write to this temp, so no reason to analyze the read.
413                                 NumOmittedReadsBeforeWrite++;
414                                 continue;
415                         }
416                         if (addrPointsToConstantData(Addr)) {
417                                 // Addr points to some constant data -- it can not race with any writes.
418                                 continue;
419                         }
420                 }
421                 Value *Addr = isa<StoreInst>(*I)
422                         ? cast<StoreInst>(I)->getPointerOperand()
423                         : cast<LoadInst>(I)->getPointerOperand();
424                 if (isa<AllocaInst>(GetUnderlyingObject(Addr, DL)) &&
425                                 !PointerMayBeCaptured(Addr, true, true)) {
426                         // The variable is addressable but not captured, so it cannot be
427                         // referenced from a different thread and participate in a data race
428                         // (see llvm/Analysis/CaptureTracking.h for details).
429                         NumOmittedNonCaptured++;
430                         continue;
431                 }
432                 All.push_back(I);
433         }
434         Local.clear();
435 }
436
437
438 bool CDSPass::instrumentLoadOrStore(Instruction *I,
439                                                                         const DataLayout &DL) {
440         IRBuilder<> IRB(I);
441         bool IsWrite = isa<StoreInst>(*I);
442         Value *Addr = IsWrite
443                 ? cast<StoreInst>(I)->getPointerOperand()
444                 : cast<LoadInst>(I)->getPointerOperand();
445
446         // swifterror memory addresses are mem2reg promoted by instruction selection.
447         // As such they cannot have regular uses like an instrumentation function and
448         // it makes no sense to track them as memory.
449         if (Addr->isSwiftError())
450         return false;
451
452         int Idx = getMemoryAccessFuncIndex(Addr, DL);
453         if (Idx < 0)
454                 return false;
455
456 //  not supported by CDS yet
457 /*  if (IsWrite && isVtableAccess(I)) {
458     LLVM_DEBUG(dbgs() << "  VPTR : " << *I << "\n");
459     Value *StoredValue = cast<StoreInst>(I)->getValueOperand();
460     // StoredValue may be a vector type if we are storing several vptrs at once.
461     // In this case, just take the first element of the vector since this is
462     // enough to find vptr races.
463     if (isa<VectorType>(StoredValue->getType()))
464       StoredValue = IRB.CreateExtractElement(
465           StoredValue, ConstantInt::get(IRB.getInt32Ty(), 0));
466     if (StoredValue->getType()->isIntegerTy())
467       StoredValue = IRB.CreateIntToPtr(StoredValue, IRB.getInt8PtrTy());
468     // Call TsanVptrUpdate.
469     IRB.CreateCall(TsanVptrUpdate,
470                    {IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()),
471                     IRB.CreatePointerCast(StoredValue, IRB.getInt8PtrTy())});
472     NumInstrumentedVtableWrites++;
473     return true;
474   }
475
476   if (!IsWrite && isVtableAccess(I)) {
477     IRB.CreateCall(TsanVptrLoad,
478                    IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()));
479     NumInstrumentedVtableReads++;
480     return true;
481   }
482 */
483
484         Value *OnAccessFunc = nullptr;
485         OnAccessFunc = IsWrite ? CDSStore[Idx] : CDSLoad[Idx];
486
487         Type *ArgType = IRB.CreatePointerCast(Addr, Addr->getType())->getType();
488
489         if ( ArgType != Int8PtrTy && ArgType != Int16PtrTy && 
490                         ArgType != Int32PtrTy && ArgType != Int64PtrTy ) {
491                 // if other types of load or stores are passed in
492                 return false;   
493         }
494         IRB.CreateCall(OnAccessFunc, IRB.CreatePointerCast(Addr, Addr->getType()));
495         if (IsWrite) NumInstrumentedWrites++;
496         else         NumInstrumentedReads++;
497         return true;
498 }
499
500 bool CDSPass::instrumentVolatile(Instruction * I, const DataLayout &DL) {
501         IRBuilder<> IRB(I);
502         Value *position = getPosition(I, IRB);
503
504         if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
505                 assert( LI->isVolatile() );
506                 Value *Addr = LI->getPointerOperand();
507                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
508                 if (Idx < 0)
509                         return false;
510
511                 Value *order = ConstantInt::get(OrdTy, volatile_order);
512                 Value *args[] = {Addr, order, position};
513                 Instruction* funcInst=CallInst::Create(CDSVolatileLoad[Idx], args);
514                 ReplaceInstWithInst(LI, funcInst);
515         } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
516                 assert( SI->isVolatile() );
517                 Value *Addr = SI->getPointerOperand();
518                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
519                 if (Idx < 0)
520                         return false;
521
522                 Value *val = SI->getValueOperand();
523                 Value *order = ConstantInt::get(OrdTy, volatile_order);
524                 Value *args[] = {Addr, val, order, position};
525                 Instruction* funcInst=CallInst::Create(CDSVolatileStore[Idx], args);
526                 ReplaceInstWithInst(SI, funcInst);
527         } else {
528                 return false;
529         }
530
531         return true;
532 }
533
534 bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
535         IRBuilder<> IRB(I);
536
537         if (auto *CI = dyn_cast<CallInst>(I)) {
538                 return instrumentAtomicCall(CI, DL);
539         }
540
541         Value *position = getPosition(I, IRB);
542
543         if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
544                 Value *Addr = LI->getPointerOperand();
545                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
546                 if (Idx < 0)
547                         return false;
548
549                 int atomic_order_index = getAtomicOrderIndex(LI->getOrdering());
550                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
551                 Value *args[] = {Addr, order, position};
552                 Instruction* funcInst=CallInst::Create(CDSAtomicLoad[Idx], args);
553                 ReplaceInstWithInst(LI, funcInst);
554         } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
555                 Value *Addr = SI->getPointerOperand();
556                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
557                 if (Idx < 0)
558                         return false;
559
560                 int atomic_order_index = getAtomicOrderIndex(SI->getOrdering());
561                 Value *val = SI->getValueOperand();
562                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
563                 Value *args[] = {Addr, val, order, position};
564                 Instruction* funcInst=CallInst::Create(CDSAtomicStore[Idx], args);
565                 ReplaceInstWithInst(SI, funcInst);
566         } else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) {
567                 Value *Addr = RMWI->getPointerOperand();
568                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
569                 if (Idx < 0)
570                         return false;
571
572                 int atomic_order_index = getAtomicOrderIndex(RMWI->getOrdering());
573                 Value *val = RMWI->getValOperand();
574                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
575                 Value *args[] = {Addr, val, order, position};
576                 Instruction* funcInst = CallInst::Create(CDSAtomicRMW[RMWI->getOperation()][Idx], args);
577                 ReplaceInstWithInst(RMWI, funcInst);
578         } else if (AtomicCmpXchgInst *CASI = dyn_cast<AtomicCmpXchgInst>(I)) {
579                 IRBuilder<> IRB(CASI);
580
581                 Value *Addr = CASI->getPointerOperand();
582                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
583                 if (Idx < 0)
584                         return false;
585
586                 const unsigned ByteSize = 1U << Idx;
587                 const unsigned BitSize = ByteSize * 8;
588                 Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
589                 Type *PtrTy = Ty->getPointerTo();
590
591                 Value *CmpOperand = IRB.CreateBitOrPointerCast(CASI->getCompareOperand(), Ty);
592                 Value *NewOperand = IRB.CreateBitOrPointerCast(CASI->getNewValOperand(), Ty);
593
594                 int atomic_order_index_succ = getAtomicOrderIndex(CASI->getSuccessOrdering());
595                 int atomic_order_index_fail = getAtomicOrderIndex(CASI->getFailureOrdering());
596                 Value *order_succ = ConstantInt::get(OrdTy, atomic_order_index_succ);
597                 Value *order_fail = ConstantInt::get(OrdTy, atomic_order_index_fail);
598
599                 Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy),
600                                                  CmpOperand, NewOperand,
601                                                  order_succ, order_fail, position};
602
603                 CallInst *funcInst = IRB.CreateCall(CDSAtomicCAS_V1[Idx], Args);
604                 Value *Success = IRB.CreateICmpEQ(funcInst, CmpOperand);
605
606                 Value *OldVal = funcInst;
607                 Type *OrigOldValTy = CASI->getNewValOperand()->getType();
608                 if (Ty != OrigOldValTy) {
609                         // The value is a pointer, so we need to cast the return value.
610                         OldVal = IRB.CreateIntToPtr(funcInst, OrigOldValTy);
611                 }
612
613                 Value *Res =
614                   IRB.CreateInsertValue(UndefValue::get(CASI->getType()), OldVal, 0);
615                 Res = IRB.CreateInsertValue(Res, Success, 1);
616
617                 I->replaceAllUsesWith(Res);
618                 I->eraseFromParent();
619         } else if (FenceInst *FI = dyn_cast<FenceInst>(I)) {
620                 int atomic_order_index = getAtomicOrderIndex(FI->getOrdering());
621                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
622                 Value *Args[] = {order, position};
623
624                 CallInst *funcInst = CallInst::Create(CDSAtomicThreadFence, Args);
625                 ReplaceInstWithInst(FI, funcInst);
626                 // errs() << "Thread Fences replaced\n";
627         }
628         return true;
629 }
630
631 bool CDSPass::isAtomicCall(Instruction *I) {
632         if ( auto *CI = dyn_cast<CallInst>(I) ) {
633                 Function *fun = CI->getCalledFunction();
634                 if (fun == NULL)
635                         return false;
636
637                 StringRef funName = fun->getName();
638
639                 // todo: come up with better rules for function name checking
640                 for (StringRef name : AtomicFuncNames) {
641                         if ( funName.contains(name) ) 
642                                 return true;
643                 }
644                 
645                 for (StringRef PartialName : PartialAtomicFuncNames) {
646                         if (funName.contains(PartialName) && 
647                                         funName.contains("atomic") )
648                                 return true;
649                 }
650         }
651
652         return false;
653 }
654
655 bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
656         IRBuilder<> IRB(CI);
657         Function *fun = CI->getCalledFunction();
658         StringRef funName = fun->getName();
659         std::vector<Value *> parameters;
660
661         User::op_iterator begin = CI->arg_begin();
662         User::op_iterator end = CI->arg_end();
663         for (User::op_iterator it = begin; it != end; ++it) {
664                 Value *param = *it;
665                 parameters.push_back(param);
666         }
667
668         // obtain source line number of the CallInst
669         Value *position = getPosition(CI, IRB);
670
671         // the pointer to the address is always the first argument
672         Value *OrigPtr = parameters[0];
673
674         int Idx = getMemoryAccessFuncIndex(OrigPtr, DL);
675         if (Idx < 0)
676                 return false;
677
678         const unsigned ByteSize = 1U << Idx;
679         const unsigned BitSize = ByteSize * 8;
680         Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
681         Type *PtrTy = Ty->getPointerTo();
682
683         // atomic_init; args = {obj, order}
684         if (funName.contains("atomic_init")) {
685                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
686                 Value *val = IRB.CreateBitOrPointerCast(parameters[1], Ty);
687                 Value *args[] = {ptr, val, position};
688
689                 Instruction* funcInst = CallInst::Create(CDSAtomicInit[Idx], args);
690                 ReplaceInstWithInst(CI, funcInst);
691
692                 return true;
693         }
694
695         // atomic_load; args = {obj, order}
696         if (funName.contains("atomic_load")) {
697                 bool isExplicit = funName.contains("atomic_load_explicit");
698
699                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
700                 Value *order;
701                 if (isExplicit)
702                         order = IRB.CreateBitOrPointerCast(parameters[1], OrdTy);
703                 else 
704                         order = ConstantInt::get(OrdTy, 
705                                                         (int) AtomicOrderingCABI::seq_cst);
706                 Value *args[] = {ptr, order, position};
707                 
708                 Instruction* funcInst = CallInst::Create(CDSAtomicLoad[Idx], args);
709                 ReplaceInstWithInst(CI, funcInst);
710
711                 return true;
712         } else if (funName.contains("atomic") && 
713                                         funName.contains("load") ) {
714                 // does this version of call always have an atomic order as an argument?
715                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
716                 Value *order = IRB.CreateBitOrPointerCast(parameters[1], OrdTy);
717                 Value *args[] = {ptr, order, position};
718
719                 if (!CI->getType()->isPointerTy()) {
720                         return false;   
721                 } 
722
723                 CallInst *funcInst = IRB.CreateCall(CDSAtomicLoad[Idx], args);
724                 Value *RetVal = IRB.CreateIntToPtr(funcInst, CI->getType());
725
726                 CI->replaceAllUsesWith(RetVal);
727                 CI->eraseFromParent();
728
729                 return true;
730         }
731
732         // atomic_store; args = {obj, val, order}
733         if (funName.contains("atomic_store")) {
734                 bool isExplicit = funName.contains("atomic_store_explicit");
735                 Value *OrigVal = parameters[1];
736
737                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
738                 Value *val = IRB.CreatePointerCast(OrigVal, Ty);
739                 Value *order;
740                 if (isExplicit)
741                         order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
742                 else 
743                         order = ConstantInt::get(OrdTy, 
744                                                         (int) AtomicOrderingCABI::seq_cst);
745                 Value *args[] = {ptr, val, order, position};
746                 
747                 Instruction* funcInst = CallInst::Create(CDSAtomicStore[Idx], args);
748                 ReplaceInstWithInst(CI, funcInst);
749
750                 return true;
751         } else if (funName.contains("atomic") && 
752                                         funName.contains("EEEE5store") ) {
753                 // does this version of call always have an atomic order as an argument?
754                 Value *OrigVal = parameters[1];
755
756                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
757                 Value *val = IRB.CreatePointerCast(OrigVal, Ty);
758                 Value *order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
759                 Value *args[] = {ptr, val, order, position};
760
761                 Instruction* funcInst = CallInst::Create(CDSAtomicStore[Idx], args);
762                 ReplaceInstWithInst(CI, funcInst);
763
764                 return true;
765         }
766
767         // atomic_fetch_*; args = {obj, val, order}
768         if (funName.contains("atomic_fetch_") || 
769                         funName.contains("atomic_exchange") ) {
770                 bool isExplicit = funName.contains("_explicit");
771                 Value *OrigVal = parameters[1];
772
773                 int op;
774                 if ( funName.contains("_fetch_add") )
775                         op = AtomicRMWInst::Add;
776                 else if ( funName.contains("_fetch_sub") )
777                         op = AtomicRMWInst::Sub;
778                 else if ( funName.contains("_fetch_and") )
779                         op = AtomicRMWInst::And;
780                 else if ( funName.contains("_fetch_or") )
781                         op = AtomicRMWInst::Or;
782                 else if ( funName.contains("_fetch_xor") )
783                         op = AtomicRMWInst::Xor;
784                 else if ( funName.contains("atomic_exchange") )
785                         op = AtomicRMWInst::Xchg;
786                 else {
787                         errs() << "Unknown atomic read-modify-write operation\n";
788                         return false;
789                 }
790
791                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
792                 Value *val = IRB.CreatePointerCast(OrigVal, Ty);
793                 Value *order;
794                 if (isExplicit)
795                         order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
796                 else 
797                         order = ConstantInt::get(OrdTy, 
798                                                         (int) AtomicOrderingCABI::seq_cst);
799                 Value *args[] = {ptr, val, order, position};
800                 
801                 Instruction* funcInst = CallInst::Create(CDSAtomicRMW[op][Idx], args);
802                 ReplaceInstWithInst(CI, funcInst);
803
804                 return true;
805         } else if (funName.contains("fetch")) {
806                 errs() << "atomic exchange captured. Not implemented yet. ";
807                 errs() << "See source file :";
808                 getPosition(CI, IRB, true);
809         } else if (funName.contains("exchange") &&
810                         !funName.contains("compare_exchange") ) {
811                 errs() << "atomic exchange captured. Not implemented yet. ";
812                 errs() << "See source file :";
813                 getPosition(CI, IRB, true);
814         }
815
816         /* atomic_compare_exchange_*; 
817            args = {obj, expected, new value, order1, order2}
818         */
819         if ( funName.contains("atomic_compare_exchange_") ) {
820                 bool isExplicit = funName.contains("_explicit");
821
822                 Value *Addr = IRB.CreatePointerCast(OrigPtr, PtrTy);
823                 Value *CmpOperand = IRB.CreatePointerCast(parameters[1], PtrTy);
824                 Value *NewOperand = IRB.CreateBitOrPointerCast(parameters[2], Ty);
825
826                 Value *order_succ, *order_fail;
827                 if (isExplicit) {
828                         order_succ = IRB.CreateBitOrPointerCast(parameters[3], OrdTy);
829                         order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
830                 } else  {
831                         order_succ = ConstantInt::get(OrdTy, 
832                                                         (int) AtomicOrderingCABI::seq_cst);
833                         order_fail = ConstantInt::get(OrdTy, 
834                                                         (int) AtomicOrderingCABI::seq_cst);
835                 }
836
837                 Value *args[] = {Addr, CmpOperand, NewOperand, 
838                                                         order_succ, order_fail, position};
839                 
840                 Instruction* funcInst = CallInst::Create(CDSAtomicCAS_V2[Idx], args);
841                 ReplaceInstWithInst(CI, funcInst);
842
843                 return true;
844         } else if ( funName.contains("compare_exchange_strong") ||
845                                 funName.contains("compare_exchange_weak") ) {
846                 Value *Addr = IRB.CreatePointerCast(OrigPtr, PtrTy);
847                 Value *CmpOperand = IRB.CreatePointerCast(parameters[1], PtrTy);
848                 Value *NewOperand = IRB.CreateBitOrPointerCast(parameters[2], Ty);
849
850                 Value *order_succ, *order_fail;
851                 order_succ = IRB.CreateBitOrPointerCast(parameters[3], OrdTy);
852                 order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
853
854                 Value *args[] = {Addr, CmpOperand, NewOperand, 
855                                                         order_succ, order_fail, position};
856                 Instruction* funcInst = CallInst::Create(CDSAtomicCAS_V2[Idx], args);
857                 ReplaceInstWithInst(CI, funcInst);
858
859                 return true;
860         }
861
862         return false;
863 }
864
865 int CDSPass::getMemoryAccessFuncIndex(Value *Addr,
866                                                                                 const DataLayout &DL) {
867         Type *OrigPtrTy = Addr->getType();
868         Type *OrigTy = cast<PointerType>(OrigPtrTy)->getElementType();
869         assert(OrigTy->isSized());
870         uint32_t TypeSize = DL.getTypeStoreSizeInBits(OrigTy);
871         if (TypeSize != 8  && TypeSize != 16 &&
872                 TypeSize != 32 && TypeSize != 64 && TypeSize != 128) {
873                 NumAccessesWithBadSize++;
874                 // Ignore all unusual sizes.
875                 return -1;
876         }
877         size_t Idx = countTrailingZeros(TypeSize / 8);
878         //assert(Idx < kNumberOfAccessSizes);
879         if (Idx >= kNumberOfAccessSizes) {
880                 return -1;
881         }
882         return Idx;
883 }
884
885
886 char CDSPass::ID = 0;
887
888 // Automatically enable the pass.
889 static void registerCDSPass(const PassManagerBuilder &,
890                                                         legacy::PassManagerBase &PM) {
891         PM.add(new CDSPass());
892 }
893
894 /* Enable the pass when opt level is greater than 0 */
895 static RegisterStandardPasses 
896         RegisterMyPass1(PassManagerBuilder::EP_OptimizerLast,
897 registerCDSPass);
898
899 /* Enable the pass when opt level is 0 */
900 static RegisterStandardPasses 
901         RegisterMyPass2(PassManagerBuilder::EP_EnabledOnOptLevel0,
902 registerCDSPass);