[WebAssembly] Support non-legal argument and return types.
authorDan Gohman <dan433584@gmail.com>
Wed, 11 Nov 2015 01:33:02 +0000 (01:33 +0000)
committerDan Gohman <dan433584@gmail.com>
Wed, 11 Nov 2015 01:33:02 +0000 (01:33 +0000)
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@252687 91177308-0d34-0410-b5e6-96231b3b80d8

lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp
lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.h
test/CodeGen/WebAssembly/import.ll

index 6eb9ae6..27095ec 100644 (file)
@@ -23,6 +23,7 @@
 
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringExtras.h"
+#include "llvm/CodeGen/Analysis.h"
 #include "llvm/CodeGen/AsmPrinter.h"
 #include "llvm/CodeGen/MachineConstantPool.h"
 #include "llvm/CodeGen/MachineInstr.h"
@@ -66,7 +67,7 @@ private:
     const auto &Subtarget = MF.getSubtarget<WebAssemblySubtarget>();
     TII = Subtarget.getInstrInfo();
     MRI = &MF.getRegInfo();
-    NumArgs = MF.getInfo<WebAssemblyFunctionInfo>()->getNumArguments();
+    NumArgs = MF.getInfo<WebAssemblyFunctionInfo>()->getParams().size();
     return AsmPrinter::runOnMachineFunction(MF);
   }
 
@@ -82,7 +83,7 @@ private:
 
   std::string getRegTypeName(unsigned RegNo) const;
   static std::string toString(const APFloat &APF);
-  const char *toString(Type *Ty) const;
+  const char *toString(MVT VT) const;
   std::string regToString(const MachineOperand &MO);
   std::string argToString(const MachineOperand &MO);
 };
@@ -167,40 +168,20 @@ std::string WebAssemblyAsmPrinter::argToString(const MachineOperand &MO) {
   return utostr(ArgNo);
 }
 
