- Add a "getOrInsertGlobal" method to the Module class. This acts similarly to
[oota-llvm.git] / lib / CodeGen / StackProtector.cpp
1 //===-- StackProtector.cpp - Stack Protector Insertion --------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass inserts stack protectors into functions which need them. A variable
11 // with a random value in it is stored onto the stack before the local variables
12 // are allocated. Upon exiting the block, the stored value is checked. If it's
13 // changed, then there was some sort of violation and the program aborts.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #define DEBUG_TYPE "stack-protector"
18 #include "llvm/CodeGen/Passes.h"
19 #include "llvm/Constants.h"
20 #include "llvm/DerivedTypes.h"
21 #include "llvm/Function.h"
22 #include "llvm/Instructions.h"
23 #include "llvm/Module.h"
24 #include "llvm/Pass.h"
25 #include "llvm/ADT/APInt.h"
26 #include "llvm/Support/CommandLine.h"
27 #include "llvm/Target/TargetData.h"
28 #include "llvm/Target/TargetLowering.h"
29 using namespace llvm;
30
31 // Enable stack protectors.
32 static cl::opt<unsigned>
33 SSPBufferSize("stack-protector-buffer-size", cl::init(8),
34               cl::desc("The lower bound for a buffer to be considered for "
35                        "stack smashing protection."));
36
37 namespace {
38   class VISIBILITY_HIDDEN StackProtector : public FunctionPass {
39     /// Level - The level of stack protection.
40     SSP::StackProtectorLevel Level;
41
42     /// TLI - Keep a pointer of a TargetLowering to consult for determining
43     /// target type sizes.
44     const TargetLowering *TLI;
45
46     /// FailBB - Holds the basic block to jump to when the stack protector check
47     /// fails.
48     BasicBlock *FailBB;
49
50     /// StackProtFrameSlot - The place on the stack that the stack protector
51     /// guard is kept.
52     AllocaInst *StackProtFrameSlot;
53
54     /// StackGuardVar - The global variable for the stack guard.
55     Constant *StackGuardVar;
56
57     Function *F;
58     Module *M;
59
60     /// InsertStackProtectorPrologue - Insert code into the entry block that
61     /// stores the __stack_chk_guard variable onto the stack.
62     void InsertStackProtectorPrologue();
63
64     /// InsertStackProtectorEpilogue - Insert code before the return
65     /// instructions checking the stack value that was stored in the
66     /// prologue. If it isn't the same as the original value, then call a
67     /// "failure" function.
68     void InsertStackProtectorEpilogue();
69
70     /// CreateFailBB - Create a basic block to jump to when the stack protector
71     /// check fails.
72     void CreateFailBB();
73
74     /// RequiresStackProtector - Check whether or not this function needs a
75     /// stack protector based upon the stack protector level.
76     bool RequiresStackProtector() const;
77   public:
78     static char ID;             // Pass identification, replacement for typeid.
79     StackProtector() : FunctionPass(&ID), Level(SSP::OFF), TLI(0), FailBB(0) {}
80     StackProtector(SSP::StackProtectorLevel lvl, const TargetLowering *tli)
81       : FunctionPass(&ID), Level(lvl), TLI(tli), FailBB(0) {}
82
83     virtual bool runOnFunction(Function &Fn);
84   };
85 } // end anonymous namespace
86
87 char StackProtector::ID = 0;
88 static RegisterPass<StackProtector>
89 X("stack-protector", "Insert stack protectors");
90
91 FunctionPass *llvm::createStackProtectorPass(SSP::StackProtectorLevel lvl,
92                                              const TargetLowering *tli) {
93   return new StackProtector(lvl, tli);
94 }
95
96 bool StackProtector::runOnFunction(Function &Fn) {
97   F = &Fn;
98   M = F->getParent();
99
100   if (!RequiresStackProtector()) return false;
101   
102   InsertStackProtectorPrologue();
103   InsertStackProtectorEpilogue();
104
105   // Cleanup.
106   FailBB = 0;
107   StackProtFrameSlot = 0;
108   StackGuardVar = 0;
109   return true;
110 }
111
112 /// InsertStackProtectorPrologue - Insert code into the entry block that stores
113 /// the __stack_chk_guard variable onto the stack.
114 void StackProtector::InsertStackProtectorPrologue() {
115   BasicBlock &Entry = F->getEntryBlock();
116   Instruction &InsertPt = Entry.front();
117
118   StackGuardVar = M->getOrInsertGlobal("__stack_chk_guard",
119                                        PointerType::getUnqual(Type::Int8Ty));
120   StackProtFrameSlot = new AllocaInst(PointerType::getUnqual(Type::Int8Ty),
121                                       "StackProt_Frame", &InsertPt);
122   LoadInst *LI = new LoadInst(StackGuardVar, "StackGuard", false, &InsertPt);
123   new StoreInst(LI, StackProtFrameSlot, false, &InsertPt);
124 }
125
126 /// InsertStackProtectorEpilogue - Insert code before the return instructions
127 /// checking the stack value that was stored in the prologue. If it isn't the
128 /// same as the original value, then call a "failure" function.
129 void StackProtector::InsertStackProtectorEpilogue() {
130   // Create the basic block to jump to when the guard check fails.
131   CreateFailBB();
132
133   Function::iterator I = F->begin(), E = F->end();
134   std::vector<BasicBlock*> ReturnBBs;
135   ReturnBBs.reserve(F->size());
136
137   for (; I != E; ++I)
138     if (isa<ReturnInst>(I->getTerminator()))
139       ReturnBBs.push_back(I);
140
141   if (ReturnBBs.empty()) return; // Odd, but could happen. . .
142
143   // Loop through the basic blocks that have return instructions. Convert this:
144   //
145   //   return:
146   //     ...
147   //     ret ...
148   //
149   // into this:
150   //
151   //   return:
152   //     ...
153   //     %1 = load __stack_chk_guard
154   //     %2 = load <stored stack guard>
155   //     %3 = cmp i1 %1, %2
156   //     br i1 %3, label %SPRet, label %CallStackCheckFailBlk
157   //
158   //   SP_return:
159   //     ret ...
160   //
161   //   CallStackCheckFailBlk:
162   //     call void @__stack_chk_fail()
163   //     unreachable
164   //
165   for (std::vector<BasicBlock*>::iterator
166          II = ReturnBBs.begin(), IE = ReturnBBs.end(); II != IE; ++II) {
167     BasicBlock *BB = *II;
168     ReturnInst *RI = cast<ReturnInst>(BB->getTerminator());
169     Function::iterator InsPt = BB; ++InsPt; // Insertion point for new BB.
170
171     // Split the basic block before the return instruction.
172     BasicBlock *NewBB = BB->splitBasicBlock(RI, "SP_return");
173
174     // Move the newly created basic block to the point right after the old basic
175     // block.
176     NewBB->removeFromParent();
177     F->getBasicBlockList().insert(InsPt, NewBB);
178
179     // Generate the stack protector instructions in the old basic block.
180     LoadInst *LI2 = new LoadInst(StackGuardVar, "", false, BB);
181     LoadInst *LI1 = new LoadInst(StackProtFrameSlot, "", true, BB);
182     ICmpInst *Cmp = new ICmpInst(CmpInst::ICMP_EQ, LI1, LI2, "", BB);
183     BranchInst::Create(NewBB, FailBB, Cmp, BB);
184   }
185 }
186
187 /// CreateFailBB - Create a basic block to jump to when the stack protector
188 /// check fails.
189 void StackProtector::CreateFailBB() {
190   assert(!FailBB && "Failure basic block already created?!");
191   FailBB = BasicBlock::Create("CallStackCheckFailBlk", F);
192   std::vector<const Type*> Params;
193   Constant *StackChkFail =
194     M->getOrInsertFunction("__stack_chk_fail", Type::VoidTy, NULL);
195   CallInst::Create(StackChkFail, "", FailBB);
196   new UnreachableInst(FailBB);
197 }
198
199 /// RequiresStackProtector - Check whether or not this function needs a stack
200 /// protector based upon the stack protector level.
201 bool StackProtector::RequiresStackProtector() const {
202   switch (Level) {
203   default: return false;
204   case SSP::ALL: return true;
205   case SSP::SOME: {
206     // If the size of the local variables allocated on the stack is greater than
207     // SSPBufferSize, then we require a stack protector.
208     uint64_t StackSize = 0;
209     const TargetData *TD = TLI->getTargetData();
210
211     for (Function::iterator I = F->begin(), E = F->end(); I != E; ++I) {
212       BasicBlock *BB = I;
213
214       for (BasicBlock::iterator
215              II = BB->begin(), IE = BB->end(); II != IE; ++II)
216         if (AllocaInst *AI = dyn_cast<AllocaInst>(II)) {
217           if (ConstantInt *CI = dyn_cast<ConstantInt>(AI->getArraySize())) {
218             uint64_t Bytes = TD->getTypeSizeInBits(AI->getAllocatedType()) / 8;
219             const APInt &Size = CI->getValue();
220             StackSize += Bytes * Size.getZExtValue();
221
222             if (SSPBufferSize <= StackSize)
223               return true;
224           }
225         }
226     }
227
228     return false;
229   }
230   }
231 }