reapply r148901 with a crucial fix.
[oota-llvm.git] / lib / VMCore / Constants.cpp
index 9f1abbd17bbd186338243cc31351e85becdb2a9d..dd13ace600e9914c265887c82c70d82ccea67fdc 100644 (file)
@@ -129,7 +129,7 @@ Constant *Constant::getIntegerValue(Type *Ty, const APInt &V) {
 
   // Broadcast a scalar to a vector, if necessary.
   if (VectorType *VTy = dyn_cast<VectorType>(Ty))
-    C = ConstantVector::get(std::vector<Constant *>(VTy->getNumElements(), C));
+    C = ConstantVector::getSplat(VTy->getNumElements(), C);
 
   return C;
 }
@@ -145,11 +145,9 @@ Constant *Constant::getAllOnesValue(Type *Ty) {
     return ConstantFP::get(Ty->getContext(), FL);
   }
 
-  SmallVector<Constant*, 16> Elts;
   VectorType *VTy = cast<VectorType>(Ty);
-  Elts.resize(VTy->getNumElements(), getAllOnesValue(VTy->getElementType()));
-  assert(Elts[0] && "Invalid AllOnes value!");
-  return cast<ConstantVector>(ConstantVector::get(Elts));
+  return ConstantVector::getSplat(VTy->getNumElements(),
+                                  getAllOnesValue(VTy->getElementType()));
 }
 
 void Constant::destroyConstantImpl() {
@@ -394,9 +392,8 @@ Constant *ConstantInt::getTrue(Type *Ty) {
   }
   assert(VTy->getElementType()->isIntegerTy(1) &&
          "True must be vector of i1 or i1.");
-  SmallVector<Constant*, 16> Splat(VTy->getNumElements(),
-                                   ConstantInt::getTrue(Ty->getContext()));
-  return ConstantVector::get(Splat);
+  return ConstantVector::getSplat(VTy->getNumElements(),
+                                  ConstantInt::getTrue(Ty->getContext()));
 }
 
 Constant *ConstantInt::getFalse(Type *Ty) {
@@ -407,9 +404,8 @@ Constant *ConstantInt::getFalse(Type *Ty) {
   }
   assert(VTy->getElementType()->isIntegerTy(1) &&
          "False must be vector of i1 or i1.");
-  SmallVector<Constant*, 16> Splat(VTy->getNumElements(),
-                                   ConstantInt::getFalse(Ty->getContext()));
-  return ConstantVector::get(Splat);
+  return ConstantVector::getSplat(VTy->getNumElements(),
+                                  ConstantInt::getFalse(Ty->getContext()));
 }
 
 
@@ -433,8 +429,7 @@ Constant *ConstantInt::get(Type *Ty, uint64_t V, bool isSigned) {
 
   // For vectors, broadcast the value.
   if (VectorType *VTy = dyn_cast<VectorType>(Ty))
-    return ConstantVector::get(SmallVector<Constant*,
-                                           16>(VTy->getNumElements(), C));
+    return ConstantVector::getSplat(VTy->getNumElements(), C);
 
   return C;
 }
@@ -459,8 +454,7 @@ Constant *ConstantInt::get(Type* Ty, const APInt& V) {
 
   // For vectors, broadcast the value.
   if (VectorType *VTy = dyn_cast<VectorType>(Ty))
-    return ConstantVector::get(
-      SmallVector<Constant *, 16>(VTy->getNumElements(), C));
+    return ConstantVector::getSplat(VTy->getNumElements(), C);
 
   return C;
 }
@@ -506,8 +500,7 @@ Constant *ConstantFP::get(Type* Ty, double V) {
 
   // For vectors, broadcast the value.
   if (VectorType *VTy = dyn_cast<VectorType>(Ty))
-    return ConstantVector::get(
-      SmallVector<Constant *, 16>(VTy->getNumElements(), C));
+    return ConstantVector::getSplat(VTy->getNumElements(), C);
 
   return C;
 }
@@ -521,31 +514,28 @@ Constant *ConstantFP::get(Type* Ty, StringRef Str) {
 
   // For vectors, broadcast the value.
   if (VectorType *VTy = dyn_cast<VectorType>(Ty))
-    return ConstantVector::get(
-      SmallVector<Constant *, 16>(VTy->getNumElements(), C));
+    return ConstantVector::getSplat(VTy->getNumElements(), C);
 
   return C; 
 }
 
 
