[InstCombine] fold bitcasts around an extractelement (3rd try)
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineCasts.cpp
index 2ce86436411b47fdc9662a9e4ab61a835eea83f8..dcd86db036b45080b965ed90e3df75677ff52b71 100644 (file)
@@ -1715,9 +1715,9 @@ static Value *optimizeIntegerToVectorInsertions(BitCastInst &CI,
   return Result;
 }
 
-/// Given a bitcasted vector fed into an extract element instruction and then
-/// bitcasted again to a scalar type, eliminate at least one bitcast by changing
-/// the vector type of the extractelement instruction.
+/// Given a bitcasted source operand fed into an extract element instruction and
+/// then bitcasted again to a scalar type, eliminate at least one bitcast by
+/// changing the vector type of the extractelement instruction.
 /// Example:
 ///   bitcast (extractelement (bitcast <2 x float> %X to <2 x i32>), 1) to float
 ///    --->
@@ -1737,15 +1737,15 @@ static Instruction *foldBitCastExtElt(BitCastInst &BitCast, InstCombiner &IC,
   if (!match(ExtElt->getOperand(0), m_BitCast(m_Value(InnerBitCast))))
     return nullptr;
 
-  // If the element type of the vector doesn't match the result type,
-  // bitcast it to a vector type that we can extract from.
-  VectorType *VecType = cast<VectorType>(InnerBitCast->getType());
-  if (VecType->getElementType() != DestType) {
-    unsigned VecWidth = VecType->getPrimitiveSizeInBits();
+  // If the source is not a vector or its element type doesn't match the result
+  // type, bitcast it to a vector type that we can extract from.
+  Type *SourceType = InnerBitCast->getType();
+  if (SourceType->getScalarType() != DestType) {
+    unsigned VecWidth = SourceType->getPrimitiveSizeInBits();
     unsigned DestWidth = DestType->getPrimitiveSizeInBits();
     unsigned NumElts = VecWidth / DestWidth;
-    VecType = VectorType::get(DestType, NumElts);
-    InnerBitCast = IC.Builder->CreateBitCast(InnerBitCast, VecType, "bc");
+    SourceType = VectorType::get(DestType, NumElts);
+    InnerBitCast = IC.Builder->CreateBitCast(InnerBitCast, SourceType, "bc");
   }
 
   return ExtractElementInst::Create(InnerBitCast, ExtElt->getOperand(1));