Fix missed normal loads/stores
[c11llvm.git] / CDSPass.cpp
index 7546be0f0c779939772ca0e16779437e2d738548..a33738a96d26880395afc7beaf3b284e1bf1e705 100644 (file)
 #include "llvm/ADT/SmallString.h"
 #include "llvm/Analysis/ValueTracking.h"
 #include "llvm/Analysis/CaptureTracking.h"
+#include "llvm/Analysis/LoopInfo.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/LegacyPassManager.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/AtomicOrdering.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Transforms/Scalar.h"
-#include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/EscapeEnumerator.h"
 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
 #include <vector>
 
 using namespace llvm;
 
+#define CDS_DEBUG
 #define DEBUG_TYPE "CDS"
 #include <llvm/IR/DebugLoc.h>
 
-Value *getPosition( Instruction * I, IRBuilder <> IRB, bool print = false)
+static inline Value *getPosition( Instruction * I, IRBuilder <> IRB, bool print = false)
 {
        const DebugLoc & debug_location = I->getDebugLoc ();
        std::string position_string;
@@ -59,27 +62,44 @@ Value *getPosition( Instruction * I, IRBuilder <> IRB, bool print = false)
        }
 
        if (print) {
-               errs() << position_string;
+               errs() << position_string << "\n";
        }
 
-       return IRB . CreateGlobalStringPtr (position_string);
+       return IRB.CreateGlobalStringPtr (position_string);
+}
+
+static inline bool checkSignature(Function * func, Value * args[]) {
+       FunctionType * FType = func->getFunctionType();
+       for (unsigned i = 0 ; i < FType->getNumParams(); i++) {
+               if (FType->getParamType(i) != args[i]->getType()) {
+#ifdef CDS_DEBUG
+                       errs() << "expects: " << *FType->getParamType(i)
+                                       << "\tbut receives: " << *args[i]->getType() << "\n";
+#endif
+                       return false;
+               }
+       }
+
+       return true;
 }
 
 STATISTIC(NumInstrumentedReads, "Number of instrumented reads");
 STATISTIC(NumInstrumentedWrites, "Number of instrumented writes");
+STATISTIC(NumOmittedReadsBeforeWrite,
+          "Number of reads ignored due to following writes");
 STATISTIC(NumAccessesWithBadSize, "Number of accesses with bad size");
 // STATISTIC(NumInstrumentedVtableWrites, "Number of vtable ptr writes");
 // STATISTIC(NumInstrumentedVtableReads, "Number of vtable ptr reads");
-
-STATISTIC(NumOmittedReadsBeforeWrite,
-          "Number of reads ignored due to following writes");
 STATISTIC(NumOmittedReadsFromConstantGlobals,
           "Number of reads from constant globals");
 STATISTIC(NumOmittedReadsFromVtable, "Number of vtable reads");
 STATISTIC(NumOmittedNonCaptured, "Number of accesses ignored due to capturing");
 
-Type * OrdTy;
+// static const char *const kCDSModuleCtorName = "cds.module_ctor";
+// static const char *const kCDSInitName = "cds_init";
 
+Type * OrdTy;
+Type * IntPtrTy;
 Type * Int8PtrTy;
 Type * Int16PtrTy;
 Type * Int32PtrTy;
@@ -89,12 +109,12 @@ Type * VoidTy;
 
 static const size_t kNumberOfAccessSizes = 4;
 
