Don't generate redundant casts of constant values when lowering calls to
[oota-llvm.git] / lib / CodeGen / IntrinsicLowering.cpp
index 56b1736b4156f4e4e40ea04acfca5f10d74895f7..ce68c29612cbb5cf1d3fe892bdc5894643238b21 100644 (file)
 #include "llvm/ADT/STLExtras.h"
 using namespace llvm;
 
+// Return the integer value Val zero-extended or truncated (if necessary) to
+// type ITy. Any new instructions are inserted at InsertBefore.
+template<typename InsertType>
+static Value *getZExtOrTrunc(Value *Val, const IntegerType *ITy,
+                             InsertType InsertPoint) {
+  const IntegerType *ValTy = cast<IntegerType>(Val->getType());
+  if (ValTy == ITy)
+    return Val;
+  Constant *CVal = dyn_cast<Constant>(Val);
+  if (ValTy->getBitWidth() < ITy->getBitWidth()) {
+    if (CVal)
+      return ConstantExpr::getZExt(CVal, ITy);
+    return new ZExtInst(Val, ITy, "", InsertPoint);
+  } else {
+    if (CVal)
+      return ConstantExpr::getTrunc(CVal, ITy);
+    return new TruncInst(Val, ITy, "", InsertPoint);
+  }
+}
+
 template <class ArgIt>
 static void EnsureFunctionExists(Module &M, const char *Name,
                                  ArgIt ArgBegin, ArgIt ArgEnd,
@@ -504,7 +524,6 @@ static Instruction *LowerPartSet(CallInst *CI) {
     // Get some types we need
     const IntegerType* ValTy = cast<IntegerType>(Val->getType());
     const IntegerType* RepTy = cast<IntegerType>(Rep->getType());
-    uint32_t ValBits = ValTy->getBitWidth();
     uint32_t RepBits = RepTy->getBitWidth();
 
     // Constant Definitions
@@ -532,13 +551,7 @@ static Instruction *LowerPartSet(CallInst *CI) {
     BinaryOperator* NumBits = BinaryOperator::CreateSub(Hi_pn, Lo_pn, "",entry);
     NumBits = BinaryOperator::CreateAdd(NumBits, One, "", entry);
     // Now, convert Lo and Hi to ValTy bit width
-    if (ValBits > 32) {
-      Lo = new ZExtInst(Lo_pn, ValTy, "", entry);
-    } else if (ValBits < 32) {
-      Lo = new TruncInst(Lo_pn, ValTy, "", entry);
-    } else {
-      Lo = Lo_pn;
-    }
+    Lo = getZExtOrTrunc(Lo_pn, ValTy, entry);
     // Determine if the replacement bits are larger than the number of bits we
     // are replacing and deal with it.
     ICmpInst* is_large = 
@@ -560,11 +573,7 @@ static Instruction *LowerPartSet(CallInst *CI) {
     Rep3->reserveOperandSpace(2);
     Rep3->addIncoming(Rep2, large);
     Rep3->addIncoming(Rep, entry);
-    Value* Rep4 = Rep3;
-    if (ValBits > RepBits)
-      Rep4 = new ZExtInst(Rep3, ValTy, "", small);
-    else if (ValBits < RepBits)
-      Rep4 = new TruncInst(Rep3, ValTy, "", small);
+    Value* Rep4 = getZExtOrTrunc(Rep3, ValTy, small);
     BranchInst::Create(result, reverse, is_forward, small);
 
     // BASIC BLOCK: reverse (reverses the bits of the replacement)
@@ -788,14 +797,8 @@ void IntrinsicLowering::LowerIntrinsicCall(CallInst *CI) {
     
   case Intrinsic::memcpy: {
     static Constant *MemcpyFCache = 0;
-    Value *Size = CI->getOperand(3);
-    const Type *IntPtr = TD.getIntPtrType();
-    if (Size->getType()->getPrimitiveSizeInBits() <
-        IntPtr->getPrimitiveSizeInBits())
-      Size = new ZExtInst(Size, IntPtr, "", CI);
-    else if (Size->getType()->getPrimitiveSizeInBits() >
-             IntPtr->getPrimitiveSizeInBits())
-      Size = new TruncInst(Size, IntPtr, "", CI);
+    const IntegerType *IntPtr = TD.getIntPtrType();
+    Value *Size = getZExtOrTrunc(CI->getOperand(3), IntPtr, CI);
     Value *Ops[3];
     Ops[0] = CI->getOperand(1);
     Ops[1] = CI->getOperand(2);
@@ -806,14 +809,8 @@ void IntrinsicLowering::LowerIntrinsicCall(CallInst *CI) {
   }
   case Intrinsic::memmove: {
     static Constant *MemmoveFCache = 0;
-    Value *Size = CI->getOperand(3);
-    const Type *IntPtr = TD.getIntPtrType();
-    if (Size->getType()->getPrimitiveSizeInBits() <
-        IntPtr->getPrimitiveSizeInBits())
-      Size = new ZExtInst(Size, IntPtr, "", CI);
-    else if (Size->getType()->getPrimitiveSizeInBits() >
-             IntPtr->getPrimitiveSizeInBits())
-      Size = new TruncInst(Size, IntPtr, "", CI);
+    const IntegerType *IntPtr = TD.getIntPtrType();
+    Value *Size = getZExtOrTrunc(CI->getOperand(3), IntPtr, CI);
     Value *Ops[3];
     Ops[0] = CI->getOperand(1);
     Ops[1] = CI->getOperand(2);
@@ -824,18 +821,12 @@ void IntrinsicLowering::LowerIntrinsicCall(CallInst *CI) {
   }
   case Intrinsic::memset: {
     static Constant *MemsetFCache = 0;
-    Value *Size = CI->getOperand(3);
-    const Type *IntPtr = TD.getIntPtrType();
-    if (Size->getType()->getPrimitiveSizeInBits() <
-        IntPtr->getPrimitiveSizeInBits())
-      Size = new ZExtInst(Size, IntPtr, "", CI);
-    else if (Size->getType()->getPrimitiveSizeInBits() >
-             IntPtr->getPrimitiveSizeInBits())
-      Size = new TruncInst(Size, IntPtr, "", CI);
+    const IntegerType *IntPtr = TD.getIntPtrType();
+    Value *Size = getZExtOrTrunc(CI->getOperand(3), IntPtr, CI);
     Value *Ops[3];
     Ops[0] = CI->getOperand(1);
     // Extend the amount to i32.
-    Ops[1] = new ZExtInst(CI->getOperand(2), Type::Int32Ty, "", CI);
+    Ops[1] = getZExtOrTrunc(CI->getOperand(2), Type::Int32Ty, CI);
     Ops[2] = Size;
     ReplaceCallWith("memset", CI, Ops, Ops+3, CI->getOperand(1)->getType(),
                     MemsetFCache);