Teach instCombine to remove malloc+free if malloc's only uses are comparisons
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineCalls.cpp
index 43ff58d8ff15bc40112c779e16b908f1eb74d1b4..08a6ff41ebb23b1c9b82c06de95046ac3da26f4e 100644 (file)
@@ -59,29 +59,32 @@ static unsigned EnforceKnownAlignment(Value *V,
       // Treat this like a bitcast.
       return EnforceKnownAlignment(U->getOperand(0), Align, PrefAlign);
     }
-    break;
+    return Align;
+  }
+  case Instruction::Alloca: {
+    AllocaInst *AI = cast<AllocaInst>(V);
+    // If there is a requested alignment and if this is an alloca, round up.
+    if (AI->getAlignment() >= PrefAlign)
+      return AI->getAlignment();
+    AI->setAlignment(PrefAlign);
+    return PrefAlign;
   }
   }
 
   if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) {
     // If there is a large requested alignment and we can, bump up the alignment
     // of the global.
-    if (!GV->isDeclaration()) {
-      if (GV->getAlignment() >= PrefAlign)
-        Align = GV->getAlignment();
-      else {
-        GV->setAlignment(PrefAlign);
-        Align = PrefAlign;
-      }
-    }
-  } else if (AllocaInst *AI = dyn_cast<AllocaInst>(V)) {
-    // If there is a requested alignment and if this is an alloca, round up.
-    if (AI->getAlignment() >= PrefAlign)
-      Align = AI->getAlignment();
-    else {
-      AI->setAlignment(PrefAlign);
-      Align = PrefAlign;
-    }
+    if (GV->isDeclaration()) return Align;
+    
+    if (GV->getAlignment() >= PrefAlign)
+      return GV->getAlignment();
+    // We can only increase the alignment of the global if it has no alignment
+    // specified or if it is not assigned a section.  If it is assigned a
+    // section, the global could be densely packed with other objects in the
+    // section, increasing the alignment could cause padding issues.
+    if (!GV->hasSection() || GV->getAlignment() == 0)
+      GV->setAlignment(PrefAlign);
+    return GV->getAlignment();
   }
 
   return Align;
@@ -136,8 +139,14 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) {
     return 0;  // If not 1/2/4/8 bytes, exit.
   
   // Use an integer load+store unless we can find something better.
-  Type *NewPtrTy =
-            PointerType::getUnqual(IntegerType::get(MI->getContext(), Size<<3));
+  unsigned SrcAddrSp =
+    cast<PointerType>(MI->getOperand(2)->getType())->getAddressSpace();
+  unsigned DstAddrSp =
+    cast<PointerType>(MI->getOperand(1)->getType())->getAddressSpace();
+
+  const IntegerType* IntType = IntegerType::get(MI->getContext(), Size<<3);
+  Type *NewSrcPtrTy = PointerType::get(IntType, SrcAddrSp);
+  Type *NewDstPtrTy = PointerType::get(IntType, DstAddrSp);
   
   // Memcpy forces the use of i8* for the source and destination.  That means
   // that if you're using memcpy to move one double around, you'll get a cast
@@ -167,8 +176,10 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) {
           break;
       }
       
-      if (SrcETy->isSingleValueType())
-        NewPtrTy = PointerType::getUnqual(SrcETy);
+      if (SrcETy->isSingleValueType()) {
+        NewSrcPtrTy = PointerType::get(SrcETy, SrcAddrSp);
+        NewDstPtrTy = PointerType::get(SrcETy, DstAddrSp);
+      }
     }
   }
   