-int getAtomicOrderIndex(AtomicOrdering order){
+int getAtomicOrderIndex(AtomicOrdering order) {
        switch (order) {
                case AtomicOrdering::Monotonic: 
                        return (int)AtomicOrderingCABI::relaxed;
-               //  case AtomicOrdering::Consume:         // not specified yet
-               //    return AtomicOrderingCABI::consume;
+               //case AtomicOrdering::Consume:         // not specified yet
+               //      return AtomicOrderingCABI::consume;
                case AtomicOrdering::Acquire: 
                        return (int)AtomicOrderingCABI::acquire;
                case AtomicOrdering::Release: 
@@ -109,50 +129,112 @@ int getAtomicOrderIndex(AtomicOrdering order){
        }
 }
 
+AtomicOrderingCABI indexToAtomicOrder(int index) {
+       switch (index) {
+               case 0:
+                       return AtomicOrderingCABI::relaxed;
+               case 1:
+                       return AtomicOrderingCABI::consume;
+               case 2:
+                       return AtomicOrderingCABI::acquire;
+               case 3:
+                       return AtomicOrderingCABI::release;
+               case 4:
+                       return AtomicOrderingCABI::acq_rel;
+               case 5:
+                       return AtomicOrderingCABI::seq_cst;
+               default:
+                       errs() << "Bad Atomic index\n";
+                       return AtomicOrderingCABI::seq_cst;
+       }
+}
+
+/* According to atomic_base.h: __cmpexch_failure_order */
+int AtomicCasFailureOrderIndex(int index) {
+       AtomicOrderingCABI succ_order = indexToAtomicOrder(index);
+       AtomicOrderingCABI fail_order;
+       if (succ_order == AtomicOrderingCABI::acq_rel)
+               fail_order = AtomicOrderingCABI::acquire;
+       else if (succ_order == AtomicOrderingCABI::release) 
+               fail_order = AtomicOrderingCABI::relaxed;
+       else
+               fail_order = succ_order;
+
+       return (int) fail_order;
+}
+
+/* The original function checkSanitizerInterfaceFunction was defined
+ * in llvm/Transforms/Utils/ModuleUtils.h
+ */
+static Function * checkCDSPassInterfaceFunction(Constant *FuncOrBitcast) {
+       if (isa<Function>(FuncOrBitcast))
+               return cast<Function>(FuncOrBitcast);
+       FuncOrBitcast->print(errs());
+       errs() << "\n";
+       std::string Err;
+       raw_string_ostream Stream(Err);
+       Stream << "CDSPass interface function redefined: " << *FuncOrBitcast;
+       report_fatal_error(Err);
+}
+
 namespace {
        struct CDSPass : public FunctionPass {
-               static char ID;
                CDSPass() : FunctionPass(ID) {}
-               bool runOnFunction(Function &F) override; 
+               StringRef getPassName() const override;
+               bool runOnFunction(Function &F) override;
+               bool doInitialization(Module &M) override;
+               static char ID;
 
        private:
                void initializeCallbacks(Module &M);
                bool instrumentLoadOrStore(Instruction *I, const DataLayout &DL);
+               bool instrumentVolatile(Instruction *I, const DataLayout &DL);
+               bool instrumentMemIntrinsic(Instruction *I);
                bool isAtomicCall(Instruction *I);
                bool instrumentAtomic(Instruction *I, const DataLayout &DL);
                bool instrumentAtomicCall(CallInst *CI, const DataLayout &DL);
+               bool shouldInstrumentBeforeAtomics(Instruction *I);
                void chooseInstructionsToInstrument(SmallVectorImpl<Instruction *> &Local,
                                                                                        SmallVectorImpl<Instruction *> &All,
                                                                                        const DataLayout &DL);
                bool addrPointsToConstantData(Value *Addr);
                int getMemoryAccessFuncIndex(Value *Addr, const DataLayout &DL);
-
-               // Callbacks to run-time library are computed in doInitialization.
-               Constant * CDSFuncEntry;
-               Constant * CDSFuncExit;
-
-               Constant * CDSLoad[kNumberOfAccessSizes];
-               Constant * CDSStore[kNumberOfAccessSizes];
-               Constant * CDSAtomicInit[kNumberOfAccessSizes];
-               Constant * CDSAtomicLoad[kNumberOfAccessSizes];
-               Constant * CDSAtomicStore[kNumberOfAccessSizes];
-               Constant * CDSAtomicRMW[AtomicRMWInst::LAST_BINOP + 1][kNumberOfAccessSizes];
-               Constant * CDSAtomicCAS_V1[kNumberOfAccessSizes];
-               Constant * CDSAtomicCAS_V2[kNumberOfAccessSizes];
-               Constant * CDSAtomicThreadFence;
+               bool instrumentLoops(Function &F);
+
+               Function * CDSFuncEntry;
+               Function * CDSFuncExit;
+
+               Function * CDSLoad[kNumberOfAccessSizes];
+               Function * CDSStore[kNumberOfAccessSizes];
+               Function * CDSVolatileLoad[kNumberOfAccessSizes];
+               Function * CDSVolatileStore[kNumberOfAccessSizes];
+               Function * CDSAtomicInit[kNumberOfAccessSizes];
+               Function * CDSAtomicLoad[kNumberOfAccessSizes];
+               Function * CDSAtomicStore[kNumberOfAccessSizes];
+               Function * CDSAtomicRMW[AtomicRMWInst::LAST_BINOP + 1][kNumberOfAccessSizes];
+               Function * CDSAtomicCAS_V1[kNumberOfAccessSizes];
+               Function * CDSAtomicCAS_V2[kNumberOfAccessSizes];
+               Function * CDSAtomicThreadFence;
+               Function * MemmoveFn, * MemcpyFn, * MemsetFn;
+               // Function * CDSCtorFunction;
+
+               std::vector<StringRef> AtomicFuncNames;
+               std::vector<StringRef> PartialAtomicFuncNames;
        };
 }
 
-static bool isVtableAccess(Instruction *I) {
-       if (MDNode *Tag = I->getMetadata(LLVMContext::MD_tbaa))
-               return Tag->isTBAAVtableAccess();
-       return false;
+StringRef CDSPass::getPassName() const {
+       return "CDSPass";
 }
 
 void CDSPass::initializeCallbacks(Module &M) {
        LLVMContext &Ctx = M.getContext();
+       AttributeList Attr;
+       Attr = Attr.addAttribute(Ctx, AttributeList::FunctionIndex,
+                       Attribute::NoUnwind);
 
        Type * Int1Ty = Type::getInt1Ty(Ctx);
+       Type * Int32Ty = Type::getInt32Ty(Ctx);
        OrdTy = Type::getInt32Ty(Ctx);
 
        Int8PtrTy  = Type::getInt8PtrTy(Ctx);
@@ -162,6 +244,13 @@ void CDSPass::initializeCallbacks(Module &M) {
 
        VoidTy = Type::getVoidTy(Ctx);
 
+       CDSFuncEntry = checkCDSPassInterfaceFunction(
+                                               M.getOrInsertFunction("cds_func_entry", 
+                                               Attr, VoidTy, Int8PtrTy));
+       CDSFuncExit = checkCDSPassInterfaceFunction(
+                                               M.getOrInsertFunction("cds_func_exit", 
+                                               Attr, VoidTy, Int8PtrTy));
+
        // Get the function to call from our untime library.
        for (unsigned i = 0; i < kNumberOfAccessSizes; i++) {
                const unsigned ByteSize = 1U << i;
@@ -177,18 +266,31 @@ void CDSPass::initializeCallbacks(Module &M) {
                // void cds_atomic_store8 (void * obj, int atomic_index, uint8_t val)
                SmallString<32> LoadName("cds_load" + BitSizeStr);
                SmallString<32> StoreName("cds_store" + BitSizeStr);
+               SmallString<32> VolatileLoadName("cds_volatile_load" + BitSizeStr);
+               SmallString<32> VolatileStoreName("cds_volatile_store" + BitSizeStr);
                SmallString<32> AtomicInitName("cds_atomic_init" + BitSizeStr);
                SmallString<32> AtomicLoadName("cds_atomic_load" + BitSizeStr);
                SmallString<32> AtomicStoreName("cds_atomic_store" + BitSizeStr);
 
-               CDSLoad[i]  = M.getOrInsertFunction(LoadName, VoidTy, PtrTy);
-               CDSStore[i] = M.getOrInsertFunction(StoreName, VoidTy, PtrTy);
-               CDSAtomicInit[i] = M.getOrInsertFunction(AtomicInitName, 
-                                                               VoidTy, PtrTy, Ty, Int8PtrTy);
-               CDSAtomicLoad[i]  = M.getOrInsertFunction(AtomicLoadName, 
-                                                               Ty, PtrTy, OrdTy, Int8PtrTy);
-               CDSAtomicStore[i] = M.getOrInsertFunction(AtomicStoreName, 
-                                                               VoidTy, PtrTy, Ty, OrdTy, Int8PtrTy);
+               CDSLoad[i]  = checkCDSPassInterfaceFunction(
+                                                       M.getOrInsertFunction(LoadName, Attr, VoidTy, Int8PtrTy));
+               CDSStore[i] = checkCDSPassInterfaceFunction(
+                                                       M.getOrInsertFunction(StoreName, Attr, VoidTy, Int8PtrTy));
+               CDSVolatileLoad[i]  = checkCDSPassInterfaceFunction(
+                                                               M.getOrInsertFunction(VolatileLoadName,
+                                                               Attr, Ty, PtrTy, Int8PtrTy));
+               CDSVolatileStore[i] = checkCDSPassInterfaceFunction(
+                                                               M.getOrInsertFunction(VolatileStoreName, 
+                                                               Attr, VoidTy, PtrTy, Ty, Int8PtrTy));
+               CDSAtomicInit[i] = checkCDSPassInterfaceFunction(
+                                                       M.getOrInsertFunction(AtomicInitName, 
+                                                       Attr, VoidTy, PtrTy, Ty, Int8PtrTy));
+               CDSAtomicLoad[i]  = checkCDSPassInterfaceFunction(
+                                                               M.getOrInsertFunction(AtomicLoadName, 
+                                                               Attr, Ty, PtrTy, OrdTy, Int8PtrTy));
+               CDSAtomicStore[i] = checkCDSPassInterfaceFunction(
+                                                               M.getOrInsertFunction(AtomicStoreName, 
+                                                               Attr, VoidTy, PtrTy, Ty, OrdTy, Int8PtrTy));
 
                for (int op = AtomicRMWInst::FIRST_BINOP; 
                        op <= AtomicRMWInst::LAST_BINOP; ++op) {
@@ -211,23 +313,71 @@ void CDSPass::initializeCallbacks(Module &M) {
                                continue;
 
                        SmallString<32> AtomicRMWName("cds_atomic" + NamePart + BitSizeStr);
-                       CDSAtomicRMW[op][i] = M.getOrInsertFunction(AtomicRMWName, 
-                                                                               Ty, PtrTy, Ty, OrdTy, Int8PtrTy);
+                       CDSAtomicRMW[op][i] = checkCDSPassInterfaceFunction(
+                                                                       M.getOrInsertFunction(AtomicRMWName, 
+                                                                       Attr, Ty, PtrTy, Ty, OrdTy, Int8PtrTy));
                }
 
                // only supportes strong version
                SmallString<32> AtomicCASName_V1("cds_atomic_compare_exchange" + BitSizeStr + "_v1");
                SmallString<32> AtomicCASName_V2("cds_atomic_compare_exchange" + BitSizeStr + "_v2");
-               CDSAtomicCAS_V1[i] = M.getOrInsertFunction(AtomicCASName_V1, 
-                                                               Ty, PtrTy, Ty, Ty, OrdTy, OrdTy, Int8PtrTy);
-               CDSAtomicCAS_V2[i] = M.getOrInsertFunction(AtomicCASName_V2, 
-                                                               Int1Ty, PtrTy, PtrTy, Ty, OrdTy, OrdTy, Int8PtrTy);
+               CDSAtomicCAS_V1[i] = checkCDSPassInterfaceFunction(
+                                                               M.getOrInsertFunction(AtomicCASName_V1, 
+                                                               Attr, Ty, PtrTy, Ty, Ty, OrdTy, OrdTy, Int8PtrTy));
+               CDSAtomicCAS_V2[i] = checkCDSPassInterfaceFunction(
+                                                               M.getOrInsertFunction(AtomicCASName_V2, 
+                                                               Attr, Int1Ty, PtrTy, PtrTy, Ty, OrdTy, OrdTy, Int8PtrTy));
        }
 
-       CDSAtomicThreadFence = M.getOrInsertFunction("cds_atomic_thread_fence", 
-                                                                                                       VoidTy, OrdTy, Int8PtrTy);
+       CDSAtomicThreadFence = checkCDSPassInterfaceFunction(
+                       M.getOrInsertFunction("cds_atomic_thread_fence", Attr, VoidTy, OrdTy, Int8PtrTy));
+
+       MemmoveFn = checkCDSPassInterfaceFunction(
+                                       M.getOrInsertFunction("memmove", Attr, Int8PtrTy, Int8PtrTy,
+                                       Int8PtrTy, IntPtrTy));
+       MemcpyFn = checkCDSPassInterfaceFunction(
+                                       M.getOrInsertFunction("memcpy", Attr, Int8PtrTy, Int8PtrTy,
+                                       Int8PtrTy, IntPtrTy));
+       MemsetFn = checkCDSPassInterfaceFunction(
+                                       M.getOrInsertFunction("memset", Attr, Int8PtrTy, Int8PtrTy,
+                                       Int32Ty, IntPtrTy));
 }
 
+bool CDSPass::doInitialization(Module &M) {
+       const DataLayout &DL = M.getDataLayout();
+       IntPtrTy = DL.getIntPtrType(M.getContext());
+       
+       // createSanitizerCtorAndInitFunctions is defined in "llvm/Transforms/Utils/ModuleUtils.h"
+       // We do not support it yet
+       /*
+       std::tie(CDSCtorFunction, std::ignore) = createSanitizerCtorAndInitFunctions(
+                       M, kCDSModuleCtorName, kCDSInitName, {}, {});
+
+       appendToGlobalCtors(M, CDSCtorFunction, 0);
+       */
+
+       AtomicFuncNames = 
+       {
+               "atomic_init", "atomic_load", "atomic_store", 
+               "atomic_fetch_", "atomic_exchange", "atomic_compare_exchange_"
+       };
+
+       PartialAtomicFuncNames = 
+       { 
+               "load", "store", "fetch", "exchange", "compare_exchange_" 
+       };
+
+       return true;
+}
+
+static bool isVtableAccess(Instruction *I) {
+       if (MDNode *Tag = I->getMetadata(LLVMContext::MD_tbaa))
+               return Tag->isTBAAVtableAccess();
+       return false;
+}
+
+// Do not instrument known races/"benign races" that come from compiler
+// instrumentatin. The user has no way of suppressing them.
 static bool shouldInstrumentReadWriteFromAddress(const Module *M, Value *Addr) {
        // Peel off GEPs and BitCasts.
        Addr = Addr->stripInBoundsOffsets();
@@ -280,52 +430,27 @@ bool CDSPass::addrPointsToConstantData(Value *Addr) {
        return false;
 }
 
-bool CDSPass::runOnFunction(Function &F) {
-       if (F.getName() == "main") {
-               F.setName("user_main");
-               errs() << "main replaced by user_main\n";
-       }
-
-       if (true) {
-               initializeCallbacks( *F.getParent() );
-
-               SmallVector<Instruction*, 8> AllLoadsAndStores;
-               SmallVector<Instruction*, 8> LocalLoadsAndStores;
-               SmallVector<Instruction*, 8> AtomicAccesses;
-
-               std::vector<Instruction *> worklist;
-
-               bool Res = false;
-               const DataLayout &DL = F.getParent()->getDataLayout();
-
-               // errs() << "--- " << F.getName() << "---\n";
-
-               for (auto &B : F) {
-                       for (auto &I : B) {
-                               if ( (&I)->isAtomic() || isAtomicCall(&I) ) {
-                                       AtomicAccesses.push_back(&I);
-                               } else if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
-                                       LocalLoadsAndStores.push_back(&I);
-                               } else if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
-                                       // not implemented yet
-                               }
-                       }
-
-                       chooseInstructionsToInstrument(LocalLoadsAndStores, AllLoadsAndStores, DL);
-               }
-
-               for (auto Inst : AllLoadsAndStores) {
-                       // Res |= instrumentLoadOrStore(Inst, DL);
-                       // errs() << "load and store are replaced\n";
-               }
-
-               for (auto Inst : AtomicAccesses) {
-                       Res |= instrumentAtomic(Inst, DL);
-               }
-
-               if (F.getName() == "user_main") {
-                       // F.dump();
-               }
+bool CDSPass::shouldInstrumentBeforeAtomics(Instruction * Inst) {
+       if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) {
+               AtomicOrdering ordering = LI->getOrdering();
+               if ( isAtLeastOrStrongerThan(ordering, AtomicOrdering::Acquire) )
+                       return true;
+       } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) {
+               AtomicOrdering ordering = SI->getOrdering();
+               if ( isAtLeastOrStrongerThan(ordering, AtomicOrdering::Acquire) )
+                       return true;
+       } else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(Inst)) {
+               AtomicOrdering ordering = RMWI->getOrdering();
+               if ( isAtLeastOrStrongerThan(ordering, AtomicOrdering::Acquire) )
+                       return true;
+       } else if (AtomicCmpXchgInst *CASI = dyn_cast<AtomicCmpXchgInst>(Inst)) {
+               AtomicOrdering ordering = CASI->getSuccessOrdering();
+               if ( isAtLeastOrStrongerThan(ordering, AtomicOrdering::Acquire) )
+                       return true;
+       } else if (FenceInst *FI = dyn_cast<FenceInst>(Inst)) {
+               AtomicOrdering ordering = FI->getOrdering();
+               if ( isAtLeastOrStrongerThan(ordering, AtomicOrdering::Acquire) )
+                       return true;
        }
 
        return false;
@@ -373,6 +498,110 @@ void CDSPass::chooseInstructionsToInstrument(
        Local.clear();
 }
 
+/* Not implemented
+void CDSPass::InsertRuntimeIgnores(Function &F) {
+       IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI());
+       IRB.CreateCall(CDSIgnoreBegin);
+       EscapeEnumerator EE(F, "cds_ignore_cleanup", ClHandleCxxExceptions);
+       while (IRBuilder<> *AtExit = EE.Next()) {
+               AtExit->CreateCall(CDSIgnoreEnd);
+       }
+}*/
+
+bool CDSPass::runOnFunction(Function &F) {
+       initializeCallbacks( *F.getParent() );
+       SmallVector<Instruction*, 8> AllLoadsAndStores;
+       SmallVector<Instruction*, 8> LocalLoadsAndStores;
+       SmallVector<Instruction*, 8> VolatileLoadsAndStores;
+       SmallVector<Instruction*, 8> AtomicAccesses;
+       SmallVector<Instruction*, 8> MemIntrinCalls;
+
+       bool Res = false;
+       bool HasAtomic = false;
+       bool HasVolatile = false;
+       const DataLayout &DL = F.getParent()->getDataLayout();
+
+       // instrumentLoops(F);
+
+       for (auto &BB : F) {
+               for (auto &Inst : BB) {
+                       if ( (&Inst)->isAtomic() ) {
+                               AtomicAccesses.push_back(&Inst);
+                               HasAtomic = true;
+
+                               if (shouldInstrumentBeforeAtomics(&Inst)) {
+                                       chooseInstructionsToInstrument(LocalLoadsAndStores, AllLoadsAndStores,
+                                               DL);
+                               }
+                       } else if (isAtomicCall(&Inst) ) {
+                               AtomicAccesses.push_back(&Inst);
+                               HasAtomic = true;
+                               chooseInstructionsToInstrument(LocalLoadsAndStores, AllLoadsAndStores,
+                                       DL);
+                       } else if (isa<LoadInst>(Inst) || isa<StoreInst>(Inst)) {
+                               LoadInst *LI = dyn_cast<LoadInst>(&Inst);
+                               StoreInst *SI = dyn_cast<StoreInst>(&Inst);
+                               bool isVolatile = ( LI ? LI->isVolatile() : SI->isVolatile() );
+
+                               if (isVolatile) {
+                                       VolatileLoadsAndStores.push_back(&Inst);
+                                       HasVolatile = true;
+                               } else
+                                       LocalLoadsAndStores.push_back(&Inst);
+                       } else if (isa<CallInst>(Inst) || isa<InvokeInst>(Inst)) {
+                               if (isa<MemIntrinsic>(Inst))
+                                       MemIntrinCalls.push_back(&Inst);
+
+                               /*if (CallInst *CI = dyn_cast<CallInst>(&Inst))
+                                       maybeMarkSanitizerLibraryCallNoBuiltin(CI, TLI);
+                               */
+
+                               chooseInstructionsToInstrument(LocalLoadsAndStores, AllLoadsAndStores,
+                                       DL);
+                       }
+               }
+
+               chooseInstructionsToInstrument(LocalLoadsAndStores, AllLoadsAndStores, DL);
+       }
+
+       for (auto Inst : AllLoadsAndStores) {
+               Res |= instrumentLoadOrStore(Inst, DL);
+       }
+
+       for (auto Inst : VolatileLoadsAndStores) {
+               Res |= instrumentVolatile(Inst, DL);
+       }
+
+       for (auto Inst : AtomicAccesses) {
+               Res |= instrumentAtomic(Inst, DL);
+       }
+
+       for (auto Inst : MemIntrinCalls) {
+               Res |= instrumentMemIntrinsic(Inst);
+       }
+
+       // Instrument function entry and exit for functions containing atomics or volatiles
+       if (Res && ( HasAtomic || HasVolatile) ) {
+               IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI());
+               /* Unused for now
+               Value *ReturnAddress = IRB.CreateCall(
+                       Intrinsic::getDeclaration(F.getParent(), Intrinsic::returnaddress),
+                       IRB.getInt32(0));
+               */
+
+               Value * FuncName = IRB.CreateGlobalStringPtr(F.getName());
+               IRB.CreateCall(CDSFuncEntry, FuncName);
+
+               EscapeEnumerator EE(F, "cds_cleanup", true);
+               while (IRBuilder<> *AtExit = EE.Next()) {
+                 AtExit->CreateCall(CDSFuncExit, FuncName);
+               }
+
+               Res = true;
+       }
+
+       return false;
+}
 
 bool CDSPass::instrumentLoadOrStore(Instruction *I,
                                                                        const DataLayout &DL) {
@@ -386,97 +615,161 @@ bool CDSPass::instrumentLoadOrStore(Instruction *I,
        // As such they cannot have regular uses like an instrumentation function and
        // it makes no sense to track them as memory.
        if (Addr->isSwiftError())
-       return false;
+               return false;
 
        int Idx = getMemoryAccessFuncIndex(Addr, DL);
+       if (Idx < 0)
+               return false;
 
-//  not supported by CDS yet
-/*  if (IsWrite && isVtableAccess(I)) {
-    LLVM_DEBUG(dbgs() << "  VPTR : " << *I << "\n");
-    Value *StoredValue = cast<StoreInst>(I)->getValueOperand();
-    // StoredValue may be a vector type if we are storing several vptrs at once.
-    // In this case, just take the first element of the vector since this is
-    // enough to find vptr races.
-    if (isa<VectorType>(StoredValue->getType()))
-      StoredValue = IRB.CreateExtractElement(
-          StoredValue, ConstantInt::get(IRB.getInt32Ty(), 0));
-    if (StoredValue->getType()->isIntegerTy())
-      StoredValue = IRB.CreateIntToPtr(StoredValue, IRB.getInt8PtrTy());
-    // Call TsanVptrUpdate.
-    IRB.CreateCall(TsanVptrUpdate,
-                   {IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()),
-                    IRB.CreatePointerCast(StoredValue, IRB.getInt8PtrTy())});
-    NumInstrumentedVtableWrites++;
-    return true;
-  }
-
-  if (!IsWrite && isVtableAccess(I)) {
-    IRB.CreateCall(TsanVptrLoad,
-                   IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()));
-    NumInstrumentedVtableReads++;
-    return true;
-  }
-*/
+       if (IsWrite && isVtableAccess(I)) {
+               /* TODO
+               LLVM_DEBUG(dbgs() << "  VPTR : " << *I << "\n");
+               Value *StoredValue = cast<StoreInst>(I)->getValueOperand();
+               // StoredValue may be a vector type if we are storing several vptrs at once.
+               // In this case, just take the first element of the vector since this is
+               // enough to find vptr races.
+               if (isa<VectorType>(StoredValue->getType()))
+                       StoredValue = IRB.CreateExtractElement(
+                                       StoredValue, ConstantInt::get(IRB.getInt32Ty(), 0));
+               if (StoredValue->getType()->isIntegerTy())
+                       StoredValue = IRB.CreateIntToPtr(StoredValue, IRB.getInt8PtrTy());
+               // Call TsanVptrUpdate.
+               IRB.CreateCall(TsanVptrUpdate,
+                                               {IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()),
+                                                       IRB.CreatePointerCast(StoredValue, IRB.getInt8PtrTy())});
+               NumInstrumentedVtableWrites++;
+               */
+               return true;
+       }
 
+       if (!IsWrite && isVtableAccess(I)) {
+               /* TODO
+               IRB.CreateCall(TsanVptrLoad,
+                                                IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()));
+               NumInstrumentedVtableReads++;
+               */
+               return true;
+       }
+
+       // TODO: unaligned reads and writes
        Value *OnAccessFunc = nullptr;
        OnAccessFunc = IsWrite ? CDSStore[Idx] : CDSLoad[Idx];
+       IRB.CreateCall(OnAccessFunc, IRB.CreatePointerCast(Addr, IRB.getInt8PtrTy()));
+       if (IsWrite) NumInstrumentedWrites++;
+       else         NumInstrumentedReads++;
+       return true;
+}
 
-       Type *ArgType = IRB.CreatePointerCast(Addr, Addr->getType())->getType();
+bool CDSPass::instrumentVolatile(Instruction * I, const DataLayout &DL) {
+       IRBuilder<> IRB(I);
+       Value *position = getPosition(I, IRB);
+
+       if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
+               Value *Addr = LI->getPointerOperand();
+               int Idx=getMemoryAccessFuncIndex(Addr, DL);
+               if (Idx < 0)
+                       return false;
+               const unsigned ByteSize = 1U << Idx;
+               const unsigned BitSize = ByteSize * 8;
+               Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
+               Type *PtrTy = Ty->getPointerTo();
+               Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy), position};
 
-       if ( ArgType != Int8PtrTy && ArgType != Int16PtrTy && 
-                       ArgType != Int32PtrTy && ArgType != Int64PtrTy ) {
-               //errs() << "A load or store of type ";
-               //errs() << *ArgType;
-               //errs() << " is passed in\n";
-               return false;   // if other types of load or stores are passed in
+               Type *OrigTy = cast<PointerType>(Addr->getType())->getElementType();
+               Value *C = IRB.CreateCall(CDSVolatileLoad[Idx], Args);
+               Value *Cast = IRB.CreateBitOrPointerCast(C, OrigTy);
+               I->replaceAllUsesWith(Cast);
+       } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
+               Value *Addr = SI->getPointerOperand();
+               int Idx=getMemoryAccessFuncIndex(Addr, DL);
+               if (Idx < 0)
+                       return false;
+               const unsigned ByteSize = 1U << Idx;
+               const unsigned BitSize = ByteSize * 8;
+               Type *Ty = Type::getIntNTy(IRB.getContext(), BitSize);
+               Type *PtrTy = Ty->getPointerTo();
+               Value *Args[] = {IRB.CreatePointerCast(Addr, PtrTy),
+                                         IRB.CreateBitOrPointerCast(SI->getValueOperand(), Ty),
+                                         position};
+               CallInst *C = CallInst::Create(CDSVolatileStore[Idx], Args);
+               ReplaceInstWithInst(I, C);
+       } else {
+               return false;
        }
-       IRB.CreateCall(OnAccessFunc, IRB.CreatePointerCast(Addr, Addr->getType()));
-       if (IsWrite) NumInstrumentedWrites++;
-       else         NumInstrumentedReads++;
+
        return true;
 }
 
+bool CDSPass::instrumentMemIntrinsic(Instruction *I) {
+       IRBuilder<> IRB(I);
+       if (MemSetInst *M = dyn_cast<MemSetInst>(I)) {
+               IRB.CreateCall(
+                       MemsetFn,
+                       {IRB.CreatePointerCast(M->getArgOperand(0), IRB.getInt8PtrTy()),
+                        IRB.CreateIntCast(M->getArgOperand(1), IRB.getInt32Ty(), false),
+                        IRB.CreateIntCast(M->getArgOperand(2), IntPtrTy, false)});
+               I->eraseFromParent();
+       } else if (MemTransferInst *M = dyn_cast<MemTransferInst>(I)) {
+               IRB.CreateCall(
+                       isa<MemCpyInst>(M) ? MemcpyFn : MemmoveFn,
+                       {IRB.CreatePointerCast(M->getArgOperand(0), IRB.getInt8PtrTy()),
+                        IRB.CreatePointerCast(M->getArgOperand(1), IRB.getInt8PtrTy()),
+                        IRB.CreateIntCast(M->getArgOperand(2), IntPtrTy, false)});
+               I->eraseFromParent();
+       }
+       return false;
+}
+
 bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
        IRBuilder<> IRB(I);
-       // LLVMContext &Ctx = IRB.getContext();
 
        if (auto *CI = dyn_cast<CallInst>(I)) {
                return instrumentAtomicCall(CI, DL);
        }
 
        Value *position = getPosition(I, IRB);
-
        if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
                Value *Addr = LI->getPointerOperand();
                int Idx=getMemoryAccessFuncIndex(Addr, DL);
+               if (Idx < 0)
+                       return false;
+
                int atomic_order_index = getAtomicOrderIndex(LI->getOrdering());
                Value *order = ConstantInt::get(OrdTy, atomic_order_index);
-               Value *args[] = {Addr, order, position};
-               Instruction* funcInst=CallInst::Create(CDSAtomicLoad[Idx], args);
+               Value *Args[] = {Addr, order, position};
+               Instruction* funcInst = CallInst::Create(CDSAtomicLoad[Idx], Args);
                ReplaceInstWithInst(LI, funcInst);
        } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
                Value *Addr = SI->getPointerOperand();
                int Idx=getMemoryAccessFuncIndex(Addr, DL);
+               if (Idx < 0)
+                       return false;
+
                int atomic_order_index = getAtomicOrderIndex(SI->getOrdering());
                Value *val = SI->getValueOperand();
                Value *order = ConstantInt::get(OrdTy, atomic_order_index);
-               Value *args[] = {Addr, val, order, position};
-               Instruction* funcInst=CallInst::Create(CDSAtomicStore[Idx], args);
+               Value *Args[] = {Addr, val, order, position};
+               Instruction* funcInst = CallInst::Create(CDSAtomicStore[Idx], Args);
                ReplaceInstWithInst(SI, funcInst);
        } else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) {
                Value *Addr = RMWI->getPointerOperand();
                int Idx=getMemoryAccessFuncIndex(Addr, DL);
+               if (Idx < 0)
+                       return false;
+
                int atomic_order_index = getAtomicOrderIndex(RMWI->getOrdering());
                Value *val = RMWI->getValOperand();
                Value *order = ConstantInt::get(OrdTy, atomic_order_index);
-               Value *args[] = {Addr, val, order, position};
-               Instruction* funcInst = CallInst::Create(CDSAtomicRMW[RMWI->getOperation()][Idx], args);
+               Value *Args[] = {Addr, val, order, position};
+               Instruction* funcInst = CallInst::Create(CDSAtomicRMW[RMWI->getOperation()][Idx], Args);
                ReplaceInstWithInst(RMWI, funcInst);
        } else if (AtomicCmpXchgInst *CASI = dyn_cast<AtomicCmpXchgInst>(I)) {
                IRBuilder<> IRB(CASI);
 
                Value *Addr = CASI->getPointerOperand();
                int Idx=getMemoryAccessFuncIndex(Addr, DL);
+               if (Idx < 0)
+                       return false;
 
                const unsigned ByteSize = 1U << Idx;
                const unsigned BitSize = ByteSize * 8;
@@ -530,11 +823,17 @@ bool CDSPass::isAtomicCall(Instruction *I) {
                        return false;
 
                StringRef funName = fun->getName();
-               // todo: come up with better rules for function name checking
-               if ( funName.contains("atomic_") ) {
-                       return true;
-               } else if (funName.contains("atomic") ) {
-                       return true;
+
+               // TODO: come up with better rules for function name checking
+               for (StringRef name : AtomicFuncNames) {
+                       if ( funName.contains(name) ) 
+                               return true;
+               }
+               
+               for (StringRef PartialName : PartialAtomicFuncNames) {
+                       if (funName.contains(PartialName) && 
+                                       funName.contains("atomic") )
+                               return true;
                }
        }
 
@@ -559,6 +858,7 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
 
        // the pointer to the address is always the first argument
        Value *OrigPtr = parameters[0];
+
        int Idx = getMemoryAccessFuncIndex(OrigPtr, DL);
        if (Idx < 0)
                return false;
@@ -570,13 +870,22 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
 
        // atomic_init; args = {obj, order}
        if (funName.contains("atomic_init")) {
+               Value *OrigVal = parameters[1];
+
                Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
-               Value *val = IRB.CreateBitOrPointerCast(parameters[1], Ty);
+               Value *val;
+               if (OrigVal->getType()->isPtrOrPtrVectorTy())
+                       val = IRB.CreatePointerCast(OrigVal, Ty);
+               else
+                       val = IRB.CreateIntCast(OrigVal, Ty, true);
+
                Value *args[] = {ptr, val, position};
 
+               if (!checkSignature(CDSAtomicInit[Idx], args))
+                       return false;
+
                Instruction* funcInst = CallInst::Create(CDSAtomicInit[Idx], args);
                ReplaceInstWithInst(CI, funcInst);
-
                return true;
        }
 
@@ -592,19 +901,29 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
                        order = ConstantInt::get(OrdTy, 
                                                        (int) AtomicOrderingCABI::seq_cst);
                Value *args[] = {ptr, order, position};
-               
+
+               if (!checkSignature(CDSAtomicLoad[Idx], args))
+                       return false;
+
                Instruction* funcInst = CallInst::Create(CDSAtomicLoad[Idx], args);
                ReplaceInstWithInst(CI, funcInst);
 
                return true;
        } else if (funName.contains("atomic") && 
-                                       funName.contains("load")) {
+                                       funName.contains("load") ) {
                // does this version of call always have an atomic order as an argument?
                Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
                Value *order = IRB.CreateBitOrPointerCast(parameters[1], OrdTy);
                Value *args[] = {ptr, order, position};
 
-               //Instruction* funcInst=CallInst::Create(CDSAtomicLoad[Idx], args);
+               // Without this check, gdax does not compile :(
+               if (!CI->getType()->isPointerTy()) {
+                       return false;   
+               } 
+
+               if (!checkSignature(CDSAtomicLoad[Idx], args))
+                       return false;
+
                CallInst *funcInst = IRB.CreateCall(CDSAtomicLoad[Idx], args);
                Value *RetVal = IRB.CreateIntToPtr(funcInst, CI->getType());
 
@@ -628,21 +947,35 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
                        order = ConstantInt::get(OrdTy, 
                                                        (int) AtomicOrderingCABI::seq_cst);
                Value *args[] = {ptr, val, order, position};
-               
+
+               if (!checkSignature(CDSAtomicStore[Idx], args))
+                       return false;
+
                Instruction* funcInst = CallInst::Create(CDSAtomicStore[Idx], args);
                ReplaceInstWithInst(CI, funcInst);
 
                return true;
        } else if (funName.contains("atomic") && 
-                                       funName.contains("EEEE5store")) {
-               // does this version of call always have an atomic order as an argument?
-               Value *OrigVal = parameters[1];
+                                       funName.contains("store") ) {
+               // Does this version of call always have an atomic order as an argument?
+               if (parameters.size() < 3)
+                       return false;
 
+               Value *OrigVal = parameters[1];
                Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
-               Value *val = IRB.CreatePointerCast(OrigVal, Ty);
-               Value *order = IRB.CreateBitOrPointerCast(parameters[1], OrdTy);
+
+               Value *val;
+               if (OrigVal->getType()->isPtrOrPtrVectorTy())
+                       val = IRB.CreatePointerCast(OrigVal, Ty);
+               else
+                       val = IRB.CreateIntCast(OrigVal, Ty, true);
+
+               Value *order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
                Value *args[] = {ptr, val, order, position};
 
+               if (!checkSignature(CDSAtomicStore[Idx], args))
+                       return false;
+
                Instruction* funcInst = CallInst::Create(CDSAtomicStore[Idx], args);
                ReplaceInstWithInst(CI, funcInst);
 
@@ -651,7 +984,8 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
 
        // atomic_fetch_*; args = {obj, val, order}
        if (funName.contains("atomic_fetch_") || 
-                       funName.contains("atomic_exchange") ) {
+               funName.contains("atomic_exchange")) {
+
                bool isExplicit = funName.contains("_explicit");
                Value *OrigVal = parameters[1];
 
@@ -669,12 +1003,17 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
                else if ( funName.contains("atomic_exchange") )
                        op = AtomicRMWInst::Xchg;
                else {
-                       errs() << "Unknown atomic read modify write operation\n";
+                       errs() << "Unknown atomic read-modify-write operation\n";
                        return false;
                }
 
                Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
-               Value *val = IRB.CreatePointerCast(OrigVal, Ty);
+               Value *val;
+               if (OrigVal->getType()->isPtrOrPtrVectorTy())
+                       val = IRB.CreatePointerCast(OrigVal, Ty);
+               else
+                       val = IRB.CreateIntCast(OrigVal, Ty, true);
+
                Value *order;
                if (isExplicit)
                        order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
@@ -682,20 +1021,57 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
                        order = ConstantInt::get(OrdTy, 
                                                        (int) AtomicOrderingCABI::seq_cst);
                Value *args[] = {ptr, val, order, position};
-               
+
+               if (!checkSignature(CDSAtomicRMW[op][Idx], args))
+                       return false;
+
                Instruction* funcInst = CallInst::Create(CDSAtomicRMW[op][Idx], args);
                ReplaceInstWithInst(CI, funcInst);
 
                return true;
        } else if (funName.contains("fetch")) {
-               errs() << "atomic exchange captured. Not implemented yet. ";
+               errs() << "atomic fetch captured. Not implemented yet. ";
                errs() << "See source file :";
                getPosition(CI, IRB, true);
+               return false;
        } else if (funName.contains("exchange") &&
                        !funName.contains("compare_exchange") ) {
-               errs() << "atomic exchange captured. Not implemented yet. ";
-               errs() << "See source file :";
-               getPosition(CI, IRB, true);
+               if (CI->getType()->isPointerTy()) {
+                       /**
+                        * TODO: instrument the following case
+                        * mcs-lock.h
+                        * std::atomic<struct T *> m_tail;
+                        * struct T * me;
+                        * struct T * pred = m_tail.exchange(me, memory_order_*);
+                        */
+                       errs() << "atomic exchange captured. Not implemented yet. ";
+                       errs() << "See source file :";
+                       getPosition(CI, IRB, true);
+
+                       return false;
+               }
+
+               Value *OrigVal = parameters[1];
+
+               Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
+               Value *val;
+               if (OrigVal->getType()->isPtrOrPtrVectorTy())
+                       val = IRB.CreatePointerCast(OrigVal, Ty);
+               else
+                       val = IRB.CreateIntCast(OrigVal, Ty, true);
+
+               Value *order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
+               Value *args[] = {ptr, val, order, position};
+
+               int op = AtomicRMWInst::Xchg;
+
+               if (!checkSignature(CDSAtomicRMW[op][Idx], args))
+                       return false;
+
+               Instruction* funcInst = CallInst::Create(CDSAtomicRMW[op][Idx], args);
+               ReplaceInstWithInst(CI, funcInst);
+
+               return true;
        }
 
        /* atomic_compare_exchange_*; 
@@ -711,7 +1087,18 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
                Value *order_succ, *order_fail;
                if (isExplicit) {
                        order_succ = IRB.CreateBitOrPointerCast(parameters[3], OrdTy);
-                       order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
+
+                       if (parameters.size() > 4) {
+                               order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
+                       } else {
+                               /* The failure order is not provided */
+                               order_fail = order_succ;
+                               ConstantInt * order_succ_cast = dyn_cast<ConstantInt>(order_succ);
+                               int index = order_succ_cast->getSExtValue();
+
+                               order_fail = ConstantInt::get(OrdTy,
+                                                               AtomicCasFailureOrderIndex(index));
+                       }
                } else  {
                        order_succ = ConstantInt::get(OrdTy, 
                                                        (int) AtomicOrderingCABI::seq_cst);
@@ -721,12 +1108,15 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
 
                Value *args[] = {Addr, CmpOperand, NewOperand, 
                                                        order_succ, order_fail, position};
-               
+
+               if (!checkSignature(CDSAtomicCAS_V2[Idx], args))
+                       return false;
+
                Instruction* funcInst = CallInst::Create(CDSAtomicCAS_V2[Idx], args);
                ReplaceInstWithInst(CI, funcInst);
 
                return true;
-       } else if ( funName.contains("compare_exchange_strong") || 
+       } else if ( funName.contains("compare_exchange_strong") ||
                                funName.contains("compare_exchange_weak") ) {
                Value *Addr = IRB.CreatePointerCast(OrigPtr, PtrTy);
                Value *CmpOperand = IRB.CreatePointerCast(parameters[1], PtrTy);
@@ -734,10 +1124,25 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
 
                Value *order_succ, *order_fail;
                order_succ = IRB.CreateBitOrPointerCast(parameters[3], OrdTy);
-               order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
+
+               if (parameters.size() > 4) {
+                       order_fail = IRB.CreateBitOrPointerCast(parameters[4], OrdTy);
+               } else {
+                       /* The failure order is not provided */
+                       order_fail = order_succ;
+                       ConstantInt * order_succ_cast = dyn_cast<ConstantInt>(order_succ);
+                       int index = order_succ_cast->getSExtValue();
+
+                       order_fail = ConstantInt::get(OrdTy,
+                                                       AtomicCasFailureOrderIndex(index));
+               }
 
                Value *args[] = {Addr, CmpOperand, NewOperand, 
                                                        order_succ, order_fail, position};
+
+               if (!checkSignature(CDSAtomicCAS_V2[Idx], args))
+                       return false;
+
                Instruction* funcInst = CallInst::Create(CDSAtomicCAS_V2[Idx], args);
                ReplaceInstWithInst(CI, funcInst);
 
@@ -760,10 +1165,65 @@ int CDSPass::getMemoryAccessFuncIndex(Value *Addr,
                return -1;
        }
        size_t Idx = countTrailingZeros(TypeSize / 8);
-       assert(Idx < kNumberOfAccessSizes);
+       //assert(Idx < kNumberOfAccessSizes);
+       if (Idx >= kNumberOfAccessSizes) {
+               return -1;
+       }
        return Idx;
 }
 
