eef87699e927fbb820880a6dfddbcceea31a987b
[c11llvm.git] / CDSPass.cpp
1 //===-- CDSPass.cpp - xxx -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 // This file is distributed under the University of Illinois Open Source
7 // License. See LICENSE.TXT for details.
8 //
9 //===----------------------------------------------------------------------===//
10 //
11 // This file is a modified version of ThreadSanitizer.cpp, a part of a race detector.
12 //
13 // The tool is under development, for the details about previous versions see
14 // http://code.google.com/p/data-race-test
15 //
16 // The instrumentation phase is quite simple:
17 //   - Insert calls to run-time library before every memory access.
18 //      - Optimizations may apply to avoid instrumenting some of the accesses.
19 //   - Insert calls at function entry/exit.
20 // The rest is handled by the run-time library.
21 //===----------------------------------------------------------------------===//
22
23 #include "llvm/ADT/Statistic.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/Analysis/ValueTracking.h"
27 #include "llvm/Analysis/CaptureTracking.h"
28 #include "llvm/IR/BasicBlock.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/LegacyPassManager.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/IR/PassManager.h"
36 #include "llvm/Pass.h"
37 #include "llvm/ProfileData/InstrProf.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include "llvm/Support/AtomicOrdering.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Transforms/Scalar.h"
42 #include "llvm/Transforms/Utils/Local.h"
43 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
44 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
45 #include "llvm/Transforms/Utils/EscapeEnumerator.h"
46 #include <vector>
47
48 using namespace llvm;
49
50 #define DEBUG_TYPE "CDS"
51 #include <llvm/IR/DebugLoc.h>
52
53 Value *getPosition( Instruction * I, IRBuilder <> IRB, bool print = false)
54 {
55         const DebugLoc & debug_location = I->getDebugLoc ();
56         std::string position_string;
57         {
58                 llvm::raw_string_ostream position_stream (position_string);
59                 debug_location . print (position_stream);
60         }
61
62         if (print) {
63                 errs() << position_string;
64         }
65
66         return IRB.CreateGlobalStringPtr (position_string);
67 }
68
69 STATISTIC(NumInstrumentedReads, "Number of instrumented reads");
70 STATISTIC(NumInstrumentedWrites, "Number of instrumented writes");
71 STATISTIC(NumAccessesWithBadSize, "Number of accesses with bad size");
72 // STATISTIC(NumInstrumentedVtableWrites, "Number of vtable ptr writes");
73 // STATISTIC(NumInstrumentedVtableReads, "Number of vtable ptr reads");
74
75 STATISTIC(NumOmittedReadsBeforeWrite,
76           "Number of reads ignored due to following writes");
77 STATISTIC(NumOmittedReadsFromConstantGlobals,
78           "Number of reads from constant globals");
79 STATISTIC(NumOmittedReadsFromVtable, "Number of vtable reads");
80 STATISTIC(NumOmittedNonCaptured, "Number of accesses ignored due to capturing");
81
82 Type * OrdTy;
83
84 Type * Int8PtrTy;
85 Type * Int16PtrTy;
86 Type * Int32PtrTy;
87 Type * Int64PtrTy;
88
89 Type * VoidTy;
90
91 static const size_t kNumberOfAccessSizes = 4;
92
93 int getAtomicOrderIndex(AtomicOrdering order){
94         switch (order) {
95                 case AtomicOrdering::Monotonic: 
96                         return (int)AtomicOrderingCABI::relaxed;
97                 //  case AtomicOrdering::Consume:         // not specified yet
98                 //    return AtomicOrderingCABI::consume;
99                 case AtomicOrdering::Acquire: 
100                         return (int)AtomicOrderingCABI::acquire;
101                 case AtomicOrdering::Release: 
102                         return (int)AtomicOrderingCABI::release;
103                 case AtomicOrdering::AcquireRelease: 
104                         return (int)AtomicOrderingCABI::acq_rel;
105                 case AtomicOrdering::SequentiallyConsistent: 
106                         return (int)AtomicOrderingCABI::seq_cst;
107                 default:
108                         // unordered or Not Atomic
109                         return -1;
110         }
111 }
112
113 namespace {
114         struct CDSPass : public FunctionPass {
115                 static char ID;
116                 CDSPass() : FunctionPass(ID) {}
117                 bool runOnFunction(Function &F) override; 
118
119         private:
120                 void initializeCallbacks(Module &M);
121                 bool instrumentLoadOrStore(Instruction *I, const DataLayout &DL);
122                 bool isAtomicCall(Instruction *I);
123                 bool instrumentAtomic(Instruction *I, const DataLayout &DL);
124                 bool instrumentAtomicCall(CallInst *CI, const DataLayout &DL);
125                 void chooseInstructionsToInstrument(SmallVectorImpl<Instruction *> &Local,
126                                                                                         SmallVectorImpl<Instruction *> &All,
127                                                                                         const DataLayout &DL);
128                 bool addrPointsToConstantData(Value *Addr);
129                 int getMemoryAccessFuncIndex(Value *Addr, const DataLayout &DL);
130
131                 // Callbacks to run-time library are computed in doInitialization.
132                 Constant * CDSFuncEntry;
133                 Constant * CDSFuncExit;
134
135                 Constant * CDSLoad[kNumberOfAccessSizes];
136                 Constant * CDSStore[kNumberOfAccessSizes];
137                 Constant * CDSAtomicInit[kNumberOfAccessSizes];
138                 Constant * CDSAtomicLoad[kNumberOfAccessSizes];
139                 Constant * CDSAtomicStore[kNumberOfAccessSizes];
140                 Constant * CDSAtomicRMW[AtomicRMWInst::LAST_BINOP + 1][kNumberOfAccessSizes];
141                 Constant * CDSAtomicCAS_V1[kNumberOfAccessSizes];
142                 Constant * CDSAtomicCAS_V2[kNumberOfAccessSizes];
143                 Constant * CDSAtomicThreadFence;
144
145                 std::vector<StringRef> AtomicFuncNames;
146                 std::vector<StringRef> PartialAtomicFuncNames;
147         };
148 }
149
150 static bool isVtableAccess(Instruction *I) {
151         if (MDNode *Tag = I->getMetadata(LLVMContext::MD_tbaa))
152                 return Tag->isTBAAVtableAccess();
153         return false;
154 }
155
156 void CDSPass::initializeCallbacks(Module &M) {
157         LLVMContext &Ctx = M.getContext();
158
159         Type * Int1Ty = Type::getInt1Ty(Ctx);
160         OrdTy = Type::getInt32Ty(Ctx);
161
162         Int8PtrTy  = Type::getInt8PtrTy(Ctx);
163         Int16PtrTy = Type::getInt16PtrTy(Ctx);
164         Int32PtrTy = Type::getInt32PtrTy(Ctx);
165         Int64PtrTy = Type::getInt64PtrTy(Ctx);
166
167         VoidTy = Type::getVoidTy(Ctx);
168
169         CDSFuncEntry = M.getOrInsertFunction("cds_func_entry", 
170                                                                 VoidTy, Int8PtrTy);
171         CDSFuncExit = M.getOrInsertFunction("cds_func_exit", 
172                                                                 VoidTy, Int8PtrTy);
173
174         // Get the function to call from our untime library.
175         for (unsigned i = 0; i < kNumberOfAccessSizes; i++) {
176                 const unsigned ByteSize = 1U << i;
177                 const unsigned BitSize = ByteSize * 8;
178
179                 std::string ByteSizeStr = utostr(ByteSize);
180                 std::string BitSizeStr = utostr(BitSize);
181
182                 Type *Ty = Type::getIntNTy(Ctx, BitSize);
183                 Type *PtrTy = Ty->getPointerTo();
184
185                 // uint8_t cds_atomic_load8 (void * obj, int atomic_index)
186                 // void cds_atomic_store8 (void * obj, int atomic_index, uint8_t val)
187                 SmallString<32> LoadName("cds_load" + BitSizeStr);
188                 SmallString<32> StoreName("cds_store" + BitSizeStr);
189                 SmallString<32> AtomicInitName("cds_atomic_init" + BitSizeStr);
190                 SmallString<32> AtomicLoadName("cds_atomic_load" + BitSizeStr);
191                 SmallString<32> AtomicStoreName("cds_atomic_store" + BitSizeStr);
192
193                 CDSLoad[i]  = M.getOrInsertFunction(LoadName, VoidTy, PtrTy);
194                 CDSStore[i] = M.getOrInsertFunction(StoreName, VoidTy, PtrTy);
195                 CDSAtomicInit[i] = M.getOrInsertFunction(AtomicInitName, 
196                                                                 VoidTy, PtrTy, Ty, Int8PtrTy);
197                 CDSAtomicLoad[i]  = M.getOrInsertFunction(AtomicLoadName, 
198                                                                 Ty, PtrTy, OrdTy, Int8PtrTy);
199                 CDSAtomicStore[i] = M.getOrInsertFunction(AtomicStoreName, 
200                                                                 VoidTy, PtrTy, Ty, OrdTy, Int8PtrTy);
201
202                 for (int op = AtomicRMWInst::FIRST_BINOP; 
203                         op <= AtomicRMWInst::LAST_BINOP; ++op) {
204                         CDSAtomicRMW[op][i] = nullptr;
205                         std::string NamePart;
206
207                         if (op == AtomicRMWInst::Xchg)
208                                 NamePart = "_exchange";
209                         else if (op == AtomicRMWInst::Add) 
210                                 NamePart = "_fetch_add";
211                         else if (op == AtomicRMWInst::Sub)
212                                 NamePart = "_fetch_sub";
213                         else if (op == AtomicRMWInst::And)
214                                 NamePart = "_fetch_and";
215                         else if (op == AtomicRMWInst::Or)
216                                 NamePart = "_fetch_or";
217                         else if (op == AtomicRMWInst::Xor)
218                                 NamePart = "_fetch_xor";
219                         else
220                                 continue;
221
222                         SmallString<32> AtomicRMWName("cds_atomic" + NamePart + BitSizeStr);
223                         CDSAtomicRMW[op][i] = M.getOrInsertFunction(AtomicRMWName, 
224                                                                                 Ty, PtrTy, Ty, OrdTy, Int8PtrTy);
225                 }
226
227                 // only supportes strong version
228                 SmallString<32> AtomicCASName_V1("cds_atomic_compare_exchange" + BitSizeStr + "_v1");
229                 SmallString<32> AtomicCASName_V2("cds_atomic_compare_exchange" + BitSizeStr + "_v2");
230                 CDSAtomicCAS_V1[i] = M.getOrInsertFunction(AtomicCASName_V1, 
231                                                                 Ty, PtrTy, Ty, Ty, OrdTy, OrdTy, Int8PtrTy);
232                 CDSAtomicCAS_V2[i] = M.getOrInsertFunction(AtomicCASName_V2, 
233                                                                 Int1Ty, PtrTy, PtrTy, Ty, OrdTy, OrdTy, Int8PtrTy);
234         }
235
236         CDSAtomicThreadFence = M.getOrInsertFunction("cds_atomic_thread_fence", 
237                                                                                                         VoidTy, OrdTy, Int8PtrTy);
238 }
239
240 static bool shouldInstrumentReadWriteFromAddress(const Module *M, Value *Addr) {
241         // Peel off GEPs and BitCasts.
242         Addr = Addr->stripInBoundsOffsets();
243
244         if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) {
245                 if (GV->hasSection()) {
246                         StringRef SectionName = GV->getSection();
247                         // Check if the global is in the PGO counters section.
248                         auto OF = Triple(M->getTargetTriple()).getObjectFormat();
249                         if (SectionName.endswith(
250                               getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false)))
251                                 return false;
252                 }
253
254                 // Check if the global is private gcov data.
255                 if (GV->getName().startswith("__llvm_gcov") ||
256                 GV->getName().startswith("__llvm_gcda"))
257                 return false;
258         }
259
260         // Do not instrument acesses from different address spaces; we cannot deal
261         // with them.
262         if (Addr) {
263                 Type *PtrTy = cast<PointerType>(Addr->getType()->getScalarType());
264                 if (PtrTy->getPointerAddressSpace() != 0)
265                         return false;
266         }
267
268         return true;
269 }
270
271 bool CDSPass::addrPointsToConstantData(Value *Addr) {
272         // If this is a GEP, just analyze its pointer operand.
273         if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Addr))
274                 Addr = GEP->getPointerOperand();
275
276         if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) {
277                 if (GV->isConstant()) {
278                         // Reads from constant globals can not race with any writes.
279                         NumOmittedReadsFromConstantGlobals++;
280                         return true;
281                 }
282         } else if (LoadInst *L = dyn_cast<LoadInst>(Addr)) {
283                 if (isVtableAccess(L)) {
284                         // Reads from a vtable pointer can not race with any writes.
285                         NumOmittedReadsFromVtable++;
286                         return true;
287                 }
288         }
289         return false;
290 }
291
292 bool CDSPass::runOnFunction(Function &F) {
293         if (F.getName() == "main") {
294                 F.setName("user_main");
295                 errs() << "main replaced by user_main\n";
296         }
297
298         if (true) {
299                 initializeCallbacks( *F.getParent() );
300
301                 AtomicFuncNames = 
302                 {
303                         "atomic_init", "atomic_load", "atomic_store", 
304                         "atomic_fetch_", "atomic_exchange", "atomic_compare_exchange_"
305                 };
306
307                 PartialAtomicFuncNames = 
308                 { 
309                         "load", "store", "fetch", "exchange", "compare_exchange_" 
310                 };
311
312                 SmallVector<Instruction*, 8> AllLoadsAndStores;
313                 SmallVector<Instruction*, 8> LocalLoadsAndStores;
314                 SmallVector<Instruction*, 8> AtomicAccesses;
315
316                 std::vector<Instruction *> worklist;
317
318                 bool Res = false;
319                 bool HasAtomic = false;
320                 const DataLayout &DL = F.getParent()->getDataLayout();
321
322                 // errs() << "--- " << F.getName() << "---\n";
323
324                 for (auto &B : F) {
325                         for (auto &I : B) {
326                                 if ( (&I)->isAtomic() || isAtomicCall(&I) ) {
327                                         AtomicAccesses.push_back(&I);
328                                         HasAtomic = true;
329                                 } else if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
330                                         LocalLoadsAndStores.push_back(&I);
331                                 } else if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
332                                         // not implemented yet
333                                 }
334                         }
335
336                         chooseInstructionsToInstrument(LocalLoadsAndStores, AllLoadsAndStores, DL);
337                 }
338
339                 for (auto Inst : AllLoadsAndStores) {
340                         // Res |= instrumentLoadOrStore(Inst, DL);
341                         // errs() << "load and store are replaced\n";
342                 }
343
344                 for (auto Inst : AtomicAccesses) {
345                         Res |= instrumentAtomic(Inst, DL);
346                 }
347
348                 // only instrument functions that contain atomics
349                 if (Res && HasAtomic) {
350                         IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI());
351                         Value *ReturnAddress = IRB.CreateCall(
352                                 Intrinsic::getDeclaration(F.getParent(), Intrinsic::returnaddress),
353                                 IRB.getInt32(0));
354
355                         Value * FuncName = IRB.CreateGlobalStringPtr(F.getName());
356                         
357                         errs() << "function name: " << F.getName() << "\n";
358                         IRB.CreateCall(CDSFuncEntry, FuncName);
359
360                         EscapeEnumerator EE(F, "cds_cleanup", true);
361                         while (IRBuilder<> *AtExit = EE.Next()) {
362                           AtExit->CreateCall(CDSFuncExit, FuncName);
363                         }
364
365                         Res = true;
366                 }
367         }
368
369         return false;
370 }
371
372 void CDSPass::chooseInstructionsToInstrument(
373         SmallVectorImpl<Instruction *> &Local, SmallVectorImpl<Instruction *> &All,
374         const DataLayout &DL) {
375         SmallPtrSet<Value*, 8> WriteTargets;
376         // Iterate from the end.
377         for (Instruction *I : reverse(Local)) {
378                 if (StoreInst *Store = dyn_cast<StoreInst>(I)) {
379                         Value *Addr = Store->getPointerOperand();
380                         if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
381                                 continue;
382                         WriteTargets.insert(Addr);
383                 } else {
384                         LoadInst *Load = cast<LoadInst>(I);
385                         Value *Addr = Load->getPointerOperand();
386                         if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
387                                 continue;
388                         if (WriteTargets.count(Addr)) {
389                                 // We will write to this temp, so no reason to analyze the read.
390                                 NumOmittedReadsBeforeWrite++;
391                                 continue;
392                         }
393                         if (addrPointsToConstantData(Addr)) {
394                                 // Addr points to some constant data -- it can not race with any writes.
395                                 continue;
396                         }
397                 }
398                 Value *Addr = isa<StoreInst>(*I)
399                         ? cast<StoreInst>(I)->getPointerOperand()
400                         : cast<LoadInst>(I)->getPointerOperand();
401                 if (isa<AllocaInst>(GetUnderlyingObject(Addr, DL)) &&
402                                 !PointerMayBeCaptured(Addr, true, true)) {
403                         // The variable is addressable but not captured, so it cannot be
404                         // referenced from a different thread and participate in a data race
405                         // (see llvm/Analysis/CaptureTracking.h for details).
406                         NumOmittedNonCaptured++;
407                         continue;
408                 }
409                 All.push_back(I);
410         }
411         Local.clear();
412 }
413
414
415 bool CDSPass::instrumentLoadOrStore(Instruction *I,
416                                                                         const DataLayout &DL) {
417         IRBuilder<> IRB(I);
418         bool IsWrite = isa<StoreInst>(*I);
419         Value *Addr = IsWrite
420                 ? cast<StoreInst>(I)->getPointerOperand()
421                 : cast<LoadInst>(I)->getPointerOperand();
422
423         // swifterror memory addresses are mem2reg promoted by instruction selection.
424         // As such they cannot have regular uses like an instrumentation function and
425         // it makes no sense to track them as memory.
426         if (Addr->isSwiftError())
427         return false;
428
429         int Idx = getMemoryAccessFuncIndex(Addr, DL);
430
431 //  not supported by CDS yet
432 /*  if (IsWrite && isVtableAccess(I)) {
433     LLVM_DEBUG(dbgs() << "  VPTR : " << *I << "\n");
434     Value *StoredValue = cast<StoreInst>(I)->getValueOperand();
435     // StoredValue may be a vector type if we are storing several vptrs at once.
436     // In this case, just take the first element of the vector since this is
437     // enough to find vptr races.
438     if (isa<VectorType>(StoredValue->getType()))
439       StoredValue = IRB.CreateExtractElement(
440           StoredValue, ConstantInt::get(IRB.getInt32Ty(), 0));
441     if (StoredValue->getType()->isIntegerTy())
442       StoredValue = IRB.CreateIntToPtr(StoredValue, IRB.getInt8PtrTy());
443     // Call TsanVptrUpdate.
444     IRB.CreateCall(TsanVptrUpdate,
445                    {IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()),
446                     IRB.CreatePointerCast(StoredValue, IRB.getInt8PtrTy())});
447     NumInstrumentedVtableWrites++;
448     return true;
449   }
450
451   if (!IsWrite && isVtableAccess(I)) {
452     IRB.CreateCall(TsanVptrLoad,
453                    IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()));
454     NumInstrumentedVtableReads++;
455     return true;
456   }
457 */
458
459         Value *OnAccessFunc = nullptr;
460         OnAccessFunc = IsWrite ? CDSStore[Idx] : CDSLoad[Idx];
461
462         Type *ArgType = IRB.CreatePointerCast(Addr, Addr->getType())->getType();
463
464         if ( ArgType != Int8PtrTy && ArgType != Int16PtrTy && 
465                         ArgType != Int32PtrTy && ArgType != Int64PtrTy ) {
466                 // if other types of load or stores are passed in
467                 return false;   
468         }
469         IRB.CreateCall(OnAccessFunc, IRB.CreatePointerCast(Addr, Addr->getType()));
470         if (IsWrite) NumInstrumentedWrites++;
471         else         NumInstrumentedReads++;
472         return true;
473 }
474
475 bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
476         IRBuilder<> IRB(I);
477
478         // errs() << "instrumenting: " << *I << "\n";
479
480         if (auto *CI = dyn_cast<CallInst>(I)) {
481                 return instrumentAtomicCall(CI, DL);
482         }
483
484         Value *position = getPosition(I, IRB);
485
486         if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
487                 Value *Addr = LI->getPointerOperand();
488                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
489                 int atomic_order_index = getAtomicOrderIndex(LI->getOrdering());
490                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
491                 Value *args[] = {Addr, order, position};
492                 Instruction* funcInst=CallInst::Create(CDSAtomicLoad[Idx], args);
493                 ReplaceInstWithInst(LI, funcInst);
494         } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
495                 Value *Addr = SI->getPointerOperand();
496                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
497                 int atomic_order_index = getAtomicOrderIndex(SI->getOrdering());
498                 Value *val = SI->getValueOperand();
499                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
500                 Value *args[] = {Addr, val, order, position};
501                 Instruction* funcInst=CallInst::Create(CDSAtomicStore[Idx], args);
502                 ReplaceInstWithInst(SI, funcInst);
503         } else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) {
504                 Value *Addr = RMWI->getPointerOperand();
505                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
506                 int atomic_order_index = getAtomicOrderIndex(RMWI->getOrdering());
507                 Value *val = RMWI->getValOperand();
508                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
509                 Value *args[] = {Addr, val, order, position};
510                 Instruction* funcInst = CallInst::Create(CDSAtomicRMW[RMWI->getOperation()][Idx], args);
511                 ReplaceInstWithInst(RMWI, funcInst);
512         } else if (AtomicCmpXchgInst *CASI = dyn_cast<AtomicCmpXchgInst>(I)) {
513                 IRBuilder<> IRB(CASI);
514
515                 Value *Addr = CASI->getPointerOperand();
516                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
517
518                 const unsigned ByteSize = 1U << Idx;
519                 const unsigned BitSize = ByteSize * 8;
520                 Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
521                 Type *PtrTy = Ty->getPointerTo();
522
523                 Value *CmpOperand = IRB.CreateBitOrPointerCast(CASI->getCompareOperand(), Ty);
524                 Value *NewOperand = IRB.CreateBitOrPointerCast(CASI->getNewValOperand(), Ty);
525
526                 int atomic_order_index_succ = getAtomicOrderIndex(CASI->getSuccessOrdering());
527                 int atomic_order_index_fail = getAtomicOrderIndex(CASI->getFailureOrdering());
528                 Value *order_succ = ConstantInt::get(OrdTy, atomic_order_index_succ);
529                 Value *order_fail = ConstantInt::get(OrdTy, atomic_order_index_fail);
530
531                 Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy),
532                                                  CmpOperand, NewOperand,
533                                                  order_succ, order_fail, position};
534
535                 CallInst *funcInst = IRB.CreateCall(CDSAtomicCAS_V1[Idx], Args);
536                 Value *Success = IRB.CreateICmpEQ(funcInst, CmpOperand);
537
538                 Value *OldVal = funcInst;
539                 Type *OrigOldValTy = CASI->getNewValOperand()->getType();
540                 if (Ty != OrigOldValTy) {
541                         // The value is a pointer, so we need to cast the return value.
542                         OldVal = IRB.CreateIntToPtr(funcInst, OrigOldValTy);
543                 }
544
545                 Value *Res =
546                   IRB.CreateInsertValue(UndefValue::get(CASI->getType()), OldVal, 0);
547                 Res = IRB.CreateInsertValue(Res, Success, 1);
548
549                 I->replaceAllUsesWith(Res);
550                 I->eraseFromParent();
551         } else if (FenceInst *FI = dyn_cast<FenceInst>(I)) {
552                 int atomic_order_index = getAtomicOrderIndex(FI->getOrdering());
553                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
554                 Value *Args[] = {order, position};
555
556                 CallInst *funcInst = CallInst::Create(CDSAtomicThreadFence, Args);
557                 ReplaceInstWithInst(FI, funcInst);
558                 // errs() << "Thread Fences replaced\n";
559         }
560         return true;
561 }
562
563 bool CDSPass::isAtomicCall(Instruction *I) {
564         if ( auto *CI = dyn_cast<CallInst>(I) ) {
565                 Function *fun = CI->getCalledFunction();
566                 if (fun == NULL)
567                         return false;
568
569                 StringRef funName = fun->getName();
570
571                 // todo: come up with better rules for function name checking
572                 for (StringRef name : AtomicFuncNames) {
573                         if ( funName.contains(name) ) 
574                                 return true;
575                 }
576                 
577                 for (StringRef PartialName : PartialAtomicFuncNames) {
578                         if (funName.contains(PartialName) && 
579                                         funName.contains("atomic") )
580                                 return true;
581                 }
582         }
583
584         return false;
585 }
586
587 bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
588         IRBuilder<> IRB(CI);
589         Function *fun = CI->getCalledFunction();
590         StringRef funName = fun->getName();
591         std::vector<Value *> parameters;
592
593         User::op_iterator begin = CI->arg_begin();
594         User::op_iterator end = CI->arg_end();
595         for (User::op_iterator it = begin; it != end; ++it) {
596                 Value *param = *it;
597                 parameters.push_back(param);
598         }
599
600         // obtain source line number of the CallInst
601         Value *position = getPosition(CI, IRB);
602
603         // the pointer to the address is always the first argument
604         Value *OrigPtr = parameters[0];
605
606         int Idx = getMemoryAccessFuncIndex(OrigPtr, DL);
607         if (Idx < 0)
608                 return false;
609
610         const unsigned ByteSize = 1U << Idx;
611         const unsigned BitSize = ByteSize * 8;
612         Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
613         Type *PtrTy = Ty->getPointerTo();
614
615         // atomic_init; args = {obj, order}
616         if (funName.contains("atomic_init")) {
617                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
618                 Value *val = IRB.CreateBitOrPointerCast(parameters[1], Ty);
619                 Value *args[] = {ptr, val, position};
620
621                 Instruction* funcInst = CallInst::Create(CDSAtomicInit[Idx], args);
622                 ReplaceInstWithInst(CI, funcInst);
623
624                 return true;
625         }
626
627         // atomic_load; args = {obj, order}
628         if (funName.contains("atomic_load")) {
629                 bool isExplicit = funName.contains("atomic_load_explicit");
630
631                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
632                 Value *order;
633                 if (isExplicit)
634                         order = IRB.CreateBitOrPointerCast(parameters[1], OrdTy);
635                 else 
636                         order = ConstantInt::get(OrdTy, 
637                                                         (int) AtomicOrderingCABI::seq_cst);
638                 Value *args[] = {ptr, order, position};
639                 
640                 Instruction* funcInst = CallInst::Create(CDSAtomicLoad[Idx], args);
641                 ReplaceInstWithInst(CI, funcInst);
642
643                 return true;
644         } else if (funName.contains("atomic") && 
645                                         funName.contains("load") ) {
646                 // does this version of call always have an atomic order as an argument?
647                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
648                 Value *order = IRB.CreateBitOrPointerCast(parameters[1], OrdTy);
649                 Value *args[] = {ptr, order, position};
650
651                 if (!CI->getType()->isPointerTy()) {
652                         return false;   
653                 } 
654
655                 CallInst *funcInst = IRB.CreateCall(CDSAtomicLoad[Idx], args);
656                 Value *RetVal = IRB.CreateIntToPtr(funcInst, CI->getType());
657
658                 CI->replaceAllUsesWith(RetVal);
659                 CI->eraseFromParent();
660
661                 return true;
662         }
663
664         // atomic_store; args = {obj, val, order}
665         if (funName.contains("atomic_store")) {
666                 bool isExplicit = funName.contains("atomic_store_explicit");
667                 Value *OrigVal = parameters[1];
668
669                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
670                 Value *val = IRB.CreatePointerCast(OrigVal, Ty);
671                 Value *order;
672                 if (isExplicit)
673                         order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
674                 else 
675                         order = ConstantInt::get(OrdTy, 
676                                                         (int) AtomicOrderingCABI::seq_cst);
677                 Value *args[] = {ptr, val, order, position};
678                 
679                 Instruction* funcInst = CallInst::Create(CDSAtomicStore[Idx], args);
680                 ReplaceInstWithInst(CI, funcInst);
681
682                 return true;
683         } else if (funName.contains("atomic") && 
684                                         funName.contains("EEEE5store") ) {
685                 // does this version of call always have an atomic order as an argument?
686                 Value *OrigVal = parameters[1];
687
688                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
689                 Value *val = IRB.CreatePointerCast(OrigVal, Ty);
690                 Value *order = IRB.CreateBitOrPointerCast(parameters[1], OrdTy);
691                 Value *args[] = {ptr, val, order, position};
692
693                 Instruction* funcInst = CallInst::Create(CDSAtomicStore[Idx], args);
694                 ReplaceInstWithInst(CI, funcInst);
695
696                 return true;
697         }
698
699         // atomic_fetch_*; args = {obj, val, order}
700         if (funName.contains("atomic_fetch_") || 
701                         funName.contains("atomic_exchange") ) {
702                 bool isExplicit = funName.contains("_explicit");
703                 Value *OrigVal = parameters[1];
704
705                 int op;
706                 if ( funName.contains("_fetch_add") )
707                         op = AtomicRMWInst::Add;
708                 else if ( funName.contains("_fetch_sub") )
709                         op = AtomicRMWInst::Sub;
710                 else if ( funName.contains("_fetch_and") )
711                         op = AtomicRMWInst::And;
712                 else if ( funName.contains("_fetch_or") )
713                         op = AtomicRMWInst::Or;
714                 else if ( funName.contains("_fetch_xor") )
715                         op = AtomicRMWInst::Xor;
716                 else if ( funName.contains("atomic_exchange") )
717                         op = AtomicRMWInst::Xchg;
718                 else {
719                         errs() << "Unknown atomic read-modify-write operation\n";
720                         return false;
721                 }
722
723                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
724                 Value *val = IRB.CreatePointerCast(OrigVal, Ty);
725                 Value *order;
726                 if (isExplicit)
727                         order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
728                 else 
729                         order = ConstantInt::get(OrdTy, 
730                                                         (int) AtomicOrderingCABI::seq_cst);
731                 Value *args[] = {ptr, val, order, position};
732                 
733                 Instruction* funcInst = CallInst::Create(CDSAtomicRMW[op][Idx], args);
734                 ReplaceInstWithInst(CI, funcInst);
735
736                 return true;
737         } else if (funName.contains("fetch")) {
738                 errs() << "atomic exchange captured. Not implemented yet. ";
739                 errs() << "See source file :";
740                 getPosition(CI, IRB, true);
741         } else if (funName.contains("exchange") &&
742                         !funName.contains("compare_exchange") ) {
743                 errs() << "atomic exchange captured. Not implemented yet. ";
744                 errs() << "See source file :";
745                 getPosition(CI, IRB, true);
746         }
747
748         /* atomic_compare_exchange_*; 
749            args = {obj, expected, new value, order1, order2}
750         */
751         if ( funName.contains("atomic_compare_exchange_") ) {
752                 bool isExplicit = funName.contains("_explicit");
753
754                 Value *Addr = IRB.CreatePointerCast(OrigPtr, PtrTy);
755                 Value *CmpOperand = IRB.CreatePointerCast(parameters[1], PtrTy);
756                 Value *NewOperand = IRB.CreateBitOrPointerCast(parameters[2], Ty);
757
758                 Value *order_succ, *order_fail;
759                 if (isExplicit) {
760                         order_succ = IRB.CreateBitOrPointerCast(parameters[3], OrdTy);
761                         order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
762                 } else  {
763                         order_succ = ConstantInt::get(OrdTy, 
764                                                         (int) AtomicOrderingCABI::seq_cst);
765                         order_fail = ConstantInt::get(OrdTy, 
766                                                         (int) AtomicOrderingCABI::seq_cst);
767                 }
768
769                 Value *args[] = {Addr, CmpOperand, NewOperand, 
770                                                         order_succ, order_fail, position};
771                 
772                 Instruction* funcInst = CallInst::Create(CDSAtomicCAS_V2[Idx], args);
773                 ReplaceInstWithInst(CI, funcInst);
774
775                 return true;
776         } else if ( funName.contains("compare_exchange_strong") ||
777                                 funName.contains("compare_exchange_weak") ) {
778                 Value *Addr = IRB.CreatePointerCast(OrigPtr, PtrTy);
779                 Value *CmpOperand = IRB.CreatePointerCast(parameters[1], PtrTy);
780                 Value *NewOperand = IRB.CreateBitOrPointerCast(parameters[2], Ty);
781
782                 Value *order_succ, *order_fail;
783                 order_succ = IRB.CreateBitOrPointerCast(parameters[3], OrdTy);
784                 order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
785
786                 Value *args[] = {Addr, CmpOperand, NewOperand, 
787                                                         order_succ, order_fail, position};
788                 Instruction* funcInst = CallInst::Create(CDSAtomicCAS_V2[Idx], args);
789                 ReplaceInstWithInst(CI, funcInst);
790
791                 return true;
792         }
793
794         return false;
795 }
796
797 int CDSPass::getMemoryAccessFuncIndex(Value *Addr,
798                                                                                 const DataLayout &DL) {
799         Type *OrigPtrTy = Addr->getType();
800         Type *OrigTy = cast<PointerType>(OrigPtrTy)->getElementType();
801         assert(OrigTy->isSized());
802         uint32_t TypeSize = DL.getTypeStoreSizeInBits(OrigTy);
803         if (TypeSize != 8  && TypeSize != 16 &&
804                 TypeSize != 32 && TypeSize != 64 && TypeSize != 128) {
805                 NumAccessesWithBadSize++;
806                 // Ignore all unusual sizes.
807                 return -1;
808         }
809         size_t Idx = countTrailingZeros(TypeSize / 8);
810         assert(Idx < kNumberOfAccessSizes);
811         return Idx;
812 }
813
814
815 char CDSPass::ID = 0;
816
817 // Automatically enable the pass.
818 static void registerCDSPass(const PassManagerBuilder &,
819                                                         legacy::PassManagerBase &PM) {
820         PM.add(new CDSPass());
821 }
822 static RegisterStandardPasses 
823         RegisterMyPass(PassManagerBuilder::EP_OptimizerLast,
824 registerCDSPass);