[Kaleidoscope] Fix a bug in Chapter 4 of the Kaleidoscope tutorial where repeat
[oota-llvm.git] / examples / Kaleidoscope / Chapter4 / toy.cpp
index a52b5552a291d185c81a918ae3d23de36b508305..d1520c2a13dac2caf50e1c1e8faeca979572acd0 100644 (file)
@@ -1,12 +1,13 @@
 #include "llvm/Analysis/Passes.h"
-#include "llvm/Analysis/Verifier.h"
 #include "llvm/ExecutionEngine/ExecutionEngine.h"
-#include "llvm/ExecutionEngine/JIT.h"
+#include "llvm/ExecutionEngine/MCJIT.h"
+#include "llvm/ExecutionEngine/SectionMemoryManager.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/Module.h"
+#include "llvm/IR/Verifier.h"
 #include "llvm/PassManager.h"
 #include "llvm/Support/TargetSelect.h"
 #include "llvm/Transforms/Scalar.h"
@@ -27,14 +28,16 @@ enum Token {
   tok_eof = -1,
 
   // commands
-  tok_def = -2, tok_extern = -3,
+  tok_def = -2,
+  tok_extern = -3,
 
   // primary
-  tok_identifier = -4, tok_number = -5
+  tok_identifier = -4,
+  tok_number = -5
 };
 