@@ -178,11 +189,12 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) {
   SrcAlign = std::max(SrcAlign, CopyAlign);
   DstAlign = std::max(DstAlign, CopyAlign);
   
-  Value *Src = Builder->CreateBitCast(MI->getOperand(2), NewPtrTy);
-  Value *Dest = Builder->CreateBitCast(MI->getOperand(1), NewPtrTy);
-  Instruction *L = new LoadInst(Src, "tmp", false, SrcAlign);
+  Value *Src = Builder->CreateBitCast(MI->getOperand(2), NewSrcPtrTy);
+  Value *Dest = Builder->CreateBitCast(MI->getOperand(1), NewDstPtrTy);
+  Instruction *L = new LoadInst(Src, "tmp", MI->isVolatile(), SrcAlign);
   InsertNewInstBefore(L, *MI);
-  InsertNewInstBefore(new StoreInst(L, Dest, false, DstAlign), *MI);
+  InsertNewInstBefore(new StoreInst(L, Dest, MI->isVolatile(), DstAlign),
+                      *MI);
 
   // Set the size of the copy to 0, it will be deleted on the next iteration.
   MI->setOperand(3, Constant::getNullValue(MemOpLength->getType()));
@@ -238,6 +250,8 @@ Instruction *InstCombiner::SimplifyMemSet(MemSetInst *MI) {
 Instruction *InstCombiner::visitCallInst(CallInst &CI) {
   if (isFreeCall(&CI))
     return visitFree(CI);
+  if (isMalloc(&CI))
+    return visitMalloc(CI);
 
   // If the caller function is nounwind, mark the call as nounwind, even if the
   // callee isn't.
@@ -275,10 +289,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
         if (GVSrc->isConstant()) {
           Module *M = CI.getParent()->getParent()->getParent();
           Intrinsic::ID MemCpyID = Intrinsic::memcpy;
-          const Type *Tys[1];
-          Tys[0] = CI.getOperand(3)->getType();
-          CI.setOperand(0, 
-                        Intrinsic::getDeclaration(M, MemCpyID, Tys, 1));
+          const Type *Tys[3] = { CI.getOperand(1)->getType(),
+                                 CI.getOperand(2)->getType(),
+                                 CI.getOperand(3)->getType() };
+          CI.setCalledFunction( 
+                        Intrinsic::getDeclaration(M, MemCpyID, Tys, 3));
           Changed = true;
         }
     }
@@ -516,7 +531,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
       // X + 0 -> {X, false}
       if (RHS->isZero()) {
         Constant *V[] = {
-          UndefValue::get(II->getOperand(0)->getType()),
+          UndefValue::get(II->getCalledValue()->getType()),
           ConstantInt::getFalse(II->getContext())
         };
         Constant *Struct = ConstantStruct::get(II->getContext(), V, 2, false);
@@ -751,134 +766,41 @@ static bool isSafeToEliminateVarargsCast(const CallSite CS,
   return true;
 }
 
+namespace {
+class InstCombineFortifiedLibCalls : public SimplifyFortifiedLibCalls {
+  InstCombiner *IC;
+protected:
+  void replaceCall(Value *With) {
+    NewInstruction = IC->ReplaceInstUsesWith(*CI, With);
+  }
+  bool isFoldable(unsigned SizeCIOp, unsigned SizeArgOp, bool isString) const {
+    if (ConstantInt *SizeCI = dyn_cast<ConstantInt>(CI->getOperand(SizeCIOp))) {
+      if (SizeCI->isAllOnesValue())
+        return true;
+      if (isString)
+        return SizeCI->getZExtValue() >=
+               GetStringLength(CI->getOperand(SizeArgOp));
+      if (ConstantInt *Arg = dyn_cast<ConstantInt>(CI->getOperand(SizeArgOp)))
+        return SizeCI->getZExtValue() >= Arg->getZExtValue();
+    }
+    return false;
+  }
+public:
+  InstCombineFortifiedLibCalls(InstCombiner *IC) : IC(IC), NewInstruction(0) { }
+  Instruction *NewInstruction;
+};
+} // end anonymous namespace
+
 // Try to fold some different type of calls here.
 // Currently we're only working with the checking functions, memcpy_chk, 
 // mempcpy_chk, memmove_chk, memset_chk, strcpy_chk, stpcpy_chk, strncpy_chk,
 // strcat_chk and strncat_chk.
 Instruction *InstCombiner::tryOptimizeCall(CallInst *CI, const TargetData *TD) {
   if (CI->getCalledFunction() == 0) return 0;
-  
-  StringRef Name = CI->getCalledFunction()->getName();
-  BasicBlock *BB = CI->getParent();
-  IRBuilder<> B(CI->getParent()->getContext());
-  
-  // Set the builder to the instruction after the call.
-  B.SetInsertPoint(BB, CI);
-
-  if (Name == "__memcpy_chk") {
-    ConstantInt *SizeCI = dyn_cast<ConstantInt>(CI->getOperand(4));
-    if (!SizeCI)
-      return 0;
-    ConstantInt *SizeArg = dyn_cast<ConstantInt>(CI->getOperand(3));
-    if (!SizeArg)
-      return 0;
-    if (SizeCI->isAllOnesValue() ||
-        SizeCI->getZExtValue() <= SizeArg->getZExtValue()) {
-      EmitMemCpy(CI->getOperand(1), CI->getOperand(2), CI->getOperand(3),
-                 1, B, TD);
-      return ReplaceInstUsesWith(*CI, CI->getOperand(1));
-    }
-    return 0;
-  }
 
-  // Should be similar to memcpy.
-  if (Name == "__mempcpy_chk") {
-    return 0;
-  }
-
-  if (Name == "__memmove_chk") {
-    ConstantInt *SizeCI = dyn_cast<ConstantInt>(CI->getOperand(4));
-    if (!SizeCI)
-      return 0;
-    ConstantInt *SizeArg = dyn_cast<ConstantInt>(CI->getOperand(3));
-    if (!SizeArg)
-      return 0;
-    if (SizeCI->isAllOnesValue() ||
-        SizeCI->getZExtValue() <= SizeArg->getZExtValue()) {
-      EmitMemMove(CI->getOperand(1), CI->getOperand(2), CI->getOperand(3),
-                  1, B, TD);
-      return ReplaceInstUsesWith(*CI, CI->getOperand(1));
-    }
-    return 0;
-  }
-
-  if (Name == "__memset_chk") {
-    ConstantInt *SizeCI = dyn_cast<ConstantInt>(CI->getOperand(4));
-    if (!SizeCI)
-      return 0;
-    ConstantInt *SizeArg = dyn_cast<ConstantInt>(CI->getOperand(3));
-    if (!SizeArg)
-      return 0;
-    if (SizeCI->isAllOnesValue() ||
-        SizeCI->getZExtValue() <= SizeArg->getZExtValue()) {
-      Value *Val = B.CreateIntCast(CI->getOperand(2), B.getInt8Ty(),
-                                   false);
-      EmitMemSet(CI->getOperand(1), Val,  CI->getOperand(3), B, TD);
-      return ReplaceInstUsesWith(*CI, CI->getOperand(1));
-    }
-    return 0;
-  }
-
-  if (Name == "__strcpy_chk") {
-    ConstantInt *SizeCI = dyn_cast<ConstantInt>(CI->getOperand(3));
-    if (!SizeCI)
-      return 0;
-    // If a) we don't have any length information, or b) we know this will
-    // fit then just lower to a plain strcpy. Otherwise we'll keep our
-    // strcpy_chk call which may fail at runtime if the size is too long.
-    // TODO: It might be nice to get a maximum length out of the possible
-    // string lengths for varying.
-    if (SizeCI->isAllOnesValue() ||
-      SizeCI->getZExtValue() >= GetStringLength(CI->getOperand(2))) {
-      Value *Ret = EmitStrCpy(CI->getOperand(1), CI->getOperand(2), B, TD);
-      return ReplaceInstUsesWith(*CI, Ret);
-    }
-    return 0;
-  }
-
-  // Should be similar to strcpy.
-  if (Name == "__stpcpy_chk") {
-    ConstantInt *SizeCI = dyn_cast<ConstantInt>(CI->getOperand(3));
-    if (!SizeCI)
-      return 0;
-    // If a) we don't have any length information, or b) we know this will
-    // fit then just lower to a plain stpcpy. Otherwise we'll keep our
-    // stpcpy_chk call which may fail at runtime if the size is too long.
-    // TODO: It might be nice to get a maximum length out of the possible
-    // string lengths for varying.
-    if (SizeCI->isAllOnesValue() ||
-        SizeCI->getZExtValue() >= GetStringLength(CI->getOperand(2))) {
-      Value *Ret = EmitStpCpy(CI->getOperand(1), CI->getOperand(2), B, TD);
-      return ReplaceInstUsesWith(*CI, Ret);
-    }
-    return 0;
-  }
-
-  if (Name == "__strncpy_chk") {
-    ConstantInt *SizeCI = dyn_cast<ConstantInt>(CI->getOperand(4));
-    if (!SizeCI)
-      return 0;
-    ConstantInt *SizeArg = dyn_cast<ConstantInt>(CI->getOperand(3));
-    if (!SizeArg)
-      return 0;
-    if (SizeCI->isAllOnesValue() ||
-        SizeCI->getZExtValue() <= SizeArg->getZExtValue()) {
-      Value *Ret = EmitStrNCpy(CI->getOperand(1), CI->getOperand(2),
-                               CI->getOperand(3), B, TD);
-      return ReplaceInstUsesWith(*CI, Ret);
-    }
-    return 0; 
-  }
-
-  if (Name == "__strcat_chk") {
-    return 0;
-  }
-
-  if (Name == "__strncat_chk") {
-    return 0;
-  }
-
-  return 0;
+  InstCombineFortifiedLibCalls Simplifier(this);
+  Simplifier.fold(CI, TD);
+  return Simplifier.NewInstruction;
 }
 
 // visitCallSite - Improvements for call and invoke instructions.
@@ -913,7 +835,7 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) {
       
       // We cannot remove an invoke, because it would change the CFG, just
       // change the callee to a null pointer.
-      cast<InvokeInst>(OldCall)->setOperand(0,
+      cast<InvokeInst>(OldCall)->setCalledFunction(
                                     Constant::getNullValue(CalleeF->getType()));
       return 0;
     }