Fix DAGCombiner::visitFP_EXTEND to ignore indexed loads
[oota-llvm.git] / lib / CodeGen / SelectionDAG / DAGCombiner.cpp
index 1de11e9b1fa293d3bae3d5b6933947e9fdb195c2..72e001af5f81a784bcdd0a053b5f89d050be42c9 100644 (file)
@@ -35,6 +35,7 @@
 #include "llvm/Target/TargetLowering.h"
 #include "llvm/Target/TargetMachine.h"
 #include "llvm/Target/TargetOptions.h"
+#include "llvm/Target/TargetSubtargetInfo.h"
 #include <algorithm>
 using namespace llvm;
 
@@ -1823,20 +1824,24 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
   // fold (mul x, 0) -> 0
   if (N1IsConst && ConstValue1 == 0)
     return N1;
+  // We require a splat of the entire scalar bit width for non-contiguous
+  // bit patterns.
+  bool IsFullSplat =
+    ConstValue1.getBitWidth() == VT.getScalarType().getSizeInBits();
   // fold (mul x, 1) -> x
-  if (N1IsConst && ConstValue1 == 1)
+  if (N1IsConst && ConstValue1 == 1 && IsFullSplat)
     return N0;
   // fold (mul x, -1) -> 0-x
   if (N1IsConst && ConstValue1.isAllOnesValue())
     return DAG.getNode(ISD::SUB, SDLoc(N), VT,
                        DAG.getConstant(0, VT), N0);
   // fold (mul x, (1 << c)) -> x << c
-  if (N1IsConst && ConstValue1.isPowerOf2())
+  if (N1IsConst && ConstValue1.isPowerOf2() && IsFullSplat)
     return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0,
                        DAG.getConstant(ConstValue1.logBase2(),
                                        getShiftAmountTy(N0.getValueType())));
   // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
