make indentations consistent (all tabs)
[c11llvm.git] / CDSPass.cpp
1 //===-- CDSPass.cpp - xxx -------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 // This file is distributed under the University of Illinois Open Source
7 // License. See LICENSE.TXT for details.
8 //
9 //===----------------------------------------------------------------------===//
10 //
11 // This file is a modified version of ThreadSanitizer.cpp, a part of a race detector.
12 //
13 // The tool is under development, for the details about previous versions see
14 // http://code.google.com/p/data-race-test
15 //
16 // The instrumentation phase is quite simple:
17 //   - Insert calls to run-time library before every memory access.
18 //      - Optimizations may apply to avoid instrumenting some of the accesses.
19 //   - Insert calls at function entry/exit.
20 // The rest is handled by the run-time library.
21 //===----------------------------------------------------------------------===//
22
23 #include "llvm/ADT/Statistic.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/SmallString.h"
26 #include "llvm/Analysis/ValueTracking.h"
27 #include "llvm/Analysis/CaptureTracking.h"
28 #include "llvm/IR/BasicBlock.h"
29 #include "llvm/IR/Function.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/LegacyPassManager.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/IR/PassManager.h"
36 #include "llvm/Pass.h"
37 #include "llvm/ProfileData/InstrProf.h"
38 #include "llvm/Support/raw_ostream.h"
39 #include "llvm/Support/AtomicOrdering.h"
40 #include "llvm/Support/Debug.h"
41 #include "llvm/Transforms/Scalar.h"
42 #include "llvm/Transforms/Utils/Local.h"
43 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
44 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
45 #include <vector>
46
47 using namespace llvm;
48
49 #define DEBUG_TYPE "CDS"
50 #include <llvm/IR/DebugLoc.h>
51
52 Value *getPosition( Instruction * I, IRBuilder <> IRB, bool print = false)
53 {
54         const DebugLoc & debug_location = I->getDebugLoc ();
55         std::string position_string;
56         {
57                 llvm::raw_string_ostream position_stream (position_string);
58                 debug_location . print (position_stream);
59         }
60
61         if (print) {
62                 errs() << position_string;
63         }
64
65         return IRB . CreateGlobalStringPtr (position_string);
66 }
67
68 STATISTIC(NumInstrumentedReads, "Number of instrumented reads");
69 STATISTIC(NumInstrumentedWrites, "Number of instrumented writes");
70 STATISTIC(NumAccessesWithBadSize, "Number of accesses with bad size");
71 // STATISTIC(NumInstrumentedVtableWrites, "Number of vtable ptr writes");
72 // STATISTIC(NumInstrumentedVtableReads, "Number of vtable ptr reads");
73
74 STATISTIC(NumOmittedReadsBeforeWrite,
75           "Number of reads ignored due to following writes");
76 STATISTIC(NumOmittedReadsFromConstantGlobals,
77           "Number of reads from constant globals");
78 STATISTIC(NumOmittedReadsFromVtable, "Number of vtable reads");
79 STATISTIC(NumOmittedNonCaptured, "Number of accesses ignored due to capturing");
80
81 Type * OrdTy;
82
83 Type * Int8PtrTy;
84 Type * Int16PtrTy;
85 Type * Int32PtrTy;
86 Type * Int64PtrTy;
87
88 Type * VoidTy;
89
90 static const size_t kNumberOfAccessSizes = 4;
91
92 int getAtomicOrderIndex(AtomicOrdering order){
93         switch (order) {
94                 case AtomicOrdering::Monotonic: 
95                         return (int)AtomicOrderingCABI::relaxed;
96                 //  case AtomicOrdering::Consume:         // not specified yet
97                 //    return AtomicOrderingCABI::consume;
98                 case AtomicOrdering::Acquire: 
99                         return (int)AtomicOrderingCABI::acquire;
100                 case AtomicOrdering::Release: 
101                         return (int)AtomicOrderingCABI::release;
102                 case AtomicOrdering::AcquireRelease: 
103                         return (int)AtomicOrderingCABI::acq_rel;
104                 case AtomicOrdering::SequentiallyConsistent: 
105                         return (int)AtomicOrderingCABI::seq_cst;
106                 default:
107                         // unordered or Not Atomic
108                         return -1;
109         }
110 }
111
112 namespace {
113         struct CDSPass : public FunctionPass {
114                 static char ID;
115                 CDSPass() : FunctionPass(ID) {}
116                 bool runOnFunction(Function &F) override; 
117
118         private:
119                 void initializeCallbacks(Module &M);
120                 bool instrumentLoadOrStore(Instruction *I, const DataLayout &DL);
121                 bool isAtomicCall(Instruction *I);
122                 bool instrumentAtomic(Instruction *I, const DataLayout &DL);
123                 bool instrumentAtomicCall(CallInst *CI, const DataLayout &DL);
124                 void chooseInstructionsToInstrument(SmallVectorImpl<Instruction *> &Local,
125                                                                                         SmallVectorImpl<Instruction *> &All,
126                                                                                         const DataLayout &DL);
127                 bool addrPointsToConstantData(Value *Addr);
128                 int getMemoryAccessFuncIndex(Value *Addr, const DataLayout &DL);
129
130                 // Callbacks to run-time library are computed in doInitialization.
131                 Constant * CDSFuncEntry;
132                 Constant * CDSFuncExit;
133
134                 Constant * CDSLoad[kNumberOfAccessSizes];
135                 Constant * CDSStore[kNumberOfAccessSizes];
136                 Constant * CDSAtomicInit[kNumberOfAccessSizes];
137                 Constant * CDSAtomicLoad[kNumberOfAccessSizes];
138                 Constant * CDSAtomicStore[kNumberOfAccessSizes];
139                 Constant * CDSAtomicRMW[AtomicRMWInst::LAST_BINOP + 1][kNumberOfAccessSizes];
140                 Constant * CDSAtomicCAS_V1[kNumberOfAccessSizes];
141                 Constant * CDSAtomicCAS_V2[kNumberOfAccessSizes];
142                 Constant * CDSAtomicThreadFence;
143         };
144 }
145
146 static bool isVtableAccess(Instruction *I) {
147         if (MDNode *Tag = I->getMetadata(LLVMContext::MD_tbaa))
148                 return Tag->isTBAAVtableAccess();
149         return false;
150 }
151
152 void CDSPass::initializeCallbacks(Module &M) {
153         LLVMContext &Ctx = M.getContext();
154
155         Type * Int1Ty = Type::getInt1Ty(Ctx);
156         OrdTy = Type::getInt32Ty(Ctx);
157
158         Int8PtrTy  = Type::getInt8PtrTy(Ctx);
159         Int16PtrTy = Type::getInt16PtrTy(Ctx);
160         Int32PtrTy = Type::getInt32PtrTy(Ctx);
161         Int64PtrTy = Type::getInt64PtrTy(Ctx);
162
163         VoidTy = Type::getVoidTy(Ctx);
164
165         // Get the function to call from our untime library.
166         for (unsigned i = 0; i < kNumberOfAccessSizes; i++) {
167                 const unsigned ByteSize = 1U << i;
168                 const unsigned BitSize = ByteSize * 8;
169
170                 std::string ByteSizeStr = utostr(ByteSize);
171                 std::string BitSizeStr = utostr(BitSize);
172
173                 Type *Ty = Type::getIntNTy(Ctx, BitSize);
174                 Type *PtrTy = Ty->getPointerTo();
175
176                 // uint8_t cds_atomic_load8 (void * obj, int atomic_index)
177                 // void cds_atomic_store8 (void * obj, int atomic_index, uint8_t val)
178                 SmallString<32> LoadName("cds_load" + BitSizeStr);
179                 SmallString<32> StoreName("cds_store" + BitSizeStr);
180                 SmallString<32> AtomicInitName("cds_atomic_init" + BitSizeStr);
181                 SmallString<32> AtomicLoadName("cds_atomic_load" + BitSizeStr);
182                 SmallString<32> AtomicStoreName("cds_atomic_store" + BitSizeStr);
183
184                 CDSLoad[i]  = M.getOrInsertFunction(LoadName, VoidTy, PtrTy);
185                 CDSStore[i] = M.getOrInsertFunction(StoreName, VoidTy, PtrTy);
186                 CDSAtomicInit[i] = M.getOrInsertFunction(AtomicInitName, 
187                                                                 VoidTy, PtrTy, Ty, Int8PtrTy);
188                 CDSAtomicLoad[i]  = M.getOrInsertFunction(AtomicLoadName, 
189                                                                 Ty, PtrTy, OrdTy, Int8PtrTy);
190                 CDSAtomicStore[i] = M.getOrInsertFunction(AtomicStoreName, 
191                                                                 VoidTy, PtrTy, Ty, OrdTy, Int8PtrTy);
192
193                 for (int op = AtomicRMWInst::FIRST_BINOP; 
194                         op <= AtomicRMWInst::LAST_BINOP; ++op) {
195                         CDSAtomicRMW[op][i] = nullptr;
196                         std::string NamePart;
197
198                         if (op == AtomicRMWInst::Xchg)
199                                 NamePart = "_exchange";
200                         else if (op == AtomicRMWInst::Add) 
201                                 NamePart = "_fetch_add";
202                         else if (op == AtomicRMWInst::Sub)
203                                 NamePart = "_fetch_sub";
204                         else if (op == AtomicRMWInst::And)
205                                 NamePart = "_fetch_and";
206                         else if (op == AtomicRMWInst::Or)
207                                 NamePart = "_fetch_or";
208                         else if (op == AtomicRMWInst::Xor)
209                                 NamePart = "_fetch_xor";
210                         else
211                                 continue;
212
213                         SmallString<32> AtomicRMWName("cds_atomic" + NamePart + BitSizeStr);
214                         CDSAtomicRMW[op][i] = M.getOrInsertFunction(AtomicRMWName, 
215                                                                                 Ty, PtrTy, Ty, OrdTy, Int8PtrTy);
216                 }
217
218                 // only supportes strong version
219                 SmallString<32> AtomicCASName_V1("cds_atomic_compare_exchange" + BitSizeStr + "_v1");
220                 SmallString<32> AtomicCASName_V2("cds_atomic_compare_exchange" + BitSizeStr + "_v2");
221                 CDSAtomicCAS_V1[i] = M.getOrInsertFunction(AtomicCASName_V1, 
222                                                                 Ty, PtrTy, Ty, Ty, OrdTy, OrdTy, Int8PtrTy);
223                 CDSAtomicCAS_V2[i] = M.getOrInsertFunction(AtomicCASName_V2, 
224                                                                 Int1Ty, PtrTy, PtrTy, Ty, OrdTy, OrdTy, Int8PtrTy);
225         }
226
227         CDSAtomicThreadFence = M.getOrInsertFunction("cds_atomic_thread_fence", 
228                                                                                                         VoidTy, OrdTy, Int8PtrTy);
229 }
230
231 static bool shouldInstrumentReadWriteFromAddress(const Module *M, Value *Addr) {
232         // Peel off GEPs and BitCasts.
233         Addr = Addr->stripInBoundsOffsets();
234
235         if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) {
236                 if (GV->hasSection()) {
237                         StringRef SectionName = GV->getSection();
238                         // Check if the global is in the PGO counters section.
239                         auto OF = Triple(M->getTargetTriple()).getObjectFormat();
240                         if (SectionName.endswith(
241                               getInstrProfSectionName(IPSK_cnts, OF, /*AddSegmentInfo=*/false)))
242                                 return false;
243                 }
244
245                 // Check if the global is private gcov data.
246                 if (GV->getName().startswith("__llvm_gcov") ||
247                 GV->getName().startswith("__llvm_gcda"))
248                 return false;
249         }
250
251         // Do not instrument acesses from different address spaces; we cannot deal
252         // with them.
253         if (Addr) {
254                 Type *PtrTy = cast<PointerType>(Addr->getType()->getScalarType());
255                 if (PtrTy->getPointerAddressSpace() != 0)
256                         return false;
257         }
258
259         return true;
260 }
261
262 bool CDSPass::addrPointsToConstantData(Value *Addr) {
263         // If this is a GEP, just analyze its pointer operand.
264         if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Addr))
265                 Addr = GEP->getPointerOperand();
266
267         if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Addr)) {
268                 if (GV->isConstant()) {
269                         // Reads from constant globals can not race with any writes.
270                         NumOmittedReadsFromConstantGlobals++;
271                         return true;
272                 }
273         } else if (LoadInst *L = dyn_cast<LoadInst>(Addr)) {
274                 if (isVtableAccess(L)) {
275                         // Reads from a vtable pointer can not race with any writes.
276                         NumOmittedReadsFromVtable++;
277                         return true;
278                 }
279         }
280         return false;
281 }
282
283 bool CDSPass::runOnFunction(Function &F) {
284         if (F.getName() == "main") {
285                 F.setName("user_main");
286                 errs() << "main replaced by user_main\n";
287         }
288
289         if (true) {
290                 initializeCallbacks( *F.getParent() );
291
292                 SmallVector<Instruction*, 8> AllLoadsAndStores;
293                 SmallVector<Instruction*, 8> LocalLoadsAndStores;
294                 SmallVector<Instruction*, 8> AtomicAccesses;
295
296                 std::vector<Instruction *> worklist;
297
298                 bool Res = false;
299                 const DataLayout &DL = F.getParent()->getDataLayout();
300
301                 // errs() << "--- " << F.getName() << "---\n";
302
303                 for (auto &B : F) {
304                         for (auto &I : B) {
305                                 if ( (&I)->isAtomic() || isAtomicCall(&I) ) {
306                                         AtomicAccesses.push_back(&I);
307                                 } else if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
308                                         LocalLoadsAndStores.push_back(&I);
309                                 } else if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
310                                         // not implemented yet
311                                 }
312                         }
313
314                         chooseInstructionsToInstrument(LocalLoadsAndStores, AllLoadsAndStores, DL);
315                 }
316
317                 for (auto Inst : AllLoadsAndStores) {
318                         // Res |= instrumentLoadOrStore(Inst, DL);
319                         // errs() << "load and store are replaced\n";
320                 }
321
322                 for (auto Inst : AtomicAccesses) {
323                         Res |= instrumentAtomic(Inst, DL);
324                 }
325
326                 if (F.getName() == "user_main") {
327                         // F.dump();
328                 }
329         }
330
331         return false;
332 }
333
334 void CDSPass::chooseInstructionsToInstrument(
335         SmallVectorImpl<Instruction *> &Local, SmallVectorImpl<Instruction *> &All,
336         const DataLayout &DL) {
337         SmallPtrSet<Value*, 8> WriteTargets;
338         // Iterate from the end.
339         for (Instruction *I : reverse(Local)) {
340                 if (StoreInst *Store = dyn_cast<StoreInst>(I)) {
341                         Value *Addr = Store->getPointerOperand();
342                         if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
343                                 continue;
344                         WriteTargets.insert(Addr);
345                 } else {
346                         LoadInst *Load = cast<LoadInst>(I);
347                         Value *Addr = Load->getPointerOperand();
348                         if (!shouldInstrumentReadWriteFromAddress(I->getModule(), Addr))
349                                 continue;
350                         if (WriteTargets.count(Addr)) {
351                                 // We will write to this temp, so no reason to analyze the read.
352                                 NumOmittedReadsBeforeWrite++;
353                                 continue;
354                         }
355                         if (addrPointsToConstantData(Addr)) {
356                                 // Addr points to some constant data -- it can not race with any writes.
357                                 continue;
358                         }
359                 }
360                 Value *Addr = isa<StoreInst>(*I)
361                         ? cast<StoreInst>(I)->getPointerOperand()
362                         : cast<LoadInst>(I)->getPointerOperand();
363                 if (isa<AllocaInst>(GetUnderlyingObject(Addr, DL)) &&
364                                 !PointerMayBeCaptured(Addr, true, true)) {
365                         // The variable is addressable but not captured, so it cannot be
366                         // referenced from a different thread and participate in a data race
367                         // (see llvm/Analysis/CaptureTracking.h for details).
368                         NumOmittedNonCaptured++;
369                         continue;
370                 }
371                 All.push_back(I);
372         }
373         Local.clear();
374 }
375
376
377 bool CDSPass::instrumentLoadOrStore(Instruction *I,
378                                                                         const DataLayout &DL) {
379         IRBuilder<> IRB(I);
380         bool IsWrite = isa<StoreInst>(*I);
381         Value *Addr = IsWrite
382                 ? cast<StoreInst>(I)->getPointerOperand()
383                 : cast<LoadInst>(I)->getPointerOperand();
384
385         // swifterror memory addresses are mem2reg promoted by instruction selection.
386         // As such they cannot have regular uses like an instrumentation function and
387         // it makes no sense to track them as memory.
388         if (Addr->isSwiftError())
389         return false;
390
391         int Idx = getMemoryAccessFuncIndex(Addr, DL);
392
393 //  not supported by CDS yet
394 /*  if (IsWrite && isVtableAccess(I)) {
395     LLVM_DEBUG(dbgs() << "  VPTR : " << *I << "\n");
396     Value *StoredValue = cast<StoreInst>(I)->getValueOperand();
397     // StoredValue may be a vector type if we are storing several vptrs at once.
398     // In this case, just take the first element of the vector since this is
399     // enough to find vptr races.
400     if (isa<VectorType>(StoredValue->getType()))
401       StoredValue = IRB.CreateExtractElement(
402           StoredValue, ConstantInt::get(IRB.getInt32Ty(), 0));
403     if (StoredValue->getType()->isIntegerTy())
404       StoredValue = IRB.CreateIntToPtr(StoredValue, IRB.getInt8PtrTy());
405     // Call TsanVptrUpdate.
406     IRB.CreateCall(TsanVptrUpdate,
407                    {IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()),
408                     IRB.CreatePointerCast(StoredValue, IRB.getInt8PtrTy())});
409     NumInstrumentedVtableWrites++;
410     return true;
411   }
412
413   if (!IsWrite && isVtableAccess(I)) {
414     IRB.CreateCall(TsanVptrLoad,
415                    IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()));
416     NumInstrumentedVtableReads++;
417     return true;
418   }
419 */
420
421         Value *OnAccessFunc = nullptr;
422         OnAccessFunc = IsWrite ? CDSStore[Idx] : CDSLoad[Idx];
423
424         Type *ArgType = IRB.CreatePointerCast(Addr, Addr->getType())->getType();
425
426         if ( ArgType != Int8PtrTy && ArgType != Int16PtrTy && 
427                         ArgType != Int32PtrTy && ArgType != Int64PtrTy ) {
428                 //errs() << "A load or store of type ";
429                 //errs() << *ArgType;
430                 //errs() << " is passed in\n";
431                 return false;   // if other types of load or stores are passed in
432         }
433         IRB.CreateCall(OnAccessFunc, IRB.CreatePointerCast(Addr, Addr->getType()));
434         if (IsWrite) NumInstrumentedWrites++;
435         else         NumInstrumentedReads++;
436         return true;
437 }
438
439 bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
440         IRBuilder<> IRB(I);
441         // LLVMContext &Ctx = IRB.getContext();
442
443         if (auto *CI = dyn_cast<CallInst>(I)) {
444                 return instrumentAtomicCall(CI, DL);
445         }
446
447         Value *position = getPosition(I, IRB);
448
449         if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
450                 Value *Addr = LI->getPointerOperand();
451                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
452                 int atomic_order_index = getAtomicOrderIndex(LI->getOrdering());
453                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
454                 Value *args[] = {Addr, order, position};
455                 Instruction* funcInst=CallInst::Create(CDSAtomicLoad[Idx], args);
456                 ReplaceInstWithInst(LI, funcInst);
457         } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
458                 Value *Addr = SI->getPointerOperand();
459                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
460                 int atomic_order_index = getAtomicOrderIndex(SI->getOrdering());
461                 Value *val = SI->getValueOperand();
462                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
463                 Value *args[] = {Addr, val, order, position};
464                 Instruction* funcInst=CallInst::Create(CDSAtomicStore[Idx], args);
465                 ReplaceInstWithInst(SI, funcInst);
466         } else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) {
467                 Value *Addr = RMWI->getPointerOperand();
468                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
469                 int atomic_order_index = getAtomicOrderIndex(RMWI->getOrdering());
470                 Value *val = RMWI->getValOperand();
471                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
472                 Value *args[] = {Addr, val, order, position};
473                 Instruction* funcInst = CallInst::Create(CDSAtomicRMW[RMWI->getOperation()][Idx], args);
474                 ReplaceInstWithInst(RMWI, funcInst);
475         } else if (AtomicCmpXchgInst *CASI = dyn_cast<AtomicCmpXchgInst>(I)) {
476                 IRBuilder<> IRB(CASI);
477
478                 Value *Addr = CASI->getPointerOperand();
479                 int Idx=getMemoryAccessFuncIndex(Addr, DL);
480
481                 const unsigned ByteSize = 1U << Idx;
482                 const unsigned BitSize = ByteSize * 8;
483                 Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
484                 Type *PtrTy = Ty->getPointerTo();
485
486                 Value *CmpOperand = IRB.CreateBitOrPointerCast(CASI->getCompareOperand(), Ty);
487                 Value *NewOperand = IRB.CreateBitOrPointerCast(CASI->getNewValOperand(), Ty);
488
489                 int atomic_order_index_succ = getAtomicOrderIndex(CASI->getSuccessOrdering());
490                 int atomic_order_index_fail = getAtomicOrderIndex(CASI->getFailureOrdering());
491                 Value *order_succ = ConstantInt::get(OrdTy, atomic_order_index_succ);
492                 Value *order_fail = ConstantInt::get(OrdTy, atomic_order_index_fail);
493
494                 Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy),
495                                                  CmpOperand, NewOperand,
496                                                  order_succ, order_fail, position};
497
498                 CallInst *funcInst = IRB.CreateCall(CDSAtomicCAS_V1[Idx], Args);
499                 Value *Success = IRB.CreateICmpEQ(funcInst, CmpOperand);
500
501                 Value *OldVal = funcInst;
502                 Type *OrigOldValTy = CASI->getNewValOperand()->getType();
503                 if (Ty != OrigOldValTy) {
504                         // The value is a pointer, so we need to cast the return value.
505                         OldVal = IRB.CreateIntToPtr(funcInst, OrigOldValTy);
506                 }
507
508                 Value *Res =
509                   IRB.CreateInsertValue(UndefValue::get(CASI->getType()), OldVal, 0);
510                 Res = IRB.CreateInsertValue(Res, Success, 1);
511
512                 I->replaceAllUsesWith(Res);
513                 I->eraseFromParent();
514         } else if (FenceInst *FI = dyn_cast<FenceInst>(I)) {
515                 int atomic_order_index = getAtomicOrderIndex(FI->getOrdering());
516                 Value *order = ConstantInt::get(OrdTy, atomic_order_index);
517                 Value *Args[] = {order, position};
518
519                 CallInst *funcInst = CallInst::Create(CDSAtomicThreadFence, Args);
520                 ReplaceInstWithInst(FI, funcInst);
521                 // errs() << "Thread Fences replaced\n";
522         }
523         return true;
524 }
525
526 bool CDSPass::isAtomicCall(Instruction *I) {
527         if ( auto *CI = dyn_cast<CallInst>(I) ) {
528                 Function *fun = CI->getCalledFunction();
529                 if (fun == NULL)
530                         return false;
531
532                 StringRef funName = fun->getName();
533                 // todo: come up with better rules for function name checking
534                 if ( funName.contains("atomic_") ) {
535                         return true;
536                 } else if (funName.contains("atomic") ) {
537                         return true;
538                 }
539         }
540
541         return false;
542 }
543
544 bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
545         IRBuilder<> IRB(CI);
546         Function *fun = CI->getCalledFunction();
547         StringRef funName = fun->getName();
548         std::vector<Value *> parameters;
549
550         User::op_iterator begin = CI->arg_begin();
551         User::op_iterator end = CI->arg_end();
552         for (User::op_iterator it = begin; it != end; ++it) {
553                 Value *param = *it;
554                 parameters.push_back(param);
555         }
556
557         // obtain source line number of the CallInst
558         Value *position = getPosition(CI, IRB);
559
560         // the pointer to the address is always the first argument
561         Value *OrigPtr = parameters[0];
562         int Idx = getMemoryAccessFuncIndex(OrigPtr, DL);
563         if (Idx < 0)
564                 return false;
565
566         const unsigned ByteSize = 1U << Idx;
567         const unsigned BitSize = ByteSize * 8;
568         Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
569         Type *PtrTy = Ty->getPointerTo();
570
571         // atomic_init; args = {obj, order}
572         if (funName.contains("atomic_init")) {
573                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
574                 Value *val = IRB.CreateBitOrPointerCast(parameters[1], Ty);
575                 Value *args[] = {ptr, val, position};
576
577                 Instruction* funcInst = CallInst::Create(CDSAtomicInit[Idx], args);
578                 ReplaceInstWithInst(CI, funcInst);
579
580                 return true;
581         }
582
583         // atomic_load; args = {obj, order}
584         if (funName.contains("atomic_load")) {
585                 bool isExplicit = funName.contains("atomic_load_explicit");
586
587                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
588                 Value *order;
589                 if (isExplicit)
590                         order = IRB.CreateBitOrPointerCast(parameters[1], OrdTy);
591                 else 
592                         order = ConstantInt::get(OrdTy, 
593                                                         (int) AtomicOrderingCABI::seq_cst);
594                 Value *args[] = {ptr, order, position};
595                 
596                 Instruction* funcInst = CallInst::Create(CDSAtomicLoad[Idx], args);
597                 ReplaceInstWithInst(CI, funcInst);
598
599                 return true;
600         } else if (funName.contains("atomic") && 
601                                         funName.contains("load")) {
602                 // does this version of call always have an atomic order as an argument?
603                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
604                 Value *order = IRB.CreateBitOrPointerCast(parameters[1], OrdTy);
605                 Value *args[] = {ptr, order, position};
606
607                 //Instruction* funcInst=CallInst::Create(CDSAtomicLoad[Idx], args);
608                 CallInst *funcInst = IRB.CreateCall(CDSAtomicLoad[Idx], args);
609                 Value *RetVal = IRB.CreateIntToPtr(funcInst, CI->getType());
610
611                 CI->replaceAllUsesWith(RetVal);
612                 CI->eraseFromParent();
613
614                 return true;
615         }
616
617         // atomic_store; args = {obj, val, order}
618         if (funName.contains("atomic_store")) {
619                 bool isExplicit = funName.contains("atomic_store_explicit");
620                 Value *OrigVal = parameters[1];
621
622                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
623                 Value *val = IRB.CreatePointerCast(OrigVal, Ty);
624                 Value *order;
625                 if (isExplicit)
626                         order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
627                 else 
628                         order = ConstantInt::get(OrdTy, 
629                                                         (int) AtomicOrderingCABI::seq_cst);
630                 Value *args[] = {ptr, val, order, position};
631                 
632                 Instruction* funcInst = CallInst::Create(CDSAtomicStore[Idx], args);
633                 ReplaceInstWithInst(CI, funcInst);
634
635                 return true;
636         } else if (funName.contains("atomic") && 
637                                         funName.contains("EEEE5store")) {
638                 // does this version of call always have an atomic order as an argument?
639                 Value *OrigVal = parameters[1];
640
641                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
642                 Value *val = IRB.CreatePointerCast(OrigVal, Ty);
643                 Value *order = IRB.CreateBitOrPointerCast(parameters[1], OrdTy);
644                 Value *args[] = {ptr, val, order, position};
645
646                 Instruction* funcInst = CallInst::Create(CDSAtomicStore[Idx], args);
647                 ReplaceInstWithInst(CI, funcInst);
648
649                 return true;
650         }
651
652         // atomic_fetch_*; args = {obj, val, order}
653         if (funName.contains("atomic_fetch_") || 
654                         funName.contains("atomic_exchange") ) {
655                 bool isExplicit = funName.contains("_explicit");
656                 Value *OrigVal = parameters[1];
657
658                 int op;
659                 if ( funName.contains("_fetch_add") )
660                         op = AtomicRMWInst::Add;
661                 else if ( funName.contains("_fetch_sub") )
662                         op = AtomicRMWInst::Sub;
663                 else if ( funName.contains("_fetch_and") )
664                         op = AtomicRMWInst::And;
665                 else if ( funName.contains("_fetch_or") )
666                         op = AtomicRMWInst::Or;
667                 else if ( funName.contains("_fetch_xor") )
668                         op = AtomicRMWInst::Xor;
669                 else if ( funName.contains("atomic_exchange") )
670                         op = AtomicRMWInst::Xchg;
671                 else {
672                         errs() << "Unknown atomic read modify write operation\n";
673                         return false;
674                 }
675
676                 Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
677                 Value *val = IRB.CreatePointerCast(OrigVal, Ty);
678                 Value *order;
679                 if (isExplicit)
680                         order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
681                 else 
682                         order = ConstantInt::get(OrdTy, 
683                                                         (int) AtomicOrderingCABI::seq_cst);
684                 Value *args[] = {ptr, val, order, position};
685                 
686                 Instruction* funcInst = CallInst::Create(CDSAtomicRMW[op][Idx], args);
687                 ReplaceInstWithInst(CI, funcInst);
688
689                 return true;
690         } else if (funName.contains("fetch")) {
691                 errs() << "atomic exchange captured. Not implemented yet. ";
692                 errs() << "See source file :";
693                 getPosition(CI, IRB, true);
694         } else if (funName.contains("exchange") &&
695                         !funName.contains("compare_exchange") ) {
696                 errs() << "atomic exchange captured. Not implemented yet. ";
697                 errs() << "See source file :";
698                 getPosition(CI, IRB, true);
699         }
700
701         /* atomic_compare_exchange_*; 
702            args = {obj, expected, new value, order1, order2}
703         */
704         if ( funName.contains("atomic_compare_exchange_") ) {
705                 bool isExplicit = funName.contains("_explicit");
706
707                 Value *Addr = IRB.CreatePointerCast(OrigPtr, PtrTy);
708                 Value *CmpOperand = IRB.CreatePointerCast(parameters[1], PtrTy);
709                 Value *NewOperand = IRB.CreateBitOrPointerCast(parameters[2], Ty);
710
711                 Value *order_succ, *order_fail;
712                 if (isExplicit) {
713                         order_succ = IRB.CreateBitOrPointerCast(parameters[3], OrdTy);
714                         order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
715                 } else  {
716                         order_succ = ConstantInt::get(OrdTy, 
717                                                         (int) AtomicOrderingCABI::seq_cst);
718                         order_fail = ConstantInt::get(OrdTy, 
719                                                         (int) AtomicOrderingCABI::seq_cst);
720                 }
721
722                 Value *args[] = {Addr, CmpOperand, NewOperand, 
723                                                         order_succ, order_fail, position};
724                 
725                 Instruction* funcInst = CallInst::Create(CDSAtomicCAS_V2[Idx], args);
726                 ReplaceInstWithInst(CI, funcInst);
727
728                 return true;
729         } else if ( funName.contains("compare_exchange_strong") || 
730                                 funName.contains("compare_exchange_weak") ) {
731                 Value *Addr = IRB.CreatePointerCast(OrigPtr, PtrTy);
732                 Value *CmpOperand = IRB.CreatePointerCast(parameters[1], PtrTy);
733                 Value *NewOperand = IRB.CreateBitOrPointerCast(parameters[2], Ty);
734
735                 Value *order_succ, *order_fail;
736                 order_succ = IRB.CreateBitOrPointerCast(parameters[3], OrdTy);
737                 order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
738
739                 Value *args[] = {Addr, CmpOperand, NewOperand, 
740                                                         order_succ, order_fail, position};
741                 Instruction* funcInst = CallInst::Create(CDSAtomicCAS_V2[Idx], args);
742                 ReplaceInstWithInst(CI, funcInst);
743
744                 return true;
745         }
746
747         return false;
748 }
749
750 int CDSPass::getMemoryAccessFuncIndex(Value *Addr,
751                                                                                 const DataLayout &DL) {
752         Type *OrigPtrTy = Addr->getType();
753         Type *OrigTy = cast<PointerType>(OrigPtrTy)->getElementType();
754         assert(OrigTy->isSized());
755         uint32_t TypeSize = DL.getTypeStoreSizeInBits(OrigTy);
756         if (TypeSize != 8  && TypeSize != 16 &&
757                 TypeSize != 32 && TypeSize != 64 && TypeSize != 128) {
758                 NumAccessesWithBadSize++;
759                 // Ignore all unusual sizes.
760                 return -1;
761         }
762         size_t Idx = countTrailingZeros(TypeSize / 8);
763         assert(Idx < kNumberOfAccessSizes);
764         return Idx;
765 }
766
767
768 char CDSPass::ID = 0;
769
770 // Automatically enable the pass.
771 static void registerCDSPass(const PassManagerBuilder &,
772                                                         legacy::PassManagerBase &PM) {
773         PM.add(new CDSPass());
774 }
775 static RegisterStandardPasses 
776         RegisterMyPass(PassManagerBuilder::EP_OptimizerLast,
777 registerCDSPass);