X-Git-Url: http://plrg.eecs.uci.edu/git/?p=oota-llvm.git;a=blobdiff_plain;f=lib%2FTarget%2FAArch64%2FAArch64ISelLowering.cpp;h=92cf1cd71970bd76acd92f4947afceb1ce967cb5;hp=f9af05e84d24a43ca9b4a513c0339b313d9fb0a9;hb=131d76722983cb030c392bcb50bba940e98ea0c6;hpb=15c5be1ee58a67965bee79832441f1136a7698dc diff --git a/lib/Target/AArch64/AArch64ISelLowering.cpp b/lib/Target/AArch64/AArch64ISelLowering.cpp index f9af05e84d2..92cf1cd7197 100644 --- a/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -2338,11 +2338,6 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op, } } -/// getFunctionAlignment - Return the Log2 alignment of this function. -unsigned AArch64TargetLowering::getFunctionAlignment(const Function *F) const { - return 2; -} - //===----------------------------------------------------------------------===// // Calling Convention Implementation //===----------------------------------------------------------------------===// @@ -2431,7 +2426,7 @@ SDValue AArch64TargetLowering::LowerFormalArguments( continue; } - + if (VA.isRegLoc()) { // Arguments stored in registers. EVT RegVT = VA.getLocVT(); @@ -3276,6 +3271,19 @@ AArch64TargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv, Flag = Chain.getValue(1); RetOps.push_back(DAG.getRegister(VA.getLocReg(), VA.getLocVT())); } + const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo(); + const MCPhysReg *I = + TRI->getCalleeSavedRegsViaCopy(&DAG.getMachineFunction()); + if (I) { + for (; *I; ++I) { + if (AArch64::GPR64RegClass.contains(*I)) + RetOps.push_back(DAG.getRegister(*I, MVT::i64)); + else if (AArch64::FPR64RegClass.contains(*I)) + RetOps.push_back(DAG.getRegister(*I, MVT::getFloatingPointVT(64))); + else + llvm_unreachable("Unexpected register class in CSRsViaCopy!"); + } + } RetOps[0] = Chain; // Update chain. @@ -5066,7 +5074,7 @@ static bool isEXTMask(ArrayRef M, EVT VT, bool &ReverseEXT, // The index of an EXT is the first element if it is not UNDEF. // Watch out for the beginning UNDEFs. The EXT index should be the expected - // value of the first element. E.g. + // value of the first element. E.g. // <-1, -1, 3, ...> is treated as <1, 2, 3, ...>. // <-1, -1, 0, 1, ...> is treated as <2*NumElts-2, 2*NumElts-1, 0, 1, ...>. // ExpectedElt is the last mask index plus 1. @@ -6681,6 +6689,9 @@ SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op, return DAG.getSExtOrTrunc(Cmp, dl, Op.getValueType()); } + if (LHS.getValueType().getVectorElementType() == MVT::f16) + return SDValue(); + assert(LHS.getValueType().getVectorElementType() == MVT::f32 || LHS.getValueType().getVectorElementType() == MVT::f64); @@ -9483,6 +9494,103 @@ static SDValue performBRCONDCombine(SDNode *N, return SDValue(); } +// Optimize some simple tbz/tbnz cases. Returns the new operand and bit to test +// as well as whether the test should be inverted. This code is required to +// catch these cases (as opposed to standard dag combines) because +// AArch64ISD::TBZ is matched during legalization. +static SDValue getTestBitOperand(SDValue Op, unsigned &Bit, bool &Invert, + SelectionDAG &DAG) { + + if (!Op->hasOneUse()) + return Op; + + // We don't handle undef/constant-fold cases below, as they should have + // already been taken care of (e.g. and of 0, test of undefined shifted bits, + // etc.) + + // (tbz (trunc x), b) -> (tbz x, b) + // This case is just here to enable more of the below cases to be caught. + if (Op->getOpcode() == ISD::TRUNCATE && + Bit < Op->getValueType(0).getSizeInBits()) { + return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG); + } + + if (Op->getNumOperands() != 2) + return Op; + + auto *C = dyn_cast(Op->getOperand(1)); + if (!C) + return Op; + + switch (Op->getOpcode()) { + default: + return Op; + + // (tbz (and x, m), b) -> (tbz x, b) + case ISD::AND: + if ((C->getZExtValue() >> Bit) & 1) + return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG); + return Op; + + // (tbz (shl x, c), b) -> (tbz x, b-c) + case ISD::SHL: + if (C->getZExtValue() <= Bit && + (Bit - C->getZExtValue()) < Op->getValueType(0).getSizeInBits()) { + Bit = Bit - C->getZExtValue(); + return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG); + } + return Op; + + // (tbz (sra x, c), b) -> (tbz x, b+c) or (tbz x, msb) if b+c is > # bits in x + case ISD::SRA: + Bit = Bit + C->getZExtValue(); + if (Bit >= Op->getValueType(0).getSizeInBits()) + Bit = Op->getValueType(0).getSizeInBits() - 1; + return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG); + + // (tbz (srl x, c), b) -> (tbz x, b+c) + case ISD::SRL: + if ((Bit + C->getZExtValue()) < Op->getValueType(0).getSizeInBits()) { + Bit = Bit + C->getZExtValue(); + return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG); + } + return Op; + + // (tbz (xor x, -1), b) -> (tbnz x, b) + case ISD::XOR: + if ((C->getZExtValue() >> Bit) & 1) + Invert = !Invert; + return getTestBitOperand(Op->getOperand(0), Bit, Invert, DAG); + } +} + +// Optimize test single bit zero/non-zero and branch. +static SDValue performTBZCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + unsigned Bit = cast(N->getOperand(2))->getZExtValue(); + bool Invert = false; + SDValue TestSrc = N->getOperand(1); + SDValue NewTestSrc = getTestBitOperand(TestSrc, Bit, Invert, DAG); + + if (TestSrc == NewTestSrc) + return SDValue(); + + unsigned NewOpc = N->getOpcode(); + if (Invert) { + if (NewOpc == AArch64ISD::TBZ) + NewOpc = AArch64ISD::TBNZ; + else { + assert(NewOpc == AArch64ISD::TBNZ); + NewOpc = AArch64ISD::TBZ; + } + } + + SDLoc DL(N); + return DAG.getNode(NewOpc, DL, MVT::Other, N->getOperand(0), NewTestSrc, + DAG.getConstant(Bit, DL, MVT::i64), N->getOperand(3)); +} + // vselect (v1i1 setcc) -> // vselect (v1iXX setcc) (XX is the size of the compared operand type) // FIXME: Currently the type legalizer can't handle VSELECT having v1i1 as @@ -9634,6 +9742,9 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, return performSTORECombine(N, DCI, DAG, Subtarget); case AArch64ISD::BRCOND: return performBRCONDCombine(N, DCI, DAG); + case AArch64ISD::TBNZ: + case AArch64ISD::TBZ: + return performTBZCombine(N, DCI, DAG); case AArch64ISD::CSEL: return performCONDCombine(N, DCI, DAG, 2, 3); case AArch64ISD::DUP: @@ -10008,3 +10119,50 @@ Value *AArch64TargetLowering::getSafeStackPointerLocation(IRBuilder<> &IRB) cons IRB.CreateConstGEP1_32(IRB.CreateCall(ThreadPointerFunc), TlsOffset), Type::getInt8PtrTy(IRB.getContext())->getPointerTo(0)); } + +void AArch64TargetLowering::initializeSplitCSR(MachineBasicBlock *Entry) const { + // Update IsSplitCSR in AArch64unctionInfo. + AArch64FunctionInfo *AFI = Entry->getParent()->getInfo(); + AFI->setIsSplitCSR(true); +} + +void AArch64TargetLowering::insertCopiesSplitCSR( + MachineBasicBlock *Entry, + const SmallVectorImpl &Exits) const { + const AArch64RegisterInfo *TRI = Subtarget->getRegisterInfo(); + const MCPhysReg *IStart = TRI->getCalleeSavedRegsViaCopy(Entry->getParent()); + if (!IStart) + return; + + const TargetInstrInfo *TII = Subtarget->getInstrInfo(); + MachineRegisterInfo *MRI = &Entry->getParent()->getRegInfo(); + MachineBasicBlock::iterator MBBI = Entry->begin(); + for (const MCPhysReg *I = IStart; *I; ++I) { + const TargetRegisterClass *RC = nullptr; + if (AArch64::GPR64RegClass.contains(*I)) + RC = &AArch64::GPR64RegClass; + else if (AArch64::FPR64RegClass.contains(*I)) + RC = &AArch64::FPR64RegClass; + else + llvm_unreachable("Unexpected register class in CSRsViaCopy!"); + + unsigned NewVR = MRI->createVirtualRegister(RC); + // Create copy from CSR to a virtual register. + // FIXME: this currently does not emit CFI pseudo-instructions, it works + // fine for CXX_FAST_TLS since the C++-style TLS access functions should be + // nounwind. If we want to generalize this later, we may need to emit + // CFI pseudo-instructions. + assert(Entry->getParent()->getFunction()->hasFnAttribute( + Attribute::NoUnwind) && + "Function should be nounwind in insertCopiesSplitCSR!"); + Entry->addLiveIn(*I); + BuildMI(*Entry, MBBI, DebugLoc(), TII->get(TargetOpcode::COPY), NewVR) + .addReg(*I); + + // Insert the copy-back instructions right before the terminator. + for (auto *Exit : Exits) + BuildMI(*Exit, Exit->getFirstTerminator(), DebugLoc(), + TII->get(TargetOpcode::COPY), *I) + .addReg(NewVR); + } +}