389793ebc12a3e6c74370ed7e557f0d6adb0ac06
[oota-llvm.git] / lib / CodeGen / StackProtector.cpp
1 //===-- StackProtector.cpp - Stack Protector Insertion --------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass inserts stack protectors into functions which need them. A variable
11 // with a random value in it is stored onto the stack before the local variables
12 // are allocated. Upon exiting the block, the stored value is checked. If it's
13 // changed, then there was some sort of violation and the program aborts.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #define DEBUG_TYPE "stack-protector"
18 #include "llvm/CodeGen/Passes.h"
19 #include "llvm/ADT/SmallPtrSet.h"
20 #include "llvm/ADT/Statistic.h"
21 #include "llvm/ADT/Triple.h"
22 #include "llvm/Analysis/Dominators.h"
23 #include "llvm/IR/Attributes.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/DataLayout.h"
26 #include "llvm/IR/DerivedTypes.h"
27 #include "llvm/IR/Function.h"
28 #include "llvm/IR/GlobalValue.h"
29 #include "llvm/IR/GlobalVariable.h"
30 #include "llvm/IR/Instructions.h"
31 #include "llvm/IR/Intrinsics.h"
32 #include "llvm/IR/Module.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Target/TargetLowering.h"
36 using namespace llvm;
37
38 STATISTIC(NumFunProtected, "Number of functions protected");
39 STATISTIC(NumAddrTaken, "Number of local variables that have their address"
40                         " taken.");
41
42 namespace {
43   class StackProtector : public FunctionPass {
44     /// TLI - Keep a pointer of a TargetLowering to consult for determining
45     /// target type sizes.
46     const TargetLoweringBase *const TLI;
47     const Triple Trip;
48
49     Function *F;
50     Module *M;
51
52     DominatorTree *DT;
53
54     /// VisitedPHIs - The set of PHI nodes visited when determining
55     /// if a variable's reference has been taken.  This set 
56     /// is maintained to ensure we don't visit the same PHI node multiple
57     /// times.
58     SmallPtrSet<const PHINode*, 16> VisitedPHIs;
59
60     /// InsertStackProtectors - Insert code into the prologue and epilogue of
61     /// the function.
62     ///
63     ///  - The prologue code loads and stores the stack guard onto the stack.
64     ///  - The epilogue checks the value stored in the prologue against the
65     ///    original value. It calls __stack_chk_fail if they differ.
66     bool InsertStackProtectors();
67
68     /// CreateFailBB - Create a basic block to jump to when the stack protector
69     /// check fails.
70     BasicBlock *CreateFailBB();
71
72     /// ContainsProtectableArray - Check whether the type either is an array or
73     /// contains an array of sufficient size so that we need stack protectors
74     /// for it.
75     bool ContainsProtectableArray(Type *Ty, bool Strong = false,
76                                   bool InStruct = false) const;
77
78     /// \brief Check whether a stack allocation has its address taken.
79     bool HasAddressTaken(const Instruction *AI);
80
81     /// RequiresStackProtector - Check whether or not this function needs a
82     /// stack protector based upon the stack protector level.
83     bool RequiresStackProtector();
84   public:
85     static char ID;             // Pass identification, replacement for typeid.
86     StackProtector() : FunctionPass(ID), TLI(0) {
87       initializeStackProtectorPass(*PassRegistry::getPassRegistry());
88     }
89     StackProtector(const TargetLoweringBase *tli)
90         : FunctionPass(ID), TLI(tli),
91           Trip(tli->getTargetMachine().getTargetTriple()) {
92       initializeStackProtectorPass(*PassRegistry::getPassRegistry());
93     }
94
95     virtual void getAnalysisUsage(AnalysisUsage &AU) const {
96       AU.addPreserved<DominatorTree>();
97     }
98
99     virtual bool runOnFunction(Function &Fn);
100   };
101 } // end anonymous namespace
102
103 char StackProtector::ID = 0;
104 INITIALIZE_PASS(StackProtector, "stack-protector",
105                 "Insert stack protectors", false, false)
106
107 FunctionPass *llvm::createStackProtectorPass(const TargetLoweringBase *tli) {
108   return new StackProtector(tli);
109 }
110
111 bool StackProtector::runOnFunction(Function &Fn) {
112   F = &Fn;
113   M = F->getParent();
114   DT = getAnalysisIfAvailable<DominatorTree>();
115
116   if (!RequiresStackProtector()) return false;
117
118   ++NumFunProtected;
119   return InsertStackProtectors();
120 }
121
122 /// ContainsProtectableArray - Check whether the type either is an array or
123 /// contains a char array of sufficient size so that we need stack protectors
124 /// for it.
125 bool StackProtector::ContainsProtectableArray(Type *Ty, bool Strong,
126                                               bool InStruct) const {
127   if (!Ty) return false;
128   if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) {
129     // In strong mode any array, regardless of type and size, triggers a
130     // protector
131     if (Strong)
132       return true;
133     const TargetMachine &TM = TLI->getTargetMachine();
134     if (!AT->getElementType()->isIntegerTy(8)) {
135       // If we're on a non-Darwin platform or we're inside of a structure, don't
136       // add stack protectors unless the array is a character array.
137       if (InStruct || !Trip.isOSDarwin())
138           return false;
139     }
140
141     // If an array has more than SSPBufferSize bytes of allocated space, then we
142     // emit stack protectors.
143     if (TM.Options.SSPBufferSize <= TLI->getDataLayout()->getTypeAllocSize(AT))
144       return true;
145   }
146
147   const StructType *ST = dyn_cast<StructType>(Ty);
148   if (!ST) return false;
149
150   for (StructType::element_iterator I = ST->element_begin(),
151          E = ST->element_end(); I != E; ++I)
152     if (ContainsProtectableArray(*I, Strong, true))
153       return true;
154
155   return false;
156 }
157
158 bool StackProtector::HasAddressTaken(const Instruction *AI) {
159   for (Value::const_use_iterator UI = AI->use_begin(), UE = AI->use_end();
160         UI != UE; ++UI) {
161     const User *U = *UI;
162     if (const StoreInst *SI = dyn_cast<StoreInst>(U)) {
163       if (AI == SI->getValueOperand())
164         return true;
165     } else if (const PtrToIntInst *SI = dyn_cast<PtrToIntInst>(U)) {
166       if (AI == SI->getOperand(0))
167         return true;
168     } else if (isa<CallInst>(U)) {
169       return true;
170     } else if (isa<InvokeInst>(U)) {
171       return true;
172     } else if (const SelectInst *SI = dyn_cast<SelectInst>(U)) {
173       if (HasAddressTaken(SI))
174         return true;
175     } else if (const PHINode *PN = dyn_cast<PHINode>(U)) {
176       // Keep track of what PHI nodes we have already visited to ensure
177       // they are only visited once.
178       if (VisitedPHIs.insert(PN))
179         if (HasAddressTaken(PN))
180           return true;
181     } else if (const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(U)) {
182       if (HasAddressTaken(GEP))
183         return true;
184     } else if (const BitCastInst *BI = dyn_cast<BitCastInst>(U)) {
185       if (HasAddressTaken(BI))
186         return true;
187     }
188   }
189   return false;
190 }
191
192 /// \brief Check whether or not this function needs a stack protector based
193 /// upon the stack protector level.
194 ///
195 /// We use two heuristics: a standard (ssp) and strong (sspstrong).
196 /// The standard heuristic which will add a guard variable to functions that
197 /// call alloca with a either a variable size or a size >= SSPBufferSize,
198 /// functions with character buffers larger than SSPBufferSize, and functions
199 /// with aggregates containing character buffers larger than SSPBufferSize. The
200 /// strong heuristic will add a guard variables to functions that call alloca
201 /// regardless of size, functions with any buffer regardless of type and size,
202 /// functions with aggregates that contain any buffer regardless of type and
203 /// size, and functions that contain stack-based variables that have had their
204 /// address taken.
205 bool StackProtector::RequiresStackProtector() {
206   bool Strong = false;
207   if (F->getAttributes().hasAttribute(AttributeSet::FunctionIndex,
208                                       Attribute::StackProtectReq))
209     return true;
210   else if (F->getAttributes().hasAttribute(AttributeSet::FunctionIndex,
211                                            Attribute::StackProtectStrong))
212     Strong = true;
213   else if (!F->getAttributes().hasAttribute(AttributeSet::FunctionIndex,
214                                             Attribute::StackProtect))
215     return false;
216
217   for (Function::iterator I = F->begin(), E = F->end(); I != E; ++I) {
218     BasicBlock *BB = I;
219
220     for (BasicBlock::iterator
221            II = BB->begin(), IE = BB->end(); II != IE; ++II) {
222       if (AllocaInst *AI = dyn_cast<AllocaInst>(II)) {
223         if (AI->isArrayAllocation()) {
224           // SSP-Strong: Enable protectors for any call to alloca, regardless
225           // of size.
226           if (Strong)
227             return true;
228   
229           if (const ConstantInt *CI =
230                dyn_cast<ConstantInt>(AI->getArraySize())) {
231             unsigned BufferSize = TLI->getTargetMachine().Options.SSPBufferSize;
232             if (CI->getLimitedValue(BufferSize) >= BufferSize)
233               // A call to alloca with size >= SSPBufferSize requires
234               // stack protectors.
235               return true;
236           } else // A call to alloca with a variable size requires protectors.
237             return true;
238         }
239
240         if (ContainsProtectableArray(AI->getAllocatedType(), Strong))
241           return true;
242
243         if (Strong && HasAddressTaken(AI)) {
244           ++NumAddrTaken; 
245           return true;
246         }
247       }
248     }
249   }
250
251   return false;
252 }
253
254 /// InsertStackProtectors - Insert code into the prologue and epilogue of the
255 /// function.
256 ///
257 ///  - The prologue code loads and stores the stack guard onto the stack.
258 ///  - The epilogue checks the value stored in the prologue against the original
259 ///    value. It calls __stack_chk_fail if they differ.
260 bool StackProtector::InsertStackProtectors() {
261   BasicBlock *FailBB = 0;       // The basic block to jump to if check fails.
262   BasicBlock *FailBBDom = 0;    // FailBB's dominator.
263   AllocaInst *AI = 0;           // Place on stack that stores the stack guard.
264   Value *StackGuardVar = 0;  // The stack guard variable.
265
266   for (Function::iterator I = F->begin(), E = F->end(); I != E; ) {
267     BasicBlock *BB = I++;
268     ReturnInst *RI = dyn_cast<ReturnInst>(BB->getTerminator());
269     if (!RI) continue;
270
271     if (!FailBB) {
272       // Insert code into the entry block that stores the __stack_chk_guard
273       // variable onto the stack:
274       //
275       //   entry:
276       //     StackGuardSlot = alloca i8*
277       //     StackGuard = load __stack_chk_guard
278       //     call void @llvm.stackprotect.create(StackGuard, StackGuardSlot)
279       //
280       PointerType *PtrTy = Type::getInt8PtrTy(RI->getContext());
281       unsigned AddressSpace, Offset;
282       if (TLI->getStackCookieLocation(AddressSpace, Offset)) {
283         Constant *OffsetVal =
284           ConstantInt::get(Type::getInt32Ty(RI->getContext()), Offset);
285
286         StackGuardVar = ConstantExpr::getIntToPtr(OffsetVal,
287                                       PointerType::get(PtrTy, AddressSpace));
288       } else if (Trip.getOS() == llvm::Triple::OpenBSD) {
289         StackGuardVar = M->getOrInsertGlobal("__guard_local", PtrTy);
290         cast<GlobalValue>(StackGuardVar)
291             ->setVisibility(GlobalValue::HiddenVisibility);
292       } else {
293         StackGuardVar = M->getOrInsertGlobal("__stack_chk_guard", PtrTy);
294       }
295
296       BasicBlock &Entry = F->getEntryBlock();
297       Instruction *InsPt = &Entry.front();
298
299       AI = new AllocaInst(PtrTy, "StackGuardSlot", InsPt);
300       LoadInst *LI = new LoadInst(StackGuardVar, "StackGuard", false, InsPt);
301
302       Value *Args[] = { LI, AI };
303       CallInst::
304         Create(Intrinsic::getDeclaration(M, Intrinsic::stackprotector),
305                Args, "", InsPt);
306
307       // Create the basic block to jump to when the guard check fails.
308       FailBB = CreateFailBB();
309     }
310
311     // For each block with a return instruction, convert this:
312     //
313     //   return:
314     //     ...
315     //     ret ...
316     //
317     // into this:
318     //
319     //   return:
320     //     ...
321     //     %1 = load __stack_chk_guard
322     //     %2 = load StackGuardSlot
323     //     %3 = cmp i1 %1, %2
324     //     br i1 %3, label %SP_return, label %CallStackCheckFailBlk
325     //
326     //   SP_return:
327     //     ret ...
328     //
329     //   CallStackCheckFailBlk:
330     //     call void @__stack_chk_fail()
331     //     unreachable
332
333     // Split the basic block before the return instruction.
334     BasicBlock *NewBB = BB->splitBasicBlock(RI, "SP_return");
335
336     if (DT && DT->isReachableFromEntry(BB)) {
337       DT->addNewBlock(NewBB, BB);
338       FailBBDom = FailBBDom ? DT->findNearestCommonDominator(FailBBDom, BB) :BB;
339     }
340
341     // Remove default branch instruction to the new BB.
342     BB->getTerminator()->eraseFromParent();
343
344     // Move the newly created basic block to the point right after the old basic
345     // block so that it's in the "fall through" position.
346     NewBB->moveAfter(BB);
347
348     // Generate the stack protector instructions in the old basic block.
349     LoadInst *LI1 = new LoadInst(StackGuardVar, "", false, BB);
350     LoadInst *LI2 = new LoadInst(AI, "", true, BB);
351     ICmpInst *Cmp = new ICmpInst(*BB, CmpInst::ICMP_EQ, LI1, LI2, "");
352     BranchInst::Create(NewBB, FailBB, Cmp, BB);
353   }
354
355   // Return if we didn't modify any basic blocks. I.e., there are no return
356   // statements in the function.
357   if (!FailBB) return false;
358
359   if (DT && FailBBDom)
360     DT->addNewBlock(FailBB, FailBBDom);
361
362   return true;
363 }
364
365 /// CreateFailBB - Create a basic block to jump to when the stack protector
366 /// check fails.
367 BasicBlock *StackProtector::CreateFailBB() {
368   LLVMContext &Context = F->getContext();
369   BasicBlock *FailBB = BasicBlock::Create(Context, "CallStackCheckFailBlk", F);
370   if (Trip.getOS() == llvm::Triple::OpenBSD) {
371     Constant *StackChkFail = M->getOrInsertFunction(
372         "__stack_smash_handler", Type::getVoidTy(Context),
373         Type::getInt8PtrTy(Context), NULL);
374
375     Constant *NameStr = ConstantDataArray::getString(Context, F->getName());
376     Constant *FuncName =
377         new GlobalVariable(*M, NameStr->getType(), true,
378                            GlobalVariable::PrivateLinkage, NameStr, "SSH");
379
380     SmallVector<Constant *, 2> IdxList;
381     IdxList.push_back(ConstantInt::get(Type::getInt8Ty(Context), 0));
382     IdxList.push_back(ConstantInt::get(Type::getInt8Ty(Context), 0));
383
384     SmallVector<Value *, 1> Args;
385     Args.push_back(ConstantExpr::getGetElementPtr(FuncName, IdxList));
386
387     CallInst::Create(StackChkFail, Args, "", FailBB);
388   } else {
389     Constant *StackChkFail = M->getOrInsertFunction(
390         "__stack_chk_fail", Type::getVoidTy(Context), NULL);
391     CallInst::Create(StackChkFail, "", FailBB);
392   }
393   new UnreachableInst(Context, FailBB);
394   return FailBB;
395 }