Check in statistifying patch for Bill
[oota-llvm.git] / lib / Transforms / IPO / LowerSetJmp.cpp
1 //===- LowerSetJmp.cpp - Code pertaining to lowering set/long jumps -------===//
2 // 
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by the LLVM research group and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 // 
8 //===----------------------------------------------------------------------===//
9 //
10 //  This file implements the lowering of setjmp and longjmp to use the
11 //  LLVM invoke and unwind instructions as necessary.
12 //
13 //  Lowering of longjmp is fairly trivial. We replace the call with a
14 //  call to the LLVM library function "__llvm_sjljeh_throw_longjmp()".
15 //  This unwinds the stack for us calling all of the destructors for
16 //  objects allocated on the stack.
17 //
18 //  At a setjmp call, the basic block is split and the setjmp removed.
19 //  The calls in a function that have a setjmp are converted to invoke
20 //  where the except part checks to see if it's a longjmp exception and,
21 //  if so, if it's handled in the function. If it is, then it gets the
22 //  value returned by the longjmp and goes to where the basic block was
23 //  split. Invoke instructions are handled in a similar fashion with the
24 //  original except block being executed if it isn't a longjmp except
25 //  that is handled by that function.
26 //
27 //===----------------------------------------------------------------------===//
28
29 //===----------------------------------------------------------------------===//
30 // FIXME: This pass doesn't deal with PHI statements just yet. That is,
31 // we expect this to occur before SSAification is done. This would seem
32 // to make sense, but in general, it might be a good idea to make this
33 // pass invokable via the "opt" command at will.
34 //===----------------------------------------------------------------------===//
35
36 #include "llvm/Constants.h"
37 #include "llvm/DerivedTypes.h"
38 #include "llvm/Instructions.h"
39 #include "llvm/Intrinsics.h"
40 #include "llvm/Module.h"
41 #include "llvm/Pass.h"
42 #include "llvm/Support/CFG.h"
43 #include "llvm/Support/InstVisitor.h"
44 #include "Support/DepthFirstIterator.h"
45 #include "Support/Statistic.h"
46 #include "Support/StringExtras.h"
47 #include "Support/VectorExtras.h"
48
49 namespace {
50   Statistic<> LongJmpsTransformed("lowersetjmp",
51                                   "Number of longjmps transformed");
52   Statistic<> SetJmpsTransformed("lowersetjmp",
53                                  "Number of setjmps transformed");
54   Statistic<> CallsTransformed("lowersetjmp",
55                                "Number of calls invokified");
56   Statistic<> InvokesTransformed("lowersetjmp",
57                                  "Number of invokes modified");
58
59   //===--------------------------------------------------------------------===//
60   // LowerSetJmp pass implementation. This is subclassed from the "Pass"
61   // class because it works on a module as a whole, not a function at a
62   // time.
63
64   class LowerSetJmp : public Pass,
65                       public InstVisitor<LowerSetJmp> {
66     // LLVM library functions...
67     Function* InitSJMap;        // __llvm_sjljeh_init_setjmpmap
68     Function* DestroySJMap;     // __llvm_sjljeh_destroy_setjmpmap
69     Function* AddSJToMap;       // __llvm_sjljeh_add_setjmp_to_map
70     Function* ThrowLongJmp;     // __llvm_sjljeh_throw_longjmp
71     Function* TryCatchLJ;       // __llvm_sjljeh_try_catching_longjmp_exception
72     Function* IsLJException;    // __llvm_sjljeh_is_longjmp_exception
73     Function* GetLJValue;       // __llvm_sjljeh_get_longjmp_value
74
75     typedef std::pair<SwitchInst*, CallInst*> SwitchValuePair;
76
77     // Keep track of those basic blocks reachable via a depth-first search of
78     // the CFG from a setjmp call. We only need to transform those "call" and
79     // "invoke" instructions that are reachable from the setjmp call site.
80     std::set<BasicBlock*> DFSBlocks;
81
82     // The setjmp map is going to hold information about which setjmps
83     // were called (each setjmp gets its own number) and with which
84     // buffer it was called.
85     std::map<Function*, AllocaInst*>            SJMap;
86
87     // The rethrow basic block map holds the basic block to branch to if
88     // the exception isn't handled in the current function and needs to
89     // be rethrown.
90     std::map<const Function*, BasicBlock*>      RethrowBBMap;
91
92     // The preliminary basic block map holds a basic block that grabs the
93     // exception and determines if it's handled by the current function.
94     std::map<const Function*, BasicBlock*>      PrelimBBMap;
95
96     // The switch/value map holds a switch inst/call inst pair. The
97     // switch inst controls which handler (if any) gets called and the
98     // value is the value returned to that handler by the call to
99     // __llvm_sjljeh_get_longjmp_value.
100     std::map<const Function*, SwitchValuePair>  SwitchValMap;
101
102     // A map of which setjmps we've seen so far in a function.
103     std::map<const Function*, unsigned>         SetJmpIDMap;
104
105     AllocaInst*     GetSetJmpMap(Function* Func);
106     BasicBlock*     GetRethrowBB(Function* Func);
107     SwitchValuePair GetSJSwitch(Function* Func, BasicBlock* Rethrow);
108
109     void TransformLongJmpCall(CallInst* Inst);
110     void TransformSetJmpCall(CallInst* Inst);
111
112     bool IsTransformableFunction(const std::string& Name);
113   public:
114     void visitCallInst(CallInst& CI);
115     void visitInvokeInst(InvokeInst& II);
116     void visitReturnInst(ReturnInst& RI);
117     void visitUnwindInst(UnwindInst& UI);
118
119     bool run(Module& M);
120     bool doInitialization(Module& M);
121   };
122
123   RegisterOpt<LowerSetJmp> X("lowersetjmp", "Lower Set Jump");
124 } // end anonymous namespace
125
126 // run - Run the transformation on the program. We grab the function
127 // prototypes for longjmp and setjmp. If they are used in the program,
128 // then we can go directly to the places they're at and transform them.
129 bool LowerSetJmp::run(Module& M)
130 {
131   bool Changed = false;
132
133   // These are what the functions are called.
134   Function* SetJmp = M.getNamedFunction("llvm.setjmp");
135   Function* LongJmp = M.getNamedFunction("llvm.longjmp");
136
137   // This program doesn't have longjmp and setjmp calls.
138   if ((!LongJmp || LongJmp->use_empty()) &&
139         (!SetJmp || SetJmp->use_empty())) return false;
140
141   // Initialize some values and functions we'll need to transform the
142   // setjmp/longjmp functions.
143   doInitialization(M);
144
145   if (SetJmp) {
146     for (Value::use_iterator B = SetJmp->use_begin(), E = SetJmp->use_end();
147          B != E; ++B) {
148       BasicBlock* BB = cast<Instruction>(*B)->getParent();
149       for (df_ext_iterator<BasicBlock*> I = df_ext_begin(BB, DFSBlocks),
150              E = df_ext_end(BB, DFSBlocks); I != E; ++I)
151         /* empty */;
152     }
153
154     while (!SetJmp->use_empty()) {
155       assert(isa<CallInst>(SetJmp->use_back()) &&
156              "User of setjmp intrinsic not a call?");
157       TransformSetJmpCall(cast<CallInst>(SetJmp->use_back()));
158       Changed = true;
159     }
160   }
161
162   if (LongJmp)
163     while (!LongJmp->use_empty()) {
164       assert(isa<CallInst>(LongJmp->use_back()) &&
165              "User of longjmp intrinsic not a call?");
166       TransformLongJmpCall(cast<CallInst>(LongJmp->use_back()));
167       Changed = true;
168     }
169
170   // Now go through the affected functions and convert calls and invokes
171   // to new invokes...
172   for (std::map<Function*, AllocaInst*>::iterator
173       B = SJMap.begin(), E = SJMap.end(); B != E; ++B) {
174     Function* F = B->first;
175     for (Function::iterator BB = F->begin(), BE = F->end(); BB != BE; ++BB)
176       for (BasicBlock::iterator IB = BB->begin(), IE = BB->end(); IB != IE; ) {
177         visit(*IB++);
178         if (IB != BB->end() && IB->getParent() != BB)
179           break;  // The next instruction got moved to a different block!
180       }
181   }
182
183   DFSBlocks.clear();
184   SJMap.clear();
185   RethrowBBMap.clear();
186   PrelimBBMap.clear();
187   SwitchValMap.clear();
188   SetJmpIDMap.clear();
189
190   return Changed;
191 }
192
193 // doInitialization - For the lower long/setjmp pass, this ensures that a
194 // module contains a declaration for the intrisic functions we are going
195 // to call to convert longjmp and setjmp calls.
196 //
197 // This function is always successful, unless it isn't.
198 bool LowerSetJmp::doInitialization(Module& M)
199 {
200   const Type *SBPTy = PointerType::get(Type::SByteTy);
201   const Type *SBPPTy = PointerType::get(SBPTy);
202
203   // N.B. See llvm/runtime/GCCLibraries/libexception/SJLJ-Exception.h for
204   // a description of the following library functions.
205
206   // void __llvm_sjljeh_init_setjmpmap(void**)
207   InitSJMap = M.getOrInsertFunction("__llvm_sjljeh_init_setjmpmap",
208                                     Type::VoidTy, SBPPTy, 0); 
209   // void __llvm_sjljeh_destroy_setjmpmap(void**)
210   DestroySJMap = M.getOrInsertFunction("__llvm_sjljeh_destroy_setjmpmap",
211                                        Type::VoidTy, SBPPTy, 0);
212
213   // void __llvm_sjljeh_add_setjmp_to_map(void**, void*, unsigned)
214   AddSJToMap = M.getOrInsertFunction("__llvm_sjljeh_add_setjmp_to_map",
215                                      Type::VoidTy, SBPPTy, SBPTy,
216                                      Type::UIntTy, 0);
217
218   // void __llvm_sjljeh_throw_longjmp(int*, int)
219   ThrowLongJmp = M.getOrInsertFunction("__llvm_sjljeh_throw_longjmp",
220                                        Type::VoidTy, SBPTy, Type::IntTy, 0);
221
222   // unsigned __llvm_sjljeh_try_catching_longjmp_exception(void **)
223   TryCatchLJ =
224     M.getOrInsertFunction("__llvm_sjljeh_try_catching_longjmp_exception",
225                           Type::UIntTy, SBPPTy, 0);
226
227   // bool __llvm_sjljeh_is_longjmp_exception()
228   IsLJException = M.getOrInsertFunction("__llvm_sjljeh_is_longjmp_exception",
229                                         Type::BoolTy, 0);
230
231   // int __llvm_sjljeh_get_longjmp_value()
232   GetLJValue = M.getOrInsertFunction("__llvm_sjljeh_get_longjmp_value",
233                                      Type::IntTy, 0);
234   return true;
235 }
236
237 // IsTransformableFunction - Return true if the function name isn't one
238 // of the ones we don't want transformed. Currently, don't transform any
239 // "llvm.{setjmp,longjmp}" functions and none of the setjmp/longjmp error
240 // handling functions (beginning with __llvm_sjljeh_...they don't throw
241 // exceptions).
242 bool LowerSetJmp::IsTransformableFunction(const std::string& Name)
243 {
244   std::string SJLJEh("__llvm_sjljeh");
245
246   if (Name.size() > SJLJEh.size())
247     return std::string(Name.begin(), Name.begin() + SJLJEh.size()) != SJLJEh;
248
249   return true;
250 }
251
252 // TransformLongJmpCall - Transform a longjmp call into a call to the
253 // internal __llvm_sjljeh_throw_longjmp function. It then takes care of
254 // throwing the exception for us.
255 void LowerSetJmp::TransformLongJmpCall(CallInst* Inst)
256 {
257   const Type* SBPTy = PointerType::get(Type::SByteTy);
258
259   // Create the call to "__llvm_sjljeh_throw_longjmp". This takes the
260   // same parameters as "longjmp", except that the buffer is cast to a
261   // char*. It returns "void", so it doesn't need to replace any of
262   // Inst's uses and doesn't get a name.
263   CastInst* CI = new CastInst(Inst->getOperand(1), SBPTy, "LJBuf", Inst);
264   new CallInst(ThrowLongJmp, make_vector<Value*>(CI, Inst->getOperand(2), 0),
265                "", Inst);
266
267   SwitchValuePair& SVP = SwitchValMap[Inst->getParent()->getParent()];
268
269   // If the function has a setjmp call in it (they are transformed first)
270   // we should branch to the basic block that determines if this longjmp
271   // is applicable here. Otherwise, issue an unwind.
272   if (SVP.first)
273     new BranchInst(SVP.first->getParent(), Inst);
274   else
275     new UnwindInst(Inst);
276
277   // Remove all insts after the branch/unwind inst.
278   Inst->getParent()->getInstList().erase(Inst,
279                                        Inst->getParent()->getInstList().end());
280
281   ++LongJmpsTransformed;
282 }
283
284 // GetSetJmpMap - Retrieve (create and initialize, if necessary) the
285 // setjmp map. This map is going to hold information about which setjmps
286 // were called (each setjmp gets its own number) and with which buffer it
287 // was called. There can be only one!
288 AllocaInst* LowerSetJmp::GetSetJmpMap(Function* Func)
289 {
290   if (SJMap[Func]) return SJMap[Func];
291
292   // Insert the setjmp map initialization before the first instruction in
293   // the function.
294   Instruction* Inst = Func->getEntryBlock().begin();
295   assert(Inst && "Couldn't find even ONE instruction in entry block!");
296
297   // Fill in the alloca and call to initialize the SJ map.
298   const Type *SBPTy = PointerType::get(Type::SByteTy);
299   AllocaInst* Map = new AllocaInst(SBPTy, 0, "SJMap", Inst);
300   new CallInst(InitSJMap, make_vector<Value*>(Map, 0), "", Inst);
301   return SJMap[Func] = Map;
302 }
303
304 // GetRethrowBB - Only one rethrow basic block is needed per function.
305 // If this is a longjmp exception but not handled in this block, this BB
306 // performs the rethrow.
307 BasicBlock* LowerSetJmp::GetRethrowBB(Function* Func)
308 {
309   if (RethrowBBMap[Func]) return RethrowBBMap[Func];
310
311   // The basic block we're going to jump to if we need to rethrow the
312   // exception.
313   BasicBlock* Rethrow = new BasicBlock("RethrowExcept", Func);
314   BasicBlock::InstListType& RethrowBlkIL = Rethrow->getInstList();
315
316   // Fill in the "Rethrow" BB with a call to rethrow the exception. This
317   // is the last instruction in the BB since at this point the runtime
318   // should exit this function and go to the next function.
319   RethrowBlkIL.push_back(new UnwindInst());
320   return RethrowBBMap[Func] = Rethrow;
321 }
322
323 // GetSJSwitch - Return the switch statement that controls which handler
324 // (if any) gets called and the value returned to that handler.
325 LowerSetJmp::SwitchValuePair LowerSetJmp::GetSJSwitch(Function* Func,
326                                                       BasicBlock* Rethrow)
327 {
328   if (SwitchValMap[Func].first) return SwitchValMap[Func];
329
330   BasicBlock* LongJmpPre = new BasicBlock("LongJmpBlkPre", Func);
331   BasicBlock::InstListType& LongJmpPreIL = LongJmpPre->getInstList();
332
333   // Keep track of the preliminary basic block for some of the other
334   // transformations.
335   PrelimBBMap[Func] = LongJmpPre;
336
337   // Grab the exception.
338   CallInst* Cond = new
339     CallInst(IsLJException, std::vector<Value*>(), "IsLJExcept");
340   LongJmpPreIL.push_back(Cond);
341
342   // The "decision basic block" gets the number associated with the
343   // setjmp call returning to switch on and the value returned by
344   // longjmp.
345   BasicBlock* DecisionBB = new BasicBlock("LJDecisionBB", Func);
346   BasicBlock::InstListType& DecisionBBIL = DecisionBB->getInstList();
347
348   LongJmpPreIL.push_back(new BranchInst(DecisionBB, Rethrow, Cond));
349
350   // Fill in the "decision" basic block.
351   CallInst* LJVal = new CallInst(GetLJValue, std::vector<Value*>(), "LJVal");
352   DecisionBBIL.push_back(LJVal);
353   CallInst* SJNum = new
354     CallInst(TryCatchLJ, make_vector<Value*>(GetSetJmpMap(Func), 0), "SJNum");
355   DecisionBBIL.push_back(SJNum);
356
357   SwitchInst* SI = new SwitchInst(SJNum, Rethrow);
358   DecisionBBIL.push_back(SI);
359   return SwitchValMap[Func] = SwitchValuePair(SI, LJVal);
360 }
361
362 // TransformSetJmpCall - The setjmp call is a bit trickier to transform.
363 // We're going to convert all setjmp calls to nops. Then all "call" and
364 // "invoke" instructions in the function are converted to "invoke" where
365 // the "except" branch is used when returning from a longjmp call.
366 void LowerSetJmp::TransformSetJmpCall(CallInst* Inst)
367 {
368   BasicBlock* ABlock = Inst->getParent();
369   Function* Func = ABlock->getParent();
370
371   // Add this setjmp to the setjmp map.
372   const Type* SBPTy = PointerType::get(Type::SByteTy);
373   CastInst* BufPtr = new CastInst(Inst->getOperand(1), SBPTy, "SBJmpBuf", Inst);
374   new CallInst(AddSJToMap,
375                make_vector<Value*>(GetSetJmpMap(Func), BufPtr,
376                                    ConstantUInt::get(Type::UIntTy,
377                                                      SetJmpIDMap[Func]++), 0),
378                "", Inst);
379
380   // Change the setjmp call into a branch statement. We'll remove the
381   // setjmp call in a little bit. No worries.
382   BasicBlock* SetJmpContBlock = ABlock->splitBasicBlock(Inst);
383   assert(SetJmpContBlock && "Couldn't split setjmp BB!!");
384
385   SetJmpContBlock->setName("SetJmpContBlock");
386
387   // Reposition the split BB in the BB list to make things tidier.
388   Func->getBasicBlockList().remove(SetJmpContBlock);
389   Func->getBasicBlockList().insert(++Function::iterator(ABlock),
390                                    SetJmpContBlock);
391
392   // This PHI node will be in the new block created from the
393   // splitBasicBlock call.
394   PHINode* PHI = new PHINode(Type::IntTy, "SetJmpReturn", Inst);
395
396   // Coming from a call to setjmp, the return is 0.
397   PHI->addIncoming(ConstantInt::getNullValue(Type::IntTy), ABlock);
398
399   // Add the case for this setjmp's number...
400   SwitchValuePair SVP = GetSJSwitch(Func, GetRethrowBB(Func));
401   SVP.first->addCase(ConstantUInt::get(Type::UIntTy, SetJmpIDMap[Func] - 1),
402                      SetJmpContBlock);
403
404   // Value coming from the handling of the exception.
405   PHI->addIncoming(SVP.second, SVP.second->getParent());
406
407   // Replace all uses of this instruction with the PHI node created by
408   // the eradication of setjmp.
409   Inst->replaceAllUsesWith(PHI);
410   Inst->getParent()->getInstList().erase(Inst);
411
412   ++SetJmpsTransformed;
413 }
414
415 // visitCallInst - This converts all LLVM call instructions into invoke
416 // instructions. The except part of the invoke goes to the "LongJmpBlkPre"
417 // that grabs the exception and proceeds to determine if it's a longjmp
418 // exception or not.
419 void LowerSetJmp::visitCallInst(CallInst& CI)
420 {
421   if (CI.getCalledFunction())
422     if (!IsTransformableFunction(CI.getCalledFunction()->getName()) ||
423         CI.getCalledFunction()->isIntrinsic()) return;
424
425   BasicBlock* OldBB = CI.getParent();
426
427   // If not reachable from a setjmp call, don't transform.
428   if (!DFSBlocks.count(OldBB)) return;
429
430   BasicBlock* NewBB = OldBB->splitBasicBlock(CI);
431   assert(NewBB && "Couldn't split BB of \"call\" instruction!!");
432   NewBB->setName("Call2Invoke");
433
434   // Reposition the split BB in the BB list to make things tidier.
435   Function* Func = OldBB->getParent();
436   Func->getBasicBlockList().remove(NewBB);
437   Func->getBasicBlockList().insert(++Function::iterator(OldBB), NewBB);
438
439   // Construct the new "invoke" instruction.
440   TerminatorInst* Term = OldBB->getTerminator();
441   std::vector<Value*> Params(CI.op_begin() + 1, CI.op_end());
442   InvokeInst* II = new
443     InvokeInst(CI.getCalledValue(), NewBB, PrelimBBMap[Func],
444                Params, CI.getName(), Term); 
445
446   // Replace the old call inst with the invoke inst and remove the call.
447   CI.replaceAllUsesWith(II);
448   CI.getParent()->getInstList().erase(&CI);
449
450   // The old terminator is useless now that we have the invoke inst.
451   Term->getParent()->getInstList().erase(Term);
452   ++CallsTransformed;
453 }
454
455 // visitInvokeInst - Converting the "invoke" instruction is fairly
456 // straight-forward. The old exception part is replaced by a query asking
457 // if this is a longjmp exception. If it is, then it goes to the longjmp
458 // exception blocks. Otherwise, control is passed the old exception.
459 void LowerSetJmp::visitInvokeInst(InvokeInst& II)
460 {
461   if (II.getCalledFunction())
462     if (!IsTransformableFunction(II.getCalledFunction()->getName()) ||
463         II.getCalledFunction()->isIntrinsic()) return;
464
465   BasicBlock* BB = II.getParent();
466
467   // If not reachable from a setjmp call, don't transform.
468   if (!DFSBlocks.count(BB)) return;
469
470   BasicBlock* NormalBB = II.getNormalDest();
471   BasicBlock* ExceptBB = II.getExceptionalDest();
472
473   Function* Func = BB->getParent();
474   BasicBlock* NewExceptBB = new BasicBlock("InvokeExcept", Func);
475   BasicBlock::InstListType& InstList = NewExceptBB->getInstList();
476
477   // If this is a longjmp exception, then branch to the preliminary BB of
478   // the longjmp exception handling. Otherwise, go to the old exception.
479   CallInst* IsLJExcept = new
480     CallInst(IsLJException, std::vector<Value*>(), "IsLJExcept");
481   InstList.push_back(IsLJExcept);
482
483   BranchInst* BR = new BranchInst(PrelimBBMap[Func], ExceptBB, IsLJExcept);
484   InstList.push_back(BR);
485
486   II.setExceptionalDest(NewExceptBB);
487   ++InvokesTransformed;
488 }
489
490 // visitReturnInst - We want to destroy the setjmp map upon exit from the
491 // function.
492 void LowerSetJmp::visitReturnInst(ReturnInst& RI)
493 {
494   Function* Func = RI.getParent()->getParent();
495   new CallInst(DestroySJMap, make_vector<Value*>(GetSetJmpMap(Func), 0),
496                "", &RI);
497 }
498
499 // visitUnwindInst - We want to destroy the setjmp map upon exit from the
500 // function.
501 void LowerSetJmp::visitUnwindInst(UnwindInst& UI)
502 {
503   Function* Func = UI.getParent()->getParent();
504   new CallInst(DestroySJMap, make_vector<Value*>(GetSetJmpMap(Func), 0),
505                "", &UI);
506 }
507
508 Pass* createLowerSetJmpPass()
509 {
510   return new LowerSetJmp();
511 }