test-release.sh: Drop some unused command-line options.
[oota-llvm.git] / utils / sort_includes.py
1 #!/usr/bin/env python
2
3 """Script to sort the top-most block of #include lines.
4
5 Assumes the LLVM coding conventions.
6
7 Currently, this script only bothers sorting the llvm/... headers. Patches
8 welcome for more functionality, and sorting other header groups.
9 """
10
11 import argparse
12 import os
13
14 def sort_includes(f):
15   """Sort the #include lines of a specific file."""
16
17   # Skip files which are under INPUTS trees or test trees.
18   if 'INPUTS/' in f.name or 'test/' in f.name:
19     return
20
21   ext = os.path.splitext(f.name)[1]
22   if ext not in ['.cpp', '.c', '.h', '.inc', '.def']:
23     return
24
25   lines = f.readlines()
26   look_for_api_header = ext in ['.cpp', '.c']
27   found_headers = False
28   headers_begin = 0
29   headers_end = 0
30   api_headers = []
31   local_headers = []
32   project_headers = []
33   system_headers = []
34   for (i, l) in enumerate(lines):
35     if l.strip() == '':
36       continue
37     if l.startswith('#include'):
38       if not found_headers:
39         headers_begin = i
40         found_headers = True
41       headers_end = i
42       header = l[len('#include'):].lstrip()
43       if look_for_api_header and header.startswith('"'):
44         api_headers.append(header)
45         look_for_api_header = False
46         continue
47       if header.startswith('<') or header.startswith('"gtest/'):
48         system_headers.append(header)
49         continue
50       if (header.startswith('"llvm/') or header.startswith('"llvm-c/') or
51           header.startswith('"clang/') or header.startswith('"clang-c/')):
52         project_headers.append(header)
53         continue
54       local_headers.append(header)
55       continue
56
57     # Only allow comments and #defines prior to any includes. If either are
58     # mixed with includes, the order might be sensitive.
59     if found_headers:
60       break
61     if l.startswith('//') or l.startswith('#define') or l.startswith('#ifndef'):
62       continue
63     break
64   if not found_headers:
65     return
66
67   local_headers = sorted(set(local_headers))
68   project_headers = sorted(set(project_headers))
69   system_headers = sorted(set(system_headers))
70   headers = api_headers + local_headers + project_headers + system_headers
71   header_lines = ['#include ' + h for h in headers]
72   lines = lines[:headers_begin] + header_lines + lines[headers_end + 1:]
73
74   f.seek(0)
75   f.truncate()
76   f.writelines(lines)
77
78 def main():
79   parser = argparse.ArgumentParser(description=__doc__)
80   parser.add_argument('files', nargs='+', type=argparse.FileType('r+'),
81                       help='the source files to sort includes within')
82   args = parser.parse_args()
83   for f in args.files:
84     sort_includes(f)
85
86 if __name__ == '__main__':
87   main()