Fix a bug related to atomic bool
[c11llvm.git] / CDSPass.cpp
index 432d98f2d0cd0c23e4b31c148f2fc86888320058..108d4deadd54bceb020085519ea88db0f9ed4a2f 100644 (file)
@@ -115,6 +115,7 @@ namespace {
                static char ID;
                CDSPass() : FunctionPass(ID) {}
                bool runOnFunction(Function &F) override; 
+               StringRef getPassName() const override;
 
        private:
                void initializeCallbacks(Module &M);
@@ -150,6 +151,10 @@ namespace {
        };
 }
 
+StringRef CDSPass::getPassName() const {
+       return "CDSPass";
+}
+
 static bool isVtableAccess(Instruction *I) {
        if (MDNode *Tag = I->getMetadata(LLVMContext::MD_tbaa))
                return Tag->isTBAAVtableAccess();
@@ -679,8 +684,15 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
 
        // atomic_init; args = {obj, order}
        if (funName.contains("atomic_init")) {
+               Value *OrigVal = parameters[1];
+
                Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
-               Value *val = IRB.CreateBitOrPointerCast(parameters[1], Ty);
+               Value *val;
+               if (OrigVal->getType()->isPtrOrPtrVectorTy())
+                       val = IRB.CreatePointerCast(OrigVal, Ty);
+               else
+                       val = IRB.CreateIntCast(OrigVal, Ty, true);
+
                Value *args[] = {ptr, val, position};
 
                Instruction* funcInst = CallInst::Create(CDSAtomicInit[Idx], args);
@@ -751,7 +763,12 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
                Value *OrigVal = parameters[1];
 
                Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
-               Value *val = IRB.CreatePointerCast(OrigVal, Ty);
+               Value *val;
+               if (OrigVal->getType()->isPtrOrPtrVectorTy())
+                       val = IRB.CreatePointerCast(OrigVal, Ty);
+               else
+                       val = IRB.CreateIntCast(OrigVal, Ty, true);
+
                Value *order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);
                Value *args[] = {ptr, val, order, position};
 
@@ -763,7 +780,12 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
 
        // atomic_fetch_*; args = {obj, val, order}
        if (funName.contains("atomic_fetch_") || 
-                       funName.contains("atomic_exchange") ) {
+               funName.contains("atomic_exchange")) {
+
+               /* TODO: implement stricter function name checking */
+               if (funName.contains("non"))
+                       return false;
+
                bool isExplicit = funName.contains("_explicit");
                Value *OrigVal = parameters[1];
 
@@ -786,7 +808,12 @@ bool CDSPass::instrumentAtomicCall(CallInst *CI, const DataLayout &DL) {
                }
 
                Value *ptr = IRB.CreatePointerCast(OrigPtr, PtrTy);
-               Value *val = IRB.CreatePointerCast(OrigVal, Ty);
+               Value *val;
+               if (OrigVal->getType()->isPtrOrPtrVectorTy())
+                       val = IRB.CreatePointerCast(OrigVal, Ty);
+               else
+                       val = IRB.CreateIntCast(OrigVal, Ty, true);
+
                Value *order;
                if (isExplicit)
                        order = IRB.CreateBitOrPointerCast(parameters[2], OrdTy);