Implement partial-word binary atomics on ppc.
[oota-llvm.git] / lib / Target / PowerPC / PPCISelLowering.cpp
index fc8399892fc581c38227cd34dce40579fcbb222b..9472cee028c9eed1ea57f8ed81f90038292ae004 100644 (file)
@@ -3914,6 +3914,126 @@ PPCTargetLowering::EmitAtomicBinary(MachineInstr *MI, MachineBasicBlock *BB,
   return BB;
 }
 
+MachineBasicBlock *
+PPCTargetLowering::EmitPartwordAtomicBinary(MachineInstr *MI, 
+                                            MachineBasicBlock *BB,
+                                            bool is8bit,    // operation
+                                            unsigned BinOpcode) {
+  const TargetInstrInfo *TII = getTargetMachine().getInstrInfo();
+  // In 64 bit mode we have to use 64 bits for addresses, even though the
+  // lwarx/stwcx are 32 bits.  With the 32-bit atomics we can use address
+  // registers without caring whether they're 32 or 64, but here we're
+  // doing actual arithmetic on the addresses.
+  bool is64bit = PPCSubTarget.isPPC64();
+
+  const BasicBlock *LLVM_BB = BB->getBasicBlock();
+  MachineFunction *F = BB->getParent();
+  MachineFunction::iterator It = BB;
+  ++It;
+
+  unsigned dest = MI->getOperand(0).getReg();
+  unsigned ptrA = MI->getOperand(1).getReg();
+  unsigned ptrB = MI->getOperand(2).getReg();
+  unsigned incr = MI->getOperand(3).getReg();
+
+  MachineBasicBlock *loopMBB = F->CreateMachineBasicBlock(LLVM_BB);
+  MachineBasicBlock *exitMBB = F->CreateMachineBasicBlock(LLVM_BB);
+  F->insert(It, loopMBB);
+  F->insert(It, exitMBB);
+  exitMBB->transferSuccessors(BB);
+
+  MachineRegisterInfo &RegInfo = F->getRegInfo();
+  const TargetRegisterClass *RC = 
+    is64bit ? (const TargetRegisterClass *) &PPC::GPRCRegClass :
+              (const TargetRegisterClass *) &PPC::G8RCRegClass;
+  unsigned TmpReg = RegInfo.createVirtualRegister(RC);
+  unsigned PtrReg = RegInfo.createVirtualRegister(RC);
+  unsigned Shift1Reg = RegInfo.createVirtualRegister(RC);
+  unsigned ShiftReg = RegInfo.createVirtualRegister(RC);
+  unsigned Incr2Reg = RegInfo.createVirtualRegister(RC);
+  unsigned MaskReg = RegInfo.createVirtualRegister(RC);
+  unsigned Mask2Reg = RegInfo.createVirtualRegister(RC);
+  unsigned Mask3Reg = RegInfo.createVirtualRegister(RC);
+  unsigned Tmp2Reg = RegInfo.createVirtualRegister(RC);
+  unsigned Tmp3Reg = RegInfo.createVirtualRegister(RC);
+  unsigned Tmp4Reg = RegInfo.createVirtualRegister(RC);
+  unsigned Ptr1Reg;
+
+  //  thisMBB:
+  //   ...
+  //   fallthrough --> loopMBB
+  BB->addSuccessor(loopMBB);
+
+  // The 4-byte load must be aligned, while a char or short may be
+  // anywhere in the word.  Hence all this nasty bookkeeping code.
+  //   add ptr1, ptrA, ptrB [copy if ptrA==0]
+  //   rlwinm shift1, ptr1, 3, 27, 28 [3, 27, 27]
+  //   xor shift, shift1, 24 [16]
+  //   rlwinm ptr, ptr1, 0, 0, 29
+  //   slw incr2, incr, shift
+  //   li mask2, 255 [li mask3, 0; ori mask2, mask3, 65535]
+  //   slw mask, mask2, shift
+  //  loopMBB:
+  //   l[wd]arx dest, ptr
+  //   add tmp, dest, incr2
+  //   andc tmp2, dest, mask
+  //   and tmp3, tmp, mask
+  //   or tmp4, tmp3, tmp2
+  //   st[wd]cx. tmp4, ptr
+  //   bne- loopMBB
+  //   fallthrough --> exitMBB
+
+  if (ptrA!=PPC::R0) {
+    Ptr1Reg = RegInfo.createVirtualRegister(RC);
+    BuildMI(BB, TII->get(is64bit ? PPC::ADD8 : PPC::ADD4), Ptr1Reg)
+      .addReg(ptrA).addReg(ptrB);
+  } else {
+    Ptr1Reg = ptrB;
+  }
+  BuildMI(BB, TII->get(PPC::RLWINM), Shift1Reg).addReg(Ptr1Reg)
+      .addImm(3).addImm(27).addImm(is8bit ? 28 : 27);
+  BuildMI(BB, TII->get(is64bit ? PPC::XOR8 : PPC::XOR), ShiftReg)
+      .addReg(Shift1Reg).addImm(is8bit ? 24 : 16);
+  if (is64bit)
+    BuildMI(BB, TII->get(PPC::RLDICR), PtrReg)
+      .addReg(Ptr1Reg).addImm(0).addImm(61);
+  else
+    BuildMI(BB, TII->get(PPC::RLWINM), PtrReg)
+      .addReg(Ptr1Reg).addImm(0).addImm(0).addImm(29);
+  BuildMI(BB, TII->get(PPC::SLW), Incr2Reg)
+      .addReg(incr).addReg(ShiftReg);
+  if (is8bit)
+    BuildMI(BB, TII->get(PPC::LI), Mask2Reg).addImm(255);
+  else {
+    BuildMI(BB, TII->get(PPC::LI), Mask3Reg).addImm(0);
+    BuildMI(BB, TII->get(PPC::ORI), Mask2Reg).addReg(Mask3Reg).addImm(65535);
+  }
+  BuildMI(BB, TII->get(PPC::SLW), MaskReg)
+      .addReg(Mask2Reg).addReg(ShiftReg);
+
+  BB = loopMBB;
+  BuildMI(BB, TII->get(PPC::LWARX), dest)
+    .addReg(PPC::R0).addReg(PtrReg);
+  BuildMI(BB, TII->get(BinOpcode), TmpReg).addReg(Incr2Reg).addReg(dest);
+  BuildMI(BB, TII->get(is64bit ? PPC::ANDC8 : PPC::ANDC), Tmp2Reg)
+    .addReg(dest).addReg(MaskReg);
+  BuildMI(BB, TII->get(is64bit ? PPC::AND8 : PPC::AND), Tmp3Reg)
+    .addReg(TmpReg).addReg(MaskReg);
+  BuildMI(BB, TII->get(is64bit ? PPC::OR8 : PPC::OR), Tmp4Reg)
+    .addReg(Tmp3Reg).addReg(Tmp2Reg);
+  BuildMI(BB, TII->get(PPC::STWCX))
+    .addReg(Tmp4Reg).addReg(PPC::R0).addReg(PtrReg);
+  BuildMI(BB, TII->get(PPC::BCC))
+    .addImm(PPC::PRED_NE).addReg(PPC::CR0).addMBB(loopMBB);    
+  BB->addSuccessor(loopMBB);
+  BB->addSuccessor(exitMBB);
+
+  //  exitMBB:
+  //   ...
+  BB = exitMBB;
+  return BB;
+}
+
 MachineBasicBlock *
 PPCTargetLowering::EmitInstrWithCustomInserter(MachineInstr *MI,
                                                MachineBasicBlock *BB) {
@@ -3974,30 +4094,60 @@ PPCTargetLowering::EmitInstrWithCustomInserter(MachineInstr *MI,
       .addReg(MI->getOperand(3).getReg()).addMBB(copy0MBB)
       .addReg(MI->getOperand(2).getReg()).addMBB(thisMBB);
   }
+  else if (MI->getOpcode() == PPC::ATOMIC_LOAD_ADD_I8)
+    BB = EmitPartwordAtomicBinary(MI, BB, true, PPC::ADD4);
+  else if (MI->getOpcode() == PPC::ATOMIC_LOAD_ADD_I16)
+    BB = EmitPartwordAtomicBinary(MI, BB, false, PPC::ADD4);
   else if (MI->getOpcode() == PPC::ATOMIC_LOAD_ADD_I32)
     BB = EmitAtomicBinary(MI, BB, false, PPC::ADD4);
   else if (MI->getOpcode() == PPC::ATOMIC_LOAD_ADD_I64)
     BB = EmitAtomicBinary(MI, BB, true, PPC::ADD8);
+
+  else if (MI->getOpcode() == PPC::ATOMIC_LOAD_AND_I8)
+    BB = EmitPartwordAtomicBinary(MI, BB, true, PPC::AND);
+  else if (MI->getOpcode() == PPC::ATOMIC_LOAD_AND_I16)
+    BB = EmitPartwordAtomicBinary(MI, BB, false, PPC::AND);
   else if (MI->getOpcode() == PPC::ATOMIC_LOAD_AND_I32)
     BB = EmitAtomicBinary(MI, BB, false, PPC::AND);
   else if (MI->getOpcode() == PPC::ATOMIC_LOAD_AND_I64)
     BB = EmitAtomicBinary(MI, BB, true, PPC::AND8);
+
+  else if (MI->getOpcode() == PPC::ATOMIC_LOAD_OR_I8)
+    BB = EmitPartwordAtomicBinary(MI, BB, true, PPC::OR);
+  else if (MI->getOpcode() == PPC::ATOMIC_LOAD_OR_I16)
+    BB = EmitPartwordAtomicBinary(MI, BB, false, PPC::OR);
   else if (MI->getOpcode() == PPC::ATOMIC_LOAD_OR_I32)
     BB = EmitAtomicBinary(MI, BB, false, PPC::OR);
   else if (MI->getOpcode() == PPC::ATOMIC_LOAD_OR_I64)
     BB = EmitAtomicBinary(MI, BB, true, PPC::OR8);
+
+  else if (MI->getOpcode() == PPC::ATOMIC_LOAD_XOR_I8)
+    BB = EmitPartwordAtomicBinary(MI, BB, true, PPC::XOR);
+  else if (MI->getOpcode() == PPC::ATOMIC_LOAD_XOR_I16)
+    BB = EmitPartwordAtomicBinary(MI, BB, false, PPC::XOR);
   else if (MI->getOpcode() == PPC::ATOMIC_LOAD_XOR_I32)
     BB = EmitAtomicBinary(MI, BB, false, PPC::XOR);
   else if (MI->getOpcode() == PPC::ATOMIC_LOAD_XOR_I64)
     BB = EmitAtomicBinary(MI, BB, true, PPC::XOR8);
+
+  else if (MI->getOpcode() == PPC::ATOMIC_LOAD_NAND_I8)
+    BB = EmitPartwordAtomicBinary(MI, BB, true, PPC::NAND);
+  else if (MI->getOpcode() == PPC::ATOMIC_LOAD_NAND_I16)
+    BB = EmitPartwordAtomicBinary(MI, BB, false, PPC::NAND);
   else if (MI->getOpcode() == PPC::ATOMIC_LOAD_NAND_I32)
     BB = EmitAtomicBinary(MI, BB, false, PPC::NAND);
   else if (MI->getOpcode() == PPC::ATOMIC_LOAD_NAND_I64)
     BB = EmitAtomicBinary(MI, BB, true, PPC::NAND8);
+
+  else if (MI->getOpcode() == PPC::ATOMIC_LOAD_SUB_I8)
+    BB = EmitPartwordAtomicBinary(MI, BB, true, PPC::SUBF);
+  else if (MI->getOpcode() == PPC::ATOMIC_LOAD_SUB_I16)
+    BB = EmitPartwordAtomicBinary(MI, BB, false, PPC::SUBF);
   else if (MI->getOpcode() == PPC::ATOMIC_LOAD_SUB_I32)
     BB = EmitAtomicBinary(MI, BB, false, PPC::SUBF);
   else if (MI->getOpcode() == PPC::ATOMIC_LOAD_SUB_I64)
     BB = EmitAtomicBinary(MI, BB, true, PPC::SUBF8);
+
   else if (MI->getOpcode() == PPC::ATOMIC_CMP_SWAP_I32 ||
            MI->getOpcode() == PPC::ATOMIC_CMP_SWAP_I64) {
     bool is64bit = MI->getOpcode() == PPC::ATOMIC_CMP_SWAP_I64;