-ConstantFP* ConstantFP::getNegativeZero(Type* Ty) {
+ConstantFP *ConstantFP::getNegativeZero(Type *Ty) {
   LLVMContext &Context = Ty->getContext();
-  APFloat apf = cast <ConstantFP>(Constant::getNullValue(Ty))->getValueAPF();
+  APFloat apf = cast<ConstantFP>(Constant::getNullValue(Ty))->getValueAPF();
   apf.changeSign();
   return get(Context, apf);
 }
 
 
-Constant *ConstantFP::getZeroValueForNegation(Type* Ty) {
-  if (VectorType *PTy = dyn_cast<VectorType>(Ty))
-    if (PTy->getElementType()->isFloatingPointTy()) {
-      SmallVector<Constant*, 16> zeros(PTy->getNumElements(),
-                           getNegativeZero(PTy->getElementType()));
-      return ConstantVector::get(zeros);
-    }
-
-  if (Ty->isFloatingPointTy()) 
-    return getNegativeZero(Ty);
+Constant *ConstantFP::getZeroValueForNegation(Type *Ty) {
+  Type *ScalarTy = Ty->getScalarType();
+  if (ScalarTy->isFloatingPointTy()) {
+    Constant *C = getNegativeZero(ScalarTy);
+    if (VectorType *VTy = dyn_cast<VectorType>(Ty))
+      return ConstantVector::getSplat(VTy->getNumElements(), C);
+    return C;
+  }
 
   return Constant::getNullValue(Ty);
 }
@@ -718,9 +708,8 @@ Constant *ConstantArray::get(LLVMContext &Context, StringRef Str,
     ElementVals.push_back(ConstantInt::get(Type::getInt8Ty(Context), Str[i]));
 
   // Add a null terminator to the string...
-  if (AddNull) {
+  if (AddNull)
     ElementVals.push_back(ConstantInt::get(Type::getInt8Ty(Context), 0));
-  }
 
   ArrayType *ATy = ArrayType::get(Type::getInt8Ty(Context), ElementVals.size());
   return get(ATy, ElementVals);
@@ -819,6 +808,12 @@ Constant *ConstantVector::get(ArrayRef<Constant*> V) {
   return pImpl->VectorConstants.getOrCreate(T, V);
 }
 
+Constant *ConstantVector::getSplat(unsigned NumElts, Constant *V) {
+  SmallVector<Constant*, 32> Elts(NumElts, V);
+  return get(Elts);
+}
+
+
 // Utility function for determining if a ConstantExpr is a CastOp or not. This
 // can't be inline because we don't want to #include Instruction.h into
 // Constant.h
@@ -1506,8 +1501,11 @@ Constant *ConstantExpr::getPtrToInt(Constant *C, Type *DstTy) {
          "PtrToInt source must be pointer or pointer vector");
   assert(DstTy->getScalarType()->isIntegerTy() && 
          "PtrToInt destination must be integer or integer vector");
-  assert(C->getType()->getNumElements() == DstTy->getNumElements() &&
-    "Invalid cast between a different number of vector elements");
+  assert(isa<VectorType>(C->getType()) == isa<VectorType>(DstTy));
+  if (isa<VectorType>(C->getType()))
+    assert(cast<VectorType>(C->getType())->getNumElements() ==
+           cast<VectorType>(DstTy)->getNumElements() &&
+           "Invalid cast between a different number of vector elements");
   return getFoldedCast(Instruction::PtrToInt, C, DstTy);
 }
 
@@ -1516,8 +1514,11 @@ Constant *ConstantExpr::getIntToPtr(Constant *C, Type *DstTy) {
          "IntToPtr source must be integer or integer vector");
   assert(DstTy->getScalarType()->isPointerTy() &&
          "IntToPtr destination must be a pointer or pointer vector");
-  assert(C->getType()->getNumElements() == DstTy->getNumElements() &&
-    "Invalid cast between a different number of vector elements");
+  assert(isa<VectorType>(C->getType()) == isa<VectorType>(DstTy));
+  if (isa<VectorType>(C->getType()))
+    assert(cast<VectorType>(C->getType())->getNumElements() ==
+           cast<VectorType>(DstTy)->getNumElements() &&
+           "Invalid cast between a different number of vector elements");
   return getFoldedCast(Instruction::IntToPtr, C, DstTy);
 }
 
@@ -1995,8 +1996,7 @@ Type *ConstantDataSequential::getElementType() const {
 }
 
 StringRef ConstantDataSequential::getRawDataValues() const {
-  return StringRef(DataElements,
-                   getType()->getNumElements()*getElementByteSize());
+  return StringRef(DataElements, getNumElements()*getElementByteSize());
 }
 
 /// isElementTypeCompatible - Return true if a ConstantDataSequential can be
@@ -2018,6 +2018,14 @@ bool ConstantDataSequential::isElementTypeCompatible(const Type *Ty) {
   return false;
 }
 
+/// getNumElements - Return the number of elements in the array or vector.
+unsigned ConstantDataSequential::getNumElements() const {
+  if (ArrayType *AT = dyn_cast<ArrayType>(getType()))
+    return AT->getNumElements();
+  return cast<VectorType>(getType())->getNumElements();
+}
+
+
 /// getElementByteSize - Return the size in bytes of the elements in the data.
 uint64_t ConstantDataSequential::getElementByteSize() const {
   return getElementType()->getPrimitiveSizeInBits()/8;
@@ -2025,7 +2033,7 @@ uint64_t ConstantDataSequential::getElementByteSize() const {
 
 /// getElementPointer - Return the start of the specified element.
 const char *ConstantDataSequential::getElementPointer(unsigned Elt) const {
-  assert(Elt < getElementType()->getNumElements() && "Invalid Elt");
+  assert(Elt < getNumElements() && "Invalid Elt");
   return DataElements+Elt*getElementByteSize();
 }
 
@@ -2044,7 +2052,8 @@ static bool isAllZeros(StringRef Arr) {
 /// we *want* an underlying "char*" to avoid TBAA type punning violations.
 Constant *ConstantDataSequential::getImpl(StringRef Elements, Type *Ty) {
   assert(isElementTypeCompatible(cast<SequentialType>(Ty)->getElementType()));
-  // If the elements are all zero, return a CAZ, which is more dense.
+  // If the elements are all zero or there are no elements, return a CAZ, which
+  // is more dense and canonical.
   if (isAllZeros(Elements))
     return ConstantAggregateZero::get(Ty);
 
@@ -2114,60 +2123,107 @@ void ConstantDataSequential::destroyConstant() {
 /// get() constructors - Return a constant with array type with an element
 /// count and element type matching the ArrayRef passed in.  Note that this
 /// can return a ConstantAggregateZero object.
-Constant *ConstantDataArray::get(ArrayRef<uint8_t> Elts, LLVMContext &Context) {
+Constant *ConstantDataArray::get(LLVMContext &Context, ArrayRef<uint8_t> Elts) {
   Type *Ty = ArrayType::get(Type::getInt8Ty(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*1), Ty);
 }
-Constant *ConstantDataArray::get(ArrayRef<uint16_t> Elts, LLVMContext &Context){
+Constant *ConstantDataArray::get(LLVMContext &Context, ArrayRef<uint16_t> Elts){
   Type *Ty = ArrayType::get(Type::getInt16Ty(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*2), Ty);
 }
-Constant *ConstantDataArray::get(ArrayRef<uint32_t> Elts, LLVMContext &Context){
+Constant *ConstantDataArray::get(LLVMContext &Context, ArrayRef<uint32_t> Elts){
   Type *Ty = ArrayType::get(Type::getInt32Ty(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*4), Ty);
 }
-Constant *ConstantDataArray::get(ArrayRef<uint64_t> Elts, LLVMContext &Context){
+Constant *ConstantDataArray::get(LLVMContext &Context, ArrayRef<uint64_t> Elts){
   Type *Ty = ArrayType::get(Type::getInt64Ty(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*8), Ty);
 }
-Constant *ConstantDataArray::get(ArrayRef<float> Elts, LLVMContext &Context) {
+Constant *ConstantDataArray::get(LLVMContext &Context, ArrayRef<float> Elts) {
   Type *Ty = ArrayType::get(Type::getFloatTy(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*4), Ty);
 }
-Constant *ConstantDataArray::get(ArrayRef<double> Elts, LLVMContext &Context) {
+Constant *ConstantDataArray::get(LLVMContext &Context, ArrayRef<double> Elts) {
   Type *Ty = ArrayType::get(Type::getDoubleTy(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*8), Ty);
 }
 
+/// getString - This method constructs a CDS and initializes it with a text
+/// string. The default behavior (AddNull==true) causes a null terminator to
+/// be placed at the end of the array (increasing the length of the string by
+/// one more than the StringRef would normally indicate.  Pass AddNull=false
+/// to disable this behavior.
+Constant *ConstantDataArray::getString(LLVMContext &Context,
+                                       StringRef Str, bool AddNull) {
+  if (!AddNull)
+    return get(Context, ArrayRef<uint8_t>((uint8_t*)Str.data(), Str.size()));
+  
+  SmallVector<uint8_t, 64> ElementVals;
+  ElementVals.append(Str.begin(), Str.end());
+  ElementVals.push_back(0);
+  return get(Context, ElementVals);
+}
 
 /// get() constructors - Return a constant with vector type with an element
 /// count and element type matching the ArrayRef passed in.  Note that this
 /// can return a ConstantAggregateZero object.
-Constant *ConstantDataVector::get(ArrayRef<uint8_t> Elts, LLVMContext &Context) {
+Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<uint8_t> Elts){
   Type *Ty = VectorType::get(Type::getInt8Ty(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*1), Ty);
 }
-Constant *ConstantDataVector::get(ArrayRef<uint16_t> Elts, LLVMContext &Context){
+Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<uint16_t> Elts){
   Type *Ty = VectorType::get(Type::getInt16Ty(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*2), Ty);
 }
-Constant *ConstantDataVector::get(ArrayRef<uint32_t> Elts, LLVMContext &Context){
+Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<uint32_t> Elts){
   Type *Ty = VectorType::get(Type::getInt32Ty(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*4), Ty);
 }
-Constant *ConstantDataVector::get(ArrayRef<uint64_t> Elts, LLVMContext &Context){
+Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<uint64_t> Elts){
   Type *Ty = VectorType::get(Type::getInt64Ty(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*8), Ty);
 }
-Constant *ConstantDataVector::get(ArrayRef<float> Elts, LLVMContext &Context) {
+Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<float> Elts) {
   Type *Ty = VectorType::get(Type::getFloatTy(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*4), Ty);
 }
-Constant *ConstantDataVector::get(ArrayRef<double> Elts, LLVMContext &Context) {
+Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<double> Elts) {
   Type *Ty = VectorType::get(Type::getDoubleTy(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*8), Ty);
 }
 
+Constant *ConstantDataVector::getSplat(unsigned NumElts, Constant *V) {
+  assert(isElementTypeCompatible(V->getType()) &&
+         "Element type not compatible with ConstantData");
+  if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
+    if (CI->getType()->isIntegerTy(8)) {
+      SmallVector<uint8_t, 16> Elts(NumElts, CI->getZExtValue());
+      return get(V->getContext(), Elts);
+    }
+    if (CI->getType()->isIntegerTy(16)) {
+      SmallVector<uint16_t, 16> Elts(NumElts, CI->getZExtValue());
+      return get(V->getContext(), Elts);
+    }
+    if (CI->getType()->isIntegerTy(32)) {
+      SmallVector<uint32_t, 16> Elts(NumElts, CI->getZExtValue());
+      return get(V->getContext(), Elts);
+    }
+    assert(CI->getType()->isIntegerTy(64) && "Unsupported ConstantData type");
+    SmallVector<uint64_t, 16> Elts(NumElts, CI->getZExtValue());
+    return get(V->getContext(), Elts);
+  }
+
+  ConstantFP *CFP = cast<ConstantFP>(V);
+  if (CFP->getType()->isFloatTy()) {
+    SmallVector<float, 16> Elts(NumElts, CFP->getValueAPF().convertToFloat());
+    return get(V->getContext(), Elts);
+  }
+  assert(CFP->getType()->isDoubleTy() && "Unsupported ConstantData type");
+  SmallVector<double, 16> Elts(NumElts, CFP->getValueAPF().convertToDouble());
+  return get(V->getContext(), Elts);
+}
+
+
 /// getElementAsInteger - If this is a sequential container of integers (of
 /// any size), return the specified element in the low bits of a uint64_t.
 uint64_t ConstantDataSequential::getElementAsInteger(unsigned Elt) const {
@@ -2192,7 +2248,8 @@ APFloat ConstantDataSequential::getElementAsAPFloat(unsigned Elt) const {
   const char *EltPtr = getElementPointer(Elt);
 
   switch (getElementType()->getTypeID()) {
-  default: assert("Accessor can only be used when element is float/double!");
+  default:
+    assert(0 && "Accessor can only be used when element is float/double!");
   case Type::FloatTyID: return APFloat(*(float*)EltPtr);
   case Type::DoubleTyID: return APFloat(*(double*)EltPtr);
   }