+bool CDSPass::instrumentLoops(Function &F)
+{
+       DominatorTree DT(F);
+       LoopInfo LI(DT);
+
+       SmallVector<Loop *, 4> Loops = LI.getLoopsInPreorder();
+       bool instrumented = false;
+
+       // Do a post-order traversal of the loops so that counter updates can be
+       // iteratively hoisted outside the loop nest.
+       for (auto *Loop : llvm::reverse(Loops)) {
+               bool instrument_loop = false;
+
+               // Iterator over loop blocks and search for atomics and volatiles
+               Loop::block_iterator it;
+               for (it = Loop->block_begin(); it != Loop->block_end(); it++) {
+                       BasicBlock * block = *it;
+                       for (auto &Inst : *block) {
+                               if ( (&Inst)->isAtomic() ) {
+                                       instrument_loop = true;
+                                       break;
+                               } else if (isAtomicCall(&Inst)) {
+                                       instrument_loop = true;
+                                       break;
+                               } else if (isa<LoadInst>(Inst) || isa<StoreInst>(Inst)) {
+                                       LoadInst *LI = dyn_cast<LoadInst>(&Inst);
+                                       StoreInst *SI = dyn_cast<StoreInst>(&Inst);
+                                       bool isVolatile = ( LI ? LI->isVolatile() : SI->isVolatile() );
+
+                                       if (isVolatile) {
+                                               instrument_loop = true;
+                                               break;
+                                       }
+                               }
+                       }
+
+                       if (instrument_loop)
+                               break;
+               }
+
+               if (instrument_loop) {
+                       // TODO: what to instrument?
+                       errs() << "Function: " << F.getName() << "\n";
+                       BasicBlock * header = Loop->getHeader();
+                       header->dump();
+
+                       instrumented = true;
+               }
+       }
+
+       return instrumented;
+}
 
 char CDSPass::ID = 0;
 
@@ -772,6 +1232,13 @@ static void registerCDSPass(const PassManagerBuilder &,
                                                        legacy::PassManagerBase &PM) {
        PM.add(new CDSPass());
 }
+
+/* Enable the pass when opt level is greater than 0 */
+static RegisterStandardPasses 
+       RegisterMyPass1(PassManagerBuilder::EP_OptimizerLast,
+registerCDSPass);
+
+/* Enable the pass when opt level is 0 */
 static RegisterStandardPasses 
-       RegisterMyPass(PassManagerBuilder::EP_OptimizerLast,
+       RegisterMyPass2(PassManagerBuilder::EP_EnabledOnOptLevel0,
 registerCDSPass);