Drop a few unneeded ctor calls (missed code review comment).
[oota-llvm.git] / utils / lit / lit / ShUtil.py
1 from __future__ import absolute_import
2 import itertools
3
4 import lit.util
5 from lit.ShCommands import Command, Pipeline, Seq
6
7 class ShLexer:
8     def __init__(self, data, win32Escapes = False):
9         self.data = data
10         self.pos = 0
11         self.end = len(data)
12         self.win32Escapes = win32Escapes
13
14     def eat(self):
15         c = self.data[self.pos]
16         self.pos += 1
17         return c
18
19     def look(self):
20         return self.data[self.pos]
21
22     def maybe_eat(self, c):
23         """
24         maybe_eat(c) - Consume the character c if it is the next character,
25         returning True if a character was consumed. """
26         if self.data[self.pos] == c:
27             self.pos += 1
28             return True
29         return False
30
31     def lex_arg_fast(self, c):
32         # Get the leading whitespace free section.
33         chunk = self.data[self.pos - 1:].split(None, 1)[0]
34         
35         # If it has special characters, the fast path failed.
36         if ('|' in chunk or '&' in chunk or 
37             '<' in chunk or '>' in chunk or
38             "'" in chunk or '"' in chunk or
39             ';' in chunk or '\\' in chunk):
40             return None
41         
42         self.pos = self.pos - 1 + len(chunk)
43         return chunk
44         
45     def lex_arg_slow(self, c):
46         if c in "'\"":
47             str = self.lex_arg_quoted(c)
48         else:
49             str = c
50         while self.pos != self.end:
51             c = self.look()
52             if c.isspace() or c in "|&;":
53                 break
54             elif c in '><':
55                 # This is an annoying case; we treat '2>' as a single token so
56                 # we don't have to track whitespace tokens.
57
58                 # If the parse string isn't an integer, do the usual thing.
59                 if not str.isdigit():
60                     break
61
62                 # Otherwise, lex the operator and convert to a redirection
63                 # token.
64                 num = int(str)
65                 tok = self.lex_one_token()
66                 assert isinstance(tok, tuple) and len(tok) == 1
67                 return (tok[0], num)                    
68             elif c == '"':
69                 self.eat()
70                 str += self.lex_arg_quoted('"')
71             elif c == "'":
72                 self.eat()
73                 str += self.lex_arg_quoted("'")
74             elif not self.win32Escapes and c == '\\':
75                 # Outside of a string, '\\' escapes everything.
76                 self.eat()
77                 if self.pos == self.end:
78                     lit.util.warning(
79                         "escape at end of quoted argument in: %r" % self.data)
80                     return str
81                 str += self.eat()
82             else:
83                 str += self.eat()
84         return str
85
86     def lex_arg_quoted(self, delim):
87         str = ''
88         while self.pos != self.end:
89             c = self.eat()
90             if c == delim:
91                 return str
92             elif c == '\\' and delim == '"':
93                 # Inside a '"' quoted string, '\\' only escapes the quote
94                 # character and backslash, otherwise it is preserved.
95                 if self.pos == self.end:
96                     lit.util.warning(
97                         "escape at end of quoted argument in: %r" % self.data)
98                     return str
99                 c = self.eat()
100                 if c == '"': # 
101                     str += '"'
102                 elif c == '\\':
103                     str += '\\'
104                 else:
105                     str += '\\' + c
106             else:
107                 str += c
108         lit.util.warning("missing quote character in %r" % self.data)
109         return str
110     
111     def lex_arg_checked(self, c):
112         pos = self.pos
113         res = self.lex_arg_fast(c)
114         end = self.pos
115
116         self.pos = pos
117         reference = self.lex_arg_slow(c)
118         if res is not None:
119             if res != reference:
120                 raise ValueError("Fast path failure: %r != %r" % (
121                         res, reference))
122             if self.pos != end:
123                 raise ValueError("Fast path failure: %r != %r" % (
124                         self.pos, end))
125         return reference
126         
127     def lex_arg(self, c):
128         return self.lex_arg_fast(c) or self.lex_arg_slow(c)
129         
130     def lex_one_token(self):
131         """
132         lex_one_token - Lex a single 'sh' token. """
133
134         c = self.eat()
135         if c == ';':
136             return (c,)
137         if c == '|':
138             if self.maybe_eat('|'):
139                 return ('||',)
140             return (c,)
141         if c == '&':
142             if self.maybe_eat('&'):
143                 return ('&&',)
144             if self.maybe_eat('>'): 
145                 return ('&>',)
146             return (c,)
147         if c == '>':
148             if self.maybe_eat('&'):
149                 return ('>&',)
150             if self.maybe_eat('>'):
151                 return ('>>',)
152             return (c,)
153         if c == '<':
154             if self.maybe_eat('&'):
155                 return ('<&',)
156             if self.maybe_eat('>'):
157                 return ('<<',)
158             return (c,)
159
160         return self.lex_arg(c)
161
162     def lex(self):
163         while self.pos != self.end:
164             if self.look().isspace():
165                 self.eat()
166             else:
167                 yield self.lex_one_token()
168
169 ###
170  
171 class ShParser:
172     def __init__(self, data, win32Escapes = False, pipefail = False):
173         self.data = data
174         self.pipefail = pipefail
175         self.tokens = ShLexer(data, win32Escapes = win32Escapes).lex()
176     
177     def lex(self):
178         for item in self.tokens:
179             return item
180         return None
181     
182     def look(self):
183         token = self.lex()
184         if token is not None:
185             self.tokens = itertools.chain([token], self.tokens)
186         return token
187     
188     def parse_command(self):
189         tok = self.lex()
190         if not tok:
191             raise ValueError("empty command!")
192         if isinstance(tok, tuple):
193             raise ValueError("syntax error near unexpected token %r" % tok[0])
194         
195         args = [tok]
196         redirects = []
197         while 1:
198             tok = self.look()
199
200             # EOF?
201             if tok is None:
202                 break
203
204             # If this is an argument, just add it to the current command.
205             if isinstance(tok, str):
206                 args.append(self.lex())
207                 continue
208
209             # Otherwise see if it is a terminator.
210             assert isinstance(tok, tuple)
211             if tok[0] in ('|',';','&','||','&&'):
212                 break
213             
214             # Otherwise it must be a redirection.
215             op = self.lex()
216             arg = self.lex()
217             if not arg:
218                 raise ValueError("syntax error near token %r" % op[0])
219             redirects.append((op, arg))
220
221         return Command(args, redirects)
222
223     def parse_pipeline(self):
224         negate = False
225
226         commands = [self.parse_command()]
227         while self.look() == ('|',):
228             self.lex()
229             commands.append(self.parse_command())
230         return Pipeline(commands, negate, self.pipefail)
231             
232     def parse(self):
233         lhs = self.parse_pipeline()
234
235         while self.look():
236             operator = self.lex()
237             assert isinstance(operator, tuple) and len(operator) == 1
238
239             if not self.look():
240                 raise ValueError(
241                     "missing argument to operator %r" % operator[0])
242             
243             # FIXME: Operator precedence!!
244             lhs = Seq(lhs, operator[0], self.parse_pipeline())
245
246         return lhs
247
248 ###
249
250 import unittest
251
252 class TestShLexer(unittest.TestCase):
253     def lex(self, str, *args, **kwargs):
254         return list(ShLexer(str, *args, **kwargs).lex())
255
256     def test_basic(self):
257         self.assertEqual(self.lex('a|b>c&d<e;f'),
258                          ['a', ('|',), 'b', ('>',), 'c', ('&',), 'd', 
259                           ('<',), 'e', (';',), 'f'])
260
261     def test_redirection_tokens(self):
262         self.assertEqual(self.lex('a2>c'),
263                          ['a2', ('>',), 'c'])
264         self.assertEqual(self.lex('a 2>c'),
265                          ['a', ('>',2), 'c'])
266         
267     def test_quoting(self):
268         self.assertEqual(self.lex(""" 'a' """),
269                          ['a'])
270         self.assertEqual(self.lex(""" "hello\\"world" """),
271                          ['hello"world'])
272         self.assertEqual(self.lex(""" "hello\\'world" """),
273                          ["hello\\'world"])
274         self.assertEqual(self.lex(""" "hello\\\\world" """),
275                          ["hello\\world"])
276         self.assertEqual(self.lex(""" he"llo wo"rld """),
277                          ["hello world"])
278         self.assertEqual(self.lex(""" a\\ b a\\\\b """),
279                          ["a b", "a\\b"])
280         self.assertEqual(self.lex(""" "" "" """),
281                          ["", ""])
282         self.assertEqual(self.lex(""" a\\ b """, win32Escapes = True),
283                          ['a\\', 'b'])
284
285 class TestShParse(unittest.TestCase):
286     def parse(self, str):
287         return ShParser(str).parse()
288
289     def test_basic(self):
290         self.assertEqual(self.parse('echo hello'),
291                          Pipeline([Command(['echo', 'hello'], [])], False))
292         self.assertEqual(self.parse('echo ""'),
293                          Pipeline([Command(['echo', ''], [])], False))
294         self.assertEqual(self.parse("""echo -DFOO='a'"""),
295                          Pipeline([Command(['echo', '-DFOO=a'], [])], False))
296         self.assertEqual(self.parse('echo -DFOO="a"'),
297                          Pipeline([Command(['echo', '-DFOO=a'], [])], False))
298
299     def test_redirection(self):
300         self.assertEqual(self.parse('echo hello > c'),
301                          Pipeline([Command(['echo', 'hello'], 
302                                            [((('>'),), 'c')])], False))
303         self.assertEqual(self.parse('echo hello > c >> d'),
304                          Pipeline([Command(['echo', 'hello'], [(('>',), 'c'),
305                                                      (('>>',), 'd')])], False))
306         self.assertEqual(self.parse('a 2>&1'),
307                          Pipeline([Command(['a'], [(('>&',2), '1')])], False))
308
309     def test_pipeline(self):
310         self.assertEqual(self.parse('a | b'),
311                          Pipeline([Command(['a'], []),
312                                    Command(['b'], [])],
313                                   False))
314
315         self.assertEqual(self.parse('a | b | c'),
316                          Pipeline([Command(['a'], []),
317                                    Command(['b'], []),
318                                    Command(['c'], [])],
319                                   False))
320
321     def test_list(self):        
322         self.assertEqual(self.parse('a ; b'),
323                          Seq(Pipeline([Command(['a'], [])], False),
324                              ';',
325                              Pipeline([Command(['b'], [])], False)))
326
327         self.assertEqual(self.parse('a & b'),
328                          Seq(Pipeline([Command(['a'], [])], False),
329                              '&',
330                              Pipeline([Command(['b'], [])], False)))
331
332         self.assertEqual(self.parse('a && b'),
333                          Seq(Pipeline([Command(['a'], [])], False),
334                              '&&',
335                              Pipeline([Command(['b'], [])], False)))
336
337         self.assertEqual(self.parse('a || b'),
338                          Seq(Pipeline([Command(['a'], [])], False),
339                              '||',
340                              Pipeline([Command(['b'], [])], False)))
341
342         self.assertEqual(self.parse('a && b || c'),
343                          Seq(Seq(Pipeline([Command(['a'], [])], False),
344                                  '&&',
345                                  Pipeline([Command(['b'], [])], False)),
346                              '||',
347                              Pipeline([Command(['c'], [])], False)))
348
349         self.assertEqual(self.parse('a; b'),
350                          Seq(Pipeline([Command(['a'], [])], False),
351                              ';',
352                              Pipeline([Command(['b'], [])], False)))
353
354 if __name__ == '__main__':
355     unittest.main()