-static std::string IdentifierStr;  // Filled in if tok_identifier
-static double NumVal;              // Filled in if tok_number
+static std::string IdentifierStr; // Filled in if tok_identifier
+static double NumVal;             // Filled in if tok_number
 
 /// gettok - Return the next token from standard input.
 static int gettok() {
@@ -49,12 +52,14 @@ static int gettok() {
     while (isalnum((LastChar = getchar())))
       IdentifierStr += LastChar;
 
-    if (IdentifierStr == "def") return tok_def;
-    if (IdentifierStr == "extern") return tok_extern;
+    if (IdentifierStr == "def")
+      return tok_def;
+    if (IdentifierStr == "extern")
+      return tok_extern;
     return tok_identifier;
   }
 
-  if (isdigit(LastChar) || LastChar == '.') {   // Number: [0-9.]+
+  if (isdigit(LastChar) || LastChar == '.') { // Number: [0-9.]+
     std::string NumStr;
     do {
       NumStr += LastChar;
@@ -67,13 +72,14 @@ static int gettok() {
 
   if (LastChar == '#') {
     // Comment until end of line.
-    do LastChar = getchar();
+    do
+      LastChar = getchar();
     while (LastChar != EOF && LastChar != '\n' && LastChar != '\r');
-    
+
     if (LastChar != EOF)
       return gettok();
   }
-  
+
   // Check for end of file.  Don't eat the EOF.
   if (LastChar == EOF)
     return tok_eof;
@@ -87,19 +93,18 @@ static int gettok() {
 //===----------------------------------------------------------------------===//
 // Abstract Syntax Tree (aka Parse Tree)
 //===----------------------------------------------------------------------===//
-
+namespace {
 /// ExprAST - Base class for all expression nodes.
 class ExprAST {
 public:
-  virtual ~ExprAST();
+  virtual ~ExprAST() {}
   virtual Value *Codegen() = 0;
 };
 
-ExprAST::~ExprAST() {}
-
 /// NumberExprAST - Expression class for numeric literals like "1.0".
 class NumberExprAST : public ExprAST {
   double Val;
+
 public:
   NumberExprAST(double val) : Val(val) {}
   virtual Value *Codegen();
@@ -108,6 +113,7 @@ public:
 /// VariableExprAST - Expression class for referencing a variable, like "a".
 class VariableExprAST : public ExprAST {
   std::string Name;
+
 public:
   VariableExprAST(const std::string &name) : Name(name) {}
   virtual Value *Codegen();
@@ -117,19 +123,21 @@ public:
 class BinaryExprAST : public ExprAST {
   char Op;
   ExprAST *LHS, *RHS;
+
 public:
-  BinaryExprAST(char op, ExprAST *lhs, ExprAST *rhs) 
-    : Op(op), LHS(lhs), RHS(rhs) {}
+  BinaryExprAST(char op, ExprAST *lhs, ExprAST *rhs)
+      : Op(op), LHS(lhs), RHS(rhs) {}
   virtual Value *Codegen();
 };
 
 /// CallExprAST - Expression class for function calls.
 class CallExprAST : public ExprAST {
   std::string Callee;
-  std::vector<ExprAST*> Args;
+  std::vector<ExprAST *> Args;
+
 public:
-  CallExprAST(const std::string &callee, std::vector<ExprAST*> &args)
-    : Callee(callee), Args(args) {}
+  CallExprAST(const std::string &callee, std::vector<ExprAST *> &args)
+      : Callee(callee), Args(args) {}
   virtual Value *Codegen();
 };
 
@@ -139,10 +147,11 @@ public:
 class PrototypeAST {
   std::string Name;
   std::vector<std::string> Args;
+
 public:
   PrototypeAST(const std::string &name, const std::vector<std::string> &args)
-    : Name(name), Args(args) {}
-  
+      : Name(name), Args(args) {}
+
   Function *Codegen();
 };
 
@@ -150,12 +159,13 @@ public:
 class FunctionAST {
   PrototypeAST *Proto;
   ExprAST *Body;
+
 public:
-  FunctionAST(PrototypeAST *proto, ExprAST *body)
-    : Proto(proto), Body(body) {}
-  
+  FunctionAST(PrototypeAST *proto, ExprAST *body) : Proto(proto), Body(body) {}
+
   Function *Codegen();
 };
+} // end anonymous namespace
 
 //===----------------------------------------------------------------------===//
 // Parser
@@ -165,9 +175,7 @@ public:
 /// token the parser is looking at.  getNextToken reads another token from the
 /// lexer and updates CurTok with its results.
 static int CurTok;
-static int getNextToken() {
-  return CurTok = gettok();
-}
+static int getNextToken() { return CurTok = gettok(); }
 
 /// BinopPrecedence - This holds the precedence for each binary operator that is
 /// defined.
@@ -177,17 +185,27 @@ static std::map<char, int> BinopPrecedence;
 static int GetTokPrecedence() {
   if (!isascii(CurTok))
     return -1;
-  
+
   // Make sure it's a declared binop.
   int TokPrec = BinopPrecedence[CurTok];
-  if (TokPrec <= 0) return -1;
+  if (TokPrec <= 0)
+    return -1;
   return TokPrec;
 }
 
 /// Error* - These are little helper functions for error handling.
-ExprAST *Error(const char *Str) { fprintf(stderr, "Error: %s\n", Str);return 0;}
-PrototypeAST *ErrorP(const char *Str) { Error(Str); return 0; }
-FunctionAST *ErrorF(const char *Str) { Error(Str); return 0; }
+ExprAST *Error(const char *Str) {
+  fprintf(stderr, "Error: %s\n", Str);
+  return 0;
+}
+PrototypeAST *ErrorP(const char *Str) {
+  Error(Str);
+  return 0;
+}
+FunctionAST *ErrorF(const char *Str) {
+  Error(Str);
+  return 0;
+}
 
 static ExprAST *ParseExpression();
 
@@ -196,22 +214,24 @@ static ExprAST *ParseExpression();
 ///   ::= identifier '(' expression* ')'
 static ExprAST *ParseIdentifierExpr() {
   std::string IdName = IdentifierStr;
-  
-  getNextToken();  // eat identifier.
-  
+
+  getNextToken(); // eat identifier.
+
   if (CurTok != '(') // Simple variable ref.
     return new VariableExprAST(IdName);
-  
+
   // Call.
-  getNextToken();  // eat (
-  std::vector<ExprAST*> Args;
+  getNextToken(); // eat (
+  std::vector<ExprAST *> Args;
   if (CurTok != ')') {
     while (1) {
       ExprAST *Arg = ParseExpression();
-      if (!Arg) return 0;
+      if (!Arg)
+        return 0;
       Args.push_back(Arg);
 
-      if (CurTok == ')') break;
+      if (CurTok == ')')
+        break;
 
       if (CurTok != ',')
         return Error("Expected ')' or ',' in argument list");
@@ -221,7 +241,7 @@ static ExprAST *ParseIdentifierExpr() {
 
   // Eat the ')'.
   getNextToken();
-  
+
   return new CallExprAST(IdName, Args);
 }
 
@@ -234,13 +254,14 @@ static ExprAST *ParseNumberExpr() {
 
 /// parenexpr ::= '(' expression ')'
 static ExprAST *ParseParenExpr() {
-  getNextToken();  // eat (.
+  getNextToken(); // eat (.
   ExprAST *V = ParseExpression();
-  if (!V) return 0;
-  
+  if (!V)
+    return 0;
+
   if (CurTok != ')')
     return Error("expected ')'");
-  getNextToken();  // eat ).
+  getNextToken(); // eat ).
   return V;
 }
 
@@ -250,10 +271,14 @@ static ExprAST *ParseParenExpr() {
 ///   ::= parenexpr
 static ExprAST *ParsePrimary() {
   switch (CurTok) {
-  default: return Error("unknown token when expecting an expression");
-  case tok_identifier: return ParseIdentifierExpr();
-  case tok_number:     return ParseNumberExpr();
-  case '(':            return ParseParenExpr();
+  default:
+    return Error("unknown token when expecting an expression");
+  case tok_identifier:
+    return ParseIdentifierExpr();
+  case tok_number:
+    return ParseNumberExpr();
+  case '(':
+    return ParseParenExpr();
   }
 }
 
@@ -263,28 +288,30 @@ static ExprAST *ParseBinOpRHS(int ExprPrec, ExprAST *LHS) {
   // If this is a binop, find its precedence.
   while (1) {
     int TokPrec = GetTokPrecedence();
-    
+
     // If this is a binop that binds at least as tightly as the current binop,
     // consume it, otherwise we are done.
     if (TokPrec < ExprPrec)
       return LHS;
-    
+
     // Okay, we know this is a binop.
     int BinOp = CurTok;
-    getNextToken();  // eat binop
-    
+    getNextToken(); // eat binop
+
     // Parse the primary expression after the binary operator.
     ExprAST *RHS = ParsePrimary();
-    if (!RHS) return 0;
-    
+    if (!RHS)
+      return 0;
+
     // If BinOp binds less tightly with RHS than the operator after RHS, let
     // the pending operator take RHS as its LHS.
     int NextPrec = GetTokPrecedence();
     if (TokPrec < NextPrec) {
-      RHS = ParseBinOpRHS(TokPrec+1, RHS);
-      if (RHS == 0) return 0;
+      RHS = ParseBinOpRHS(TokPrec + 1, RHS);
+      if (RHS == 0)
+        return 0;
     }
-    
+
     // Merge LHS/RHS.
     LHS = new BinaryExprAST(BinOp, LHS, RHS);
   }
@@ -295,8 +322,9 @@ static ExprAST *ParseBinOpRHS(int ExprPrec, ExprAST *LHS) {
 ///
 static ExprAST *ParseExpression() {
   ExprAST *LHS = ParsePrimary();
-  if (!LHS) return 0;
-  
+  if (!LHS)
+    return 0;
+
   return ParseBinOpRHS(0, LHS);
 }
 
@@ -308,27 +336,28 @@ static PrototypeAST *ParsePrototype() {
 
   std::string FnName = IdentifierStr;
   getNextToken();
-  
+
   if (CurTok != '(')
     return ErrorP("Expected '(' in prototype");
-  
+
   std::vector<std::string> ArgNames;
   while (getNextToken() == tok_identifier)
     ArgNames.push_back(IdentifierStr);
   if (CurTok != ')')
     return ErrorP("Expected ')' in prototype");
-  
+
   // success.
-  getNextToken();  // eat ')'.
-  
+  getNextToken(); // eat ')'.
+
   return new PrototypeAST(FnName, ArgNames);
 }
 
 /// definition ::= 'def' prototype expression
 static FunctionAST *ParseDefinition() {
-  getNextToken();  // eat def.
+  getNextToken(); // eat def.
   PrototypeAST *Proto = ParsePrototype();
-  if (Proto == 0) return 0;
+  if (Proto == 0)
+    return 0;
 
   if (ExprAST *E = ParseExpression())
     return new FunctionAST(Proto, E);
@@ -347,20 +376,263 @@ static FunctionAST *ParseTopLevelExpr() {
 
 /// external ::= 'extern' prototype
 static PrototypeAST *ParseExtern() {
-  getNextToken();  // eat extern.
+  getNextToken(); // eat extern.
   return ParsePrototype();
 }
 
+//===----------------------------------------------------------------------===//
+// Quick and dirty hack
+//===----------------------------------------------------------------------===//
+
+// FIXME: Obviously we can do better than this
+std::string GenerateUniqueName(const char *root)
+{
+  static int i = 0;
+  char s[16];
+  sprintf(s, "%s%d", root, i++);
+  std::string S = s;
+  return S;
+}
+
+std::string MakeLegalFunctionName(std::string Name)
+{
+  std::string NewName;
+  if (!Name.length())
+      return GenerateUniqueName("anon_func_");
+
+  // Start with what we have
+  NewName = Name;
+
+  // Look for a numberic first character
+  if (NewName.find_first_of("0123456789") == 0) {
+    NewName.insert(0, 1, 'n');
+  }
+
+  // Replace illegal characters with their ASCII equivalent
+  std::string legal_elements = "_abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
+  size_t pos;
+  while ((pos = NewName.find_first_not_of(legal_elements)) != std::string::npos) {
+    char old_c = NewName.at(pos);
+    char new_str[16];
+    sprintf(new_str, "%d", (int)old_c);
+    NewName = NewName.replace(pos, 1, new_str);
+  }
+
+  return NewName;
+}
+
+//===----------------------------------------------------------------------===//
+// MCJIT helper class
+//===----------------------------------------------------------------------===//
+
+class MCJITHelper
+{
+public:
+  MCJITHelper(LLVMContext& C) : Context(C), OpenModule(NULL) {}
+  ~MCJITHelper();
+
+  Function *getFunction(const std::string FnName);
+  Module *getModuleForNewFunction();
+  void *getPointerToFunction(Function* F);
+  void *getSymbolAddress(const std::string &Name);
+  void dump();
+
+private:
+  typedef std::vector<Module*> ModuleVector;
+  typedef std::vector<ExecutionEngine*> EngineVector;
+
+  LLVMContext  &Context;
+  Module       *OpenModule;
+  ModuleVector  Modules;
+  EngineVector  Engines;
+};
+
+class HelpingMemoryManager : public SectionMemoryManager
+{
+  HelpingMemoryManager(const HelpingMemoryManager&) LLVM_DELETED_FUNCTION;
+  void operator=(const HelpingMemoryManager&) LLVM_DELETED_FUNCTION;
+
+public:
+  HelpingMemoryManager(MCJITHelper *Helper) : MasterHelper(Helper) {}
+  virtual ~HelpingMemoryManager() {}
+
+  /// This method returns the address of the specified symbol.
+  /// Our implementation will attempt to find symbols in other
+  /// modules associated with the MCJITHelper to cross link symbols
+  /// from one generated module to another.
+  virtual uint64_t getSymbolAddress(const std::string &Name) override;
+private:
+  MCJITHelper *MasterHelper;
+};
+
+uint64_t HelpingMemoryManager::getSymbolAddress(const std::string &Name)
+{
+  uint64_t FnAddr = SectionMemoryManager::getSymbolAddress(Name);
+  if (FnAddr)
+    return FnAddr;
+
+  uint64_t HelperFun = (uint64_t) MasterHelper->getSymbolAddress(Name);
+  if (!HelperFun)
+    report_fatal_error("Program used extern function '" + Name +
+                       "' which could not be resolved!");
+
+  return HelperFun;
+}
+
+MCJITHelper::~MCJITHelper()
+{
+  if (OpenModule)
+    delete OpenModule;
+  EngineVector::iterator begin = Engines.begin();
+  EngineVector::iterator end = Engines.end();
+  EngineVector::iterator it;
+  for (it = begin; it != end; ++it)
+    delete *it;
+}
+
+Function *MCJITHelper::getFunction(const std::string FnName) {
+  ModuleVector::iterator begin = Modules.begin();
+  ModuleVector::iterator end = Modules.end();
+  ModuleVector::iterator it;
+  for (it = begin; it != end; ++it) {
+    Function *F = (*it)->getFunction(FnName);
+    if (F) {
+      if (*it == OpenModule)
+          return F;
+
+      assert(OpenModule != NULL);
+
+      // This function is in a module that has already been JITed.
+      // We need to generate a new prototype for external linkage.
+      Function *PF = OpenModule->getFunction(FnName);
+      if (PF && !PF->empty()) {
+        ErrorF("redefinition of function across modules");
+        return 0;
+      }
+
+      // If we don't have a prototype yet, create one.
+      if (!PF)
+        PF = Function::Create(F->getFunctionType(), 
+                                      Function::ExternalLinkage, 
+                                      FnName, 
+                                      OpenModule);
+      return PF;
+    }
+  }
+  return NULL;
+}
+
+Module *MCJITHelper::getModuleForNewFunction() {
+  // If we have a Module that hasn't been JITed, use that.
+  if (OpenModule)
+    return OpenModule;
+
+  // Otherwise create a new Module.
+  std::string ModName = GenerateUniqueName("mcjit_module_");
+  Module *M = new Module(ModName, Context);
+  Modules.push_back(M);
+  OpenModule = M;
+  return M;
+}
+
+void *MCJITHelper::getPointerToFunction(Function* F) {
+  // See if an existing instance of MCJIT has this function.
+  EngineVector::iterator begin = Engines.begin();
+  EngineVector::iterator end = Engines.end();
+  EngineVector::iterator it;
+  for (it = begin; it != end; ++it) {
+    void *P = (*it)->getPointerToFunction(F);
+    if (P)
+      return P;
+  }
+
+  // If we didn't find the function, see if we can generate it.
+  if (OpenModule) {
+    std::string ErrStr;
+    ExecutionEngine *NewEngine = EngineBuilder(std::unique_ptr<Module>(OpenModule))
+                                              .setErrorStr(&ErrStr)
+                                              .setMCJITMemoryManager(std::unique_ptr<HelpingMemoryManager>(new HelpingMemoryManager(this)))
+                                              .create();
+    if (!NewEngine) {
+      fprintf(stderr, "Could not create ExecutionEngine: %s\n", ErrStr.c_str());
+      exit(1);
+    }
+
+    // Create a function pass manager for this engine
+    FunctionPassManager *FPM = new FunctionPassManager(OpenModule);
+
+    // Set up the optimizer pipeline.  Start with registering info about how the
+    // target lays out data structures.
+    OpenModule->setDataLayout(NewEngine->getDataLayout());
+    FPM->add(new DataLayoutPass());
+    // Provide basic AliasAnalysis support for GVN.
+    FPM->add(createBasicAliasAnalysisPass());
+    // Promote allocas to registers.
+    FPM->add(createPromoteMemoryToRegisterPass());
+    // Do simple "peephole" optimizations and bit-twiddling optzns.
+    FPM->add(createInstructionCombiningPass());
+    // Reassociate expressions.
+    FPM->add(createReassociatePass());
+    // Eliminate Common SubExpressions.
+    FPM->add(createGVNPass());
+    // Simplify the control flow graph (deleting unreachable blocks, etc).
+    FPM->add(createCFGSimplificationPass());
+    FPM->doInitialization();
+
+    // For each function in the module
+    Module::iterator it;
+    Module::iterator end = OpenModule->end();
+    for (it = OpenModule->begin(); it != end; ++it) {
+      // Run the FPM on this function
+      FPM->run(*it);
+    }
+
+    // We don't need this anymore
+    delete FPM;
+
+    OpenModule = NULL;
+    Engines.push_back(NewEngine);
+    NewEngine->finalizeObject();
+    return NewEngine->getPointerToFunction(F);
+  }
+  return NULL;
+}
+
+void *MCJITHelper::getSymbolAddress(const std::string &Name)
+{
+  // Look for the symbol in each of our execution engines.
+  EngineVector::iterator begin = Engines.begin();
+  EngineVector::iterator end = Engines.end();
+  EngineVector::iterator it;
+  for (it = begin; it != end; ++it) {
+    uint64_t FAddr = (*it)->getFunctionAddress(Name);
+    if (FAddr) {
+       return (void *)FAddr; 
+    }
+  }
+  return NULL;
+}
+
+void MCJITHelper::dump()
+{
+  ModuleVector::iterator begin = Modules.begin();
+  ModuleVector::iterator end = Modules.end();
+  ModuleVector::iterator it;
+  for (it = begin; it != end; ++it)
+    (*it)->dump();
+}
 //===----------------------------------------------------------------------===//
 // Code Generation
 //===----------------------------------------------------------------------===//
 
-static Module *TheModule;
+static MCJITHelper *JITHelper;
 static IRBuilder<> Builder(getGlobalContext());
-static std::map<std::string, Value*> NamedValues;
-static FunctionPassManager *TheFPM;
+static std::map<std::string, Value *> NamedValues;
 
-Value *ErrorV(const char *Str) { Error(Str); return 0; }
+Value *ErrorV(const char *Str) {
+  Error(Str);
+  return 0;
+}
 
 Value *NumberExprAST::Codegen() {
   return ConstantFP::get(getGlobalContext(), APFloat(Val));
@@ -375,93 +647,103 @@ Value *VariableExprAST::Codegen() {
 Value *BinaryExprAST::Codegen() {
   Value *L = LHS->Codegen();
   Value *R = RHS->Codegen();
-  if (L == 0 || R == 0) return 0;
-  
+  if (L == 0 || R == 0)
+    return 0;
+
   switch (Op) {
-  case '+': return Builder.CreateFAdd(L, R, "addtmp");
-  case '-': return Builder.CreateFSub(L, R, "subtmp");
-  case '*': return Builder.CreateFMul(L, R, "multmp");
+  case '+':
+    return Builder.CreateFAdd(L, R, "addtmp");
+  case '-':
+    return Builder.CreateFSub(L, R, "subtmp");
+  case '*':
+    return Builder.CreateFMul(L, R, "multmp");
   case '<':
     L = Builder.CreateFCmpULT(L, R, "cmptmp");
     // Convert bool 0/1 to double 0.0 or 1.0
     return Builder.CreateUIToFP(L, Type::getDoubleTy(getGlobalContext()),
                                 "booltmp");
-  default: return ErrorV("invalid binary operator");
+  default:
+    return ErrorV("invalid binary operator");
   }
 }
 
 Value *CallExprAST::Codegen() {
   // Look up the name in the global module table.
-  Function *CalleeF = TheModule->getFunction(Callee);
+  Function *CalleeF = JITHelper->getFunction(Callee);
   if (CalleeF == 0)
     return ErrorV("Unknown function referenced");
-  
+
   // If argument mismatch error.
   if (CalleeF->arg_size() != Args.size())
     return ErrorV("Incorrect # arguments passed");
 
-  std::vector<Value*> ArgsV;
+  std::vector<Value *> ArgsV;
   for (unsigned i = 0, e = Args.size(); i != e; ++i) {
     ArgsV.push_back(Args[i]->Codegen());
-    if (ArgsV.back() == 0) return 0;
+    if (ArgsV.back() == 0)
+      return 0;
   }
-  
+
   return Builder.CreateCall(CalleeF, ArgsV, "calltmp");
 }
 
 Function *PrototypeAST::Codegen() {
   // Make the function type:  double(double,double) etc.
-  std::vector<Type*> Doubles(Args.size(),
-                             Type::getDoubleTy(getGlobalContext()));
-  FunctionType *FT = FunctionType::get(Type::getDoubleTy(getGlobalContext()),
-                                       Doubles, false);
-  
-  Function *F = Function::Create(FT, Function::ExternalLinkage, Name, TheModule);
-  
+  std::vector<Type *> Doubles(Args.size(),
+                              Type::getDoubleTy(getGlobalContext()));
+  FunctionType *FT =
+      FunctionType::get(Type::getDoubleTy(getGlobalContext()), Doubles, false);
+
+  std::string FnName = MakeLegalFunctionName(Name);
+
+  Module *M = JITHelper->getModuleForNewFunction();
+
+  Function *F =
+      Function::Create(FT, Function::ExternalLinkage, FnName, M);
+
   // If F conflicted, there was already something named 'Name'.  If it has a
   // body, don't allow redefinition or reextern.
-  if (F->getName() != Name) {
+  if (F->getName() != FnName) {
     // Delete the one we just made and get the existing one.
     F->eraseFromParent();
-    F = TheModule->getFunction(Name);
-    
+    F = JITHelper->getFunction(Name); 
     // If F already has a body, reject this.
     if (!F->empty()) {
       ErrorF("redefinition of function");
       return 0;
     }
-    
+
     // If F took a different number of args, reject.
     if (F->arg_size() != Args.size()) {
       ErrorF("redefinition of function with different # args");
       return 0;
     }
   }
-  
+
   // Set names for all arguments.
   unsigned Idx = 0;
   for (Function::arg_iterator AI = F->arg_begin(); Idx != Args.size();
        ++AI, ++Idx) {
     AI->setName(Args[Idx]);
-    
+
     // Add arguments to variable symbol table.
     NamedValues[Args[Idx]] = AI;
   }
-  
+
   return F;
 }
 
 Function *FunctionAST::Codegen() {
   NamedValues.clear();
-  
+
   Function *TheFunction = Proto->Codegen();
   if (TheFunction == 0)
     return 0;
-  
+
   // Create a new basic block to start insertion into.
   BasicBlock *BB = BasicBlock::Create(getGlobalContext(), "entry", TheFunction);
   Builder.SetInsertPoint(BB);
-  
+
   if (Value *RetVal = Body->Codegen()) {
     // Finish off the function.
     Builder.CreateRet(RetVal);
@@ -469,12 +751,9 @@ Function *FunctionAST::Codegen() {
     // Validate the generated code, checking for consistency.
     verifyFunction(*TheFunction);
 
-    // Optimize the function.
-    TheFPM->run(*TheFunction);
-    
     return TheFunction;
   }
-  
+
   // Error reading body, remove function.
   TheFunction->eraseFromParent();
   return 0;
@@ -484,8 +763,6 @@ Function *FunctionAST::Codegen() {
 // Top-Level parsing and JIT Driver
 //===----------------------------------------------------------------------===//
 
-static ExecutionEngine *TheExecutionEngine;
-
 static void HandleDefinition() {
   if (FunctionAST *F = ParseDefinition()) {
     if (Function *LF = F->Codegen()) {
@@ -515,8 +792,8 @@ static void HandleTopLevelExpression() {
   if (FunctionAST *F = ParseTopLevelExpr()) {
     if (Function *LF = F->Codegen()) {
       // JIT the function, returning a function pointer.
-      void *FPtr = TheExecutionEngine->getPointerToFunction(LF);
-      
+      void *FPtr = JITHelper->getPointerToFunction(LF);
+
       // Cast it to the right type (takes no arguments, returns a double) so we
       // can call it as a native function.
       double (*FP)() = (double (*)())(intptr_t)FPtr;
@@ -533,11 +810,20 @@ static void MainLoop() {
   while (1) {
     fprintf(stderr, "ready> ");
     switch (CurTok) {
-    case tok_eof:    return;
-    case ';':        getNextToken(); break;  // ignore top-level semicolons.
-    case tok_def:    HandleDefinition(); break;
-    case tok_extern: HandleExtern(); break;
-    default:         HandleTopLevelExpression(); break;
+    case tok_eof:
+      return;
+    case ';':
+      getNextToken();
+      break; // ignore top-level semicolons.
+    case tok_def:
+      HandleDefinition();
+      break;
+    case tok_extern:
+      HandleExtern();
+      break;
+    default:
+      HandleTopLevelExpression();
+      break;
     }
   }
 }
@@ -547,8 +833,7 @@ static void MainLoop() {
 //===----------------------------------------------------------------------===//
 
 /// putchard - putchar that takes a double and returns 0.
-extern "C" 
-double putchard(double X) {
+extern "C" double putchard(double X) {
   putchar((char)X);
   return 0;
 }
@@ -559,58 +844,27 @@ double putchard(double X) {
 
 int main() {
   InitializeNativeTarget();
+  InitializeNativeTargetAsmPrinter();
+  InitializeNativeTargetAsmParser();
   LLVMContext &Context = getGlobalContext();
+  JITHelper = new MCJITHelper(Context);
 
   // Install standard binary operators.
   // 1 is lowest precedence.
   BinopPrecedence['<'] = 10;
   BinopPrecedence['+'] = 20;
   BinopPrecedence['-'] = 20;
-  BinopPrecedence['*'] = 40;  // highest.
+  BinopPrecedence['*'] = 40; // highest.
 
   // Prime the first token.
   fprintf(stderr, "ready> ");
   getNextToken();
 
-  // Make the module, which holds all the code.
-  TheModule = new Module("my cool jit", Context);
-
-  // Create the JIT.  This takes ownership of the module.
-  std::string ErrStr;
-  TheExecutionEngine = EngineBuilder(TheModule).setErrorStr(&ErrStr).create();
-  if (!TheExecutionEngine) {
-    fprintf(stderr, "Could not create ExecutionEngine: %s\n", ErrStr.c_str());
-    exit(1);
-  }
-
-  FunctionPassManager OurFPM(TheModule);
-
-  // Set up the optimizer pipeline.  Start with registering info about how the
-  // target lays out data structures.
-  OurFPM.add(new DataLayout(*TheExecutionEngine->getDataLayout()));
-  // Provide basic AliasAnalysis support for GVN.
-  OurFPM.add(createBasicAliasAnalysisPass());
-  // Do simple "peephole" optimizations and bit-twiddling optzns.
-  OurFPM.add(createInstructionCombiningPass());
-  // Reassociate expressions.
-  OurFPM.add(createReassociatePass());
-  // Eliminate Common SubExpressions.
-  OurFPM.add(createGVNPass());
-  // Simplify the control flow graph (deleting unreachable blocks, etc).
-  OurFPM.add(createCFGSimplificationPass());
-
-  OurFPM.doInitialization();
-
-  // Set the global so the code gen can use this.
-  TheFPM = &OurFPM;
-
   // Run the main "interpreter loop" now.
   MainLoop();
 
-  TheFPM = 0;
-
   // Print out all of the generated code.
-  TheModule->dump();
+  JITHelper->dump();
 
   return 0;
 }