IR: Allow vectors of halfs to be ConstantDataVectors
[oota-llvm.git] / lib / IR / Constants.cpp
index b237c913dcf1e6fc763afc94e8f3d53158d300ee..509783fff8bdf8f404d07fe9e5809df944121bea 100644 (file)
@@ -53,6 +53,11 @@ bool Constant::isNegativeZeroValue() const {
       if (SplatCFP && SplatCFP->isZero() && SplatCFP->isNegative())
         return true;
 
+  if (const ConstantVector *CV = dyn_cast<ConstantVector>(this))
+    if (ConstantFP *SplatCFP = dyn_cast_or_null<ConstantFP>(CV->getSplatValue()))
+      if (SplatCFP && SplatCFP->isZero() && SplatCFP->isNegative())
+        return true;
+
   // We've already handled true FP case; any other FP vectors can't represent -0.0.
   if (getType()->isFPOrFPVectorTy())
     return false;
@@ -74,6 +79,11 @@ bool Constant::isZeroValue() const {
       if (SplatCFP && SplatCFP->isZero())
         return true;
 
+  if (const ConstantVector *CV = dyn_cast<ConstantVector>(this))
+    if (ConstantFP *SplatCFP = dyn_cast_or_null<ConstantFP>(CV->getSplatValue()))
+      if (SplatCFP && SplatCFP->isZero())
+        return true;
+
   // Otherwise, just use +0.0.
   return isNullValue();
 }
@@ -847,6 +857,59 @@ static bool rangeOnlyContains(ItTy Start, ItTy End, EltTy Elt) {
   return true;
 }
 
+template <typename SequentialTy, typename ElementTy>
+static Constant *getIntSequenceIfElementsMatch(ArrayRef<Constant *> V) {
+  assert(!V.empty() && "Cannot get empty int sequence.");
+
+  SmallVector<ElementTy, 16> Elts;
+  for (Constant *C : V)
+    if (auto *CI = dyn_cast<ConstantInt>(C))
+      Elts.push_back(CI->getZExtValue());
+    else
+      return nullptr;
+  return SequentialTy::get(V[0]->getContext(), Elts);
+}
+
+template <typename SequentialTy, typename ElementTy>
+static Constant *getFPSequenceIfElementsMatch(ArrayRef<Constant *> V) {
+  assert(!V.empty() && "Cannot get empty FP sequence.");
+
+  SmallVector<ElementTy, 16> Elts;
+  for (Constant *C : V)
+    if (auto *CFP = dyn_cast<ConstantFP>(C))
+      Elts.push_back(CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
+    else
+      return nullptr;
+  return SequentialTy::getFP(V[0]->getContext(), Elts);
+}
+
+template <typename SequenceTy>
+static Constant *getSequenceIfElementsMatch(Constant *C,
+                                            ArrayRef<Constant *> V) {
+  // We speculatively build the elements here even if it turns out that there is
+  // a constantexpr or something else weird, since it is so uncommon for that to
+  // happen.
+  if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) {
+    if (CI->getType()->isIntegerTy(8))
+      return getIntSequenceIfElementsMatch<SequenceTy, uint8_t>(V);
+    else if (CI->getType()->isIntegerTy(16))
+      return getIntSequenceIfElementsMatch<SequenceTy, uint16_t>(V);
+    else if (CI->getType()->isIntegerTy(32))
+      return getIntSequenceIfElementsMatch<SequenceTy, uint32_t>(V);
+    else if (CI->getType()->isIntegerTy(64))
+      return getIntSequenceIfElementsMatch<SequenceTy, uint64_t>(V);
+  } else if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
+    if (CFP->getType()->isHalfTy())
+      return getFPSequenceIfElementsMatch<SequenceTy, uint16_t>(V);
+    else if (CFP->getType()->isFloatTy())
+      return getFPSequenceIfElementsMatch<SequenceTy, uint32_t>(V);
+    else if (CFP->getType()->isDoubleTy())
+      return getFPSequenceIfElementsMatch<SequenceTy, uint64_t>(V);
+  }
+
+  return nullptr;
+}
+
 ConstantArray::ConstantArray(ArrayType *T, ArrayRef<Constant *> V)
   : Constant(T, ConstantArrayVal,
              OperandTraits<ConstantArray>::op_end(this) - V.size(),
@@ -864,6 +927,7 @@ Constant *ConstantArray::get(ArrayType *Ty, ArrayRef<Constant*> V) {
     return C;
   return Ty->getContext().pImpl->ArrayConstants.getOrCreate(Ty, V);
 }
+
 Constant *ConstantArray::getImpl(ArrayType *Ty, ArrayRef<Constant*> V) {
   // Empty arrays are canonicalized to ConstantAggregateZero.
   if (V.empty())
@@ -886,74 +950,8 @@ Constant *ConstantArray::getImpl(ArrayType *Ty, ArrayRef<Constant*> V) {
 
   // Check to see if all of the elements are ConstantFP or ConstantInt and if
   // the element type is compatible with ConstantDataVector.  If so, use it.
-  if (ConstantDataSequential::isElementTypeCompatible(C->getType())) {
-    // We speculatively build the elements here even if it turns out that there
-    // is a constantexpr or something else weird in the array, since it is so
-    // uncommon for that to happen.
-    if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) {
-      if (CI->getType()->isIntegerTy(8)) {
-        SmallVector<uint8_t, 16> Elts;
-        for (unsigned i = 0, e = V.size(); i != e; ++i)
-          if (ConstantInt *CI = dyn_cast<ConstantInt>(V[i]))
-            Elts.push_back(CI->getZExtValue());
-          else
-            break;
-        if (Elts.size() == V.size())
-          return ConstantDataArray::get(C->getContext(), Elts);
-      } else if (CI->getType()->isIntegerTy(16)) {
-        SmallVector<uint16_t, 16> Elts;
-        for (unsigned i = 0, e = V.size(); i != e; ++i)
-          if (ConstantInt *CI = dyn_cast<ConstantInt>(V[i]))
-            Elts.push_back(CI->getZExtValue());
-          else
-            break;
-        if (Elts.size() == V.size())
-          return ConstantDataArray::get(C->getContext(), Elts);
-      } else if (CI->getType()->isIntegerTy(32)) {
-        SmallVector<uint32_t, 16> Elts;
-        for (unsigned i = 0, e = V.size(); i != e; ++i)
-          if (ConstantInt *CI = dyn_cast<ConstantInt>(V[i]))
-            Elts.push_back(CI->getZExtValue());
-          else
-            break;
-        if (Elts.size() == V.size())
-          return ConstantDataArray::get(C->getContext(), Elts);
-      } else if (CI->getType()->isIntegerTy(64)) {
-        SmallVector<uint64_t, 16> Elts;
-        for (unsigned i = 0, e = V.size(); i != e; ++i)
-          if (ConstantInt *CI = dyn_cast<ConstantInt>(V[i]))
-            Elts.push_back(CI->getZExtValue());
-          else
-            break;
-        if (Elts.size() == V.size())
-          return ConstantDataArray::get(C->getContext(), Elts);
-      }
-    }
-
-    if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
-      if (CFP->getType()->isFloatTy()) {
-        SmallVector<uint32_t, 16> Elts;
-        for (unsigned i = 0, e = V.size(); i != e; ++i)
-          if (ConstantFP *CFP = dyn_cast<ConstantFP>(V[i]))
-            Elts.push_back(
-                CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
-          else
-            break;
-        if (Elts.size() == V.size())
-          return ConstantDataArray::getFP(C->getContext(), Elts);
-      } else if (CFP->getType()->isDoubleTy()) {
-        SmallVector<uint64_t, 16> Elts;
-        for (unsigned i = 0, e = V.size(); i != e; ++i)
-          if (ConstantFP *CFP = dyn_cast<ConstantFP>(V[i]))
-            Elts.push_back(
-                CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
-          else
-            break;
-        if (Elts.size() == V.size())
-          return ConstantDataArray::getFP(C->getContext(), Elts);
-      }
-    }
-  }
+  if (ConstantDataSequential::isElementTypeCompatible(C->getType()))
+    return getSequenceIfElementsMatch<ConstantDataArray>(C, V);
 
   // Otherwise, we really do want to create a ConstantArray.
   return nullptr;
@@ -1049,6 +1047,7 @@ Constant *ConstantVector::get(ArrayRef<Constant*> V) {
   VectorType *Ty = VectorType::get(V.front()->getType(), V.size());
   return Ty->getContext().pImpl->VectorConstants.getOrCreate(Ty, V);
 }
+
 Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {
   assert(!V.empty() && "Vectors can't be empty");
   VectorType *T = VectorType::get(V.front()->getType(), V.size());
@@ -1074,74 +1073,8 @@ Constant *ConstantVector::getImpl(ArrayRef<Constant*> V) {
 
   // Check to see if all of the elements are ConstantFP or ConstantInt and if
   // the element type is compatible with ConstantDataVector.  If so, use it.
-  if (ConstantDataSequential::isElementTypeCompatible(C->getType())) {
-    // We speculatively build the elements here even if it turns out that there
-    // is a constantexpr or something else weird in the array, since it is so
-    // uncommon for that to happen.
-    if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) {
-      if (CI->getType()->isIntegerTy(8)) {
-        SmallVector<uint8_t, 16> Elts;
-        for (unsigned i = 0, e = V.size(); i != e; ++i)
-          if (ConstantInt *CI = dyn_cast<ConstantInt>(V[i]))
-            Elts.push_back(CI->getZExtValue());
-          else
-            break;
-        if (Elts.size() == V.size())
-          return ConstantDataVector::get(C->getContext(), Elts);
-      } else if (CI->getType()->isIntegerTy(16)) {
-        SmallVector<uint16_t, 16> Elts;
-        for (unsigned i = 0, e = V.size(); i != e; ++i)
-          if (ConstantInt *CI = dyn_cast<ConstantInt>(V[i]))
-            Elts.push_back(CI->getZExtValue());
-          else
-            break;
-        if (Elts.size() == V.size())
-          return ConstantDataVector::get(C->getContext(), Elts);
-      } else if (CI->getType()->isIntegerTy(32)) {
-        SmallVector<uint32_t, 16> Elts;
-        for (unsigned i = 0, e = V.size(); i != e; ++i)
-          if (ConstantInt *CI = dyn_cast<ConstantInt>(V[i]))
-            Elts.push_back(CI->getZExtValue());
-          else
-            break;
-        if (Elts.size() == V.size())
-          return ConstantDataVector::get(C->getContext(), Elts);
-      } else if (CI->getType()->isIntegerTy(64)) {
-        SmallVector<uint64_t, 16> Elts;
-        for (unsigned i = 0, e = V.size(); i != e; ++i)
-          if (ConstantInt *CI = dyn_cast<ConstantInt>(V[i]))
-            Elts.push_back(CI->getZExtValue());
-          else
-            break;
-        if (Elts.size() == V.size())
-          return ConstantDataVector::get(C->getContext(), Elts);
-      }
-    }
-
-    if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) {
-      if (CFP->getType()->isFloatTy()) {
-        SmallVector<uint32_t, 16> Elts;
-        for (unsigned i = 0, e = V.size(); i != e; ++i)
-          if (ConstantFP *CFP = dyn_cast<ConstantFP>(V[i]))
-            Elts.push_back(
-                CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
-          else
-            break;
-        if (Elts.size() == V.size())
-          return ConstantDataVector::getFP(C->getContext(), Elts);
-      } else if (CFP->getType()->isDoubleTy()) {
-        SmallVector<uint64_t, 16> Elts;
-        for (unsigned i = 0, e = V.size(); i != e; ++i)
-          if (ConstantFP *CFP = dyn_cast<ConstantFP>(V[i]))
-            Elts.push_back(
-                CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
-          else
-            break;
-        if (Elts.size() == V.size())
-          return ConstantDataVector::getFP(C->getContext(), Elts);
-      }
-    }
-  }
+  if (ConstantDataSequential::isElementTypeCompatible(C->getType()))
+    return getSequenceIfElementsMatch<ConstantDataVector>(C, V);
 
   // Otherwise, the element type isn't compatible with ConstantDataVector, or
   // the operand list constants a ConstantExpr or something else strange.
@@ -2434,7 +2367,7 @@ StringRef ConstantDataSequential::getRawDataValues() const {
 /// ConstantDataArray only works with normal float and int types that are
 /// stored densely in memory, not with things like i42 or x86_f80.
 bool ConstantDataSequential::isElementTypeCompatible(Type *Ty) {
-  if (Ty->isFloatTy() || Ty->isDoubleTy()) return true;
+  if (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy()) return true;
   if (auto *IT = dyn_cast<IntegerType>(Ty)) {
     switch (IT->getBitWidth()) {
     case 8:
@@ -2706,6 +2639,11 @@ Constant *ConstantDataVector::getSplat(unsigned NumElts, Constant *V) {
   }
 
   if (ConstantFP *CFP = dyn_cast<ConstantFP>(V)) {
+    if (CFP->getType()->isHalfTy()) {
+      SmallVector<uint16_t, 16> Elts(
+          NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
+      return getFP(V->getContext(), Elts);
+    }
     if (CFP->getType()->isFloatTy()) {
       SmallVector<uint32_t, 16> Elts(
           NumElts, CFP->getValueAPF().bitcastToAPInt().getLimitedValue());
@@ -2751,6 +2689,10 @@ APFloat ConstantDataSequential::getElementAsAPFloat(unsigned Elt) const {
   switch (getElementType()->getTypeID()) {
   default:
     llvm_unreachable("Accessor can only be used when element is float/double!");
+  case Type::HalfTyID: {
+    auto EltVal = *reinterpret_cast<const uint16_t *>(EltPtr);
+    return APFloat(APFloat::IEEEhalf, APInt(16, EltVal));
+  }
   case Type::FloatTyID: {
     auto EltVal = *reinterpret_cast<const uint32_t *>(EltPtr);
     return APFloat(APFloat::IEEEsingle, APInt(32, EltVal));
@@ -2785,7 +2727,8 @@ double ConstantDataSequential::getElementAsDouble(unsigned Elt) const {
 /// Note that this has to compute a new constant to return, so it isn't as
 /// efficient as getElementAsInteger/Float/Double.
 Constant *ConstantDataSequential::getElementAsConstant(unsigned Elt) const {
-  if (getElementType()->isFloatTy() || getElementType()->isDoubleTy())
+  if (getElementType()->isHalfTy() || getElementType()->isFloatTy() ||
+      getElementType()->isDoubleTy())
     return ConstantFP::get(getContext(), getElementAsAPFloat(Elt));
 
   return ConstantInt::get(getElementType(), getElementAsInteger(Elt));