-const char *WebAssemblyAsmPrinter::toString(Type *Ty) const {
-  switch (Ty->getTypeID()) {
+const char *WebAssemblyAsmPrinter::toString(MVT VT) const {
+  switch (VT.SimpleTy) {
   default:
     break;
-  // Treat all pointers as the underlying integer into linear memory.
-  case Type::PointerTyID:
-    switch (getPointerSize()) {
-    case 4:
-      return "i32";
-    case 8:
-      return "i64";
-    default:
-      llvm_unreachable("unsupported pointer size");
-    }
-    break;
-  case Type::FloatTyID:
+  case MVT::f32:
     return "f32";
-  case Type::DoubleTyID:
+  case MVT::f64:
     return "f64";
-  case Type::IntegerTyID:
-    switch (Ty->getIntegerBitWidth()) {
-    case 8:
-      return "i8";
-    case 16:
-      return "i16";
-    case 32:
-      return "i32";
-    case 64:
-      return "i64";
-    default:
-      break;
-    }
+  case MVT::i32:
+    return "i32";
+  case MVT::i64:
+    return "i64";
   }
-  DEBUG(dbgs() << "Invalid type "; Ty->print(dbgs()); dbgs() << '\n');
+  DEBUG(dbgs() << "Invalid type " << EVT(VT).getEVTString() << '\n');
   llvm_unreachable("invalid type");
   return "<invalid>";
 }
@@ -219,40 +200,37 @@ void WebAssemblyAsmPrinter::EmitJumpTableInfo() {
 }
 
 void WebAssemblyAsmPrinter::EmitFunctionBodyStart() {
-  const Function *F = MF->getFunction();
-  Type *Rt = F->getReturnType();
   SmallString<128> Str;
   raw_svector_ostream OS(Str);
-  bool First = true;
 
-  if (!Rt->isVoidTy() || !F->arg_empty()) {
-    for (const Argument &A : F->args()) {
-      OS << (First ? "" : "\n") << "\t.param " << toString(A.getType());
-      First = false;
-    }
-    if (!Rt->isVoidTy()) {
-      OS << (First ? "" : "\n") << "\t.result " << toString(Rt);
-      First = false;
-    }
-  }
+  for (MVT VT : MF->getInfo<WebAssemblyFunctionInfo>()->getParams())
+    OS << "\t" ".param "
+       << toString(VT) << '\n';
+  for (MVT VT : MF->getInfo<WebAssemblyFunctionInfo>()->getResults())
+    OS << "\t" ".result "
+       << toString(VT) << '\n';
 
   bool FirstVReg = true;
   for (unsigned Idx = 0, IdxE = MRI->getNumVirtRegs(); Idx != IdxE; ++Idx) {
     unsigned VReg = TargetRegisterInfo::index2VirtReg(Idx);
     // FIXME: Don't skip dead virtual registers for now: that would require
     //        remapping all locals' numbers.
-    //if (!MRI->use_empty(VReg)) {
-      if (FirstVReg) {
-        OS << (First ? "" : "\n") << "\t.local ";
-        First = false;
-      }
-      OS << (FirstVReg ? "" : ", ") << getRegTypeName(VReg);
-      FirstVReg = false;
+    // if (!MRI->use_empty(VReg)) {
+    if (FirstVReg)
+      OS << "\t" ".local ";
+    else
+      OS << ", ";
+    OS << getRegTypeName(VReg);
+    FirstVReg = false;
     //}
   }
+  if (!FirstVReg)
+    OS << '\n';
 
-  if (!First)
-    OutStreamer->EmitRawText(OS.str());
+  // EmitRawText appends a newline, so strip off the last newline.
+  StringRef Text = OS.str();
+  if (!Text.empty())
+    OutStreamer->EmitRawText(Text.substr(0, Text.size() - 1));
   AsmPrinter::EmitFunctionBodyStart();
 }
 
@@ -334,27 +312,75 @@ void WebAssemblyAsmPrinter::EmitInstruction(const MachineInstr *MI) {
   }
 }
 
+static void ComputeLegalValueVTs(LLVMContext &Context,
+                                 const WebAssemblyTargetLowering &TLI,
+                                 const DataLayout &DL, Type *Ty,
+                                 SmallVectorImpl<MVT> &ValueVTs) {
+  SmallVector<EVT, 4> VTs;
+  ComputeValueVTs(TLI, DL, Ty, VTs);
+
+  for (EVT VT : VTs) {
+    unsigned NumRegs = TLI.getNumRegisters(Context, VT);
+    MVT RegisterVT = TLI.getRegisterType(Context, VT);
+    for (unsigned i = 0; i != NumRegs; ++i)
+      ValueVTs.push_back(RegisterVT);
+  }
+}
+
 void WebAssemblyAsmPrinter::EmitEndOfAsmFile(Module &M) {
+  const DataLayout &DL = M.getDataLayout();
+
   SmallString<128> Str;
   raw_svector_ostream OS(Str);
   for (const Function &F : M)
     if (F.isDeclarationForLinker()) {
       assert(F.hasName() && "imported functions must have a name");
-      if (F.getName().startswith("llvm."))
+      if (F.isIntrinsic())
         continue;
       if (Str.empty())
         OS << "\t.imports\n";
-      Type *Rt = F.getReturnType();
+
       OS << "\t.import " << toSymbol(F.getName()) << " \"\" \"" << F.getName()
-         << "\" (param";
-      for (const Argument &A : F.args())
-        OS << ' ' << toString(A.getType());
-      OS << ')';
-      if (!Rt->isVoidTy())
-        OS << " (result " << toString(Rt) << ')';
+         << "\"";
+
+      const WebAssemblyTargetLowering &TLI =
+          *TM.getSubtarget<WebAssemblySubtarget>(F).getTargetLowering();
+
+      // If we need to legalize the return type, it'll get converted into
+      // passing a pointer.
+      bool SawParam = false;
+      SmallVector<MVT, 4> ResultVTs;
+      ComputeLegalValueVTs(M.getContext(), TLI, DL, F.getReturnType(),
+                           ResultVTs);
+      if (ResultVTs.size() > 1) {
+        ResultVTs.clear();
+        OS << " (param " << toString(TLI.getPointerTy(DL));
+        SawParam = true;
+      }
+
+      for (const Argument &A : F.args()) {
+        SmallVector<MVT, 4> ParamVTs;
+        ComputeLegalValueVTs(M.getContext(), TLI, DL, A.getType(), ParamVTs);
+        for (EVT VT : ParamVTs) {
+          if (!SawParam) {
+            OS << " (param";
+            SawParam = true;
+          }
+          OS << ' ' << toString(VT.getSimpleVT());
+        }
+      }
+      if (SawParam)
+        OS << ')';
+
+      for (EVT VT : ResultVTs)
+        OS << " (result " << toString(VT.getSimpleVT()) << ')';
+
       OS << '\n';
     }
-  OutStreamer->EmitRawText(OS.str());
+
+  StringRef Text = OS.str();
+  if (!Text.empty())
+    OutStreamer->EmitRawText(Text.substr(0, Text.size() - 1));
 }
 
 // Force static initialization.
index d813367..899e768 100644 (file)
@@ -252,13 +252,8 @@ WebAssemblyTargetLowering::LowerCall(CallLoweringInfo &CLI,
     fail(DL, DAG, "WebAssembly doesn't support tail call yet");
   CLI.IsTailCall = false;
 
-  SmallVectorImpl<ISD::OutputArg> &Outs = CLI.Outs;
   SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
 
-  bool IsStructRet = (Outs.empty()) ? false : Outs[0].Flags.isSRet();
-  if (IsStructRet)
-    fail(DL, DAG, "WebAssembly doesn't support struct return yet");
-
   SmallVectorImpl<ISD::InputArg> &Ins = CLI.Ins;
   if (Ins.size() > 1)
     fail(DL, DAG, "WebAssembly doesn't support more than 1 returned value yet");
@@ -316,6 +311,7 @@ SDValue WebAssemblyTargetLowering::LowerReturn(
     const SmallVectorImpl<ISD::OutputArg> &Outs,
     const SmallVectorImpl<SDValue> &OutVals, SDLoc DL,
     SelectionDAG &DAG) const {
+  MachineFunction &MF = DAG.getMachineFunction();
 
   assert(Outs.size() <= 1 && "WebAssembly can only return up to one value");
   if (CallConv != CallingConv::C)
@@ -327,6 +323,33 @@ SDValue WebAssemblyTargetLowering::LowerReturn(
   RetOps.append(OutVals.begin(), OutVals.end());
   Chain = DAG.getNode(WebAssemblyISD::RETURN, DL, MVT::Other, RetOps);
 
+  // Record the number and types of the return values.
+  for (const ISD::OutputArg &Out : Outs) {
+    if (Out.Flags.isZExt())
+      fail(DL, DAG, "WebAssembly hasn't implemented zext results");
+    if (Out.Flags.isSExt())
+      fail(DL, DAG, "WebAssembly hasn't implemented sext results");
+    if (Out.Flags.isInReg())
+      fail(DL, DAG, "WebAssembly hasn't implemented inreg results");
+    if (Out.Flags.isSRet())
+      fail(DL, DAG, "WebAssembly hasn't implemented sret results");
+    if (Out.Flags.isByVal())
+      fail(DL, DAG, "WebAssembly hasn't implemented byval results");
+    if (Out.Flags.isInAlloca())
+      fail(DL, DAG, "WebAssembly hasn't implemented inalloca results");
+    if (Out.Flags.isNest())
+      fail(DL, DAG, "WebAssembly hasn't implemented nest results");
+    if (Out.Flags.isReturned())
+      fail(DL, DAG, "WebAssembly hasn't implemented returned results");
+    if (Out.Flags.isInConsecutiveRegs())
+      fail(DL, DAG, "WebAssembly hasn't implemented cons regs results");
+    if (Out.Flags.isInConsecutiveRegsLast())
+      fail(DL, DAG, "WebAssembly hasn't implemented cons regs last results");
+    if (!Out.IsFixed)
+      fail(DL, DAG, "WebAssembly doesn't support non-fixed results yet");
+    MF.getInfo<WebAssemblyFunctionInfo>()->addResult(Out.VT);
+  }
+
   return Chain;
 }
 
@@ -340,8 +363,6 @@ SDValue WebAssemblyTargetLowering::LowerFormalArguments(
     fail(DL, DAG, "WebAssembly doesn't support non-C calling conventions");
   if (IsVarArg)
     fail(DL, DAG, "WebAssembly doesn't support varargs yet");
-  if (MF.getFunction()->hasStructRetAttr())
-    fail(DL, DAG, "WebAssembly doesn't support struct return yet");
 
   unsigned ArgNo = 0;
   for (const ISD::InputArg &In : Ins) {
@@ -365,21 +386,18 @@ SDValue WebAssemblyTargetLowering::LowerFormalArguments(
       fail(DL, DAG, "WebAssembly hasn't implemented cons regs arguments");
     if (In.Flags.isInConsecutiveRegsLast())
       fail(DL, DAG, "WebAssembly hasn't implemented cons regs last arguments");
-    if (In.Flags.isSplit())
-      fail(DL, DAG, "WebAssembly hasn't implemented split arguments");
     // FIXME Do something with In.getOrigAlign()?
     InVals.push_back(
         In.Used
             ? DAG.getNode(WebAssemblyISD::ARGUMENT, DL, In.VT,
                           DAG.getTargetConstant(ArgNo, DL, MVT::i32))
             : DAG.getNode(ISD::UNDEF, DL, In.VT));
+
+    // Record the number and types of arguments.
+    MF.getInfo<WebAssemblyFunctionInfo>()->addParam(In.VT);
     ++ArgNo;
   }
 
-  // Record the number of arguments, since argument indices and local variable
-  // indices are in the same index space.
-  MF.getInfo<WebAssemblyFunctionInfo>()->setNumArguments(ArgNo);
-
   return Chain;
 }
 
index a571e63..bac0dfa 100644 (file)
@@ -27,15 +27,19 @@ namespace llvm {
 class WebAssemblyFunctionInfo final : public MachineFunctionInfo {
   MachineFunction &MF;
 
-  unsigned NumArguments;
+  std::vector<MVT> Params;
+  std::vector<MVT> Results;
 
 public:
   explicit WebAssemblyFunctionInfo(MachineFunction &MF)
-      : MF(MF), NumArguments(0) {}
+      : MF(MF) {}
   ~WebAssemblyFunctionInfo() override;
 
-  void setNumArguments(unsigned N) { NumArguments = N; }
-  unsigned getNumArguments() const { return NumArguments; }
+  void addParam(MVT VT) { Params.push_back(VT); }
+  const std::vector<MVT> &getParams() const { return Params; }
+
+  void addResult(MVT VT) { Results.push_back(VT); }
+  const std::vector<MVT> &getResults() const { return Results; }
 };
 
 } // end namespace llvm
index 6f1f8e0..09c7cef 100644 (file)
@@ -5,19 +5,28 @@ target triple = "wasm32-unknown-unknown"
 
 ; CHECK-LABEL: .text
 ; CHECK-LABEL: f:
-define void @f(i32 %a, float %b) {
+define void @f(i32 %a, float %b, i128 %c, i1 %d) {
   tail call i32 @printi(i32 %a)
   tail call float @printf(float %b)
   tail call void @printv()
+  tail call void @split_arg(i128 %c)
+  tail call void @expanded_arg(i1 %d)
+  tail call i1 @lowered_result()
   ret void
 }
 
 ; CHECK-LABEL: .imports
-; CHECK-NEXT:  .import $printi "" "printi" (param i32) (result i32)
-; CHECK-NEXT:  .import $printf "" "printf" (param f32) (result f32)
-; CHECK-NEXT:  .import $printv "" "printv" (param)
-; CHECK-NEXT:  .import $add2 "" "add2" (param i32 i32) (result i32)
+; CHECK-NEXT:  .import $printi "" "printi" (param i32) (result i32){{$}}
+; CHECK-NEXT:  .import $printf "" "printf" (param f32) (result f32){{$}}
+; CHECK-NEXT:  .import $printv "" "printv"{{$}}
+; CHECK-NEXT:  .import $add2 "" "add2" (param i32 i32) (result i32){{$}}
+; CHECK-NEXT:  .import $split_arg "" "split_arg" (param i64 i64){{$}}
+; CHECK-NEXT:  .import $expanded_arg "" "expanded_arg" (param i32){{$}}
+; CHECK-NEXT:  .import $lowered_result "" "lowered_result" (result i32){{$}}
 declare i32 @printi(i32)
 declare float @printf(float)
 declare void @printv()
 declare i32 @add2(i32, i32)
+declare void @split_arg(i128)
+declare void @expanded_arg(i1)
+declare i1 @lowered_result()