InstCombine: Fold more shuffles of shuffles.
[oota-llvm.git] / lib / Transforms / InstCombine / InstCombineVectorOps.cpp
index de8a3acdbd855150dc49e727295d12a7f97c5bf0..56243059a618a57d904433933bd09a3b771845d0 100644 (file)
@@ -614,11 +614,16 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
   // we are absolutely afraid of producing a shuffle mask not in the input
   // program, because the code gen may not be smart enough to turn a merged
   // shuffle into two specific shuffles: it may produce worse code.  As such,
-  // we only merge two shuffles if the result is either a splat or one of the
-  // input shuffle masks.  In this case, merging the shuffles just removes
-  // one instruction, which we know is safe.  This is good for things like
+  // we only merge two shuffles if the result is a splat, one of the input
+  // input shuffle masks, or if there's only one input to the shuffle.
+  // In this case, merging the shuffles just removes one instruction, which
+  // we know is safe.  This is good for things like
   // turning: (splat(splat)) -> splat, or
   // merge(V[0..n], V[n+1..2n]) -> V[0..2n]
+  //
+  // FIXME: This is almost certainly far, far too conservative. We should
+  // have a better model. Perhaps a TargetTransformInfo hook to ask whether
+  // a shuffle is considered OK?
   ShuffleVectorInst* LHSShuffle = dyn_cast<ShuffleVectorInst>(LHS);
   ShuffleVectorInst* RHSShuffle = dyn_cast<ShuffleVectorInst>(RHS);
   if (LHSShuffle)
@@ -743,8 +748,10 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
   }
 
   // If the result mask is equal to one of the original shuffle masks,
-  // or is a splat, do the replacement.
-  if (isSplat || newMask == LHSMask || newMask == RHSMask || newMask == Mask) {
+  // or is a splat, do the replacement. Similarly, if there is only one
+  // input vector, go ahead and do the folding.
+  if (isSplat || newMask == LHSMask || newMask == RHSMask || newMask == Mask ||
+      isa<UndefValue>(RHS)) {
     SmallVector<Constant*, 16> Elts;
     Type *Int32Ty = Type::getInt32Ty(SVI.getContext());
     for (unsigned i = 0, e = newMask.size(); i != e; ++i) {