Fix another annoying bug that took forever to track down. This one involves abstract...
[oota-llvm.git] / lib / Bytecode / Reader / InstructionReader.cpp
index 9dc5c6fc3dabdea5236aed50c98b304f11e7908a..1f4aa68f4aebaeface3c719138be676a6ca25137 100644 (file)
@@ -1,4 +1,4 @@
-//===- ReadInst.cpp - Code to read an instruction from bytecode -------------===
+//===- ReadInst.cpp - Code to read an instruction from bytecode -----------===//
 //
 // This file defines the mechanism to read an instruction from a bytecode 
 // stream.
@@ -9,7 +9,7 @@
 // TODO: Change from getValue(Raw.Arg1) etc, to getArg(Raw, 1)
 //       Make it check type, so that casts are checked.
 //
-//===------------------------------------------------------------------------===
+//===----------------------------------------------------------------------===//
 
 #include "llvm/iOther.h"
 #include "llvm/iTerminators.h"
@@ -20,7 +20,7 @@
 bool BytecodeParser::ParseRawInst(const uchar *&Buf, const uchar *EndBuf, 
                                  RawInst &Result) {
   unsigned Op, Typ;
-  if (read(Buf, EndBuf, Op)) return true;
+  if (read(Buf, EndBuf, Op)) return failure(true);
 
   Result.NumOperands =  Op >> 30;
   Result.Opcode      = (Op >> 24) & 63;
@@ -45,44 +45,45 @@ bool BytecodeParser::ParseRawInst(const uchar *&Buf, const uchar *EndBuf,
     break;
   case 0:
     Buf -= 4;  // Hrm, try this again...
-    if (read_vbr(Buf, EndBuf, Result.Opcode)) return true;
-    if (read_vbr(Buf, EndBuf, Typ)) return true;
+    if (read_vbr(Buf, EndBuf, Result.Opcode)) return failure(true);
+    if (read_vbr(Buf, EndBuf, Typ)) return failure(true);
     Result.Ty = getType(Typ);
-    if (read_vbr(Buf, EndBuf, Result.NumOperands)) return true;
+    if (Result.Ty == 0) return failure(true);
+    if (read_vbr(Buf, EndBuf, Result.NumOperands)) return failure(true);
 
     switch (Result.NumOperands) {
     case 0: 
       cerr << "Zero Arg instr found!\n"; 
-      return true;  // This encoding is invalid!
+      return failure(true);  // This encoding is invalid!
     case 1: 
-      if (read_vbr(Buf, EndBuf, Result.Arg1)) return true;
+      if (read_vbr(Buf, EndBuf, Result.Arg1)) return failure(true);
       break;
     case 2:
       if (read_vbr(Buf, EndBuf, Result.Arg1) || 
-         read_vbr(Buf, EndBuf, Result.Arg2)) return true;
+         read_vbr(Buf, EndBuf, Result.Arg2)) return failure(true);
       break;
     case 3:
       if (read_vbr(Buf, EndBuf, Result.Arg1) || 
          read_vbr(Buf, EndBuf, Result.Arg2) ||
-         read_vbr(Buf, EndBuf, Result.Arg3)) return true;
+          read_vbr(Buf, EndBuf, Result.Arg3)) return failure(true);
       break;
     default:
       if (read_vbr(Buf, EndBuf, Result.Arg1) || 
-         read_vbr(Buf, EndBuf, Result.Arg2)) return true;
+         read_vbr(Buf, EndBuf, Result.Arg2)) return failure(true);
 
       // Allocate a vector to hold arguments 3, 4, 5, 6 ...
       Result.VarArgs = new vector<unsigned>(Result.NumOperands-2);
       for (unsigned a = 0; a < Result.NumOperands-2; a++)
-       if (read_vbr(Buf, EndBuf, (*Result.VarArgs)[a])) return true;
+       if (read_vbr(Buf, EndBuf, (*Result.VarArgs)[a])) return failure(true);
       break;
     }
