Wrong parameter was passed in. Fixed now.
[c11llvm.git] / CDSPass.cpp
index 7546be0f0c779939772ca0e16779437e2d738548..c0bd2446930365fe8a40d0ca8d35489876e0967f 100644 (file)
@@ -42,6 +42,7 @@
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
+#include "llvm/Transforms/Utils/EscapeEnumerator.h"
 #include <vector>
 
 using namespace llvm;
@@ -59,10 +60,10 @@ Value *getPosition( Instruction * I, IRBuilder <> IRB, bool print = false)
        }
 
        if (print) {
-               errs() << position_string;
+               errs() << position_string << "\n";
        }
 
-       return IRB . CreateGlobalStringPtr (position_string);
+       return IRB.CreateGlobalStringPtr (position_string);
 }
 
 STATISTIC(NumInstrumentedReads, "Number of instrumented reads");
@@ -140,6 +141,9 @@ namespace {
                Constant * CDSAtomicCAS_V1[kNumberOfAccessSizes];
                Constant * CDSAtomicCAS_V2[kNumberOfAccessSizes];
                Constant * CDSAtomicThreadFence;
+
+               std::vector<StringRef> AtomicFuncNames;
+               std::vector<StringRef> PartialAtomicFuncNames;
        };
 }
 
@@ -162,6 +166,11 @@ void CDSPass::initializeCallbacks(Module &M) {
 
        VoidTy = Type::getVoidTy(Ctx);
 
+       CDSFuncEntry = M.getOrInsertFunction("cds_func_entry", 
+                                                               VoidTy, Int8PtrTy);
+       CDSFuncExit = M.getOrInsertFunction("cds_func_exit", 
+                                                               VoidTy, Int8PtrTy);
+
        // Get the function to call from our untime library.
        for (unsigned i = 0; i < kNumberOfAccessSizes; i++) {
                const unsigned ByteSize = 1U << i;
@@ -289,6 +298,17 @@ bool CDSPass::runOnFunction(Function &F) {
        if (true) {
                initializeCallbacks( *F.getParent() );
 
+               AtomicFuncNames = 
+               {
+                       "atomic_init", "atomic_load", "atomic_store", 
+                       "atomic_fetch_", "atomic_exchange", "atomic_compare_exchange_"
+               };
+
+               PartialAtomicFuncNames = 
+               { 
+                       "load", "store", "fetch", "exchange", "compare_exchange_" 
+               };
+
                SmallVector<Instruction*, 8> AllLoadsAndStores;
                SmallVector<Instruction*, 8> LocalLoadsAndStores;
                SmallVector<Instruction*, 8> AtomicAccesses;
@@ -296,6 +316,7 @@ bool CDSPass::runOnFunction(Function &F) {
                std::vector<Instruction *> worklist;
 
                bool Res = false;
+               bool HasAtomic = false;
                const DataLayout &DL = F.getParent()->getDataLayout();
 
                // errs() << "--- " << F.getName() << "---\n";
@@ -304,6 +325,7 @@ bool CDSPass::runOnFunction(Function &F) {
                        for (auto &I : B) {
                                if ( (&I)->isAtomic() || isAtomicCall(&I) ) {
                                        AtomicAccesses.push_back(&I);
+                                       HasAtomic = true;
                                } else if (isa<LoadInst>(I) || isa<StoreInst>(I)) {
                                        LocalLoadsAndStores.push_back(&I);
                                } else if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
@@ -315,16 +337,31 @@ bool CDSPass::runOnFunction(Function &F) {
                }
 
                for (auto Inst : AllLoadsAndStores) {
-                       // Res |= instrumentLoadOrStore(Inst, DL);
-                       // errs() << "load and store are replaced\n";
+                       Res |= instrumentLoadOrStore(Inst, DL);
                }
 
                for (auto Inst : AtomicAccesses) {
                        Res |= instrumentAtomic(Inst, DL);
                }
 
-               if (F.getName() == "user_main") {
-                       // F.dump();
+               // only instrument functions that contain atomics
+               if (Res && HasAtomic) {
+                       IRBuilder<> IRB(F.getEntryBlock().getFirstNonPHI());
+                       /* Unused for now
+                       Value *ReturnAddress = IRB.CreateCall(
+                               Intrinsic::getDeclaration(F.getParent(), Intrinsic::returnaddress),
+                               IRB.getInt32(0));
+                       */
+
+                       Value * FuncName = IRB.CreateGlobalStringPtr(F.getName());
+                       IRB.CreateCall(CDSFuncEntry, FuncName);
+
+                       EscapeEnumerator EE(F, "cds_cleanup", true);
+                       while (IRBuilder<> *AtExit = EE.Next()) {
+                         AtExit->CreateCall(CDSFuncExit, FuncName);
+                       }
+
+                       Res = true;
                }
        }
 
@@ -389,6 +426,8 @@ bool CDSPass::instrumentLoadOrStore(Instruction *I,
        return false;
 
        int Idx = getMemoryAccessFuncIndex(Addr, DL);
+       if (Idx < 0)
+               return false;
 
 //  not supported by CDS yet
 /*  if (IsWrite && isVtableAccess(I)) {
@@ -425,10 +464,8 @@ bool CDSPass::instrumentLoadOrStore(Instruction *I,
 
        if ( ArgType != Int8PtrTy && ArgType != Int16PtrTy && 
                        ArgType != Int32PtrTy && ArgType != Int64PtrTy ) {
-               //errs() << "A load or store of type ";
-               //errs() << *ArgType;
-               //errs() << " is passed in\n";
-               return false;   // if other types of load or stores are passed in
+               // if other types of load or stores are passed in
+               return false;   
        }
        IRB.CreateCall(OnAccessFunc, IRB.CreatePointerCast(Addr, Addr->getType()));
        if (IsWrite) NumInstrumentedWrites++;
@@ -438,7 +475,8 @@ bool CDSPass::instrumentLoadOrStore(Instruction *I,
 
 bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
        IRBuilder<> IRB(I);
-       // LLVMContext &Ctx = IRB.getContext();
+
+       // errs() << "instrumenting: " << *I << "\n";
 
        if (auto *CI = dyn_cast<CallInst>(I)) {
                return instrumentAtomicCall(CI, DL);
@@ -449,6 +487,9 @@ bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
        if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
                Value *Addr = LI->getPointerOperand();
                int Idx=getMemoryAccessFuncIndex(Addr, DL);
+               if (Idx < 0)
+                       return false;
+
                int atomic_order_index = getAtomicOrderIndex(LI->getOrdering());
                Value *order = ConstantInt::get(OrdTy, atomic_order_index);
                Value *args[] = {Addr, order, position};
@@ -457,6 +498,9 @@ bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
        } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
                Value *Addr = SI->getPointerOperand();
                int Idx=getMemoryAccessFuncIndex(Addr, DL);
+               if (Idx < 0)
+                       return false;
+
                int atomic_order_index = getAtomicOrderIndex(SI->getOrdering());
                Value *val = SI->getValueOperand();
                Value *order = ConstantInt::get(OrdTy, atomic_order_index);
@@ -466,6 +510,9 @@ bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
        } else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) {
                Value *Addr = RMWI->getPointerOperand();
                int Idx=getMemoryAccessFuncIndex(Addr, DL);
+               if (Idx < 0)
+                       return false;
+
                int atomic_order_index = getAtomicOrderIndex(RMWI->getOrdering());
                Value *val = RMWI->getValOperand();
                Value *order = ConstantInt::get(OrdTy, atomic_order_index);
@@ -477,6 +524,8 @@ bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
 
                Value *Addr = CASI->getPointerOperand();
                int Idx=getMemoryAccessFuncIndex(Addr, DL);
+               if (Idx < 0)
+                       return false;
 
                const unsigned ByteSize = 1U << Idx;
                const unsigned BitSize = ByteSize * 8;
@@ -530,11 +579,17 @@ bool CDSPass::isAtomicCall(Instruction *I) {
                        return false;
 
                StringRef funName = fun->getName();
+
                // todo: come up with better rules for function name checking
-               if ( funName.contains("atomic_") ) {
-                       return true;
-               } else if (funName.contains("atomic") ) {
-                       return true;
+               for (StringRef name : AtomicFuncNames) {
+                       if ( funName.contains(name) ) 
+                               return true;
+               }
+               
+               for (StringRef PartialName : PartialAtomicFuncNames) {
+                       if (funName.contains(PartialName) && 
+                                       funName.contains("atomic") )
+                               return true;
                }
        }
 
@@ -559,6 +614,7 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
 
        // the pointer to the address is always the first argument
        Value *OrigPtr = parameters[0];
+
        int Idx = getMemoryAccessFuncIndex(OrigPtr, DL);
        if (Idx < 0)
                return false;
@@ -598,13 +654,16 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
 
                return true;
        } else if (funName.contains("atomic") && 
-                                       funName.contains("load")) {
+                                       funName.contains("load") ) {
                // does this version of call always have an atomic order as an argument?
                Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
                Value *order = IRB.CreateBitOrPointerCast(parameters[1], OrdTy);
                Value *args[] = {ptr, order, position};
 
-               //Instruction* funcInst=CallInst::Create(CDSAtomicLoad[Idx], args);
+               if (!CI->getType()->isPointerTy()) {
+                       return false;   
+               } 
+
                CallInst *funcInst = IRB.CreateCall(CDSAtomicLoad[Idx], args);
                Value *RetVal = IRB.CreateIntToPtr(funcInst, CI->getType());
 
@@ -634,13 +693,13 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
 
                return true;
        } else if (funName.contains("atomic") && 
-                                       funName.contains("EEEE5store")) {
+                                       funName.contains("EEEE5store") ) {
                // does this version of call always have an atomic order as an argument?
                Value *OrigVal = parameters[1];
 
                Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
                Value *val = IRB.CreatePointerCast(OrigVal, Ty);
-               Value *order = IRB.CreateBitOrPointerCast(parameters[1], OrdTy);
+               Value *order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
                Value *args[] = {ptr, val, order, position};
 
                Instruction* funcInst = CallInst::Create(CDSAtomicStore[Idx], args);
@@ -669,7 +728,7 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
                else if ( funName.contains("atomic_exchange") )
                        op = AtomicRMWInst::Xchg;
                else {
-                       errs() << "Unknown atomic read modify write operation\n";
+                       errs() << "Unknown atomic read-modify-write operation\n";
                        return false;
                }
 
@@ -726,7 +785,7 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
                ReplaceInstWithInst(CI, funcInst);
 
                return true;
-       } else if ( funName.contains("compare_exchange_strong") || 
+       } else if ( funName.contains("compare_exchange_strong") ||
                                funName.contains("compare_exchange_weak") ) {
                Value *Addr = IRB.CreatePointerCast(OrigPtr, PtrTy);
                Value *CmpOperand = IRB.CreatePointerCast(parameters[1], PtrTy);
@@ -760,7 +819,10 @@ int CDSPass::getMemoryAccessFuncIndex(Value *Addr,
                return -1;
        }
        size_t Idx = countTrailingZeros(TypeSize / 8);
-       assert(Idx < kNumberOfAccessSizes);
+       //assert(Idx < kNumberOfAccessSizes);
+       if (Idx >= kNumberOfAccessSizes) {
+               return -1;
+       }
        return Idx;
 }
 
@@ -772,6 +834,13 @@ static void registerCDSPass(const PassManagerBuilder &,
                                                        legacy::PassManagerBase &PM) {
        PM.add(new CDSPass());
 }
+
+/* Enable the pass when opt level is greater than 0 */
+static RegisterStandardPasses 
+       RegisterMyPass1(PassManagerBuilder::EP_OptimizerLast,
+registerCDSPass);
+
+/* Enable the pass when opt level is 0 */
 static RegisterStandardPasses 
-       RegisterMyPass(PassManagerBuilder::EP_OptimizerLast,
+       RegisterMyPass2(PassManagerBuilder::EP_EnabledOnOptLevel0,
 registerCDSPass);