679faeffefcc1b430d466fb9e761d9d576bbb86e
[oota-llvm.git] / lib / CodeGen / ForwardControlFlowIntegrity.cpp
1 //===-- ForwardControlFlowIntegrity.cpp: Forward-Edge CFI -----------------===//
2 //
3 // This file is distributed under the University of Illinois Open Source
4 // License. See LICENSE.TXT for details.
5 //
6 //===----------------------------------------------------------------------===//
7 ///
8 /// \file
9 /// \brief A pass that instruments code with fast checks for indirect calls and
10 /// hooks for a function to check violations.
11 ///
12 //===----------------------------------------------------------------------===//
13
14 #define DEBUG_TYPE "cfi"
15
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/Analysis/JumpInstrTableInfo.h"
19 #include "llvm/CodeGen/ForwardControlFlowIntegrity.h"
20 #include "llvm/CodeGen/JumpInstrTables.h"
21 #include "llvm/CodeGen/Passes.h"
22 #include "llvm/IR/Attributes.h"
23 #include "llvm/IR/CallSite.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/GlobalValue.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/InlineAsm.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/LLVMContext.h"
32 #include "llvm/IR/Module.h"
33 #include "llvm/IR/Operator.h"
34 #include "llvm/IR/Type.h"
35 #include "llvm/IR/Verifier.h"
36 #include "llvm/Pass.h"
37 #include "llvm/Support/CommandLine.h"
38 #include "llvm/Support/Debug.h"
39 #include "llvm/Support/raw_ostream.h"
40
41 using namespace llvm;
42
43 STATISTIC(NumCFIIndirectCalls,
44           "Number of indirect call sites rewritten by the CFI pass");
45
46 char ForwardControlFlowIntegrity::ID = 0;
47 INITIALIZE_PASS_BEGIN(ForwardControlFlowIntegrity, "forward-cfi",
48                       "Control-Flow Integrity", true, true)
49 INITIALIZE_PASS_DEPENDENCY(JumpInstrTableInfo);
50 INITIALIZE_PASS_DEPENDENCY(JumpInstrTables);
51 INITIALIZE_PASS_END(ForwardControlFlowIntegrity, "forward-cfi",
52                     "Control-Flow Integrity", true, true)
53
54 ModulePass *llvm::createForwardControlFlowIntegrityPass() {
55   return new ForwardControlFlowIntegrity();
56 }
57
58 ModulePass *llvm::createForwardControlFlowIntegrityPass(
59     JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing,
60     StringRef CFIFuncName) {
61   return new ForwardControlFlowIntegrity(JTT, CFIType, CFIEnforcing,
62                                          CFIFuncName);
63 }
64
65 // Checks to see if a given CallSite is making an indirect call, including
66 // cases where the indirect call is made through a bitcast.
67 static bool isIndirectCall(CallSite &CS) {
68   if (CS.getCalledFunction())
69     return false;
70
71   // Check the value to see if it is merely a bitcast of a function. In
72   // this case, it will translate to a direct function call in the resulting
73   // assembly, so we won't treat it as an indirect call here.
74   const Value *V = CS.getCalledValue();
75   if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) {
76     return !(CE->isCast() && isa<Function>(CE->getOperand(0)));
77   }
78
79   // Otherwise, since we know it's a call, it must be an indirect call
80   return true;
81 }
82
83 static const char cfi_failure_func_name[] = "__llvm_cfi_pointer_warning";
84 static const char cfi_func_name_prefix[] = "__llvm_cfi_function_";
85
86 ForwardControlFlowIntegrity::ForwardControlFlowIntegrity()
87     : ModulePass(ID), IndirectCalls(), JTType(JumpTable::Single),
88       CFIType(CFIntegrity::Sub), CFIEnforcing(false), CFIFuncName("") {
89   initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry());
90 }
91
92 ForwardControlFlowIntegrity::ForwardControlFlowIntegrity(
93     JumpTable::JumpTableType JTT, CFIntegrity CFIType, bool CFIEnforcing,
94     std::string CFIFuncName)
95     : ModulePass(ID), IndirectCalls(), JTType(JTT), CFIType(CFIType),
96       CFIEnforcing(CFIEnforcing), CFIFuncName(CFIFuncName) {
97   initializeForwardControlFlowIntegrityPass(*PassRegistry::getPassRegistry());
98 }
99
100 ForwardControlFlowIntegrity::~ForwardControlFlowIntegrity() {}
101
102 void ForwardControlFlowIntegrity::getAnalysisUsage(AnalysisUsage &AU) const {
103   AU.addRequired<JumpInstrTableInfo>();
104   AU.addRequired<JumpInstrTables>();
105 }
106
107 void ForwardControlFlowIntegrity::getIndirectCalls(Module &M) {
108   // To get the indirect calls, we iterate over all functions and iterate over
109   // the list of basic blocks in each. We extract a total list of indirect calls
110   // before modifying any of them, since our modifications will modify the list
111   // of basic blocks.
112   for (Function &F : M) {
113     for (BasicBlock &BB : F) {
114       for (Instruction &I : BB) {
115         CallSite CS(&I);
116         if (!(CS && isIndirectCall(CS)))
117           continue;
118
119         Value *CalledValue = CS.getCalledValue();
120
121         // Don't rewrite this instruction if the indirect call is actually just
122         // inline assembly, since our transformation will generate an invalid
123         // module in that case.
124         if (isa<InlineAsm>(CalledValue))
125           continue;
126
127         IndirectCalls.push_back(&I);
128       }
129     }
130   }
131 }
132
133 void ForwardControlFlowIntegrity::updateIndirectCalls(Module &M,
134                                                       CFITables &CFIT) {
135   Type *Int64Ty = Type::getInt64Ty(M.getContext());
136   for (Instruction *I : IndirectCalls) {
137     CallSite CS(I);
138     Value *CalledValue = CS.getCalledValue();
139
140     // Get the function type for this call and look it up in the tables.
141     Type *VTy = CalledValue->getType();
142     PointerType *PTy = dyn_cast<PointerType>(VTy);
143     Type *EltTy = PTy->getElementType();
144     FunctionType *FunTy = dyn_cast<FunctionType>(EltTy);
145     FunctionType *TransformedTy = JumpInstrTables::transformType(JTType, FunTy);
146     ++NumCFIIndirectCalls;
147     Constant *JumpTableStart = nullptr;
148     Constant *JumpTableMask = nullptr;
149     Constant *JumpTableSize = nullptr;
150
151     // Some call sites have function types that don't correspond to any
152     // address-taken function in the module. This happens when function pointers
153     // are passed in from external code.
154     auto it = CFIT.find(TransformedTy);
155     if (it == CFIT.end()) {
156       // In this case, make sure that the function pointer will change by
157       // setting the mask and the start to be 0 so that the transformed
158       // function is 0.
159       JumpTableStart = ConstantInt::get(Int64Ty, 0);
160       JumpTableMask = ConstantInt::get(Int64Ty, 0);
161       JumpTableSize = ConstantInt::get(Int64Ty, 0);
162     } else {
163       JumpTableStart = it->second.StartValue;
164       JumpTableMask = it->second.MaskValue;
165       JumpTableSize = it->second.Size;
166     }
167
168     rewriteFunctionPointer(M, I, CalledValue, JumpTableStart, JumpTableMask,
169                            JumpTableSize);
170   }
171
172   return;
173 }
174
175 bool ForwardControlFlowIntegrity::runOnModule(Module &M) {
176   JumpInstrTableInfo *JITI = &getAnalysis<JumpInstrTableInfo>();
177   Type *Int64Ty = Type::getInt64Ty(M.getContext());
178   Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext());
179
180   // JumpInstrTableInfo stores information about the alignment of each entry.
181   // The alignment returned by JumpInstrTableInfo is alignment in bytes, not
182   // in the exponent.
183   ByteAlignment = JITI->entryByteAlignment();
184   LogByteAlignment = llvm::Log2_64(ByteAlignment);
185
186   // Set up tables for control-flow integrity based on information about the
187   // jump-instruction tables.
188   CFITables CFIT;
189   for (const auto &KV : JITI->getTables()) {
190     uint64_t Size = static_cast<uint64_t>(KV.second.size());
191     uint64_t TableSize = NextPowerOf2(Size);
192
193     int64_t MaskValue = ((TableSize << LogByteAlignment) - 1) & -ByteAlignment;
194     Constant *JumpTableMaskValue = ConstantInt::get(Int64Ty, MaskValue);
195     Constant *JumpTableSize = ConstantInt::get(Int64Ty, Size);
196
197     // The base of the table is defined to be the first jumptable function in
198     // the table.
199     Function *First = KV.second.begin()->second;
200     Constant *JumpTableStartValue = ConstantExpr::getBitCast(First, VoidPtrTy);
201     CFIT[KV.first].StartValue = JumpTableStartValue;
202     CFIT[KV.first].MaskValue = JumpTableMaskValue;
203     CFIT[KV.first].Size = JumpTableSize;
204   }
205
206   if (CFIT.empty())
207     return false;
208
209   getIndirectCalls(M);
210
211   if (!CFIEnforcing) {
212     addWarningFunction(M);
213   }
214
215   // Update the instructions with the check and the indirect jump through our
216   // table.
217   updateIndirectCalls(M, CFIT);
218
219   return true;
220 }
221
222 void ForwardControlFlowIntegrity::addWarningFunction(Module &M) {
223   PointerType *CharPtrTy = Type::getInt8PtrTy(M.getContext());
224
225   // Get the type of the Warning Function: void (i8*, i8*),
226   // where the first argument is the name of the function in which the violation
227   // occurs, and the second is the function pointer that violates CFI.
228   SmallVector<Type *, 2> WarningFunArgs;
229   WarningFunArgs.push_back(CharPtrTy);
230   WarningFunArgs.push_back(CharPtrTy);
231   FunctionType *WarningFunTy =
232       FunctionType::get(Type::getVoidTy(M.getContext()), WarningFunArgs, false);
233
234   if (!CFIFuncName.empty()) {
235     Constant *FailureFun = M.getOrInsertFunction(CFIFuncName, WarningFunTy);
236     if (!FailureFun)
237       report_fatal_error("Could not get or insert the function specified by"
238                          " -cfi-func-name");
239   } else {
240     // The default warning function swallows the warning and lets the call
241     // continue, since there's no generic way for it to print out this
242     // information.
243     Function *WarningFun = M.getFunction(cfi_failure_func_name);
244     if (!WarningFun) {
245       WarningFun =
246           Function::Create(WarningFunTy, GlobalValue::LinkOnceAnyLinkage,
247                            cfi_failure_func_name, &M);
248     }
249
250     BasicBlock *Entry =
251         BasicBlock::Create(M.getContext(), "entry", WarningFun, 0);
252     ReturnInst::Create(M.getContext(), Entry);
253   }
254 }
255
256 void ForwardControlFlowIntegrity::rewriteFunctionPointer(
257     Module &M, Instruction *I, Value *FunPtr, Constant *JumpTableStart,
258     Constant *JumpTableMask, Constant *JumpTableSize) {
259   IRBuilder<> TempBuilder(I);
260
261   Type *OrigFunType = FunPtr->getType();
262
263   BasicBlock *CurBB = cast<BasicBlock>(I->getParent());
264   Function *CurF = cast<Function>(CurBB->getParent());
265   Type *Int64Ty = Type::getInt64Ty(M.getContext());
266
267   Value *TI = TempBuilder.CreatePtrToInt(FunPtr, Int64Ty);
268   Value *TStartInt = TempBuilder.CreatePtrToInt(JumpTableStart, Int64Ty);
269
270   Value *NewFunPtr = nullptr;
271   Value *Check = nullptr;
272   switch (CFIType) {
273   case CFIntegrity::Sub: {
274     // This is the subtract, mask, and add version.
275     // Subtract from the base.
276     Value *Sub = TempBuilder.CreateSub(TI, TStartInt);
277
278     // Mask the difference to force this to be a table offset.
279     Value *And = TempBuilder.CreateAnd(Sub, JumpTableMask);
280
281     // Add it back to the base.
282     Value *Result = TempBuilder.CreateAdd(And, TStartInt);
283
284     // Convert it back into a function pointer that we can call.
285     NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType);
286     break;
287   }
288   case CFIntegrity::Ror: {
289     // This is the subtract and rotate version.
290     // Rotate right by the alignment value. The optimizer should recognize
291     // this sequence as a rotation.
292
293     // This cast is safe, since unsigned is always a subset of uint64_t.
294     uint64_t LogByteAlignment64 = static_cast<uint64_t>(LogByteAlignment);
295     Constant *RightShift = ConstantInt::get(Int64Ty, LogByteAlignment64);
296     Constant *LeftShift = ConstantInt::get(Int64Ty, 64 - LogByteAlignment64);
297
298     // Subtract from the base.
299     Value *Sub = TempBuilder.CreateSub(TI, TStartInt);
300
301     // Create the equivalent of a rotate-right instruction.
302     Value *Shr = TempBuilder.CreateLShr(Sub, RightShift);
303     Value *Shl = TempBuilder.CreateShl(Sub, LeftShift);
304     Value *Or = TempBuilder.CreateOr(Shr, Shl);
305
306     // Perform unsigned comparison to check for inclusion in the table.
307     Check = TempBuilder.CreateICmpULT(Or, JumpTableSize);
308     NewFunPtr = FunPtr;
309     break;
310   }
311   case CFIntegrity::Add: {
312     // This is the mask and add version.
313     // Mask the function pointer to turn it into an offset into the table.
314     Value *And = TempBuilder.CreateAnd(TI, JumpTableMask);
315
316     // Then or this offset to the base and get the pointer value.
317     Value *Result = TempBuilder.CreateAdd(And, TStartInt);
318
319     // Convert it back into a function pointer that we can call.
320     NewFunPtr = TempBuilder.CreateIntToPtr(Result, OrigFunType);
321     break;
322   }
323   }
324
325   if (!CFIEnforcing) {
326     // If a check hasn't been added (in the rotation version), then check to see
327     // if it's the same as the original function. This check determines whether
328     // or not we call the CFI failure function.
329     if (!Check)
330       Check = TempBuilder.CreateICmpEQ(NewFunPtr, FunPtr);
331     BasicBlock *InvalidPtrBlock =
332         BasicBlock::Create(M.getContext(), "invalid.ptr", CurF, 0);
333     BasicBlock *ContinuationBB = CurBB->splitBasicBlock(I);
334
335     // Remove the unconditional branch that connects the two blocks.
336     TerminatorInst *TermInst = CurBB->getTerminator();
337     TermInst->eraseFromParent();
338
339     // Add a conditional branch that depends on the Check above.
340     BranchInst::Create(ContinuationBB, InvalidPtrBlock, Check, CurBB);
341
342     // Call the warning function for this pointer, then continue.
343     Instruction *BI = BranchInst::Create(ContinuationBB, InvalidPtrBlock);
344     insertWarning(M, InvalidPtrBlock, BI, FunPtr);
345   } else {
346     // Modify the instruction to call this value.
347     CallSite CS(I);
348     CS.setCalledFunction(NewFunPtr);
349   }
350 }
351
352 void ForwardControlFlowIntegrity::insertWarning(Module &M, BasicBlock *Block,
353                                                 Instruction *I, Value *FunPtr) {
354   Function *ParentFun = cast<Function>(Block->getParent());
355
356   // Get the function to call right before the instruction.
357   Function *WarningFun = nullptr;
358   if (CFIFuncName.empty()) {
359     WarningFun = M.getFunction(cfi_failure_func_name);
360   } else {
361     WarningFun = M.getFunction(CFIFuncName);
362   }
363
364   assert(WarningFun && "Could not find the CFI failure function");
365
366   Type *VoidPtrTy = Type::getInt8PtrTy(M.getContext());
367
368   IRBuilder<> WarningInserter(I);
369   // Create a mergeable GlobalVariable containing the name of the function.
370   Value *ParentNameGV =
371       WarningInserter.CreateGlobalString(ParentFun->getName());
372   Value *ParentNamePtr = WarningInserter.CreateBitCast(ParentNameGV, VoidPtrTy);
373   Value *FunVoidPtr = WarningInserter.CreateBitCast(FunPtr, VoidPtrTy);
374   WarningInserter.CreateCall2(WarningFun, ParentNamePtr, FunVoidPtr);
375 }