-    if (align32(Buf, EndBuf)) return true;
+    if (align32(Buf, EndBuf)) return failure(true);
     break;
   }
 
 #if 0
   cerr << "NO: "  << Result.NumOperands   << " opcode: " << Result.Opcode 
-       << " Ty: " << Result.Ty->getName() << " arg1: "   << Result.Arg1 
+       << " Ty: " << Result.Ty->getDescription() << " arg1: "   << Result.Arg1 
        << " arg2: "   << Result.Arg2 << " arg3: "   << Result.Arg3 << endl;
 #endif
   return false;
@@ -92,7 +93,7 @@ bool BytecodeParser::ParseRawInst(const uchar *&Buf, const uchar *EndBuf,
 bool BytecodeParser::ParseInstruction(const uchar *&Buf, const uchar *EndBuf,
                                      Instruction *&Res) {
   RawInst Raw;
-  if (ParseRawInst(Buf, EndBuf, Raw)) return true;;
+  if (ParseRawInst(Buf, EndBuf, Raw)) return failure(true);
 
   if (Raw.Opcode >= Instruction::FirstUnaryOp && 
       Raw.Opcode <  Instruction::NumUnaryOps  && Raw.NumOperands == 1) {
@@ -109,10 +110,13 @@ bool BytecodeParser::ParseInstruction(const uchar *&Buf, const uchar *EndBuf,
 
   Value *V;
   switch (Raw.Opcode) {
-  case Instruction::Cast:
-    Res = new CastInst(getValue(Raw.Ty, Raw.Arg1), getType(Raw.Arg2));
+  case Instruction::Cast: {
+    V = getValue(Raw.Ty, Raw.Arg1);
+    const Type *Ty = getType(Raw.Arg2);
+    if (V == 0 || Ty == 0) { cerr << "Invalid cast!\n"; return true; }
+    Res = new CastInst(V, Ty);
     return false;
-
+  }
   case Instruction::PHINode: {
     PHINode *PN = new PHINode(Raw.Ty);
     switch (Raw.NumOperands) {
@@ -120,22 +124,22 @@ bool BytecodeParser::ParseInstruction(const uchar *&Buf, const uchar *EndBuf,
     case 1: 
     case 3: cerr << "Invalid phi node encountered!\n"; 
             delete PN; 
-           return true;
+           return failure(true);
     case 2: PN->addIncoming(getValue(Raw.Ty, Raw.Arg1),
-                           (BasicBlock*)getValue(Type::LabelTy, Raw.Arg2)); 
+                           cast<BasicBlock>(getValue(Type::LabelTy,Raw.Arg2)));
       break;
     default:
       PN->addIncoming(getValue(Raw.Ty, Raw.Arg1), 
-                     (BasicBlock*)getValue(Type::LabelTy, Raw.Arg2));
+                     cast<BasicBlock>(getValue(Type::LabelTy, Raw.Arg2)));
       if (Raw.VarArgs->size() & 1) {
        cerr << "PHI Node with ODD number of arguments!\n";
        delete PN;
-       return true;
+       return failure(true);
       } else {
         vector<unsigned> &args = *Raw.VarArgs;
         for (unsigned i = 0; i < args.size(); i+=2)
           PN->addIncoming(getValue(Raw.Ty, args[i]),
-                         (BasicBlock*)getValue(Type::LabelTy, args[i+1]));
+                         cast<BasicBlock>(getValue(Type::LabelTy, args[i+1])));
       }
       delete Raw.VarArgs; 
       break;
@@ -160,12 +164,12 @@ bool BytecodeParser::ParseInstruction(const uchar *&Buf, const uchar *EndBuf,
 
   case Instruction::Br:
     if (Raw.NumOperands == 1) {
-      Res = new BranchInst((BasicBlock*)getValue(Type::LabelTy, Raw.Arg1));
+      Res = new BranchInst(cast<BasicBlock>(getValue(Type::LabelTy, Raw.Arg1)));
       return false;
     } else if (Raw.NumOperands == 3) {
-      Res = new BranchInst((BasicBlock*)getValue(Type::LabelTy, Raw.Arg1),
-                          (BasicBlock*)getValue(Type::LabelTy, Raw.Arg2),
-                                       getValue(Type::BoolTy , Raw.Arg3));
+      Res = new BranchInst(cast<BasicBlock>(getValue(Type::LabelTy, Raw.Arg1)),
+                          cast<BasicBlock>(getValue(Type::LabelTy, Raw.Arg2)),
+                                            getValue(Type::BoolTy , Raw.Arg3));
       return false;
     }
     break;
@@ -173,110 +177,219 @@ bool BytecodeParser::ParseInstruction(const uchar *&Buf, const uchar *EndBuf,
   case Instruction::Switch: {
     SwitchInst *I = 
       new SwitchInst(getValue(Raw.Ty, Raw.Arg1), 
-                     (BasicBlock*)getValue(Type::LabelTy, Raw.Arg2));
+                     cast<BasicBlock>(getValue(Type::LabelTy, Raw.Arg2)));
     Res = I;
     if (Raw.NumOperands < 3) return false;  // No destinations?  Wierd.
 
     if (Raw.NumOperands == 3 || Raw.VarArgs->size() & 1) {
       cerr << "Switch statement with odd number of arguments!\n";
       delete I;
-      return true;
+      return failure(true);
     }      
     
     vector<unsigned> &args = *Raw.VarArgs;
     for (unsigned i = 0; i < args.size(); i += 2)
-      I->dest_push_back((ConstPoolVal*)getValue(Raw.Ty, args[i]),
-                        (BasicBlock*)getValue(Type::LabelTy, args[i+1]));
+      I->dest_push_back(cast<ConstPoolVal>(getValue(Raw.Ty, args[i])),
+                        cast<BasicBlock>(getValue(Type::LabelTy, args[i+1])));
 
     delete Raw.VarArgs;
     return false;
   }
 
   case Instruction::Call: {
-    Method *M = (Method*)getValue(Raw.Ty, Raw.Arg1);
-    if (M == 0) return true;
+    Value *M = getValue(Raw.Ty, Raw.Arg1);
+    if (M == 0) return failure(true);
 
-    const MethodType::ParamTypes &PL = M->getMethodType()->getParamTypes();
-    MethodType::ParamTypes::const_iterator It = PL.begin();
+    // Check to make sure we have a pointer to method type
+    PointerType *PTy = dyn_cast<PointerType>(M->getType());
+    if (PTy == 0) return failure(true);
+    MethodType *MTy = dyn_cast<MethodType>(PTy->getValueType());
+    if (MTy == 0) return failure(true);
 
     vector<Value *> Params;
-    switch (Raw.NumOperands) {
-    case 0: cerr << "Invalid call instruction encountered!\n";
-           return true;
-    case 1: break;
-    case 2: Params.push_back(getValue(*It++, Raw.Arg2)); break;
-    case 3: Params.push_back(getValue(*It++, Raw.Arg2)); 
-            if (It == PL.end()) return true;
-            Params.push_back(getValue(*It++, Raw.Arg3)); break;
-    default:
-      Params.push_back(getValue(*It++, Raw.Arg2));
-      {
-        vector<unsigned> &args = *Raw.VarArgs;
-        for (unsigned i = 0; i < args.size(); i++) {
-         if (It == PL.end()) return true;
-          Params.push_back(getValue(*It++, args[i]));
+    const MethodType::ParamTypes &PL = MTy->getParamTypes();
+
+    if (!MTy->isVarArg()) {
+      MethodType::ParamTypes::const_iterator It = PL.begin();
+
+      switch (Raw.NumOperands) {
+      case 0: cerr << "Invalid call instruction encountered!\n";
+       return failure(true);
+      case 1: break;
+      case 2: Params.push_back(getValue(*It++, Raw.Arg2)); break;
+      case 3: Params.push_back(getValue(*It++, Raw.Arg2)); 
+       if (It == PL.end()) return failure(true);
+       Params.push_back(getValue(*It++, Raw.Arg3)); break;
+      default:
+       Params.push_back(getValue(*It++, Raw.Arg2));
+       {
+         vector<unsigned> &args = *Raw.VarArgs;
+         for (unsigned i = 0; i < args.size(); i++) {
+           if (It == PL.end()) return failure(true);
+           // TODO: Check getValue for null!
+           Params.push_back(getValue(*It++, args[i]));
+         }
+       }
+       delete Raw.VarArgs;
+      }
+      if (It != PL.end()) return failure(true);
+    } else {
+      if (Raw.NumOperands > 2) {
+       vector<unsigned> &args = *Raw.VarArgs;
+       if (args.size() < 1) return failure(true);
+
+       if ((args.size() & 1) != 0)
+         return failure(true);  // Must be pairs of type/value
+       for (unsigned i = 0; i < args.size(); i+=2) {
+         const Type *Ty = getType(args[i]);
+         if (Ty == 0)
+           return failure(true);
+         
+         Value *V = getValue(Ty, args[i+1]);
+         if (V == 0) return failure(true);
+         Params.push_back(V);
        }
+       delete Raw.VarArgs;
       }
-      delete Raw.VarArgs;
     }
-    if (It != PL.end()) return true;
 
     Res = new CallInst(M, Params);
     return false;
   }
+  case Instruction::Invoke: {
+    Value *M = getValue(Raw.Ty, Raw.Arg1);
+    if (M == 0) return failure(true);
+
+    // Check to make sure we have a pointer to method type
+    PointerType *PTy = dyn_cast<PointerType>(M->getType());
+    if (PTy == 0) return failure(true);
+    MethodType *MTy = dyn_cast<MethodType>(PTy->getValueType());
+    if (MTy == 0) return failure(true);
+
+    vector<Value *> Params;
+    const MethodType::ParamTypes &PL = MTy->getParamTypes();
+    vector<unsigned> &args = *Raw.VarArgs;
+
+    BasicBlock *Normal, *Except;
+
+    if (!MTy->isVarArg()) {
+      if (Raw.NumOperands < 3) return failure(true);
+
+      Normal = cast<BasicBlock>(getValue(Type::LabelTy, Raw.Arg2));
+      Except = cast<BasicBlock>(getValue(Type::LabelTy, args[0]));
+
+      MethodType::ParamTypes::const_iterator It = PL.begin();
+      for (unsigned i = 1; i < args.size(); i++) {
+       if (It == PL.end()) return failure(true);
+       // TODO: Check getValue for null!
+       Params.push_back(getValue(*It++, args[i]));
+      }
+
+      if (It != PL.end()) return failure(true);
+    } else {
+      if (args.size() < 4) return failure(true);
+
+      Normal = cast<BasicBlock>(getValue(Type::LabelTy, args[0]));
+      Except = cast<BasicBlock>(getValue(Type::LabelTy, args[2]));
+
+      if ((args.size() & 1) != 0)
+       return failure(true);  // Must be pairs of type/value
+      for (unsigned i = 4; i < args.size(); i+=2) {
+       // TODO: Check getValue for null!
+       Params.push_back(getValue(getType(args[i]), args[i+1]));
+      }
+    }
+
+    delete Raw.VarArgs;
+    Res = new InvokeInst(M, Normal, Except, Params);
+    return false;
+  }
   case Instruction::Malloc:
-    if (Raw.NumOperands > 2) return true;
+    if (Raw.NumOperands > 2) return failure(true);
     V = Raw.NumOperands ? getValue(Type::UIntTy, Raw.Arg1) : 0;
     Res = new MallocInst(Raw.Ty, V);
     return false;
 
   case Instruction::Alloca:
-    if (Raw.NumOperands > 2) return true;
+    if (Raw.NumOperands > 2) return failure(true);
     V = Raw.NumOperands ? getValue(Type::UIntTy, Raw.Arg1) : 0;
     Res = new AllocaInst(Raw.Ty, V);
     return false;
 
   case Instruction::Free:
     V = getValue(Raw.Ty, Raw.Arg1);
-    if (!V->getType()->isPointerType()) return true;
+    if (!V->getType()->isPointerType()) return failure(true);
     Res = new FreeInst(V);
     return false;
 
-  case Instruction::Load: {
+  case Instruction::Load:
+  case Instruction::GetElementPtr: {
     vector<ConstPoolVal*> Idx;
     switch (Raw.NumOperands) {
-    case 0: cerr << "Invalid load encountered!\n"; return true;
+    case 0: cerr << "Invalid load encountered!\n"; return failure(true);
     case 1: break;
     case 2: V = getValue(Type::UByteTy, Raw.Arg2);
-            if (!V->isConstant()) return true;
-            Idx.push_back(V->castConstant());
+            if (!isa<ConstPoolVal>(V)) return failure(true);
+            Idx.push_back(cast<ConstPoolVal>(V));
             break;
     case 3: V = getValue(Type::UByteTy, Raw.Arg2);
-            if (!V->isConstant()) return true;
-            Idx.push_back(V->castConstant());
+            if (!isa<ConstPoolVal>(V)) return failure(true);
+            Idx.push_back(cast<ConstPoolVal>(V));
            V = getValue(Type::UByteTy, Raw.Arg3);
-            if (!V->isConstant()) return true;
-            Idx.push_back(V->castConstant());
+            if (!isa<ConstPoolVal>(V)) return failure(true);
+            Idx.push_back(cast<ConstPoolVal>(V));
             break;
     default:
       V = getValue(Type::UByteTy, Raw.Arg2);
-      if (!V->isConstant()) return true;
-      Idx.push_back(V->castConstant());
+      if (!isa<ConstPoolVal>(V)) return failure(true);
+      Idx.push_back(cast<ConstPoolVal>(V));
+      vector<unsigned> &args = *Raw.VarArgs;
+      for (unsigned i = 0, E = args.size(); i != E; ++i) {
+       V = getValue(Type::UByteTy, args[i]);
+       if (!isa<ConstPoolVal>(V)) return failure(true);
+       Idx.push_back(cast<ConstPoolVal>(V));
+      }
+      delete Raw.VarArgs; 
+      break;
+    }
+    if (Raw.Opcode == Instruction::Load)
+      Res = new LoadInst(getValue(Raw.Ty, Raw.Arg1), Idx);
+    else if (Raw.Opcode == Instruction::GetElementPtr)
+      Res = new GetElementPtrInst(getValue(Raw.Ty, Raw.Arg1), Idx);
+    else
+      abort();
+    return false;
+  }
+  case Instruction::Store: {
+    vector<ConstPoolVal*> Idx;
+    switch (Raw.NumOperands) {
+    case 0: 
+    case 1: cerr << "Invalid store encountered!\n"; return failure(true);
+    case 2: break;
+    case 3: V = getValue(Type::UByteTy, Raw.Arg3);
+            if (!isa<ConstPoolVal>(V)) return failure(true);
+            Idx.push_back(cast<ConstPoolVal>(V));
+            break;
+    default:
       vector<unsigned> &args = *Raw.VarArgs;
       for (unsigned i = 0, E = args.size(); i != E; ++i) {
        V = getValue(Type::UByteTy, args[i]);
-       if (!V->isConstant()) return true;
-       Idx.push_back(V->castConstant());
+       if (!isa<ConstPoolVal>(V)) return failure(true);
+       Idx.push_back(cast<ConstPoolVal>(V));
       }
       delete Raw.VarArgs; 
       break;
     }
-    Res = new LoadInst(getValue(Raw.Ty, Raw.Arg1), Idx);
+
+    const Type *ElType = StoreInst::getIndexedType(Raw.Ty, Idx);
+    if (ElType == 0) return failure(true);
+    Res = new StoreInst(getValue(ElType, Raw.Arg1), getValue(Raw.Ty, Raw.Arg2),
+                       Idx);
     return false;
   }
   }  // end switch(Raw.Opcode) 
 
   cerr << "Unrecognized instruction! " << Raw.Opcode 
        << " ADDR = 0x" << (void*)Buf << endl;
-  return true;
+  return failure(true);
 }