-  if (N1IsConst && (-ConstValue1).isPowerOf2()) {
+  if (N1IsConst && (-ConstValue1).isPowerOf2() && IsFullSplat) {
     unsigned Log2Val = (-ConstValue1).logBase2();
     // FIXME: If the input is something that is easily negated (e.g. a
     // single-use add), we should put the negate there.
@@ -3109,7 +3114,7 @@ SDValue DAGCombiner::MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1) {
   SDValue BSwap = DAG.getNode(ISD::BSWAP, SDLoc(N), VT,
                               SDValue(Parts[0],0));
 
-  // Result of the bswap should be rotated by 16. If it's not legal, than
+  // Result of the bswap should be rotated by 16. If it's not legal, then
   // do  (x << 16) | (x >> 16).
   SDValue ShAmt = DAG.getConstant(16, getShiftAmountTy(VT));
   if (TLI.isOperationLegalOrCustom(ISD::ROTL, VT))
@@ -3336,6 +3341,7 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) {
   unsigned OpSizeInBits = VT.getSizeInBits();
   SDValue LHSShiftArg = LHSShift.getOperand(0);
   SDValue LHSShiftAmt = LHSShift.getOperand(1);
+  SDValue RHSShiftArg = RHSShift.getOperand(0);
   SDValue RHSShiftAmt = RHSShift.getOperand(1);
 
   // fold (or (shl x, C1), (srl x, C2)) -> (rotl x, C1)
@@ -3374,29 +3380,9 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) {
   if (LHSMask.getNode() || RHSMask.getNode())
     return 0;
 
-  // fold (or (shl x, y), (srl x, (sub 32, y))) -> (rotl x, y)
-  // fold (or (shl x, y), (srl x, (sub 32, y))) -> (rotr x, (sub 32, y))
-  if (RHSShiftAmt.getOpcode() == ISD::SUB &&
-      LHSShiftAmt == RHSShiftAmt.getOperand(1)) {
-    if (ConstantSDNode *SUBC =
-          dyn_cast<ConstantSDNode>(RHSShiftAmt.getOperand(0))) {
-      if (SUBC->getAPIntValue() == OpSizeInBits)
-        return DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg,
-                           HasROTL ? LHSShiftAmt : RHSShiftAmt).getNode();
-    }
-  }
-
-  // fold (or (shl x, (sub 32, y)), (srl x, r)) -> (rotr x, y)
-  // fold (or (shl x, (sub 32, y)), (srl x, r)) -> (rotl x, (sub 32, y))
-  if (LHSShiftAmt.getOpcode() == ISD::SUB &&
-      RHSShiftAmt == LHSShiftAmt.getOperand(1))
-    if (ConstantSDNode *SUBC =
-          dyn_cast<ConstantSDNode>(LHSShiftAmt.getOperand(0)))
-      if (SUBC->getAPIntValue() == OpSizeInBits)
-        return DAG.getNode(HasROTR ? ISD::ROTR : ISD::ROTL, DL, VT, LHSShiftArg,
-                           HasROTR ? RHSShiftAmt : LHSShiftAmt).getNode();
-
-  // Look for sign/zext/any-extended or truncate cases:
+  // If the shift amount is sign/zext/any-extended just peel it off.
+  SDValue LExtOp0 = LHSShiftAmt;
+  SDValue RExtOp0 = RHSShiftAmt;
   if ((LHSShiftAmt.getOpcode() == ISD::SIGN_EXTEND ||
        LHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
        LHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
@@ -3405,32 +3391,74 @@ SDNode *DAGCombiner::MatchRotate(SDValue LHS, SDValue RHS, SDLoc DL) {
        RHSShiftAmt.getOpcode() == ISD::ZERO_EXTEND ||
        RHSShiftAmt.getOpcode() == ISD::ANY_EXTEND ||
        RHSShiftAmt.getOpcode() == ISD::TRUNCATE)) {
-    SDValue LExtOp0 = LHSShiftAmt.getOperand(0);
-    SDValue RExtOp0 = RHSShiftAmt.getOperand(0);
-    if (RExtOp0.getOpcode() == ISD::SUB &&
-        RExtOp0.getOperand(1) == LExtOp0) {
-      // fold (or (shl x, (*ext y)), (srl x, (*ext (sub 32, y)))) ->
-      //   (rotl x, y)
-      // fold (or (shl x, (*ext y)), (srl x, (*ext (sub 32, y)))) ->
-      //   (rotr x, (sub 32, y))
-      if (ConstantSDNode *SUBC =
-            dyn_cast<ConstantSDNode>(RExtOp0.getOperand(0)))
-        if (SUBC->getAPIntValue() == OpSizeInBits)
-          return DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT,
-                             LHSShiftArg,
-                             HasROTL ? LHSShiftAmt : RHSShiftAmt).getNode();
-    } else if (LExtOp0.getOpcode() == ISD::SUB &&
-               RExtOp0 == LExtOp0.getOperand(1)) {
-      // fold (or (shl x, (*ext (sub 32, y))), (srl x, (*ext y))) ->
-      //   (rotr x, y)
-      // fold (or (shl x, (*ext (sub 32, y))), (srl x, (*ext y))) ->
-      //   (rotl x, (sub 32, y))
-      if (ConstantSDNode *SUBC =
-            dyn_cast<ConstantSDNode>(LExtOp0.getOperand(0)))
-        if (SUBC->getAPIntValue() == OpSizeInBits)
-          return DAG.getNode(HasROTR ? ISD::ROTR : ISD::ROTL, DL, VT,
-                             LHSShiftArg,
-                             HasROTR ? RHSShiftAmt : LHSShiftAmt).getNode();
+    LExtOp0 = LHSShiftAmt.getOperand(0);
+    RExtOp0 = RHSShiftAmt.getOperand(0);
+  }
+
+  if (RExtOp0.getOpcode() == ISD::SUB && RExtOp0.getOperand(1) == LExtOp0) {
+    // fold (or (shl x, (*ext y)), (srl x, (*ext (sub 32, y)))) ->
+    //   (rotl x, y)
+    // fold (or (shl x, (*ext y)), (srl x, (*ext (sub 32, y)))) ->
+    //   (rotr x, (sub 32, y))
+    if (ConstantSDNode *SUBC =
+            dyn_cast<ConstantSDNode>(RExtOp0.getOperand(0))) {
+      if (SUBC->getAPIntValue() == OpSizeInBits) {
+        return DAG.getNode(HasROTL ? ISD::ROTL : ISD::ROTR, DL, VT, LHSShiftArg,
+                           HasROTL ? LHSShiftAmt : RHSShiftAmt).getNode();
+      } else if (LHSShiftArg.getOpcode() == ISD::ZERO_EXTEND ||
+                 LHSShiftArg.getOpcode() == ISD::ANY_EXTEND) {
+        // fold (or (shl (*ext x), (*ext y)),
+        //          (srl (*ext x), (*ext (sub 32, y)))) ->
+        //   (*ext (rotl x, y))
+        // fold (or (shl (*ext x), (*ext y)),
+        //          (srl (*ext x), (*ext (sub 32, y)))) ->
+        //   (*ext (rotr x, (sub 32, y)))
+        SDValue LArgExtOp0 = LHSShiftArg.getOperand(0);
+        EVT LArgVT = LArgExtOp0.getValueType();
+        bool HasROTRWithLArg = TLI.isOperationLegalOrCustom(ISD::ROTR, LArgVT);
+        bool HasROTLWithLArg = TLI.isOperationLegalOrCustom(ISD::ROTL, LArgVT);
+        if (HasROTRWithLArg || HasROTLWithLArg) {
+          if (LArgVT.getSizeInBits() == SUBC->getAPIntValue()) {
+            SDValue V =
+                DAG.getNode(HasROTLWithLArg ? ISD::ROTL : ISD::ROTR, DL, LArgVT,
+                            LArgExtOp0, HasROTL ? LHSShiftAmt : RHSShiftAmt);
+            return DAG.getNode(LHSShiftArg.getOpcode(), DL, VT, V).getNode();
+          }     
+        }     
+      }
+    }
+  } else if (LExtOp0.getOpcode() == ISD::SUB &&
+             RExtOp0 == LExtOp0.getOperand(1)) {
+    // fold (or (shl x, (*ext (sub 32, y))), (srl x, (*ext y))) ->
+    //   (rotr x, y)
+    // fold (or (shl x, (*ext (sub 32, y))), (srl x, (*ext y))) ->
+    //   (rotl x, (sub 32, y))
+    if (ConstantSDNode *SUBC =
+            dyn_cast<ConstantSDNode>(LExtOp0.getOperand(0))) {
+      if (SUBC->getAPIntValue() == OpSizeInBits) {
+        return DAG.getNode(HasROTR ? ISD::ROTR : ISD::ROTL, DL, VT, LHSShiftArg,
+                           HasROTR ? RHSShiftAmt : LHSShiftAmt).getNode();
+      } else if (RHSShiftArg.getOpcode() == ISD::ZERO_EXTEND ||
+                 RHSShiftArg.getOpcode() == ISD::ANY_EXTEND) {
+        // fold (or (shl (*ext x), (*ext (sub 32, y))),
+        //          (srl (*ext x), (*ext y))) ->
+        //   (*ext (rotl x, y))
+        // fold (or (shl (*ext x), (*ext (sub 32, y))),
+        //          (srl (*ext x), (*ext y))) ->
+        //   (*ext (rotr x, (sub 32, y)))
+        SDValue RArgExtOp0 = RHSShiftArg.getOperand(0);
+        EVT RArgVT = RArgExtOp0.getValueType();
+        bool HasROTRWithRArg = TLI.isOperationLegalOrCustom(ISD::ROTR, RArgVT);
+        bool HasROTLWithRArg = TLI.isOperationLegalOrCustom(ISD::ROTL, RArgVT);
+        if (HasROTRWithRArg || HasROTLWithRArg) {
+          if (RArgVT.getSizeInBits() == SUBC->getAPIntValue()) {
+            SDValue V =
+                DAG.getNode(HasROTRWithRArg ? ISD::ROTR : ISD::ROTL, DL, RArgVT,
+                            RArgExtOp0, HasROTR ? RHSShiftAmt : LHSShiftAmt);
+            return DAG.getNode(RHSShiftArg.getOpcode(), DL, VT, V).getNode();
+          }
+        }
+      }
     }
   }
 
@@ -3728,6 +3756,26 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
     }
   }
 
