Add instrumentation for memory intrinsic instructions
authorweiyu <weiyuluo1232@gmail.com>
Tue, 26 Nov 2019 01:39:15 +0000 (17:39 -0800)
committerweiyu <weiyuluo1232@gmail.com>
Tue, 26 Nov 2019 01:39:15 +0000 (17:39 -0800)
CDSPass.cpp

index ecbb9576548d0d1fed211fd6e6e924012358292e..cd5c8927ef607da86f1d6ec840b8e1ef58325453 100644 (file)
@@ -29,6 +29,7 @@
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instructions.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/LegacyPassManager.h"
 #include "llvm/IR/Module.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/LegacyPassManager.h"
 #include "llvm/IR/Module.h"
@@ -172,6 +173,7 @@ namespace {
                void initializeCallbacks(Module &M);
                bool instrumentLoadOrStore(Instruction *I, const DataLayout &DL);
                bool instrumentVolatile(Instruction *I, const DataLayout &DL);
                void initializeCallbacks(Module &M);
                bool instrumentLoadOrStore(Instruction *I, const DataLayout &DL);
                bool instrumentVolatile(Instruction *I, const DataLayout &DL);
+               bool instrumentMemIntrinsic(Instruction *I);
                bool isAtomicCall(Instruction *I);
                bool instrumentAtomic(Instruction *I, const DataLayout &DL);
                bool instrumentAtomicCall(CallInst *CI, const DataLayout &DL);
                bool isAtomicCall(Instruction *I);
                bool instrumentAtomic(Instruction *I, const DataLayout &DL);
                bool instrumentAtomicCall(CallInst *CI, const DataLayout &DL);
@@ -195,7 +197,8 @@ namespace {
                Function * CDSAtomicCAS_V1[kNumberOfAccessSizes];
                Function * CDSAtomicCAS_V2[kNumberOfAccessSizes];
                Function * CDSAtomicThreadFence;
                Function * CDSAtomicCAS_V1[kNumberOfAccessSizes];
                Function * CDSAtomicCAS_V2[kNumberOfAccessSizes];
                Function * CDSAtomicThreadFence;
-               Function * CDSCtorFunction;
+               Function * MemmoveFn, * MemcpyFn, * MemsetFn;
+               // Function * CDSCtorFunction;
 
                std::vector<StringRef> AtomicFuncNames;
                std::vector<StringRef> PartialAtomicFuncNames;
 
                std::vector<StringRef> AtomicFuncNames;
                std::vector<StringRef> PartialAtomicFuncNames;
@@ -208,8 +211,12 @@ StringRef CDSPass::getPassName() const {
 
 void CDSPass::initializeCallbacks(Module &M) {
        LLVMContext &Ctx = M.getContext();
 
 void CDSPass::initializeCallbacks(Module &M) {
        LLVMContext &Ctx = M.getContext();
+       AttributeList Attr;
+       Attr = Attr.addAttribute(Ctx, AttributeList::FunctionIndex,
+                       Attribute::NoUnwind);
 
        Type * Int1Ty = Type::getInt1Ty(Ctx);
 
        Type * Int1Ty = Type::getInt1Ty(Ctx);
+       Type * Int32Ty = Type::getInt32Ty(Ctx);
        OrdTy = Type::getInt32Ty(Ctx);
 
        Int8PtrTy  = Type::getInt8PtrTy(Ctx);
        OrdTy = Type::getInt32Ty(Ctx);
 
        Int8PtrTy  = Type::getInt8PtrTy(Ctx);
@@ -221,10 +228,10 @@ void CDSPass::initializeCallbacks(Module &M) {
 
        CDSFuncEntry = checkCDSPassInterfaceFunction(
                                                M.getOrInsertFunction("cds_func_entry", 
 
        CDSFuncEntry = checkCDSPassInterfaceFunction(
                                                M.getOrInsertFunction("cds_func_entry", 
-                                               VoidTy, Int8PtrTy));
+                                               Attr, VoidTy, Int8PtrTy));
        CDSFuncExit = checkCDSPassInterfaceFunction(
                                                M.getOrInsertFunction("cds_func_exit", 
        CDSFuncExit = checkCDSPassInterfaceFunction(
                                                M.getOrInsertFunction("cds_func_exit", 
-                                               VoidTy, Int8PtrTy));
+                                               Attr, VoidTy, Int8PtrTy));
 
        // Get the function to call from our untime library.
        for (unsigned i = 0; i < kNumberOfAccessSizes; i++) {
 
        // Get the function to call from our untime library.
        for (unsigned i = 0; i < kNumberOfAccessSizes; i++) {
@@ -248,24 +255,24 @@ void CDSPass::initializeCallbacks(Module &M) {
                SmallString<32> AtomicStoreName("cds_atomic_store" + BitSizeStr);
 
                CDSLoad[i]  = checkCDSPassInterfaceFunction(
                SmallString<32> AtomicStoreName("cds_atomic_store" + BitSizeStr);
 
                CDSLoad[i]  = checkCDSPassInterfaceFunction(
-                                                       M.getOrInsertFunction(LoadName, VoidTy, PtrTy));
+                                                       M.getOrInsertFunction(LoadName, Attr, VoidTy, PtrTy));
                CDSStore[i] = checkCDSPassInterfaceFunction(
                CDSStore[i] = checkCDSPassInterfaceFunction(
-                                                       M.getOrInsertFunction(StoreName, VoidTy, PtrTy));
+                                                       M.getOrInsertFunction(StoreName, Attr, VoidTy, PtrTy));
                CDSVolatileLoad[i]  = checkCDSPassInterfaceFunction(
                                                                M.getOrInsertFunction(VolatileLoadName,
                CDSVolatileLoad[i]  = checkCDSPassInterfaceFunction(
                                                                M.getOrInsertFunction(VolatileLoadName,
-                                                               Ty, PtrTy, Int8PtrTy));
+                                                               Attr, Ty, PtrTy, Int8PtrTy));
                CDSVolatileStore[i] = checkCDSPassInterfaceFunction(
                                                                M.getOrInsertFunction(VolatileStoreName, 
                CDSVolatileStore[i] = checkCDSPassInterfaceFunction(
                                                                M.getOrInsertFunction(VolatileStoreName, 
-                                                               VoidTy, PtrTy, Ty, Int8PtrTy));
+                                                               Attr, VoidTy, PtrTy, Ty, Int8PtrTy));
                CDSAtomicInit[i] = checkCDSPassInterfaceFunction(
                                                        M.getOrInsertFunction(AtomicInitName, 
                CDSAtomicInit[i] = checkCDSPassInterfaceFunction(
                                                        M.getOrInsertFunction(AtomicInitName, 
-                                                       VoidTy, PtrTy, Ty, Int8PtrTy));
+                                                       Attr, VoidTy, PtrTy, Ty, Int8PtrTy));
                CDSAtomicLoad[i]  = checkCDSPassInterfaceFunction(
                                                                M.getOrInsertFunction(AtomicLoadName, 
                CDSAtomicLoad[i]  = checkCDSPassInterfaceFunction(
                                                                M.getOrInsertFunction(AtomicLoadName, 
-                                                               Ty, PtrTy, OrdTy, Int8PtrTy));
+                                                               Attr, Ty, PtrTy, OrdTy, Int8PtrTy));
                CDSAtomicStore[i] = checkCDSPassInterfaceFunction(
                                                                M.getOrInsertFunction(AtomicStoreName, 
                CDSAtomicStore[i] = checkCDSPassInterfaceFunction(
                                                                M.getOrInsertFunction(AtomicStoreName, 
-                                                               VoidTy, PtrTy, Ty, OrdTy, Int8PtrTy));
+                                                               Attr, VoidTy, PtrTy, Ty, OrdTy, Int8PtrTy));
 
                for (int op = AtomicRMWInst::FIRST_BINOP; 
                        op <= AtomicRMWInst::LAST_BINOP; ++op) {
 
                for (int op = AtomicRMWInst::FIRST_BINOP; 
                        op <= AtomicRMWInst::LAST_BINOP; ++op) {
@@ -290,7 +297,7 @@ void CDSPass::initializeCallbacks(Module &M) {
                        SmallString<32> AtomicRMWName("cds_atomic" + NamePart + BitSizeStr);
                        CDSAtomicRMW[op][i] = checkCDSPassInterfaceFunction(
                                                                        M.getOrInsertFunction(AtomicRMWName, 
                        SmallString<32> AtomicRMWName("cds_atomic" + NamePart + BitSizeStr);
                        CDSAtomicRMW[op][i] = checkCDSPassInterfaceFunction(
                                                                        M.getOrInsertFunction(AtomicRMWName, 
-                                                                       Ty, PtrTy, Ty, OrdTy, Int8PtrTy));
+                                                                       Attr, Ty, PtrTy, Ty, OrdTy, Int8PtrTy));
                }
 
                // only supportes strong version
                }
 
                // only supportes strong version
@@ -298,15 +305,24 @@ void CDSPass::initializeCallbacks(Module &M) {
                SmallString<32> AtomicCASName_V2("cds_atomic_compare_exchange" + BitSizeStr + "_v2");
                CDSAtomicCAS_V1[i] = checkCDSPassInterfaceFunction(
                                                                M.getOrInsertFunction(AtomicCASName_V1, 
                SmallString<32> AtomicCASName_V2("cds_atomic_compare_exchange" + BitSizeStr + "_v2");
                CDSAtomicCAS_V1[i] = checkCDSPassInterfaceFunction(
                                                                M.getOrInsertFunction(AtomicCASName_V1, 
-                                                               Ty, PtrTy, Ty, Ty, OrdTy, OrdTy, Int8PtrTy));
+                                                               Attr, Ty, PtrTy, Ty, Ty, OrdTy, OrdTy, Int8PtrTy));
                CDSAtomicCAS_V2[i] = checkCDSPassInterfaceFunction(
                                                                M.getOrInsertFunction(AtomicCASName_V2, 
                CDSAtomicCAS_V2[i] = checkCDSPassInterfaceFunction(
                                                                M.getOrInsertFunction(AtomicCASName_V2, 
-                                                               Int1Ty, PtrTy, PtrTy, Ty, OrdTy, OrdTy, Int8PtrTy));
+                                                               Attr, Int1Ty, PtrTy, PtrTy, Ty, OrdTy, OrdTy, Int8PtrTy));
        }
 
        CDSAtomicThreadFence = checkCDSPassInterfaceFunction(
        }
 
        CDSAtomicThreadFence = checkCDSPassInterfaceFunction(
-                                                               M.getOrInsertFunction("cds_atomic_thread_fence", 
-                                                               VoidTy, OrdTy, Int8PtrTy));
+                       M.getOrInsertFunction("cds_atomic_thread_fence", Attr, VoidTy, OrdTy, Int8PtrTy));
+
+       MemmoveFn = checkCDSPassInterfaceFunction(
+                                       M.getOrInsertFunction("memmove", Attr, Int8PtrTy, Int8PtrTy,
+                                       Int8PtrTy, IntPtrTy));
+       MemcpyFn = checkCDSPassInterfaceFunction(
+                                       M.getOrInsertFunction("memcpy", Attr, Int8PtrTy, Int8PtrTy,
+                                       Int8PtrTy, IntPtrTy));
+       MemsetFn = checkCDSPassInterfaceFunction(
+                                       M.getOrInsertFunction("memset", Attr, Int8PtrTy, Int8PtrTy,
+                                       Int32Ty, IntPtrTy));
 }
 
 bool CDSPass::doInitialization(Module &M) {
 }
 
 bool CDSPass::doInitialization(Module &M) {
@@ -459,9 +475,9 @@ bool CDSPass::runOnFunction(Function &F) {
        SmallVector<Instruction*, 8> LocalLoadsAndStores;
        SmallVector<Instruction*, 8> VolatileLoadsAndStores;
        SmallVector<Instruction*, 8> AtomicAccesses;
        SmallVector<Instruction*, 8> LocalLoadsAndStores;
        SmallVector<Instruction*, 8> VolatileLoadsAndStores;
        SmallVector<Instruction*, 8> AtomicAccesses;
+       SmallVector<Instruction*, 8> MemIntrinCalls;
 
        bool Res = false;
 
        bool Res = false;
-       bool HasCall = false;
        bool HasAtomic = false;
        bool HasVolatile = false;
        const DataLayout &DL = F.getParent()->getDataLayout();
        bool HasAtomic = false;
        bool HasVolatile = false;
        const DataLayout &DL = F.getParent()->getDataLayout();
@@ -482,15 +498,15 @@ bool CDSPass::runOnFunction(Function &F) {
                                } else
                                        LocalLoadsAndStores.push_back(&Inst);
                        } else if (isa<CallInst>(Inst) || isa<InvokeInst>(Inst)) {
                                } else
                                        LocalLoadsAndStores.push_back(&Inst);
                        } else if (isa<CallInst>(Inst) || isa<InvokeInst>(Inst)) {
-                               /* TODO: To be added
-                               if (CallInst *CI = dyn_cast<CallInst>(&Inst))
-                                       maybeMarkSanitizerLibraryCallNoBuiltin(CI, TLI);
                                if (isa<MemIntrinsic>(Inst))
                                        MemIntrinCalls.push_back(&Inst);
                                if (isa<MemIntrinsic>(Inst))
                                        MemIntrinCalls.push_back(&Inst);
-                               HasCalls = true;
+
+                               /*if (CallInst *CI = dyn_cast<CallInst>(&Inst))
+                                       maybeMarkSanitizerLibraryCallNoBuiltin(CI, TLI);
+                               */
+
                                chooseInstructionsToInstrument(LocalLoadsAndStores, AllLoadsAndStores,
                                        DL);
                                chooseInstructionsToInstrument(LocalLoadsAndStores, AllLoadsAndStores,
                                        DL);
-                               */
                        }
                }
 
                        }
                }
 
@@ -509,11 +525,9 @@ bool CDSPass::runOnFunction(Function &F) {
                Res |= instrumentAtomic(Inst, DL);
        }
 
                Res |= instrumentAtomic(Inst, DL);
        }
 
-       /* TODO
        for (auto Inst : MemIntrinCalls) {
                Res |= instrumentMemIntrinsic(Inst);
        }
        for (auto Inst : MemIntrinCalls) {
                Res |= instrumentMemIntrinsic(Inst);
        }
-       */
 
        // Only instrument functions that contain atomics or volatiles
        if (Res && ( HasAtomic || HasVolatile) ) {
 
        // Only instrument functions that contain atomics or volatiles
        if (Res && ( HasAtomic || HasVolatile) ) {
@@ -633,6 +647,26 @@ bool CDSPass::instrumentVolatile(Instruction * I, const DataLayout &DL) {
        return true;
 }
 
        return true;
 }
 
+bool CDSPass::instrumentMemIntrinsic(Instruction *I) {
+       IRBuilder<> IRB(I);
+       if (MemSetInst *M = dyn_cast<MemSetInst>(I)) {
+               IRB.CreateCall(
+                       MemsetFn,
+                       {IRB.CreatePointerCast(M->getArgOperand(0), IRB.getInt8PtrTy()),
+                        IRB.CreateIntCast(M->getArgOperand(1), IRB.getInt32Ty(), false),
+                        IRB.CreateIntCast(M->getArgOperand(2), IntPtrTy, false)});
+               I->eraseFromParent();
+       } else if (MemTransferInst *M = dyn_cast<MemTransferInst>(I)) {
+               IRB.CreateCall(
+                       isa<MemCpyInst>(M) ? MemcpyFn : MemmoveFn,
+                       {IRB.CreatePointerCast(M->getArgOperand(0), IRB.getInt8PtrTy()),
+                        IRB.CreatePointerCast(M->getArgOperand(1), IRB.getInt8PtrTy()),
+                        IRB.CreateIntCast(M->getArgOperand(2), IntPtrTy, false)});
+               I->eraseFromParent();
+       }
+       return false;
+}
+
 bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
        IRBuilder<> IRB(I);
 
 bool CDSPass::instrumentAtomic(Instruction * I, const DataLayout &DL) {
        IRBuilder<> IRB(I);
 
@@ -738,7 +772,7 @@ bool CDSPass::isAtomicCall(Instruction *I) {
 
                StringRef funName = fun->getName();
 
 
                StringRef funName = fun->getName();
 
-               // todo: come up with better rules for function name checking
+               // TODO: come up with better rules for function name checking
                for (StringRef name : AtomicFuncNames) {
                        if ( funName.contains(name) ) 
                                return true;
                for (StringRef name : AtomicFuncNames) {
                        if ( funName.contains(name) ) 
                                return true;