AArch64: fix 128-bit shifts
[oota-llvm.git] / lib / Target / AArch64 / AArch64ISelLowering.cpp
index 2232e419a6190739cf88500f038f8685ad743d6d..f0fb03451b2a6f72dbdd7ff7a65f42c41f02091c 100644 (file)
@@ -4380,46 +4380,57 @@ SDValue AArch64TargetLowering::LowerShiftRightParts(SDValue Op,
   SDValue ShOpLo = Op.getOperand(0);
   SDValue ShOpHi = Op.getOperand(1);
   SDValue ShAmt = Op.getOperand(2);
-  SDValue ARMcc;
   unsigned Opc = (Op.getOpcode() == ISD::SRA_PARTS) ? ISD::SRA : ISD::SRL;
 
   assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS);
 
   SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i64,
                                  DAG.getConstant(VTBits, dl, MVT::i64), ShAmt);
-  SDValue Tmp1 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, ShAmt);
+  SDValue HiBitsForLo = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, RevShAmt);
+
+  // Unfortunately, if ShAmt == 0, we just calculated "(SHL ShOpHi, 64)" which
+  // is "undef". We wanted 0, so CSEL it directly.
+  SDValue Cmp = emitComparison(ShAmt, DAG.getConstant(0, dl, MVT::i64),
+                               ISD::SETEQ, dl, DAG);
+  SDValue CCVal = DAG.getConstant(AArch64CC::EQ, dl, MVT::i32);
+  HiBitsForLo =
+      DAG.getNode(AArch64ISD::CSEL, dl, VT, DAG.getConstant(0, dl, MVT::i64),
+                  HiBitsForLo, CCVal, Cmp);
+
   SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i64, ShAmt,
                                    DAG.getConstant(VTBits, dl, MVT::i64));
-  SDValue Tmp2 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, RevShAmt);
 
-  SDValue Cmp = emitComparison(ExtraShAmt, DAG.getConstant(0, dl, MVT::i64),
-                               ISD::SETGE, dl, DAG);
-  SDValue CCVal = DAG.getConstant(AArch64CC::GE, dl, MVT::i32);
+  SDValue LoBitsForLo = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, ShAmt);
+  SDValue LoForNormalShift =
+      DAG.getNode(ISD::OR, dl, VT, LoBitsForLo, HiBitsForLo);
 
-  SDValue FalseValLo = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
-  SDValue TrueValLo = DAG.getNode(Opc, dl, VT, ShOpHi, ExtraShAmt);
-  SDValue Lo =
-      DAG.getNode(AArch64ISD::CSEL, dl, VT, TrueValLo, FalseValLo, CCVal, Cmp);
+  Cmp = emitComparison(ExtraShAmt, DAG.getConstant(0, dl, MVT::i64), ISD::SETGE,
+                       dl, DAG);
+  CCVal = DAG.getConstant(AArch64CC::GE, dl, MVT::i32);
+  SDValue LoForBigShift = DAG.getNode(Opc, dl, VT, ShOpHi, ExtraShAmt);
+  SDValue Lo = DAG.getNode(AArch64ISD::CSEL, dl, VT, LoForBigShift,
+                           LoForNormalShift, CCVal, Cmp);
 
   // AArch64 shifts larger than the register width are wrapped rather than
   // clamped, so we can't just emit "hi >> x".
-  SDValue FalseValHi = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
-  SDValue TrueValHi = Opc == ISD::SRA
-                          ? DAG.getNode(Opc, dl, VT, ShOpHi,
-                                        DAG.getConstant(VTBits - 1, dl,
-                                                        MVT::i64))
-                          : DAG.getConstant(0, dl, VT);
-  SDValue Hi =
-      DAG.getNode(AArch64ISD::CSEL, dl, VT, TrueValHi, FalseValHi, CCVal, Cmp);
+  SDValue HiForNormalShift = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt);
+  SDValue HiForBigShift =
+      Opc == ISD::SRA
+          ? DAG.getNode(Opc, dl, VT, ShOpHi,
+                        DAG.getConstant(VTBits - 1, dl, MVT::i64))
+          : DAG.getConstant(0, dl, VT);
+  SDValue Hi = DAG.getNode(AArch64ISD::CSEL, dl, VT, HiForBigShift,
+                           HiForNormalShift, CCVal, Cmp);
 
   SDValue Ops[2] = { Lo, Hi };
   return DAG.getMergeValues(Ops, dl);
 }
 