+  // fold (shl (zext (srl x, C)), C) -> (zext (shl (srl x, C), C))
+  // Only fold this if the inner zext has no other uses to avoid increasing
+  // the total number of instructions.
+  if (N1C && N0.getOpcode() == ISD::ZERO_EXTEND && N0.hasOneUse() &&
+      N0.getOperand(0).getOpcode() == ISD::SRL &&
+      isa<ConstantSDNode>(N0.getOperand(0)->getOperand(1))) {
+    uint64_t c1 =
+      cast<ConstantSDNode>(N0.getOperand(0)->getOperand(1))->getZExtValue();
+    if (c1 < VT.getSizeInBits()) {
+      uint64_t c2 = N1C->getZExtValue();
+      if (c1 == c2) {
+        SDValue NewOp0 = N0.getOperand(0);
+        EVT CountVT = NewOp0.getOperand(1).getValueType();
+        SDValue NewSHL = DAG.getNode(ISD::SHL, SDLoc(N), NewOp0.getValueType(),
+                                     NewOp0, DAG.getConstant(c2, CountVT));
+        return DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N0), VT, NewSHL);
+      }
+    }
+  }
+
   // fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
   //                               (and (srl x, (sub c1, c2), MASK)
   // Only fold this if the inner shift has no other uses -- if it does, folding
