s/isReturnStruct()/hasStructRetAttr()/g
[oota-llvm.git] / lib / VMCore / ConstantFold.cpp
index 6660c8d3e2f3168242aaaf29c3d02c96384810aa..762a24a26330165ccee1d95cd5adbcc62b585c4b 100644 (file)
@@ -2,8 +2,8 @@
 //
 //                     The LLVM Compiler Infrastructure
 //
-// This file was developed by the LLVM research group and is distributed under
-// the University of Illinois Open Source License. See LICENSE.TXT for details.
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
 //
 //===----------------------------------------------------------------------===//
 //
@@ -170,12 +170,12 @@ static Constant *FoldBitCast(Constant *V, const Type *DestTy) {
 
 Constant *llvm::ConstantFoldCastInstruction(unsigned opc, const Constant *V,
                                             const Type *DestTy) {
-  const Type *SrcTy = V->getType();
-
   if (isa<UndefValue>(V)) {
     // zext(undef) = 0, because the top bits will be zero.
     // sext(undef) = 0, because the top bits will all be the same.
-    if (opc == Instruction::ZExt || opc == Instruction::SExt)
+    // [us]itofp(undef) = 0, because the result value is bounded.
+    if (opc == Instruction::ZExt || opc == Instruction::SExt ||
+        opc == Instruction::UIToFP || opc == Instruction::SIToFP)
       return Constant::getNullValue(DestTy);
     return UndefValue::get(DestTy);
   }
@@ -255,12 +255,11 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, const Constant *V,
     if (const ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
       APInt api = CI->getValue();
       const uint64_t zero[] = {0, 0};
-      uint32_t BitWidth = cast<IntegerType>(SrcTy)->getBitWidth();
       APFloat apf = APFloat(APInt(DestTy->getPrimitiveSizeInBits(),
                                   2, zero));
-      (void)apf.convertFromZeroExtendedInteger(api.getRawData(), BitWidth
-                                   opc==Instruction::SIToFP,
-                                   APFloat::rmNearestTiesToEven);
+      (void)apf.convertFromAPInt(api
+                                 opc==Instruction::SIToFP,
+                                 APFloat::rmNearestTiesToEven);
       return ConstantFP::get(DestTy, apf);
     }
     if (const ConstantVector *CV = dyn_cast<ConstantVector>(V)) {
@@ -396,11 +395,54 @@ Constant *llvm::ConstantFoldInsertElementInstruction(const Constant *Val,
   return 0;
 }
 
+/// GetVectorElement - If C is a ConstantVector, ConstantAggregateZero or Undef
+/// return the specified element value.  Otherwise return null.
+static Constant *GetVectorElement(const Constant *C, unsigned EltNo) {
+  if (const ConstantVector *CV = dyn_cast<ConstantVector>(C))
+    return const_cast<Constant*>(CV->getOperand(EltNo));
+  
+  const Type *EltTy = cast<VectorType>(C->getType())->getElementType();
+  if (isa<ConstantAggregateZero>(C))
+    return Constant::getNullValue(EltTy);
+  if (isa<UndefValue>(C))
+    return UndefValue::get(EltTy);
+  return 0;
+}
+
 Constant *llvm::ConstantFoldShuffleVectorInstruction(const Constant *V1,
                                                      const Constant *V2,
                                                      const Constant *Mask) {
-  // TODO:
-  return 0;
+  // Undefined shuffle mask -> undefined value.
+  if (isa<UndefValue>(Mask)) return UndefValue::get(V1->getType());
+  
+  unsigned NumElts = cast<VectorType>(V1->getType())->getNumElements();
+  const Type *EltTy = cast<VectorType>(V1->getType())->getElementType();
+  
+  // Loop over the shuffle mask, evaluating each element.
+  SmallVector<Constant*, 32> Result;
+  for (unsigned i = 0; i != NumElts; ++i) {
+    Constant *InElt = GetVectorElement(Mask, i);
+    if (InElt == 0) return 0;
+    
+    if (isa<UndefValue>(InElt))
+      InElt = UndefValue::get(EltTy);
+    else if (ConstantInt *CI = dyn_cast<ConstantInt>(InElt)) {
+      unsigned Elt = CI->getZExtValue();
+      if (Elt >= NumElts*2)
+        InElt = UndefValue::get(EltTy);
+      else if (Elt >= NumElts)
+        InElt = GetVectorElement(V2, Elt-NumElts);
+      else
+        InElt = GetVectorElement(V1, Elt);
+      if (InElt == 0) return 0;
+    } else {
+      // Unknown value.
+      return 0;
+    }
+    Result.push_back(InElt);
+  }
+  
+  return ConstantVector::get(&Result[0], Result.size());
 }
 
 /// EvalVectorOp - Given two vector constants and a function pointer, apply the
@@ -615,25 +657,28 @@ Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
       case Instruction::Xor:
         return ConstantInt::get(C1V ^ C2V);
       case Instruction::Shl:
-        if (uint32_t shiftAmt = C2V.getZExtValue())
+        if (uint32_t shiftAmt = C2V.getZExtValue()) {
           if (shiftAmt < C1V.getBitWidth())
             return ConstantInt::get(C1V.shl(shiftAmt));
           else
             return UndefValue::get(C1->getType()); // too big shift is undef
+        }
         return const_cast<ConstantInt*>(CI1); // Zero shift is identity
       case Instruction::LShr:
-        if (uint32_t shiftAmt = C2V.getZExtValue())
+        if (uint32_t shiftAmt = C2V.getZExtValue()) {
           if (shiftAmt < C1V.getBitWidth())
             return ConstantInt::get(C1V.lshr(shiftAmt));
           else
             return UndefValue::get(C1->getType()); // too big shift is undef
+        }
         return const_cast<ConstantInt*>(CI1); // Zero shift is identity
       case Instruction::AShr:
-        if (uint32_t shiftAmt = C2V.getZExtValue())
+        if (uint32_t shiftAmt = C2V.getZExtValue()) {
           if (shiftAmt < C1V.getBitWidth())
             return ConstantInt::get(C1V.ashr(shiftAmt));
           else
             return UndefValue::get(C1->getType()); // too big shift is undef
+        }
         return const_cast<ConstantInt*>(CI1); // Zero shift is identity
       }
     }
@@ -1038,18 +1083,20 @@ static ICmpInst::Predicate evaluateICmpRelation(const Constant *V1,
             // Ok, we ran out of things they have in common.  If any leftovers
             // are non-zero then we have a difference, otherwise we are equal.
             for (; i < CE1->getNumOperands(); ++i)
-              if (!CE1->getOperand(i)->isNullValue())
+              if (!CE1->getOperand(i)->isNullValue()) {
                 if (isa<ConstantInt>(CE1->getOperand(i)))
                   return isSigned ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT;
                 else
                   return ICmpInst::BAD_ICMP_PREDICATE; // Might be equal.
+              }
 
             for (; i < CE2->getNumOperands(); ++i)
-              if (!CE2->getOperand(i)->isNullValue())
+              if (!CE2->getOperand(i)->isNullValue()) {
                 if (isa<ConstantInt>(CE2->getOperand(i)))
                   return isSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
                 else
                   return ICmpInst::BAD_ICMP_PREDICATE; // Might be equal.
+              }
             return ICmpInst::ICMP_EQ;
           }
         }
