Transform: add SymbolRewriter pass
[oota-llvm.git] / lib / Transforms / Utils / SymbolRewriter.cpp
1 //===- SymbolRewriter.cpp - Symbol Rewriter ---------------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // SymbolRewriter is a LLVM pass which can rewrite symbols transparently within
11 // existing code.  It is implemented as a compiler pass and is configured via a
12 // YAML configuration file.
13 //
14 // The YAML configuration file format is as follows:
15 //
16 // RewriteMapFile := RewriteDescriptors
17 // RewriteDescriptors := RewriteDescriptor | RewriteDescriptors
18 // RewriteDescriptor := RewriteDescriptorType ':' '{' RewriteDescriptorFields '}'
19 // RewriteDescriptorFields := RewriteDescriptorField | RewriteDescriptorFields
20 // RewriteDescriptorField := FieldIdentifier ':' FieldValue ','
21 // RewriteDescriptorType := Identifier
22 // FieldIdentifier := Identifier
23 // FieldValue := Identifier
24 // Identifier := [0-9a-zA-Z]+
25 //
26 // Currently, the following descriptor types are supported:
27 //
28 // - function:          (function rewriting)
29 //      + Source        (original name of the function)
30 //      + Target        (explicit transformation)
31 //      + Transform     (pattern transformation)
32 //      + Naked         (boolean, whether the function is undecorated)
33 // - global variable:   (external linkage global variable rewriting)
34 //      + Source        (original name of externally visible variable)
35 //      + Target        (explicit transformation)
36 //      + Transform     (pattern transformation)
37 // - global alias:      (global alias rewriting)
38 //      + Source        (original name of the aliased name)
39 //      + Target        (explicit transformation)
40 //      + Transform     (pattern transformation)
41 //
42 // Note that source and exactly one of [Target, Transform] must be provided
43 //
44 // New rewrite descriptors can be created.  Addding a new rewrite descriptor
45 // involves:
46 //
47 //  a) extended the rewrite descriptor kind enumeration
48 //     (<anonymous>::RewriteDescriptor::RewriteDescriptorType)
49 //  b) implementing the new descriptor
50 //     (c.f. <anonymous>::ExplicitRewriteFunctionDescriptor)
51 //  c) extending the rewrite map parser
52 //     (<anonymous>::RewriteMapParser::parseEntry)
53 //
54 //  Specify to rewrite the symbols using the `-rewrite-symbols` option, and
55 //  specify the map file to use for the rewriting via the `-rewrite-map-file`
56 //  option.
57 //
58 //===----------------------------------------------------------------------===//
59
60 #define DEBUG_TYPE "symbol-rewriter"
61 #include "llvm/CodeGen/Passes.h"
62 #include "llvm/Pass.h"
63 #include "llvm/PassManager.h"
64 #include "llvm/Support/CommandLine.h"
65 #include "llvm/Support/Debug.h"
66 #include "llvm/Support/MemoryBuffer.h"
67 #include "llvm/Support/Regex.h"
68 #include "llvm/Support/SourceMgr.h"
69 #include "llvm/Support/YAMLParser.h"
70 #include "llvm/Support/raw_ostream.h"
71 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
72 #include "llvm/Transforms/Utils/SymbolRewriter.h"
73
74 using namespace llvm;
75
76 static cl::list<std::string> RewriteMapFiles("rewrite-map-file",
77                                              cl::desc("Symbol Rewrite Map"),
78                                              cl::value_desc("filename"));
79
80 namespace llvm {
81 namespace SymbolRewriter {
82 template <RewriteDescriptor::Type DT, typename ValueType,
83           ValueType *(llvm::Module::*Get)(StringRef) const>
84 class ExplicitRewriteDescriptor : public RewriteDescriptor {
85 public:
86   const std::string Source;
87   const std::string Target;
88
89   ExplicitRewriteDescriptor(StringRef S, StringRef T, const bool Naked)
90       : RewriteDescriptor(DT), Source(Naked ? StringRef("\01" + S.str()) : S),
91         Target(T) {}
92
93   bool performOnModule(Module &M) override;
94
95   static bool classof(const RewriteDescriptor *RD) {
96     return RD->getType() == DT;
97   }
98 };
99
100 template <RewriteDescriptor::Type DT, typename ValueType,
101           ValueType *(llvm::Module::*Get)(StringRef) const>
102 bool ExplicitRewriteDescriptor<DT, ValueType, Get>::performOnModule(Module &M) {
103   bool Changed = false;
104   if (ValueType *S = (M.*Get)(Source)) {
105     if (Value *T = (M.*Get)(Target))
106       S->setValueName(T->getValueName());
107     else
108       S->setName(Target);
109     Changed = true;
110   }
111   return Changed;
112 }
113
114 template <RewriteDescriptor::Type DT, typename ValueType,
115           ValueType *(llvm::Module::*Get)(StringRef) const,
116           iterator_range<typename iplist<ValueType>::iterator> (llvm::Module::*Iterator)()>
117 class PatternRewriteDescriptor : public RewriteDescriptor {
118 public:
119   const std::string Pattern;
120   const std::string Transform;
121
122   PatternRewriteDescriptor(StringRef P, StringRef T)
123     : RewriteDescriptor(DT), Pattern(P), Transform(T) { }
124
125   bool performOnModule(Module &M) override;
126
127   static bool classof(const RewriteDescriptor *RD) {
128     return RD->getType() == DT;
129   }
130 };
131
132 template <RewriteDescriptor::Type DT, typename ValueType,
133           ValueType *(llvm::Module::*Get)(StringRef) const,
134           iterator_range<typename iplist<ValueType>::iterator> (llvm::Module::*Iterator)()>
135 bool PatternRewriteDescriptor<DT, ValueType, Get, Iterator>::
136 performOnModule(Module &M) {
137   bool Changed = false;
138   for (auto &C : (M.*Iterator)()) {
139     std::string Error;
140
141     std::string Name = Regex(Pattern).sub(Transform, C.getName(), &Error);
142     if (!Error.empty())
143       report_fatal_error("unable to transforn " + C.getName() + " in " +
144                          M.getModuleIdentifier() + ": " + Error);
145
146     if (Value *V = (M.*Get)(Name))
147       C.setValueName(V->getValueName());
148     else
149       C.setName(Name);
150
151     Changed = true;
152   }
153   return Changed;
154 }
155
156 /// Represents a rewrite for an explicitly named (function) symbol.  Both the
157 /// source function name and target function name of the transformation are
158 /// explicitly spelt out.
159 using ExplicitRewriteFunctionDescriptor =
160     ExplicitRewriteDescriptor<RewriteDescriptor::Type::Function, llvm::Function,
161                               &llvm::Module::getFunction>;
162
163 /// Represents a rewrite for an explicitly named (global variable) symbol.  Both
164 /// the source variable name and target variable name are spelt out.  This
165 /// applies only to module level variables.
166 using ExplicitRewriteGlobalVariableDescriptor =
167     ExplicitRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable,
168                               llvm::GlobalVariable,
169                               &llvm::Module::getGlobalVariable>;
170
171 /// Represents a rewrite for an explicitly named global alias.  Both the source
172 /// and target name are explicitly spelt out.
173 using ExplicitRewriteNamedAliasDescriptor =
174     ExplicitRewriteDescriptor<RewriteDescriptor::Type::NamedAlias,
175                               llvm::GlobalAlias, &llvm::Module::getNamedAlias>;
176
177 /// Represents a rewrite for a regular expression based pattern for functions.
178 /// A pattern for the function name is provided and a transformation for that
179 /// pattern to determine the target function name create the rewrite rule.
180 using PatternRewriteFunctionDescriptor =
181     PatternRewriteDescriptor<RewriteDescriptor::Type::Function, llvm::Function,
182                              &llvm::Module::getFunction,
183                              &llvm::Module::functions>;
184
185
186 /// Represents a rewrite for a global variable based upon a matching pattern.
187 /// Each global variable matching the provided pattern will be transformed as
188 /// described in the transformation pattern for the target.  Applies only to
189 /// module level variables.
190 using PatternRewriteGlobalVariableDescriptor =
191     PatternRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable,
192                              llvm::GlobalVariable,
193                              &llvm::Module::getGlobalVariable,
194                              &llvm::Module::globals>;
195
196 /// PatternRewriteNamedAliasDescriptor - represents a rewrite for global
197 /// aliases which match a given pattern.  The provided transformation will be
198 /// applied to each of the matching names.
199 using PatternRewriteNamedAliasDescriptor =
200     PatternRewriteDescriptor<RewriteDescriptor::Type::NamedAlias,
201                              llvm::GlobalAlias,
202                              &llvm::Module::getNamedAlias,
203                              &llvm::Module::aliases>;
204
205
206 bool RewriteMapParser::parse(const std::string &MapFile,
207                              RewriteDescriptorList *DL) {
208   ErrorOr<std::unique_ptr<MemoryBuffer>> Mapping =
209       MemoryBuffer::getFile(MapFile);
210
211   if (!Mapping)
212     report_fatal_error("unable to read rewrite map '" + MapFile + "': " +
213                        Mapping.getError().message());
214
215   if (!parse(*Mapping, DL))
216     report_fatal_error("unable to parse rewrite map '" + MapFile + "'");
217
218   return true;
219 }
220
221 bool RewriteMapParser::parse(std::unique_ptr<MemoryBuffer> &MapFile,
222                              RewriteDescriptorList *DL) {
223   SourceMgr SM;
224   yaml::Stream YS(MapFile->getBuffer(), SM);
225
226   for (auto &Document : YS) {
227     yaml::MappingNode *DescriptorList;
228
229     // ignore empty documents
230     if (isa<yaml::NullNode>(Document.getRoot()))
231       continue;
232
233     DescriptorList = dyn_cast<yaml::MappingNode>(Document.getRoot());
234     if (!DescriptorList) {
235       YS.printError(Document.getRoot(), "DescriptorList node must be a map");
236       return false;
237     }
238
239     for (auto &Descriptor : *DescriptorList)
240       if (!parseEntry(YS, Descriptor, DL))
241         return false;
242   }
243
244   return true;
245 }
246
247 bool RewriteMapParser::parseEntry(yaml::Stream &YS, yaml::KeyValueNode &Entry,
248                                   RewriteDescriptorList *DL) {
249   const std::string kRewriteTypeFunction = "function";
250   const std::string kRewriteTypeGlobalVariable = "global variable";
251   const std::string kRewriteTypeGlobalAlias = "global alias";
252
253   yaml::ScalarNode *Key;
254   yaml::MappingNode *Value;
255   SmallString<32> KeyStorage;
256   StringRef RewriteType;
257
258   Key = dyn_cast<yaml::ScalarNode>(Entry.getKey());
259   if (!Key) {
260     YS.printError(Entry.getKey(), "rewrite type must be a scalar");
261     return false;
262   }
263
264   Value = dyn_cast<yaml::MappingNode>(Entry.getValue());
265   if (!Value) {
266     YS.printError(Entry.getValue(), "rewrite descriptor must be a map");
267     return false;
268   }
269
270   RewriteType = Key->getValue(KeyStorage);
271   if (RewriteType == kRewriteTypeFunction)
272     return parseRewriteFunctionDescriptor(YS, Key, Value, DL);
273   else if (RewriteType == kRewriteTypeGlobalVariable)
274     return parseRewriteGlobalVariableDescriptor(YS, Key, Value, DL);
275   else if (RewriteType == kRewriteTypeGlobalAlias)
276     return parseRewriteGlobalAliasDescriptor(YS, Key, Value, DL);
277
278   YS.printError(Entry.getKey(), "unknown rewrite type");
279   return false;
280 }
281
282 bool RewriteMapParser::
283 parseRewriteFunctionDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
284                                yaml::MappingNode *Descriptor,
285                                RewriteDescriptorList *DL) {
286   const std::string kDescriptorFieldSource = "source";
287   const std::string kDescriptorFieldTarget = "target";
288   const std::string kDescriptorFieldTransform = "transform";
289   const std::string kDescriptorFieldNaked = "naked";
290
291   bool Naked = false;
292   std::string Source;
293   std::string Target;
294   std::string Transform;
295
296   for (auto &Field : *Descriptor) {
297     yaml::ScalarNode *Key;
298     yaml::ScalarNode *Value;
299     SmallString<32> KeyStorage;
300     SmallString<32> ValueStorage;
301     StringRef KeyValue;
302
303     Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
304     if (!Key) {
305       YS.printError(Field.getKey(), "descriptor key must be a scalar");
306       return false;
307     }
308
309     Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
310     if (!Value) {
311       YS.printError(Field.getValue(), "descriptor value must be a scalar");
312       return false;
313     }
314
315     KeyValue = Key->getValue(KeyStorage);
316     if (KeyValue == kDescriptorFieldSource) {
317       std::string Error;
318
319       Source = Value->getValue(ValueStorage);
320       if (!Regex(Source).isValid(Error)) {
321         YS.printError(Field.getKey(), "invalid regex: " + Error);
322         return false;
323       }
324     } else if (KeyValue == kDescriptorFieldTarget) {
325       Target = Value->getValue(ValueStorage);
326     } else if (KeyValue == kDescriptorFieldTransform) {
327       Transform = Value->getValue(ValueStorage);
328     } else if (KeyValue == kDescriptorFieldNaked) {
329       std::string Undecorated;
330
331       Undecorated = Value->getValue(ValueStorage);
332       Naked = StringRef(Undecorated).lower() == "true" || Undecorated == "1";
333     } else {
334       YS.printError(Field.getKey(), "unknown key for function");
335       return false;
336     }
337   }
338
339   if (Transform.empty() == Target.empty()) {
340     YS.printError(Descriptor,
341                   "exactly one of transform or target must be specified");
342     return false;
343   }
344
345   // TODO see if there is a more elegant solution to selecting the rewrite
346   // descriptor type
347   if (!Target.empty())
348     DL->push_back(new ExplicitRewriteFunctionDescriptor(Source, Target, Naked));
349   else
350     DL->push_back(new PatternRewriteFunctionDescriptor(Source, Transform));
351
352   return true;
353 }
354
355 bool RewriteMapParser::
356 parseRewriteGlobalVariableDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
357                                      yaml::MappingNode *Descriptor,
358                                      RewriteDescriptorList *DL) {
359   const std::string kDescriptorFieldSource = "source";
360   const std::string kDescriptorFieldTarget = "target";
361   const std::string kDescriptorFieldTransform = "transform";
362
363   std::string Source;
364   std::string Target;
365   std::string Transform;
366
367   for (auto &Field : *Descriptor) {
368     yaml::ScalarNode *Key;
369     yaml::ScalarNode *Value;
370     SmallString<32> KeyStorage;
371     SmallString<32> ValueStorage;
372     StringRef KeyValue;
373
374     Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
375     if (!Key) {
376       YS.printError(Field.getKey(), "descriptor Key must be a scalar");
377       return false;
378     }
379
380     Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
381     if (!Value) {
382       YS.printError(Field.getValue(), "descriptor value must be a scalar");
383       return false;
384     }
385
386     KeyValue = Key->getValue(KeyStorage);
387     if (KeyValue == kDescriptorFieldSource) {
388       std::string Error;
389
390       Source = Value->getValue(ValueStorage);
391       if (!Regex(Source).isValid(Error)) {
392         YS.printError(Field.getKey(), "invalid regex: " + Error);
393         return false;
394       }
395     } else if (KeyValue == kDescriptorFieldTarget) {
396       Target = Value->getValue(ValueStorage);
397     } else if (KeyValue == kDescriptorFieldTransform) {
398       Transform = Value->getValue(ValueStorage);
399     } else {
400       YS.printError(Field.getKey(), "unknown Key for Global Variable");
401       return false;
402     }
403   }
404
405   if (Transform.empty() == Target.empty()) {
406     YS.printError(Descriptor,
407                   "exactly one of transform or target must be specified");
408     return false;
409   }
410
411   if (!Target.empty())
412     DL->push_back(new ExplicitRewriteGlobalVariableDescriptor(Source, Target,
413                                                               /*Naked*/false));
414   else
415     DL->push_back(new PatternRewriteGlobalVariableDescriptor(Source,
416                                                              Transform));
417
418   return true;
419 }
420
421 bool RewriteMapParser::
422 parseRewriteGlobalAliasDescriptor(yaml::Stream &YS, yaml::ScalarNode *K,
423                                   yaml::MappingNode *Descriptor,
424                                   RewriteDescriptorList *DL) {
425   const std::string kDescriptorFieldSource = "source";
426   const std::string kDescriptorFieldTarget = "target";
427   const std::string kDescriptorFieldTransform = "transform";
428
429   std::string Source;
430   std::string Target;
431   std::string Transform;
432
433   for (auto &Field : *Descriptor) {
434     yaml::ScalarNode *Key;
435     yaml::ScalarNode *Value;
436     SmallString<32> KeyStorage;
437     SmallString<32> ValueStorage;
438     StringRef KeyValue;
439
440     Key = dyn_cast<yaml::ScalarNode>(Field.getKey());
441     if (!Key) {
442       YS.printError(Field.getKey(), "descriptor key must be a scalar");
443       return false;
444     }
445
446     Value = dyn_cast<yaml::ScalarNode>(Field.getValue());
447     if (!Value) {
448       YS.printError(Field.getValue(), "descriptor value must be a scalar");
449       return false;
450     }
451
452     KeyValue = Key->getValue(KeyStorage);
453     if (KeyValue == kDescriptorFieldSource) {
454       std::string Error;
455
456       Source = Value->getValue(ValueStorage);
457       if (!Regex(Source).isValid(Error)) {
458         YS.printError(Field.getKey(), "invalid regex: " + Error);
459         return false;
460       }
461     } else if (KeyValue == kDescriptorFieldTarget) {
462       Target = Value->getValue(ValueStorage);
463     } else if (KeyValue == kDescriptorFieldTransform) {
464       Transform = Value->getValue(ValueStorage);
465     } else {
466       YS.printError(Field.getKey(), "unknown key for Global Alias");
467       return false;
468     }
469   }
470
471   if (Transform.empty() == Target.empty()) {
472     YS.printError(Descriptor,
473                   "exactly one of transform or target must be specified");
474     return false;
475   }
476
477   if (!Target.empty())
478     DL->push_back(new ExplicitRewriteNamedAliasDescriptor(Source, Target,
479                                                           /*Naked*/false));
480   else
481     DL->push_back(new PatternRewriteNamedAliasDescriptor(Source, Transform));
482
483   return true;
484 }
485 }
486 }
487
488 namespace {
489 class RewriteSymbols : public ModulePass {
490 public:
491   static char ID; // Pass identification, replacement for typeid
492
493   RewriteSymbols();
494   RewriteSymbols(SymbolRewriter::RewriteDescriptorList &DL);
495
496   virtual bool runOnModule(Module &M) override;
497
498 private:
499   void loadAndParseMapFiles();
500
501   SymbolRewriter::RewriteDescriptorList Descriptors;
502 };
503
504 char RewriteSymbols::ID = 0;
505
506 RewriteSymbols::RewriteSymbols() : ModulePass(ID) {
507   initializeRewriteSymbolsPass(*PassRegistry::getPassRegistry());
508   loadAndParseMapFiles();
509 }
510
511 RewriteSymbols::RewriteSymbols(SymbolRewriter::RewriteDescriptorList &DL)
512     : ModulePass(ID) {
513   std::swap(Descriptors, DL);
514 }
515
516 bool RewriteSymbols::runOnModule(Module &M) {
517   bool Changed;
518
519   Changed = false;
520   for (auto &Descriptor : Descriptors)
521     Changed |= Descriptor.performOnModule(M);
522
523   return Changed;
524 }
525
526 void RewriteSymbols::loadAndParseMapFiles() {
527   const std::vector<std::string> MapFiles(RewriteMapFiles);
528   SymbolRewriter::RewriteMapParser parser;
529
530   for (const auto &MapFile : MapFiles)
531     parser.parse(MapFile, &Descriptors);
532 }
533 }
534
535 INITIALIZE_PASS(RewriteSymbols, "rewrite-symbols", "Rewrite Symbols", false,
536                 false);
537
538 ModulePass *llvm::createRewriteSymbolsPass() { return new RewriteSymbols(); }
539
540 ModulePass *
541 llvm::createRewriteSymbolsPass(SymbolRewriter::RewriteDescriptorList &DL) {
542   return new RewriteSymbols(DL);
543 }