PTX: Fix whitespace errors
[oota-llvm.git] / lib / Target / PTX / PTXAsmPrinter.cpp
1 //===-- PTXAsmPrinter.cpp - PTX LLVM assembly writer ----------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file contains a printer that converts from our internal representation
11 // of machine-dependent LLVM code to PTX assembly language.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #define DEBUG_TYPE "ptx-asm-printer"
16
17 #include "PTX.h"
18 #include "PTXMachineFunctionInfo.h"
19 #include "PTXTargetMachine.h"
20 #include "llvm/DerivedTypes.h"
21 #include "llvm/Module.h"
22 #include "llvm/ADT/SmallString.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/Twine.h"
25 #include "llvm/CodeGen/AsmPrinter.h"
26 #include "llvm/CodeGen/MachineInstr.h"
27 #include "llvm/CodeGen/MachineRegisterInfo.h"
28 #include "llvm/MC/MCStreamer.h"
29 #include "llvm/MC/MCSymbol.h"
30 #include "llvm/Target/Mangler.h"
31 #include "llvm/Target/TargetLoweringObjectFile.h"
32 #include "llvm/Target/TargetRegistry.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Debug.h"
35 #include "llvm/Support/ErrorHandling.h"
36 #include "llvm/Support/MathExtras.h"
37 #include "llvm/Support/raw_ostream.h"
38
39 using namespace llvm;
40
41 namespace {
42 class PTXAsmPrinter : public AsmPrinter {
43 public:
44   explicit PTXAsmPrinter(TargetMachine &TM, MCStreamer &Streamer)
45     : AsmPrinter(TM, Streamer) {}
46
47   const char *getPassName() const { return "PTX Assembly Printer"; }
48
49   bool doFinalization(Module &M);
50
51   virtual void EmitStartOfAsmFile(Module &M);
52
53   virtual bool runOnMachineFunction(MachineFunction &MF);
54
55   virtual void EmitFunctionBodyStart();
56   virtual void EmitFunctionBodyEnd() { OutStreamer.EmitRawText(Twine("}")); }
57
58   virtual void EmitInstruction(const MachineInstr *MI);
59
60   void printOperand(const MachineInstr *MI, int opNum, raw_ostream &OS);
61   void printMemOperand(const MachineInstr *MI, int opNum, raw_ostream &OS,
62                        const char *Modifier = 0);
63   void printParamOperand(const MachineInstr *MI, int opNum, raw_ostream &OS,
64                          const char *Modifier = 0);
65   void printPredicateOperand(const MachineInstr *MI, raw_ostream &O);
66
67   // autogen'd.
68   void printInstruction(const MachineInstr *MI, raw_ostream &OS);
69   static const char *getRegisterName(unsigned RegNo);
70
71 private:
72   void EmitVariableDeclaration(const GlobalVariable *gv);
73   void EmitFunctionDeclaration();
74 }; // class PTXAsmPrinter
75 } // namespace
76
77 static const char PARAM_PREFIX[] = "__param_";
78
79 static const char *getRegisterTypeName(unsigned RegNo) {
80 #define TEST_REGCLS(cls, clsstr)                \
81   if (PTX::cls ## RegisterClass->contains(RegNo)) return # clsstr;
82   TEST_REGCLS(Preds, pred);
83   TEST_REGCLS(RRegu16, u16);
84   TEST_REGCLS(RRegu32, u32);
85   TEST_REGCLS(RRegu64, u64);
86   TEST_REGCLS(RRegf32, f32);
87   TEST_REGCLS(RRegf64, f64);
88 #undef TEST_REGCLS
89
90   llvm_unreachable("Not in any register class!");
91   return NULL;
92 }
93
94 static const char *getStateSpaceName(unsigned addressSpace) {
95   switch (addressSpace) {
96   default: llvm_unreachable("Unknown state space");
97   case PTX::GLOBAL:    return "global";
98   case PTX::CONSTANT:  return "const";
99   case PTX::LOCAL:     return "local";
100   case PTX::PARAMETER: return "param";
101   case PTX::SHARED:    return "shared";
102   }
103   return NULL;
104 }
105
106 static const char *getTypeName(const Type* type) {
107   while (true) {
108     switch (type->getTypeID()) {
109       default: llvm_unreachable("Unknown type");
110       case Type::FloatTyID: return ".f32";
111       case Type::DoubleTyID: return ".f64";
112       case Type::IntegerTyID:
113         switch (type->getPrimitiveSizeInBits()) {
114           default: llvm_unreachable("Unknown integer bit-width");
115           case 16: return ".u16";
116           case 32: return ".u32";
117           case 64: return ".u64";
118         }
119       case Type::ArrayTyID:
120       case Type::PointerTyID:
121         type = dyn_cast<const SequentialType>(type)->getElementType();
122         break;
123     }
124   }
125   return NULL;
126 }
127
128 bool PTXAsmPrinter::doFinalization(Module &M) {
129   // XXX Temproarily remove global variables so that doFinalization() will not
130   // emit them again (global variables are emitted at beginning).
131
132   Module::GlobalListType &global_list = M.getGlobalList();
133   int i, n = global_list.size();
134   GlobalVariable **gv_array = new GlobalVariable* [n];
135
136   // first, back-up GlobalVariable in gv_array
137   i = 0;
138   for (Module::global_iterator I = global_list.begin(), E = global_list.end();
139        I != E; ++I)
140     gv_array[i++] = &*I;
141
142   // second, empty global_list
143   while (!global_list.empty())
144     global_list.remove(global_list.begin());
145
146   // call doFinalization
147   bool ret = AsmPrinter::doFinalization(M);
148
149   // now we restore global variables
150   for (i = 0; i < n; i ++)
151     global_list.insert(global_list.end(), gv_array[i]);
152
153   delete[] gv_array;
154   return ret;
155 }
156
157 void PTXAsmPrinter::EmitStartOfAsmFile(Module &M)
158 {
159   const PTXSubtarget& ST = TM.getSubtarget<PTXSubtarget>();
160
161   OutStreamer.EmitRawText(Twine("\t.version " + ST.getPTXVersionString()));
162   OutStreamer.EmitRawText(Twine("\t.target " + ST.getTargetString() +
163                                 (ST.supportsDouble() ? ""
164                                                      : ", map_f64_to_f32")));
165   OutStreamer.AddBlankLine();
166
167   // declare global variables
168   for (Module::const_global_iterator i = M.global_begin(), e = M.global_end();
169        i != e; ++i)
170     EmitVariableDeclaration(i);
171 }
172
173 bool PTXAsmPrinter::runOnMachineFunction(MachineFunction &MF) {
174   SetupMachineFunction(MF);
175   EmitFunctionDeclaration();
176   EmitFunctionBody();
177   return false;
178 }
179
180 void PTXAsmPrinter::EmitFunctionBodyStart() {
181   OutStreamer.EmitRawText(Twine("{"));
182
183   const PTXMachineFunctionInfo *MFI = MF->getInfo<PTXMachineFunctionInfo>();
184
185   // Print local variable definition
186   for (PTXMachineFunctionInfo::reg_iterator
187        i = MFI->localVarRegBegin(), e = MFI->localVarRegEnd(); i != e; ++ i) {
188     unsigned reg = *i;
189
190     std::string def = "\t.reg .";
191     def += getRegisterTypeName(reg);
192     def += ' ';
193     def += getRegisterName(reg);
194     def += ';';
195     OutStreamer.EmitRawText(Twine(def));
196   }
197 }
198
199 void PTXAsmPrinter::EmitInstruction(const MachineInstr *MI) {
200   std::string str;
201   str.reserve(64);
202
203   raw_string_ostream OS(str);
204
205   // Emit predicate
206   printPredicateOperand(MI, OS);
207
208   // Write instruction to str
209   printInstruction(MI, OS);
210   OS << ';';
211   OS.flush();
212
213   StringRef strref = StringRef(str);
214   OutStreamer.EmitRawText(strref);
215 }
216
217 void PTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum,
218                                  raw_ostream &OS) {
219   const MachineOperand &MO = MI->getOperand(opNum);
220
221   switch (MO.getType()) {
222     default:
223       llvm_unreachable("<unknown operand type>");
224       break;
225     case MachineOperand::MO_GlobalAddress:
226       OS << *Mang->getSymbol(MO.getGlobal());
227       break;
228     case MachineOperand::MO_Immediate:
229       OS << (long) MO.getImm();
230       break;
231     case MachineOperand::MO_MachineBasicBlock:
232       OS << *MO.getMBB()->getSymbol();
233       break;
234     case MachineOperand::MO_Register:
235       OS << getRegisterName(MO.getReg());
236       break;
237     case MachineOperand::MO_FPImmediate:
238       APInt constFP = MO.getFPImm()->getValueAPF().bitcastToAPInt();
239       bool  isFloat = MO.getFPImm()->getType()->getTypeID() == Type::FloatTyID;
240       // Emit 0F for 32-bit floats and 0D for 64-bit doubles.
241       if (isFloat) {
242         OS << "0F";
243       }
244       else {
245         OS << "0D";
246       }
247       // Emit the encoded floating-point value.
248       if (constFP.getZExtValue() > 0) {
249         OS << constFP.toString(16, false);
250       }
251       else {
252         OS << "00000000";
253         // If We have a double-precision zero, pad to 8-bytes.
254         if (!isFloat) {
255           OS << "00000000";
256         }
257       }
258       break;
259   }
260 }
261
262 void PTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum,
263                                     raw_ostream &OS, const char *Modifier) {
264   printOperand(MI, opNum, OS);
265
266   if (MI->getOperand(opNum+1).isImm() && MI->getOperand(opNum+1).getImm() == 0)
267     return; // don't print "+0"
268
269   OS << "+";
270   printOperand(MI, opNum+1, OS);
271 }
272
273 void PTXAsmPrinter::printParamOperand(const MachineInstr *MI, int opNum,
274                                       raw_ostream &OS, const char *Modifier) {
275   OS << PARAM_PREFIX << (int) MI->getOperand(opNum).getImm() + 1;
276 }
277
278 void PTXAsmPrinter::EmitVariableDeclaration(const GlobalVariable *gv) {
279   // Check to see if this is a special global used by LLVM, if so, emit it.
280   if (EmitSpecialLLVMGlobal(gv))
281     return;
282
283   MCSymbol *gvsym = Mang->getSymbol(gv);
284
285   assert(gvsym->isUndefined() && "Cannot define a symbol twice!");
286
287   std::string decl;
288
289   // check if it is defined in some other translation unit
290   if (gv->isDeclaration())
291     decl += ".extern ";
292
293   // state space: e.g., .global
294   decl += ".";
295   decl += getStateSpaceName(gv->getType()->getAddressSpace());
296   decl += " ";
297
298   // alignment (optional)
299   unsigned alignment = gv->getAlignment();
300   if (alignment != 0) {
301     decl += ".align ";
302     decl += utostr(Log2_32(gv->getAlignment()));
303     decl += " ";
304   }
305
306
307   if (PointerType::classof(gv->getType())) {
308     const PointerType* pointerTy = dyn_cast<const PointerType>(gv->getType());
309     const Type* elementTy = pointerTy->getElementType();
310
311     decl += ".b8 ";
312     decl += gvsym->getName();
313     decl += "[";
314
315     if (elementTy->isArrayTy())
316     {
317       assert(elementTy->isArrayTy() && "Only pointers to arrays are supported");
318
319       const ArrayType* arrayTy = dyn_cast<const ArrayType>(elementTy);
320       elementTy = arrayTy->getElementType();
321
322       unsigned numElements = arrayTy->getNumElements();
323
324       while (elementTy->isArrayTy()) {
325
326         arrayTy = dyn_cast<const ArrayType>(elementTy);
327         elementTy = arrayTy->getElementType();
328
329         numElements *= arrayTy->getNumElements();
330       }
331
332       // FIXME: isPrimitiveType() == false for i16?
333       assert(elementTy->isSingleValueType() &&
334               "Non-primitive types are not handled");
335
336       // Compute the size of the array, in bytes.
337       uint64_t arraySize = (elementTy->getPrimitiveSizeInBits() >> 3)
338                         * numElements;
339
340       decl += utostr(arraySize);
341     }
342
343     decl += "]";
344
345     // handle string constants (assume ConstantArray means string)
346
347     if (gv->hasInitializer())
348     {
349       Constant *C = gv->getInitializer();  
350       if (const ConstantArray *CA = dyn_cast<ConstantArray>(C))
351       {
352         decl += " = {";
353
354         for (unsigned i = 0, e = C->getNumOperands(); i != e; ++i)
355         {
356           if (i > 0)   decl += ",";
357
358           decl += "0x" +
359                 utohexstr(cast<ConstantInt>(CA->getOperand(i))->getZExtValue());
360         }
361
362         decl += "}";
363       }
364     }
365   }
366   else {
367     // Note: this is currently the fall-through case and most likely generates
368     //       incorrect code.
369     decl += getTypeName(gv->getType());
370     decl += " ";
371
372     decl += gvsym->getName();
373
374     if (ArrayType::classof(gv->getType()) ||
375         PointerType::classof(gv->getType()))
376       decl += "[]";
377   }
378
379   decl += ";";
380
381   OutStreamer.EmitRawText(Twine(decl));
382
383   OutStreamer.AddBlankLine();
384 }
385
386 void PTXAsmPrinter::EmitFunctionDeclaration() {
387   // The function label could have already been emitted if two symbols end up
388   // conflicting due to asm renaming.  Detect this and emit an error.
389   if (!CurrentFnSym->isUndefined()) {
390     report_fatal_error("'" + Twine(CurrentFnSym->getName()) +
391                        "' label emitted multiple times to assembly file");
392     return;
393   }
394
395   const PTXMachineFunctionInfo *MFI = MF->getInfo<PTXMachineFunctionInfo>();
396   const bool isKernel = MFI->isKernel();
397   unsigned reg;
398
399   std::string decl = isKernel ? ".entry" : ".func";
400
401   // Print return register
402   reg = MFI->retReg();
403   if (!isKernel && reg != PTX::NoRegister) {
404     decl += " (.reg ."; // FIXME: could it return in .param space?
405     decl += getRegisterTypeName(reg);
406     decl += " ";
407     decl += getRegisterName(reg);
408     decl += ")";
409   }
410
411   // Print function name
412   decl += " ";
413   decl += CurrentFnSym->getName().str();
414
415   // Print parameter list
416   if (!MFI->argRegEmpty()) {
417     decl += " (";
418     if (isKernel) {
419       unsigned cnt = 0;
420       for(PTXMachineFunctionInfo::reg_iterator
421           i = MFI->argRegBegin(), e = MFI->argRegEnd(), b = i;
422           i != e; ++i) {
423         reg = *i;
424         assert(reg != PTX::NoRegister && "Not a valid register!");
425         if (i != b)
426           decl += ", ";
427         decl += ".param .";
428         decl += getRegisterTypeName(reg);
429         decl += " ";
430         decl += PARAM_PREFIX;
431         decl += utostr(++cnt);
432       }
433     } else {
434       for (PTXMachineFunctionInfo::reg_iterator
435            i = MFI->argRegBegin(), e = MFI->argRegEnd(), b = i;
436            i != e; ++i) {
437         reg = *i;
438         assert(reg != PTX::NoRegister && "Not a valid register!");
439         if (i != b)
440           decl += ", ";
441         decl += ".reg .";
442         decl += getRegisterTypeName(reg);
443         decl += " ";
444         decl += getRegisterName(reg);
445       }
446     }
447     decl += ")";
448   }
449
450   OutStreamer.EmitRawText(Twine(decl));
451 }
452
453 void PTXAsmPrinter::
454 printPredicateOperand(const MachineInstr *MI, raw_ostream &O) {
455   int i = MI->findFirstPredOperandIdx();
456   if (i == -1)
457     llvm_unreachable("missing predicate operand");
458
459   unsigned reg = MI->getOperand(i).getReg();
460   int predOp = MI->getOperand(i+1).getImm();
461
462   DEBUG(dbgs() << "predicate: (" << reg << ", " << predOp << ")\n");
463
464   if (reg != PTX::NoRegister) {
465     O << '@';
466     if (predOp == PTX::PRED_NEGATE)
467       O << '!';
468     O << getRegisterName(reg);
469   }
470 }
471
472 #include "PTXGenAsmWriter.inc"
473
474 // Force static initialization.
475 extern "C" void LLVMInitializePTXAsmPrinter() {
476   RegisterAsmPrinter<PTXAsmPrinter> X(ThePTX32Target);
477   RegisterAsmPrinter<PTXAsmPrinter> Y(ThePTX64Target);
478 }