[InstCombine] fold bitcasts around an extractelement (3rd try)
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineCasts.cpp
index 23bf40124b57ff5d936b6bba88a40929199f3537..dcd86db036b45080b965ed90e3df75677ff52b71 100644 (file)
@@ -1715,15 +1715,19 @@ static Value *optimizeIntegerToVectorInsertions(BitCastInst &CI,
   return Result;
 }
 
-/// Given a bitcasted vector fed into an extract element instruction and then
-/// bitcasted again, eliminate at least one bitcast by changing the vector type
-/// of the extractelement instruction.
+/// Given a bitcasted source operand fed into an extract element instruction and
+/// then bitcasted again to a scalar type, eliminate at least one bitcast by
+/// changing the vector type of the extractelement instruction.
 /// Example:
 ///   bitcast (extractelement (bitcast <2 x float> %X to <2 x i32>), 1) to float
 ///    --->
 ///   extractelement <2 x float> %X, i32 1
 static Instruction *foldBitCastExtElt(BitCastInst &BitCast, InstCombiner &IC,
                                       const DataLayout &DL) {
+  Type *DestType = BitCast.getType();
+  if (DestType->isVectorTy())
+    return nullptr;
+
   // TODO: Create and use a pattern matcher for ExtractElementInst.
   auto *ExtElt = dyn_cast<ExtractElementInst>(BitCast.getOperand(0));
   if (!ExtElt || !ExtElt->hasOneUse())
@@ -1733,17 +1737,15 @@ static Instruction *foldBitCastExtElt(BitCastInst &BitCast, InstCombiner &IC,
   if (!match(ExtElt->getOperand(0), m_BitCast(m_Value(InnerBitCast))))
     return nullptr;
 
-  VectorType *VecType = cast<VectorType>(InnerBitCast->getType());
-  Type *DestType = BitCast.getType();
-
-  // If the element type of the vector doesn't match the result type,
-  // bitcast it to a vector type that we can extract from.
-  if (VecType->getElementType() != DestType) {
-    unsigned VecWidth = VecType->getPrimitiveSizeInBits();
+  // If the source is not a vector or its element type doesn't match the result
+  // type, bitcast it to a vector type that we can extract from.
+  Type *SourceType = InnerBitCast->getType();
+  if (SourceType->getScalarType() != DestType) {
+    unsigned VecWidth = SourceType->getPrimitiveSizeInBits();
     unsigned DestWidth = DestType->getPrimitiveSizeInBits();
     unsigned NumElts = VecWidth / DestWidth;
-    VecType = VectorType::get(DestType, NumElts);
-    InnerBitCast = IC.Builder->CreateBitCast(InnerBitCast, VecType, "bc");
+    SourceType = VectorType::get(DestType, NumElts);
+    InnerBitCast = IC.Builder->CreateBitCast(InnerBitCast, SourceType, "bc");
   }
 
   return ExtractElementInst::Create(InnerBitCast, ExtElt->getOperand(1));