ocaml bindings: introduce classify_value
[oota-llvm.git] / bindings / ocaml / llvm / llvm_ocaml.c
index cbc05448fa75c99baffc5836944e48b96045ff35..ec2d9490c28596761206f8d2e8dafaa2759177ac 100644 (file)
@@ -375,6 +375,69 @@ CAMLprim LLVMTypeRef llvm_type_of(LLVMValueRef Val) {
   return LLVMTypeOf(Val);
 }
 
+/* keep in sync with ValueKind.t */
+enum ValueKind {
+  NullValue=0,
+  Argument,
+  BasicBlock,
+  InlineAsm,
+  MDNode,
+  MDString,
+  BlockAddress,
+  ConstantAggregateZero,
+  ConstantArray,
+  ConstantExpr,
+  ConstantFP,
+  ConstantInt,
+  ConstantPointerNull,
+  ConstantStruct,
+  ConstantVector,
+  Function,
+  GlobalAlias,
+  GlobalVariable,
+  UndefValue,
+  Instruction
+};
+
+/* llvalue -> ValueKind.t */
+#define DEFINE_CASE(Val, Kind) \
+    do {if (LLVMIsA##Kind(Val)) CAMLreturn(Val_int(Kind));} while(0)
+
+CAMLprim value llvm_classify_value(LLVMValueRef Val) {
+  CAMLparam0();
+  if (!Val)
+    CAMLreturn(Val_int(NullValue));
+  if (LLVMIsAConstant(Val)) {
+    DEFINE_CASE(Val, BlockAddress);
+    DEFINE_CASE(Val, ConstantAggregateZero);
+    DEFINE_CASE(Val, ConstantArray);
+    DEFINE_CASE(Val, ConstantExpr);
+    DEFINE_CASE(Val, ConstantFP);
+    DEFINE_CASE(Val, ConstantInt);
+    DEFINE_CASE(Val, ConstantPointerNull);
+    DEFINE_CASE(Val, ConstantStruct);
+    DEFINE_CASE(Val, ConstantVector);
+  }
+  if (LLVMIsAInstruction(Val)) {
+    CAMLlocal1(result);
+    result = caml_alloc_small(1, 0);
+    Store_field(result, 0, Val_int(LLVMGetInstructionOpcode(Val)));
+    CAMLreturn(result);
+  }
+  if (LLVMIsAGlobalValue(Val)) {
+    DEFINE_CASE(Val, Function);
+    DEFINE_CASE(Val, GlobalAlias);
+    DEFINE_CASE(Val, GlobalVariable);
+  }
+  DEFINE_CASE(Val, Argument);
+  DEFINE_CASE(Val, BasicBlock);
+  DEFINE_CASE(Val, InlineAsm);
+  DEFINE_CASE(Val, MDNode);
+  DEFINE_CASE(Val, MDString);
+  DEFINE_CASE(Val, UndefValue);
+  failwith("Unknown Value class");
+}
+
 /* llvalue -> string */
 CAMLprim value llvm_value_name(LLVMValueRef Val) {
   return copy_string(LLVMGetValueName(Val));
@@ -1034,7 +1097,10 @@ DEFINE_ITERATORS(instr, Instruction, LLVMBasicBlockRef, LLVMValueRef,
 
 /* llvalue -> Opcode.t */
 CAMLprim value llvm_instr_get_opcode(LLVMValueRef Inst) {
-  LLVMOpcode o = LLVMGetInstructionOpcode(Inst);
+  LLVMOpcode o;
+  if (!LLVMIsAInstruction(Inst))
+      failwith("Not an instruction");
+  o = LLVMGetInstructionOpcode(Inst);
   assert (o <= LLVMUnwind );
   return Val_int(o);
 }