@@ -6683,7 +6731,7 @@ SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
   }
 
   // fold (fpext (load x)) -> (fpext (fptrunc (extload x)))
-  if (ISD::isNON_EXTLoad(N0.getNode()) && N0.hasOneUse() &&
+  if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
       ((!LegalOperations && !cast<LoadSDNode>(N0)->isVolatile()) ||
        TLI.isLoadExtLegal(ISD::EXTLOAD, N0.getValueType()))) {
     LoadSDNode *LN0 = cast<LoadSDNode>(N0);
@@ -7488,7 +7536,9 @@ SDValue DAGCombiner::visitLOAD(SDNode *N) {
     }
   }
 
-  if (CombinerAA) {
+  bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA :
+    TLI.getTargetMachine().getSubtarget<TargetSubtargetInfo>().useAA();
+  if (UseAA) {
     // Walk up chain skipping non-aliasing memory nodes.
     SDValue BetterChain = FindBetterChain(N, Chain);
 
@@ -8519,7 +8569,9 @@ SDValue DAGCombiner::visitSTORE(SDNode *N) {
   if (NewST.getNode())
     return NewST;
 
-  if (CombinerAA) {
+  bool UseAA = CombinerAA.getNumOccurrences() > 0 ? CombinerAA :
+    TLI.getTargetMachine().getSubtarget<TargetSubtargetInfo>().useAA();
+  if (UseAA) {
     // Walk up chain skipping non-aliasing memory nodes.
     SDValue BetterChain = FindBetterChain(N, Chain);
 
@@ -10229,7 +10281,9 @@ bool DAGCombiner::isAlias(SDValue Ptr1, int64_t Size1,
       return false;
   }
 
-  if (CombinerGlobalAA) {
+  bool UseAA = CombinerGlobalAA.getNumOccurrences() > 0 ? CombinerGlobalAA :
+    TLI.getTargetMachine().getSubtarget<TargetSubtargetInfo>().useAA();
+  if (UseAA && SrcValue1 && SrcValue2) {
     // Use alias analysis information.
     int64_t MinOffset = std::min(SrcValueOffset1, SrcValueOffset2);
     int64_t Overlap1 = Size1 + SrcValueOffset1 - MinOffset;