ptx: add ld instruction and test
[oota-llvm.git] / lib / Target / PTX / PTXAsmPrinter.cpp
index f4525521d50537196ddaecfbdc353cac613342dc..cd27fb5d82efcbe1432dfb782467a6da135193b0 100644 (file)
@@ -17,7 +17,8 @@
 #include "PTX.h"
 #include "PTXMachineFunctionInfo.h"
 #include "PTXTargetMachine.h"
-#include "llvm/Support/raw_ostream.h"
+#include "llvm/DerivedTypes.h"
+#include "llvm/Module.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/Twine.h"
 #include "llvm/CodeGen/MachineInstr.h"
 #include "llvm/MC/MCStreamer.h"
 #include "llvm/MC/MCSymbol.h"
+#include "llvm/Target/Mangler.h"
 #include "llvm/Target/TargetLoweringObjectFile.h"
 #include "llvm/Target/TargetRegistry.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
+#include "llvm/Support/MathExtras.h"
 #include "llvm/Support/raw_ostream.h"
 
 using namespace llvm;
@@ -50,6 +53,8 @@ public:
 
   const char *getPassName() const { return "PTX Assembly Printer"; }
 
+  bool doFinalization(Module &M);
+
   virtual void EmitStartOfAsmFile(Module &M);
 
   virtual bool runOnMachineFunction(MachineFunction &MF);
@@ -68,6 +73,7 @@ public:
   static const char *getRegisterName(unsigned RegNo);
 
 private:
+  void EmitVariableDeclaration(const GlobalVariable *gv);
   void EmitFunctionDeclaration();
 }; // class PTXAsmPrinter
 } // namespace
@@ -96,11 +102,54 @@ static const char *getInstructionTypeName(const MachineInstr *MI) {
   return NULL;
 }
 
+static const char *getStateSpaceName(unsigned addressSpace) {
+  if (addressSpace <= 255)
+    return "global";
+  // TODO Add more state spaces
+
+  llvm_unreachable("Unknown state space");
+  return NULL;
+}
+
+bool PTXAsmPrinter::doFinalization(Module &M) {
+  // XXX Temproarily remove global variables so that doFinalization() will not
+  // emit them again (global variables are emitted at beginning).
+
+  Module::GlobalListType &global_list = M.getGlobalList();
+  int i, n = global_list.size();
+  GlobalVariable **gv_array = new GlobalVariable* [n];
+
+  // first, back-up GlobalVariable in gv_array
+  i = 0;
+  for (Module::global_iterator I = global_list.begin(), E = global_list.end();
+       I != E; ++I)
+    gv_array[i++] = &*I;
+
+  // second, empty global_list
+  while (!global_list.empty())
+    global_list.remove(global_list.begin());
+
+  // call doFinalization
+  bool ret = AsmPrinter::doFinalization(M);
+
+  // now we restore global variables
+  for (i = 0; i < n; i ++)
+    global_list.insert(global_list.end(), gv_array[i]);
+
+  delete[] gv_array;
+  return ret;
+}
+
 void PTXAsmPrinter::EmitStartOfAsmFile(Module &M)
 {
   OutStreamer.EmitRawText(Twine("\t.version " + OptPTXVersion));
   OutStreamer.EmitRawText(Twine("\t.target " + OptPTXTarget));
   OutStreamer.AddBlankLine();
+
+  // declare global variables
+  for (Module::const_global_iterator i = M.global_begin(), e = M.global_end();
+       i != e; ++i)
+    EmitVariableDeclaration(i);
 }
 
 bool PTXAsmPrinter::runOnMachineFunction(MachineFunction &MF) {
@@ -156,12 +205,15 @@ void PTXAsmPrinter::printOperand(const MachineInstr *MI, int opNum,
     default:
       llvm_unreachable("<unknown operand type>");
       break;
-    case MachineOperand::MO_Register:
-      OS << getRegisterName(MO.getReg());
+    case MachineOperand::MO_GlobalAddress:
+      OS << *Mang->getSymbol(MO.getGlobal());
       break;
     case MachineOperand::MO_Immediate:
       OS << (int) MO.getImm();
       break;
+    case MachineOperand::MO_Register:
+      OS << getRegisterName(MO.getReg());
+      break;
   }
 }
 
@@ -176,6 +228,49 @@ void PTXAsmPrinter::printMemOperand(const MachineInstr *MI, int opNum,
   printOperand(MI, opNum+1, OS);
 }
 
+void PTXAsmPrinter::EmitVariableDeclaration(const GlobalVariable *gv) {
+  // Check to see if this is a special global used by LLVM, if so, emit it.
+  if (EmitSpecialLLVMGlobal(gv))
+    return;
+
+  MCSymbol *gvsym = Mang->getSymbol(gv);
+
+  assert(gvsym->isUndefined() && "Cannot define a symbol twice!");
+
+  std::string decl;
+
+  // check if it is defined in some other translation unit
+  if (gv->isDeclaration())
+    decl += ".extern ";
+
+  // state space: e.g., .global
+  decl += ".";
+  decl += getStateSpaceName(gv->getType()->getAddressSpace());
+  decl += " ";
+
+  // alignment (optional)
+  unsigned alignment = gv->getAlignment();
+  if (alignment != 0) {
+    decl += ".align ";
+    decl += utostr(Log2_32(gv->getAlignment()));
+    decl += " ";
+  }
+
+  // TODO: add types
+  decl += ".s32 ";
+
+  decl += gvsym->getName();
+
+  if (ArrayType::classof(gv->getType()) || PointerType::classof(gv->getType()))
+    decl += "[]";
+
+  decl += ";";
+
+  OutStreamer.EmitRawText(Twine(decl));
+
+  OutStreamer.AddBlankLine();
+}
+
 void PTXAsmPrinter::EmitFunctionDeclaration() {
   // The function label could have already been emitted if two symbols end up
   // conflicting due to asm renaming.  Detect this and emit an error.
@@ -212,7 +307,7 @@ void PTXAsmPrinter::EmitFunctionDeclaration() {
       for (int i = 0, e = MFI->getNumArg(); i != e; ++i) {
         if (i != 0)
           decl += ", ";
-        decl += ".param .s32 "; // TODO: param's type
+        decl += ".param .s32 "; // TODO: add types
         decl += PARAM_PREFIX;
         decl += utostr(i + 1);
       }