Type safety for Constants.cpp! Some of this is temporary, as I'm planning to push...
[oota-llvm.git] / lib / VMCore / ConstantFold.cpp
index 9fa41b373cd7ee65a3c4873d39fe59255c064bc9..6c392145a504b671a5e7cbfb119c52ccd82ac27f 100644 (file)
@@ -153,26 +153,20 @@ static Constant *FoldBitCast(Constant *V, const Type *DestTy) {
       // Integral -> Integral. This is a no-op because the bit widths must
       // be the same. Consequently, we just fold to V.
       return V;
-    
-    if (DestTy->isFloatingPoint()) {
-      assert((DestTy == Type::DoubleTy || DestTy == Type::FloatTy) && 
-             "Unknown FP type!");
-      return ConstantFP::get(APFloat(CI->getValue()));
-    }
+
+    if (DestTy->isFloatingPoint())
+      return ConstantFP::get(APFloat(CI->getValue(),
+                                     DestTy != Type::PPC_FP128Ty));
+
     // Otherwise, can't fold this (vector?)
     return 0;
   }
-  
+
   // Handle ConstantFP input.
-  if (const ConstantFP *FP = dyn_cast<ConstantFP>(V)) {
+  if (const ConstantFP *FP = dyn_cast<ConstantFP>(V))
     // FP -> Integral.
-    if (DestTy == Type::Int32Ty) {
-      return ConstantInt::get(FP->getValueAPF().bitcastToAPInt());
-    } else {
-      assert(DestTy == Type::Int64Ty && "only support f32/f64 for now!");
-      return ConstantInt::get(FP->getValueAPF().bitcastToAPInt());
-    }
-  }
+    return ConstantInt::get(FP->getValueAPF().bitcastToAPInt());
+
   return 0;
 }
 
@@ -214,6 +208,22 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, const Constant *V,
     }
   }
 
+  // If the cast operand is a constant vector, perform the cast by
+  // operating on each element. In the cast of bitcasts, the element
+  // count may be mismatched; don't attempt to handle that here.
+  if (const ConstantVector *CV = dyn_cast<ConstantVector>(V))
+    if (isa<VectorType>(DestTy) &&
+        cast<VectorType>(DestTy)->getNumElements() ==
+        CV->getType()->getNumElements()) {
+      std::vector<Constant*> res;
+      const VectorType *DestVecTy = cast<VectorType>(DestTy);
+      const Type *DstEltTy = DestVecTy->getElementType();
+      for (unsigned i = 0, e = CV->getType()->getNumElements(); i != e; ++i)
+        res.push_back(ConstantExpr::getCast(opc,
+                                            CV->getOperand(i), DstEltTy));
+      return ConstantVector::get(DestVecTy, res);
+    }
+
   // We actually have to do a cast now. Perform the cast according to the
   // opcode specified.
   switch (opc) {
@@ -243,14 +253,6 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, const Constant *V,
       APInt Val(DestBitWidth, 2, x);
       return ConstantInt::get(Val);
     }
-    if (const ConstantVector *CV = dyn_cast<ConstantVector>(V)) {
-      std::vector<Constant*> res;
-      const VectorType *DestVecTy = cast<VectorType>(DestTy);
-      const Type *DstEltTy = DestVecTy->getElementType();
-      for (unsigned i = 0, e = CV->getType()->getNumElements(); i != e; ++i)
-        res.push_back(ConstantExpr::getCast(opc, CV->getOperand(i), DstEltTy));
-      return ConstantVector::get(DestVecTy, res);
-    }
     return 0; // Can't fold.
   case Instruction::IntToPtr:   //always treated as unsigned
     if (V->isNullValue())       // Is it an integral null value?
@@ -272,14 +274,6 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, const Constant *V,
                                  APFloat::rmNearestTiesToEven);
       return ConstantFP::get(apf);
     }