+
 /// LowerShiftLeftParts - Lower SHL_PARTS, which returns two
 /// i64 values and take a 2 x i64 value to shift plus a shift amount.
 SDValue AArch64TargetLowering::LowerShiftLeftParts(SDValue Op,
-                                                 SelectionDAG &DAG) const {
+                                                   SelectionDAG &DAG) const {
   assert(Op.getNumOperands() == 3 && "Not a double-shift!");
   EVT VT = Op.getValueType();
   unsigned VTBits = VT.getSizeInBits();
@@ -4427,31 +4438,41 @@ SDValue AArch64TargetLowering::LowerShiftLeftParts(SDValue Op,
   SDValue ShOpLo = Op.getOperand(0);
   SDValue ShOpHi = Op.getOperand(1);
   SDValue ShAmt = Op.getOperand(2);
-  SDValue ARMcc;
 
   assert(Op.getOpcode() == ISD::SHL_PARTS);
   SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i64,
                                  DAG.getConstant(VTBits, dl, MVT::i64), ShAmt);
-  SDValue Tmp1 = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, RevShAmt);
+  SDValue LoBitsForHi = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, RevShAmt);
+
+  // Unfortunately, if ShAmt == 0, we just calculated "(SRL ShOpLo, 64)" which
+  // is "undef". We wanted 0, so CSEL it directly.
+  SDValue Cmp = emitComparison(ShAmt, DAG.getConstant(0, dl, MVT::i64),
+                               ISD::SETEQ, dl, DAG);
+  SDValue CCVal = DAG.getConstant(AArch64CC::EQ, dl, MVT::i32);
+  LoBitsForHi =
+      DAG.getNode(AArch64ISD::CSEL, dl, VT, DAG.getConstant(0, dl, MVT::i64),
+                  LoBitsForHi, CCVal, Cmp);
+
   SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i64, ShAmt,
                                    DAG.getConstant(VTBits, dl, MVT::i64));
-  SDValue Tmp2 = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, ShAmt);
-  SDValue Tmp3 = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ExtraShAmt);
+  SDValue HiBitsForHi = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, ShAmt);
+  SDValue HiForNormalShift =
+      DAG.getNode(ISD::OR, dl, VT, LoBitsForHi, HiBitsForHi);
 
-  SDValue FalseVal = DAG.getNode(ISD::OR, dl, VT, Tmp1, Tmp2);
+  SDValue HiForBigShift = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ExtraShAmt);
 
-  SDValue Cmp = emitComparison(ExtraShAmt, DAG.getConstant(0, dl, MVT::i64),
-                               ISD::SETGE, dl, DAG);
-  SDValue CCVal = DAG.getConstant(AArch64CC::GE, dl, MVT::i32);
-  SDValue Hi =
-      DAG.getNode(AArch64ISD::CSEL, dl, VT, Tmp3, FalseVal, CCVal, Cmp);
+  Cmp = emitComparison(ExtraShAmt, DAG.getConstant(0, dl, MVT::i64), ISD::SETGE,
+                       dl, DAG);
+  CCVal = DAG.getConstant(AArch64CC::GE, dl, MVT::i32);
+  SDValue Hi = DAG.getNode(AArch64ISD::CSEL, dl, VT, HiForBigShift,
+                           HiForNormalShift, CCVal, Cmp);
 
   // AArch64 shifts of larger than register sizes are wrapped rather than
   // clamped, so we can't just emit "lo << a" if a is too big.
-  SDValue TrueValLo = DAG.getConstant(0, dl, VT);
-  SDValue FalseValLo = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
-  SDValue Lo =
-      DAG.getNode(AArch64ISD::CSEL, dl, VT, TrueValLo, FalseValLo, CCVal, Cmp);
+  SDValue LoForBigShift = DAG.getConstant(0, dl, VT);
+  SDValue LoForNormalShift = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt);
+  SDValue Lo = DAG.getNode(AArch64ISD::CSEL, dl, VT, LoForBigShift,
+                           LoForNormalShift, CCVal, Cmp);
 
   SDValue Ops[2] = { Lo, Hi };
   return DAG.getMergeValues(Ops, dl);