Teach getCastOpcode about element-by-element vector casts. For example, "trunc"
[oota-llvm.git] / lib / VMCore / Instructions.cpp
index 61da9b6b8e0c2b9f6d2698ae1f1f73deaeb7b4e7..af1492d356a0c0e5187ef7a6c0aa02a8464aae01 100644 (file)
@@ -2254,6 +2254,14 @@ bool CastInst::isCastable(const Type *SrcTy, const Type *DestTy) {
   if (SrcTy == DestTy)
     return true;
 
+  if (const VectorType *SrcVecTy = dyn_cast<VectorType>(SrcTy))
+    if (const VectorType *DestVecTy = dyn_cast<VectorType>(DestTy))
+      if (SrcVecTy->getNumElements() == DestVecTy->getNumElements()) {
+        // An element by element cast.  Valid if casting the elements is valid.
+        SrcTy = SrcVecTy->getElementType();
+        DestTy = DestVecTy->getElementType();
+      }
+
   // Get the bit sizes, we'll need these
   unsigned SrcBits = SrcTy->getScalarSizeInBits();   // 0 for ptr
   unsigned DestBits = DestTy->getScalarSizeInBits(); // 0 for ptr
@@ -2322,14 +2330,27 @@ bool CastInst::isCastable(const Type *SrcTy, const Type *DestTy) {
 Instruction::CastOps
 CastInst::getCastOpcode(
   const Value *Src, bool SrcIsSigned, const Type *DestTy, bool DestIsSigned) {
-  // Get the bit sizes, we'll need these
   const Type *SrcTy = Src->getType();
-  unsigned SrcBits = SrcTy->getScalarSizeInBits();   // 0 for ptr
-  unsigned DestBits = DestTy->getScalarSizeInBits(); // 0 for ptr
 
   assert(SrcTy->isFirstClassType() && DestTy->isFirstClassType() &&
          "Only first class types are castable!");
 
+  if (SrcTy == DestTy)
+    return BitCast;
+
+  if (const VectorType *SrcVecTy = dyn_cast<VectorType>(SrcTy))
+    if (const VectorType *DestVecTy = dyn_cast<VectorType>(DestTy))
+      if (SrcVecTy->getNumElements() == DestVecTy->getNumElements()) {
+        // An element by element cast.  Find the appropriate opcode based on the
+        // element types.
+        SrcTy = SrcVecTy->getElementType();
+        DestTy = DestVecTy->getElementType();
+      }
+
+  // Get the bit sizes, we'll need these
+  unsigned SrcBits = SrcTy->getScalarSizeInBits();   // 0 for ptr
+  unsigned DestBits = DestTy->getScalarSizeInBits(); // 0 for ptr
+
   // Run through the possibilities ...
   if (DestTy->isIntegerTy()) {                      // Casting to integral
     if (SrcTy->isIntegerTy()) {                     // Casting from integral
@@ -2384,7 +2405,7 @@ CastInst::getCastOpcode(
     if (const VectorType *SrcPTy = dyn_cast<VectorType>(SrcTy)) {
       assert(DestPTy->getBitWidth() == SrcPTy->getBitWidth() &&
              "Casting vector to vector of different widths");
-      SrcPTy = NULL;
+      (void)SrcPTy;
       return BitCast;                             // vector -> vector
     } else if (DestPTy->getBitWidth() == SrcBits) {
       return BitCast;                               // float/int -> vector