[python-bindings] Added support for iterating over a function's basic blocks, dumping...
authorMichael Gottesman <mgottesman@apple.com>
Wed, 11 Sep 2013 01:01:40 +0000 (01:01 +0000)
committerMichael Gottesman <mgottesman@apple.com>
Wed, 11 Sep 2013 01:01:40 +0000 (01:01 +0000)
Tests are included.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@190473 91177308-0d34-0410-b5e6-96231b3b80d8

bindings/python/llvm/core.py
bindings/python/llvm/tests/test_core.py

index 3da69d39f1c1ae6c65696d95afa871b6175c0eb1..14b0b4ce7b141cd1c5ce070e9086d569e5148451 100644 (file)
@@ -25,6 +25,7 @@ __all__ = [
     "Module",
     "Value",
     "Function",
+    "BasicBlock",
     "Context",
     "PassRegistry"
 ]
@@ -196,6 +197,69 @@ class Function(Value):
         f = lib.LLVMGetPreviousFunction(self)
         return f and Function(f)
     
+    @property
+    def first(self):
+        b = lib.LLVMGetFirstBasicBlock(self)
+        return b and BasicBlock(b)
+
+    @property
+    def last(self):
+        b = lib.LLVMGetLastBasicBlock(self)
+        return b and BasicBlock(b)
+
+    class __bb_iterator(object):
+        def __init__(self, function, reverse=False):
+            self.function = function
+            self.reverse = reverse
+            if self.reverse:
+                self.bb = function.last
+            else:
+                self.bb = function.first
+        
+        def __iter__(self):
+            return self
+        
+        def next(self):
+            if not isinstance(self.bb, BasicBlock):
+                raise StopIteration("")
+            result = self.bb
+            if self.reverse:
+                self.bb = self.bb.prev
+            else:
+                self.bb = self.bb.next
+            return result
+    
+    def __iter__(self):
+        return Function.__bb_iterator(self)
+
+    def __reversed__(self):
+        return Function.__bb_iterator(self, reverse=True)
+    
+    def __len__(self):
+        return lib.LLVMCountBasicBlocks(self)
+
+class BasicBlock(LLVMObject):
+    
+    def __init__(self, value):
+        LLVMObject.__init__(self, value)
+
+    @property
+    def next(self):
+        b = lib.LLVMGetNextBasicBlock(self)
+        return b and BasicBlock(b)
+
+    @property
+    def prev(self):
+        b = lib.LLVMGetPreviousBasicBlock(self)
+        return b and BasicBlock(b)
+    
+    @property
+    def name(self):
+        return lib.LLVMGetValueName(Value(lib.LLVMBasicBlockAsValue(self)))
+
+    def dump(self):
+        lib.LLVMDumpValue(Value(lib.LLVMBasicBlockAsValue(self)))
+
 class Context(LLVMObject):
 
     def __init__(self, context=None):
@@ -325,6 +389,25 @@ def register_library(library):
     library.LLVMDumpValue.argtypes = [Value]
     library.LLVMDumpValue.restype = None
 
+    # Basic Block Declarations.
+    library.LLVMGetFirstBasicBlock.argtypes = [Function]
+    library.LLVMGetFirstBasicBlock.restype = c_object_p
+
+    library.LLVMGetLastBasicBlock.argtypes = [Function]
+    library.LLVMGetLastBasicBlock.restype = c_object_p
+
+    library.LLVMGetNextBasicBlock.argtypes = [BasicBlock]
+    library.LLVMGetNextBasicBlock.restype = c_object_p
+
+    library.LLVMGetPreviousBasicBlock.argtypes = [BasicBlock]
+    library.LLVMGetPreviousBasicBlock.restype = c_object_p
+
+    library.LLVMBasicBlockAsValue.argtypes = [BasicBlock]
+    library.LLVMBasicBlockAsValue.restype = c_object_p
+
+    library.LLVMCountBasicBlocks.argtypes = [Function]
+    library.LLVMCountBasicBlocks.restype = c_uint
+
 def register_enumerations():
     for name, value in enumerations.OpCodes:
         OpCode.register(name, value)
index a1f79a490db502e2be6abac081a1773393343482..67e294b056bc94d808cfe30e7d8698ea8e7db0e3 100644 (file)
@@ -78,3 +78,25 @@ class TestCore(TestBase):
             self.assertEqual(f.name, functions[i])
             f.dump()
 
+    def test_function_basicblock_iteration(self):
+        m = parse_bitcode(MemoryBuffer(filename=self.get_test_bc()))
+        i = 0
+        
+        bb_list = ['b1', 'b2', 'end']
+        
+        f = m.first
+        while f.name != "f6":
+            f = f.next
+        
+        # Forward
+        for bb in f:
+            self.assertEqual(bb.name, bb_list[i])
+            bb.dump()
+            i += 1
+        
+        # Backwards
+        for bb in reversed(f):
+            i -= 1
+            self.assertEqual(bb.name, bb_list[i])
+            bb.dump()
+