-    if (const ConstantVector *CV = dyn_cast<ConstantVector>(V)) {
-      std::vector<Constant*> res;
-      const VectorType *DestVecTy = cast<VectorType>(DestTy);
-      const Type *DstEltTy = DestVecTy->getElementType();
-      for (unsigned i = 0, e = CV->getType()->getNumElements(); i != e; ++i)
-        res.push_back(ConstantExpr::getCast(opc, CV->getOperand(i), DstEltTy));
-      return ConstantVector::get(DestVecTy, res);
-    }
     return 0;
   case Instruction::ZExt:
     if (const ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
@@ -608,10 +602,8 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
       return Constant::getNullValue(C1->getType());
     case Instruction::UDiv:
     case Instruction::SDiv:
-    case Instruction::FDiv:
     case Instruction::URem:
     case Instruction::SRem:
-    case Instruction::FRem:
       if (!isa<UndefValue>(C2))                    // undef / X -> 0
         return Constant::getNullValue(C1->getType());
       return const_cast<Constant*>(C2);            // X / undef -> undef
@@ -655,11 +647,15 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
     case Instruction::SDiv:
       if (CI2->equalsInt(1))
         return const_cast<Constant*>(C1);                     // X / 1 == X
+      if (CI2->equalsInt(0))
+        return UndefValue::get(CI2->getType());               // X / 0 == undef
       break;
     case Instruction::URem:
     case Instruction::SRem:
       if (CI2->equalsInt(1))
         return Constant::getNullValue(CI2->getType());        // X % 1 == 0
+      if (CI2->equalsInt(0))
+        return UndefValue::get(CI2->getType());               // X % 0 == undef
       break;
     case Instruction::And:
       if (CI2->isZero()) return const_cast<Constant*>(C2);    // X & 0 == 0
@@ -733,24 +729,20 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
       case Instruction::Mul:     
         return ConstantInt::get(C1V * C2V);
       case Instruction::UDiv:
-        if (CI2->isNullValue())                  
-          return 0;        // X / 0 -> can't fold
+        assert(!CI2->isNullValue() && "Div by zero handled above");
         return ConstantInt::get(C1V.udiv(C2V));
       case Instruction::SDiv:
-        if (CI2->isNullValue()) 
-          return 0;        // X / 0 -> can't fold
+        assert(!CI2->isNullValue() && "Div by zero handled above");
         if (C2V.isAllOnesValue() && C1V.isMinSignedValue())
-          return 0;        // MIN_INT / -1 -> overflow
+          return UndefValue::get(CI1->getType());   // MIN_INT / -1 -> undef
         return ConstantInt::get(C1V.sdiv(C2V));
       case Instruction::URem:
-        if (C2->isNullValue()) 
-          return 0;        // X / 0 -> can't fold
+        assert(!CI2->isNullValue() && "Div by zero handled above");
         return ConstantInt::get(C1V.urem(C2V));
-      case Instruction::SRem:    
-        if (CI2->isNullValue()) 
-          return 0;        // X % 0 -> can't fold
+      case Instruction::SRem:
+        assert(!CI2->isNullValue() && "Div by zero handled above");
         if (C2V.isAllOnesValue() && C1V.isMinSignedValue())
-          return 0;        // MIN_INT % -1 -> overflow
+          return UndefValue::get(CI1->getType());   // MIN_INT % -1 -> undef
         return ConstantInt::get(C1V.srem(C2V));
       case Instruction::And:
         return ConstantInt::get(C1V & C2V);
@@ -789,29 +781,19 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
       switch (Opcode) {
       default:                   
         break;
-      case Instruction::Add:
+      case Instruction::FAdd:
         (void)C3V.add(C2V, APFloat::rmNearestTiesToEven);
         return ConstantFP::get(C3V);
-      case Instruction::Sub:     
+      case Instruction::FSub:
         (void)C3V.subtract(C2V, APFloat::rmNearestTiesToEven);
         return ConstantFP::get(C3V);
-      case Instruction::Mul:
+      case Instruction::FMul:
         (void)C3V.multiply(C2V, APFloat::rmNearestTiesToEven);
         return ConstantFP::get(C3V);
       case Instruction::FDiv:
         (void)C3V.divide(C2V, APFloat::rmNearestTiesToEven);
         return ConstantFP::get(C3V);
       case Instruction::FRem:
-        if (C2V.isZero()) {
-          // IEEE 754, Section 7.1, #5
-          if (CFP1->getType() == Type::DoubleTy)
-            return ConstantFP::get(APFloat(std::numeric_limits<double>::
-                                           quiet_NaN()));
-          if (CFP1->getType() == Type::FloatTy)
-            return ConstantFP::get(APFloat(std::numeric_limits<float>::
-                                           quiet_NaN()));
-          break;
-        }
         (void)C3V.mod(C2V, APFloat::rmNearestTiesToEven);
         return ConstantFP::get(C3V);
       }
@@ -824,12 +806,18 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
       switch (Opcode) {
       default:
         break;
-      case Instruction::Add: 
+      case Instruction::Add:
         return EvalVectorOp(CP1, CP2, VTy, ConstantExpr::getAdd);
-      case Instruction::Sub: 
+      case Instruction::FAdd:
+        return EvalVectorOp(CP1, CP2, VTy, ConstantExpr::getFAdd);
+      case Instruction::Sub:
         return EvalVectorOp(CP1, CP2, VTy, ConstantExpr::getSub);
-      case Instruction::Mul: 
+      case Instruction::FSub:
+        return EvalVectorOp(CP1, CP2, VTy, ConstantExpr::getFSub);
+      case Instruction::Mul:
         return EvalVectorOp(CP1, CP2, VTy, ConstantExpr::getMul);
+      case Instruction::FMul:
+        return EvalVectorOp(CP1, CP2, VTy, ConstantExpr::getFMul);
       case Instruction::UDiv:
         return EvalVectorOp(CP1, CP2, VTy, ConstantExpr::getUDiv);
       case Instruction::SDiv:
@@ -848,6 +836,12 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
         return EvalVectorOp(CP1, CP2, VTy, ConstantExpr::getOr);
       case Instruction::Xor: 
         return EvalVectorOp(CP1, CP2, VTy, ConstantExpr::getXor);
+      case Instruction::LShr:
+        return EvalVectorOp(CP1, CP2, VTy, ConstantExpr::getLShr);
+      case Instruction::AShr:
+        return EvalVectorOp(CP1, CP2, VTy, ConstantExpr::getAShr);
+      case Instruction::Shl:
+        return EvalVectorOp(CP1, CP2, VTy, ConstantExpr::getShl);
       }
     }
   }
@@ -861,7 +855,9 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
     // other way if possible.
     switch (Opcode) {
     case Instruction::Add:
+    case Instruction::FAdd:
     case Instruction::Mul:
+    case Instruction::FMul:
     case Instruction::And:
     case Instruction::Or:
     case Instruction::Xor:
@@ -872,6 +868,7 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
     case Instruction::LShr:
     case Instruction::AShr:
     case Instruction::Sub:
+    case Instruction::FSub:
     case Instruction::SDiv:
     case Instruction::UDiv:
     case Instruction::FDiv:
@@ -1680,7 +1677,7 @@ Constant *llvm::ConstantFoldGetElementPtr(const Constant *C,
         Offset = ConstantExpr::getSExt(Offset, Base->getType());
       else if (Base->getType()->getPrimitiveSizeInBits() <
                Offset->getType()->getPrimitiveSizeInBits())
-        Base = ConstantExpr::getZExt(Base, Base->getType());
+        Base = ConstantExpr::getZExt(Base, Offset->getType());
       
       Base = ConstantExpr::getAdd(Base, Offset);
       return ConstantExpr::getIntToPtr(Base, CE->getType());