reapply r148901 with a crucial fix.
[oota-llvm.git] / lib / VMCore / Constants.cpp
index df98d7586441d4a1d77b4b4812d0b164350f598e..dd13ace600e9914c265887c82c70d82ccea67fdc 100644 (file)
@@ -129,7 +129,7 @@ Constant *Constant::getIntegerValue(Type *Ty, const APInt &V) {
 
   // Broadcast a scalar to a vector, if necessary.
   if (VectorType *VTy = dyn_cast<VectorType>(Ty))
-    C = ConstantVector::get(std::vector<Constant *>(VTy->getNumElements(), C));
+    C = ConstantVector::getSplat(VTy->getNumElements(), C);
 
   return C;
 }
@@ -145,11 +145,9 @@ Constant *Constant::getAllOnesValue(Type *Ty) {
     return ConstantFP::get(Ty->getContext(), FL);
   }
 
-  SmallVector<Constant*, 16> Elts;
   VectorType *VTy = cast<VectorType>(Ty);
-  Elts.resize(VTy->getNumElements(), getAllOnesValue(VTy->getElementType()));
-  assert(Elts[0] && "Invalid AllOnes value!");
-  return cast<ConstantVector>(ConstantVector::get(Elts));
+  return ConstantVector::getSplat(VTy->getNumElements(),
+                                  getAllOnesValue(VTy->getElementType()));
 }
 
 void Constant::destroyConstantImpl() {
@@ -394,9 +392,8 @@ Constant *ConstantInt::getTrue(Type *Ty) {
   }
   assert(VTy->getElementType()->isIntegerTy(1) &&
          "True must be vector of i1 or i1.");
-  SmallVector<Constant*, 16> Splat(VTy->getNumElements(),
-                                   ConstantInt::getTrue(Ty->getContext()));
-  return ConstantVector::get(Splat);
+  return ConstantVector::getSplat(VTy->getNumElements(),
+                                  ConstantInt::getTrue(Ty->getContext()));
 }
 
 Constant *ConstantInt::getFalse(Type *Ty) {
@@ -407,9 +404,8 @@ Constant *ConstantInt::getFalse(Type *Ty) {
   }
   assert(VTy->getElementType()->isIntegerTy(1) &&
          "False must be vector of i1 or i1.");
-  SmallVector<Constant*, 16> Splat(VTy->getNumElements(),
-                                   ConstantInt::getFalse(Ty->getContext()));
-  return ConstantVector::get(Splat);
+  return ConstantVector::getSplat(VTy->getNumElements(),
+                                  ConstantInt::getFalse(Ty->getContext()));
 }
 
 
@@ -433,8 +429,7 @@ Constant *ConstantInt::get(Type *Ty, uint64_t V, bool isSigned) {
 
   // For vectors, broadcast the value.
   if (VectorType *VTy = dyn_cast<VectorType>(Ty))
-    return ConstantVector::get(SmallVector<Constant*,
-                                           16>(VTy->getNumElements(), C));
+    return ConstantVector::getSplat(VTy->getNumElements(), C);
 
   return C;
 }
@@ -459,8 +454,7 @@ Constant *ConstantInt::get(Type* Ty, const APInt& V) {
 
   // For vectors, broadcast the value.
   if (VectorType *VTy = dyn_cast<VectorType>(Ty))
-    return ConstantVector::get(
-      SmallVector<Constant *, 16>(VTy->getNumElements(), C));
+    return ConstantVector::getSplat(VTy->getNumElements(), C);
 
   return C;
 }
@@ -506,8 +500,7 @@ Constant *ConstantFP::get(Type* Ty, double V) {
 
   // For vectors, broadcast the value.
   if (VectorType *VTy = dyn_cast<VectorType>(Ty))
-    return ConstantVector::get(
-      SmallVector<Constant *, 16>(VTy->getNumElements(), C));
+    return ConstantVector::getSplat(VTy->getNumElements(), C);
 
   return C;
 }
@@ -521,31 +514,28 @@ Constant *ConstantFP::get(Type* Ty, StringRef Str) {
 
   // For vectors, broadcast the value.
   if (VectorType *VTy = dyn_cast<VectorType>(Ty))
-    return ConstantVector::get(
-      SmallVector<Constant *, 16>(VTy->getNumElements(), C));
+    return ConstantVector::getSplat(VTy->getNumElements(), C);
 
   return C; 
 }
 
 