@@ -1078,20 +1125,22 @@ Constant *llvm::ConstantFoldCompareInstruction(unsigned short pred,
   if (C1->isNullValue()) {
     if (const GlobalValue *GV = dyn_cast<GlobalValue>(C2))
       // Don't try to evaluate aliases.  External weak GV can be null.
-      if (!isa<GlobalAlias>(GV) && !GV->hasExternalWeakLinkage())
+      if (!isa<GlobalAlias>(GV) && !GV->hasExternalWeakLinkage()) {
         if (pred == ICmpInst::ICMP_EQ)
           return ConstantInt::getFalse();
         else if (pred == ICmpInst::ICMP_NE)
           return ConstantInt::getTrue();
+      }
   // icmp eq/ne(GV,null) -> false/true
   } else if (C2->isNullValue()) {
     if (const GlobalValue *GV = dyn_cast<GlobalValue>(C1))
       // Don't try to evaluate aliases.  External weak GV can be null.
-      if (!isa<GlobalAlias>(GV) && !GV->hasExternalWeakLinkage())
+      if (!isa<GlobalAlias>(GV) && !GV->hasExternalWeakLinkage()) {
         if (pred == ICmpInst::ICMP_EQ)
           return ConstantInt::getFalse();
         else if (pred == ICmpInst::ICMP_NE)
           return ConstantInt::getTrue();
+      }
   }
 
   if (isa<ConstantInt>(C1) && isa<ConstantInt>(C2)) {
@@ -1340,12 +1389,13 @@ Constant *llvm::ConstantFoldGetElementPtr(const Constant *C,
     return const_cast<Constant*>(C);
 
   if (isa<UndefValue>(C)) {
-    const Type *Ty = GetElementPtrInst::getIndexedType(C->getType(),
+    const PointerType *Ptr = cast<PointerType>(C->getType());
+    const Type *Ty = GetElementPtrInst::getIndexedType(Ptr,
                                                        (Value **)Idxs,
                                                        (Value **)Idxs+NumIdx,
                                                        true);
     assert(Ty != 0 && "Invalid indices for GEP!");
-    return UndefValue::get(PointerType::get(Ty));
+    return UndefValue::get(PointerType::get(Ty, Ptr->getAddressSpace()));
   }
 
   Constant *Idx0 = Idxs[0];
@@ -1357,12 +1407,14 @@ Constant *llvm::ConstantFoldGetElementPtr(const Constant *C,
         break;
       }
     if (isNull) {
-      const Type *Ty = GetElementPtrInst::getIndexedType(C->getType(),
+      const PointerType *Ptr = cast<PointerType>(C->getType());
+      const Type *Ty = GetElementPtrInst::getIndexedType(Ptr,
                                                          (Value**)Idxs,
                                                          (Value**)Idxs+NumIdx,
                                                          true);
       assert(Ty != 0 && "Invalid indices for GEP!");
-      return ConstantPointerNull::get(PointerType::get(Ty));
+      return 
+        ConstantPointerNull::get(PointerType::get(Ty,Ptr->getAddressSpace()));
     }
   }