Teach the include sorting script about the gtest headers; sort them with
[oota-llvm.git] / utils / sort_includes.py
1 #!/usr/bin/env python
2
3 """Script to sort the top-most block of #include lines.
4
5 Assumes the LLVM coding conventions.
6
7 Currently, this script only bothers sorting the llvm/... headers. Patches
8 welcome for more functionality, and sorting other header groups.
9 """
10
11 import argparse
12 import os
13
14 def sort_includes(f):
15   """Sort the #include lines of a specific file."""
16   lines = f.readlines()
17   look_for_api_header = os.path.splitext(f.name)[1] == '.cpp'
18   found_headers = False
19   headers_begin = 0
20   headers_end = 0
21   api_headers = []
22   local_headers = []
23   project_headers = []
24   system_headers = []
25   for (i, l) in enumerate(lines):
26     if l.strip() == '':
27       continue
28     if l.startswith('#include'):
29       if not found_headers:
30         headers_begin = i
31         found_headers = True
32       headers_end = i
33       header = l[len('#include'):].lstrip()
34       if look_for_api_header and header.startswith('"'):
35         api_headers.append(header)
36         look_for_api_header = False
37         continue
38       if header.startswith('<') or header.startswith('"gtest/'):
39         system_headers.append(header)
40         continue
41       if (header.startswith('"llvm/') or header.startswith('"llvm-c/') or
42           header.startswith('"clang/') or header.startswith('"clang-c/')):
43         project_headers.append(header)
44         continue
45       local_headers.append(header)
46       continue
47
48     # Only allow comments and #defines prior to any includes. If either are
49     # mixed with includes, the order might be sensitive.
50     if found_headers:
51       break
52     if l.startswith('//') or l.startswith('#define') or l.startswith('#ifndef'):
53       continue
54     break
55   if not found_headers:
56     return
57
58   local_headers.sort()
59   project_headers.sort()
60   system_headers.sort()
61   headers = api_headers + local_headers + project_headers + system_headers
62   header_lines = ['#include ' + h for h in headers]
63   lines = lines[:headers_begin] + header_lines + lines[headers_end + 1:]
64
65   f.seek(0)
66   f.truncate()
67   f.writelines(lines)
68
69 def main():
70   parser = argparse.ArgumentParser(description=__doc__)
71   parser.add_argument('files', nargs='+', type=argparse.FileType('r+'),
72                       help='the source files to sort includes within')
73   args = parser.parse_args()
74   for f in args.files:
75     sort_includes(f)
76
77 if __name__ == '__main__':
78   main()