-ConstantFP* ConstantFP::getNegativeZero(Type* Ty) {
+ConstantFP *ConstantFP::getNegativeZero(Type *Ty) {
   LLVMContext &Context = Ty->getContext();
-  APFloat apf = cast <ConstantFP>(Constant::getNullValue(Ty))->getValueAPF();
+  APFloat apf = cast<ConstantFP>(Constant::getNullValue(Ty))->getValueAPF();
   apf.changeSign();
   return get(Context, apf);
 }
 
 
-Constant *ConstantFP::getZeroValueForNegation(Type* Ty) {
-  if (VectorType *PTy = dyn_cast<VectorType>(Ty))
-    if (PTy->getElementType()->isFloatingPointTy()) {
-      SmallVector<Constant*, 16> zeros(PTy->getNumElements(),
-                           getNegativeZero(PTy->getElementType()));
-      return ConstantVector::get(zeros);
-    }
-
-  if (Ty->isFloatingPointTy()) 
-    return getNegativeZero(Ty);
+Constant *ConstantFP::getZeroValueForNegation(Type *Ty) {
+  Type *ScalarTy = Ty->getScalarType();
+  if (ScalarTy->isFloatingPointTy()) {
+    Constant *C = getNegativeZero(ScalarTy);
+    if (VectorType *VTy = dyn_cast<VectorType>(Ty))
+      return ConstantVector::getSplat(VTy->getNumElements(), C);
+    return C;
+  }
 
   return Constant::getNullValue(Ty);
 }
@@ -624,6 +614,15 @@ Constant *ConstantAggregateZero::getElementValue(Constant *C) {
   return getStructElement(cast<ConstantInt>(C)->getZExtValue());
 }
 
+/// getElementValue - Return a zero of the right value for the specified GEP
+/// index.
+Constant *ConstantAggregateZero::getElementValue(unsigned Idx) {
+  if (isa<SequentialType>(getType()))
+    return getSequentialElement();
+  return getStructElement(Idx);
+}
+
+
 //===----------------------------------------------------------------------===//
 //                         UndefValue Implementation
 //===----------------------------------------------------------------------===//
@@ -648,6 +647,15 @@ UndefValue *UndefValue::getElementValue(Constant *C) {
   return getStructElement(cast<ConstantInt>(C)->getZExtValue());
 }
 
+/// getElementValue - Return an undef of the right value for the specified GEP
+/// index.
+UndefValue *UndefValue::getElementValue(unsigned Idx) {
+  if (isa<SequentialType>(getType()))
+    return getSequentialElement();
+  return getStructElement(Idx);
+}
+
+
 
 //===----------------------------------------------------------------------===//
 //                            ConstantXXX Classes
@@ -700,9 +708,8 @@ Constant *ConstantArray::get(LLVMContext &Context, StringRef Str,
     ElementVals.push_back(ConstantInt::get(Type::getInt8Ty(Context), Str[i]));
 
   // Add a null terminator to the string...
-  if (AddNull) {
+  if (AddNull)
     ElementVals.push_back(ConstantInt::get(Type::getInt8Ty(Context), 0));
-  }
 
   ArrayType *ATy = ArrayType::get(Type::getInt8Ty(Context), ElementVals.size());
   return get(ATy, ElementVals);
@@ -801,6 +808,12 @@ Constant *ConstantVector::get(ArrayRef<Constant*> V) {
   return pImpl->VectorConstants.getOrCreate(T, V);
 }
 
+Constant *ConstantVector::getSplat(unsigned NumElts, Constant *V) {
+  SmallVector<Constant*, 32> Elts(NumElts, V);
+  return get(Elts);
+}
+
+
 // Utility function for determining if a ConstantExpr is a CastOp or not. This
 // can't be inline because we don't want to #include Instruction.h into
 // Constant.h
@@ -1488,8 +1501,11 @@ Constant *ConstantExpr::getPtrToInt(Constant *C, Type *DstTy) {
          "PtrToInt source must be pointer or pointer vector");
   assert(DstTy->getScalarType()->isIntegerTy() && 
          "PtrToInt destination must be integer or integer vector");
-  assert(C->getType()->getNumElements() == DstTy->getNumElements() &&
-    "Invalid cast between a different number of vector elements");
+  assert(isa<VectorType>(C->getType()) == isa<VectorType>(DstTy));
+  if (isa<VectorType>(C->getType()))
+    assert(cast<VectorType>(C->getType())->getNumElements() ==
+           cast<VectorType>(DstTy)->getNumElements() &&
+           "Invalid cast between a different number of vector elements");
   return getFoldedCast(Instruction::PtrToInt, C, DstTy);
 }
 
@@ -1498,8 +1514,11 @@ Constant *ConstantExpr::getIntToPtr(Constant *C, Type *DstTy) {
          "IntToPtr source must be integer or integer vector");
   assert(DstTy->getScalarType()->isPointerTy() &&
          "IntToPtr destination must be a pointer or pointer vector");
-  assert(C->getType()->getNumElements() == DstTy->getNumElements() &&
-    "Invalid cast between a different number of vector elements");
+  assert(isa<VectorType>(C->getType()) == isa<VectorType>(DstTy));
+  if (isa<VectorType>(C->getType()))
+    assert(cast<VectorType>(C->getType())->getNumElements() ==
+           cast<VectorType>(DstTy)->getNumElements() &&
+           "Invalid cast between a different number of vector elements");
   return getFoldedCast(Instruction::IntToPtr, C, DstTy);
 }
 
@@ -1976,6 +1995,10 @@ Type *ConstantDataSequential::getElementType() const {
   return getType()->getElementType();
 }
 
+StringRef ConstantDataSequential::getRawDataValues() const {
+  return StringRef(DataElements, getNumElements()*getElementByteSize());
+}
+
 /// isElementTypeCompatible - Return true if a ConstantDataSequential can be
 /// formed with a vector or array of the specified element type.
 /// ConstantDataArray only works with normal float and int types that are
@@ -1995,6 +2018,14 @@ bool ConstantDataSequential::isElementTypeCompatible(const Type *Ty) {
   return false;
 }
 
+/// getNumElements - Return the number of elements in the array or vector.
+unsigned ConstantDataSequential::getNumElements() const {
+  if (ArrayType *AT = dyn_cast<ArrayType>(getType()))
+    return AT->getNumElements();
+  return cast<VectorType>(getType())->getNumElements();
+}
+
+
 /// getElementByteSize - Return the size in bytes of the elements in the data.
 uint64_t ConstantDataSequential::getElementByteSize() const {
   return getElementType()->getPrimitiveSizeInBits()/8;
@@ -2002,7 +2033,7 @@ uint64_t ConstantDataSequential::getElementByteSize() const {
 
 /// getElementPointer - Return the start of the specified element.
 const char *ConstantDataSequential::getElementPointer(unsigned Elt) const {
-  assert(Elt < getElementType()->getNumElements() && "Invalid Elt");
+  assert(Elt < getNumElements() && "Invalid Elt");
   return DataElements+Elt*getElementByteSize();
 }
 
@@ -2021,7 +2052,8 @@ static bool isAllZeros(StringRef Arr) {
 /// we *want* an underlying "char*" to avoid TBAA type punning violations.
 Constant *ConstantDataSequential::getImpl(StringRef Elements, Type *Ty) {
   assert(isElementTypeCompatible(cast<SequentialType>(Ty)->getElementType()));
-  // If the elements are all zero, return a CAZ, which is more dense.
+  // If the elements are all zero or there are no elements, return a CAZ, which
+  // is more dense and canonical.
   if (isAllZeros(Elements))
     return ConstantAggregateZero::get(Ty);
 
@@ -2049,14 +2081,12 @@ Constant *ConstantDataSequential::getImpl(StringRef Elements, Type *Ty) {
 }
 
 void ConstantDataSequential::destroyConstant() {
-  uint64_t ByteSize = getElementByteSize() * getElementType()->getNumElements();
-  
   // Remove the constant from the StringMap.
   StringMap<ConstantDataSequential*> &CDSConstants = 
     getType()->getContext().pImpl->CDSConstants;
   
   StringMap<ConstantDataSequential*>::iterator Slot =
-    CDSConstants.find(StringRef(DataElements, ByteSize));
+    CDSConstants.find(getRawDataValues());
 
   assert(Slot != CDSConstants.end() && "CDS not found in uniquing table");
 
@@ -2093,60 +2123,107 @@ void ConstantDataSequential::destroyConstant() {
 /// get() constructors - Return a constant with array type with an element
 /// count and element type matching the ArrayRef passed in.  Note that this
 /// can return a ConstantAggregateZero object.
-Constant *ConstantDataArray::get(ArrayRef<uint8_t> Elts, LLVMContext &Context) {
+Constant *ConstantDataArray::get(LLVMContext &Context, ArrayRef<uint8_t> Elts) {
   Type *Ty = ArrayType::get(Type::getInt8Ty(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*1), Ty);
 }
-Constant *ConstantDataArray::get(ArrayRef<uint16_t> Elts, LLVMContext &Context){
+Constant *ConstantDataArray::get(LLVMContext &Context, ArrayRef<uint16_t> Elts){
   Type *Ty = ArrayType::get(Type::getInt16Ty(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*2), Ty);
 }
-Constant *ConstantDataArray::get(ArrayRef<uint32_t> Elts, LLVMContext &Context){
+Constant *ConstantDataArray::get(LLVMContext &Context, ArrayRef<uint32_t> Elts){
   Type *Ty = ArrayType::get(Type::getInt32Ty(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*4), Ty);
 }
-Constant *ConstantDataArray::get(ArrayRef<uint64_t> Elts, LLVMContext &Context){
+Constant *ConstantDataArray::get(LLVMContext &Context, ArrayRef<uint64_t> Elts){
   Type *Ty = ArrayType::get(Type::getInt64Ty(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*8), Ty);
 }
-Constant *ConstantDataArray::get(ArrayRef<float> Elts, LLVMContext &Context) {
+Constant *ConstantDataArray::get(LLVMContext &Context, ArrayRef<float> Elts) {
   Type *Ty = ArrayType::get(Type::getFloatTy(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*4), Ty);
 }
-Constant *ConstantDataArray::get(ArrayRef<double> Elts, LLVMContext &Context) {
+Constant *ConstantDataArray::get(LLVMContext &Context, ArrayRef<double> Elts) {
   Type *Ty = ArrayType::get(Type::getDoubleTy(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*8), Ty);
 }
 
+/// getString - This method constructs a CDS and initializes it with a text
+/// string. The default behavior (AddNull==true) causes a null terminator to
+/// be placed at the end of the array (increasing the length of the string by
+/// one more than the StringRef would normally indicate.  Pass AddNull=false
+/// to disable this behavior.
+Constant *ConstantDataArray::getString(LLVMContext &Context,
+                                       StringRef Str, bool AddNull) {
+  if (!AddNull)
+    return get(Context, ArrayRef<uint8_t>((uint8_t*)Str.data(), Str.size()));
+  
+  SmallVector<uint8_t, 64> ElementVals;
+  ElementVals.append(Str.begin(), Str.end());
+  ElementVals.push_back(0);
+  return get(Context, ElementVals);
+}
 
 /// get() constructors - Return a constant with vector type with an element
 /// count and element type matching the ArrayRef passed in.  Note that this
 /// can return a ConstantAggregateZero object.
-Constant *ConstantDataVector::get(ArrayRef<uint8_t> Elts, LLVMContext &Context) {
+Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<uint8_t> Elts){
   Type *Ty = VectorType::get(Type::getInt8Ty(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*1), Ty);
 }
-Constant *ConstantDataVector::get(ArrayRef<uint16_t> Elts, LLVMContext &Context){
+Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<uint16_t> Elts){
   Type *Ty = VectorType::get(Type::getInt16Ty(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*2), Ty);
 }
-Constant *ConstantDataVector::get(ArrayRef<uint32_t> Elts, LLVMContext &Context){
+Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<uint32_t> Elts){
   Type *Ty = VectorType::get(Type::getInt32Ty(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*4), Ty);
 }
-Constant *ConstantDataVector::get(ArrayRef<uint64_t> Elts, LLVMContext &Context){
+Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<uint64_t> Elts){
   Type *Ty = VectorType::get(Type::getInt64Ty(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*8), Ty);
 }
-Constant *ConstantDataVector::get(ArrayRef<float> Elts, LLVMContext &Context) {
+Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<float> Elts) {
   Type *Ty = VectorType::get(Type::getFloatTy(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*4), Ty);
 }
-Constant *ConstantDataVector::get(ArrayRef<double> Elts, LLVMContext &Context) {
+Constant *ConstantDataVector::get(LLVMContext &Context, ArrayRef<double> Elts) {
   Type *Ty = VectorType::get(Type::getDoubleTy(Context), Elts.size());
   return getImpl(StringRef((char*)Elts.data(), Elts.size()*8), Ty);
 }
 
+Constant *ConstantDataVector::getSplat(unsigned NumElts, Constant *V) {
+  assert(isElementTypeCompatible(V->getType()) &&
+         "Element type not compatible with ConstantData");
+  if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
+    if (CI->getType()->isIntegerTy(8)) {
+      SmallVector<uint8_t, 16> Elts(NumElts, CI->getZExtValue());
+      return get(V->getContext(), Elts);
+    }
+    if (CI->getType()->isIntegerTy(16)) {
+      SmallVector<uint16_t, 16> Elts(NumElts, CI->getZExtValue());
+      return get(V->getContext(), Elts);
+    }
+    if (CI->getType()->isIntegerTy(32)) {
+      SmallVector<uint32_t, 16> Elts(NumElts, CI->getZExtValue());
+      return get(V->getContext(), Elts);
+    }
+    assert(CI->getType()->isIntegerTy(64) && "Unsupported ConstantData type");
+    SmallVector<uint64_t, 16> Elts(NumElts, CI->getZExtValue());
+    return get(V->getContext(), Elts);
+  }
+
+  ConstantFP *CFP = cast<ConstantFP>(V);
+  if (CFP->getType()->isFloatTy()) {
+    SmallVector<float, 16> Elts(NumElts, CFP->getValueAPF().convertToFloat());
+    return get(V->getContext(), Elts);
+  }
+  assert(CFP->getType()->isDoubleTy() && "Unsupported ConstantData type");
+  SmallVector<double, 16> Elts(NumElts, CFP->getValueAPF().convertToDouble());
+  return get(V->getContext(), Elts);
+}
+
+
 /// getElementAsInteger - If this is a sequential container of integers (of
 /// any size), return the specified element in the low bits of a uint64_t.
 uint64_t ConstantDataSequential::getElementAsInteger(unsigned Elt) const {
@@ -2171,7 +2248,8 @@ APFloat ConstantDataSequential::getElementAsAPFloat(unsigned Elt) const {
   const char *EltPtr = getElementPointer(Elt);
 
   switch (getElementType()->getTypeID()) {
-  default: assert("Accessor can only be used when element is float/double!");
+  default:
+    assert(0 && "Accessor can only be used when element is float/double!");
   case Type::FloatTyID: return APFloat(*(float*)EltPtr);
   case Type::DoubleTyID: return APFloat(*(double*)EltPtr);
   }
@@ -2203,7 +2281,25 @@ Constant *ConstantDataSequential::getElementAsConstant(unsigned Elt) const {
   return ConstantInt::get(getElementType(), getElementAsInteger(Elt));
 }
 
+/// isString - This method returns true if this is an array of i8.
+bool ConstantDataSequential::isString() const {
+  return isa<ArrayType>(getType()) && getElementType()->isIntegerTy(8);
+}
 
+/// isCString - This method returns true if the array "isString", ends with a
+/// nul byte, and does not contains any other nul bytes.
+bool ConstantDataSequential::isCString() const {
+  if (!isString())
+    return false;
+  
+  StringRef Str = getAsString();
+  
+  // The last value must be nul.
+  if (Str.back() != 0) return false;
+  
+  // Other elements must be non-nul.
+  return Str.drop_back().find(0) == StringRef::npos;
+}
 
 
 //===----------------------------------------------------------------------===//