revamp BoundsChecking considerably:
[oota-llvm.git] / lib / Transforms / Scalar / BoundsChecking.cpp
1 //===- BoundsChecking.cpp - Instrumentation for run-time bounds checking --===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements a pass that instruments the code to perform run-time
11 // bounds checking on loads, stores, and other memory intrinsics.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #define DEBUG_TYPE "bounds-checking"
16 #include "llvm/Transforms/Scalar.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/Statistic.h"
19 #include "llvm/Analysis/LoopInfo.h"
20 #include "llvm/Analysis/ScalarEvolution.h"
21 #include "llvm/Analysis/ScalarEvolutionExpander.h"
22 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
23 #include "llvm/Support/Debug.h"
24 #include "llvm/Support/InstIterator.h"
25 #include "llvm/Support/IRBuilder.h"
26 #include "llvm/Support/raw_ostream.h"
27 #include "llvm/Support/TargetFolder.h"
28 #include "llvm/Target/TargetData.h"
29 #include "llvm/Transforms/Utils/Local.h"
30 #include "llvm/GlobalVariable.h"
31 #include "llvm/Instructions.h"
32 #include "llvm/Intrinsics.h"
33 #include "llvm/Metadata.h"
34 #include "llvm/Operator.h"
35 #include "llvm/Pass.h"
36 using namespace llvm;
37
38 STATISTIC(ChecksAdded, "Bounds checks added");
39 STATISTIC(ChecksSkipped, "Bounds checks skipped");
40 STATISTIC(ChecksUnable, "Bounds checks unable to add");
41 STATISTIC(ChecksUnableInterproc, "Bounds checks unable to add (interprocedural)");
42 STATISTIC(ChecksUnableLoad, "Bounds checks unable to add (LoadInst)");
43
44 typedef IRBuilder<true, TargetFolder> BuilderTy;
45
46 namespace {
47   // FIXME: can use unions here to save space
48   struct CacheData {
49     APInt Offset;
50     Value *OffsetValue;
51     APInt Size;
52     Value *SizeValue;
53     bool ReturnVal;
54   };
55   typedef DenseMap<Value*, CacheData> CacheMapTy;
56
57   struct BoundsChecking : public FunctionPass {
58     static char ID;
59
60     BoundsChecking(unsigned _Penalty = 5) : FunctionPass(ID), Penalty(_Penalty){
61       initializeBoundsCheckingPass(*PassRegistry::getPassRegistry());
62     }
63
64     virtual bool runOnFunction(Function &F);
65
66     virtual void getAnalysisUsage(AnalysisUsage &AU) const {
67       AU.addRequired<TargetData>();
68       AU.addRequired<LoopInfo>();
69       AU.addRequired<ScalarEvolution>();
70     }
71
72   private:
73     const TargetData *TD;
74     LoopInfo *LI;
75     ScalarEvolution *SE;
76     BuilderTy *Builder;
77     Function *Fn;
78     BasicBlock *TrapBB;
79     unsigned Penalty;
80     CacheMapTy CacheMap;
81
82     BasicBlock *getTrapBB();
83     void emitBranchToTrap(Value *Cmp = 0);
84     bool computeAllocSize(Value *Ptr, APInt &Offset, Value* &OffsetValue,
85                           APInt &Size, Value* &SizeValue);
86     bool instrument(Value *Ptr, Value *Val);
87  };
88 }
89
90 char BoundsChecking::ID = 0;
91 INITIALIZE_PASS_BEGIN(BoundsChecking, "bounds-checking",
92                       "Run-time bounds checking", false, false)
93 INITIALIZE_PASS_DEPENDENCY(ScalarEvolution)
94 INITIALIZE_PASS_END(BoundsChecking, "bounds-checking",
95                       "Run-time bounds checking", false, false)
96
97
98 /// getTrapBB - create a basic block that traps. All overflowing conditions
99 /// branch to this block. There's only one trap block per function.
100 BasicBlock *BoundsChecking::getTrapBB() {
101   if (TrapBB)
102     return TrapBB;
103
104   BasicBlock::iterator PrevInsertPoint = Builder->GetInsertPoint();
105   TrapBB = BasicBlock::Create(Fn->getContext(), "trap", Fn);
106   Builder->SetInsertPoint(TrapBB);
107
108   llvm::Value *F = Intrinsic::getDeclaration(Fn->getParent(), Intrinsic::trap);
109   CallInst *TrapCall = Builder->CreateCall(F);
110   TrapCall->setDoesNotReturn();
111   TrapCall->setDoesNotThrow();
112   Builder->CreateUnreachable();
113
114   Builder->SetInsertPoint(PrevInsertPoint);
115   return TrapBB;
116 }
117
118
119 /// emitBranchToTrap - emit a branch instruction to a trap block.
120 /// If Cmp is non-null, perform a jump only if its value evaluates to true.
121 void BoundsChecking::emitBranchToTrap(Value *Cmp) {
122   Instruction *Inst = Builder->GetInsertPoint();
123   BasicBlock *OldBB = Inst->getParent();
124   BasicBlock *Cont = OldBB->splitBasicBlock(Inst);
125   OldBB->getTerminator()->eraseFromParent();
126
127   if (Cmp)
128     BranchInst::Create(getTrapBB(), Cont, Cmp, OldBB);
129   else
130     BranchInst::Create(getTrapBB(), OldBB);
131 }
132
133
134 #define GET_VALUE(Val, Int) \
135   if (!Val) \
136     Val = ConstantInt::get(IntTy, Int)
137
138 #define RETURN(Val) \
139   do { ReturnVal = Val; goto cache_and_return; } while (0)
140
141 /// computeAllocSize - compute the object size and the offset within the object
142 /// pointed by Ptr. OffsetValue/SizeValue will be null if they are constant, and
143 /// therefore the result is given in Offset/Size variables instead.
144 /// Returns true if the offset and size could be computed within the given
145 /// maximum run-time penalty.
146 bool BoundsChecking::computeAllocSize(Value *Ptr, APInt &Offset,
147                                       Value* &OffsetValue, APInt &Size,
148                                       Value* &SizeValue) {
149   Ptr = Ptr->stripPointerCasts();
150
151   // lookup to see if we've seen the Ptr before
152   CacheMapTy::iterator CacheIt = CacheMap.find(Ptr);
153   if (CacheIt != CacheMap.end()) {
154     CacheData &Cache = CacheIt->second;
155     Offset = Cache.Offset;
156     OffsetValue = Cache.OffsetValue;
157     Size = Cache.Size;
158     SizeValue = Cache.SizeValue;
159     return Cache.ReturnVal;
160   }
161
162   IntegerType *IntTy = TD->getIntPtrType(Fn->getContext());
163   unsigned IntTyBits = IntTy->getBitWidth();
164   bool ReturnVal;
165
166   // always generate code immediately before the instruction being processed, so
167   // that the generated code dominates the same BBs
168   Instruction *PrevInsertPoint = Builder->GetInsertPoint();
169   if (Instruction *I = dyn_cast<Instruction>(Ptr))
170     Builder->SetInsertPoint(I);
171
172   // initalize with "don't know" state: offset=0 and size=uintmax
173   Offset = 0;
174   Size = APInt::getMaxValue(TD->getTypeSizeInBits(IntTy));
175   OffsetValue = SizeValue = 0;
176
177   if (GEPOperator *GEP = dyn_cast<GEPOperator>(Ptr)) {
178     APInt PtrOffset(IntTyBits, 0);
179     Value *PtrOffsetValue = 0;
180     if (!computeAllocSize(GEP->getPointerOperand(), PtrOffset, PtrOffsetValue,
181                           Size, SizeValue))
182       RETURN(false);
183
184     if (GEP->hasAllConstantIndices()) {
185       SmallVector<Value*, 8> Ops(GEP->idx_begin(), GEP->idx_end());
186       Offset = TD->getIndexedOffset(GEP->getPointerOperandType(), Ops);
187       // if PtrOffset is constant, return immediately
188       if (!PtrOffsetValue) {
189         Offset += PtrOffset;
190         RETURN(true);
191       }
192       OffsetValue = ConstantInt::get(IntTy, Offset);
193     } else {
194       OffsetValue = EmitGEPOffset(Builder, *TD, GEP);
195     }
196
197     GET_VALUE(PtrOffsetValue, PtrOffset);
198     OffsetValue = Builder->CreateAdd(PtrOffsetValue, OffsetValue);
199     RETURN(true);
200
201   // global variable with definitive size
202   } else if (GlobalVariable *GV = dyn_cast<GlobalVariable>(Ptr)) {
203     if (GV->hasDefinitiveInitializer()) {
204       Constant *C = GV->getInitializer();
205       Size = TD->getTypeAllocSize(C->getType());
206       RETURN(true);
207     }
208     RETURN(false);
209
210   // stack allocation
211   } else if (AllocaInst *AI = dyn_cast<AllocaInst>(Ptr)) {
212     if (!AI->getAllocatedType()->isSized())
213       RETURN(false);
214
215     Size = TD->getTypeAllocSize(AI->getAllocatedType());
216     if (!AI->isArrayAllocation())
217       RETURN(true); // we are done
218
219     Value *ArraySize = AI->getArraySize();
220     if (const ConstantInt *C = dyn_cast<ConstantInt>(ArraySize)) {
221       Size *= C->getValue();
222       RETURN(true);
223     }
224
225     if (Penalty < 2)
226       RETURN(false);
227
228     // VLA: compute size dynamically
229     SizeValue = ConstantInt::get(ArraySize->getType(), Size);
230     SizeValue = Builder->CreateMul(SizeValue, ArraySize);
231     RETURN(true);
232
233   // function arguments
234   } else if (Argument *A = dyn_cast<Argument>(Ptr)) {
235     // right now we only support byval arguments, so that no interprocedural
236     // analysis is necessary
237     if (!A->hasByValAttr()) {
238       ++ChecksUnableInterproc;
239       RETURN(false);
240     }
241
242     PointerType *PT = cast<PointerType>(A->getType());
243     Size = TD->getTypeAllocSize(PT->getElementType());
244     RETURN(true);
245
246   // ptr = select(ptr1, ptr2)
247   } else if (SelectInst *SI = dyn_cast<SelectInst>(Ptr)) {
248     APInt OffsetTrue(IntTyBits, 0), OffsetFalse(IntTyBits, 0);
249     APInt SizeTrue(IntTyBits, 0), SizeFalse(IntTyBits, 0);
250     Value *OffsetValueTrue = 0, *OffsetValueFalse = 0;
251     Value *SizeValueTrue = 0, *SizeValueFalse = 0;
252
253     bool TrueAlloc = computeAllocSize(SI->getTrueValue(), OffsetTrue,
254                                       OffsetValueTrue, SizeTrue, SizeValueTrue);
255     bool FalseAlloc = computeAllocSize(SI->getFalseValue(), OffsetFalse,
256                                        OffsetValueFalse, SizeFalse,
257                                        SizeValueFalse);
258     if (!TrueAlloc && !FalseAlloc)
259       RETURN(false);
260
261     // fold constant sizes & offsets if they are equal
262     if (!OffsetValueTrue && !OffsetValueFalse && OffsetTrue == OffsetFalse)
263       Offset = OffsetTrue;
264     else if (Penalty > 1) {
265       GET_VALUE(OffsetValueTrue, OffsetTrue);
266       GET_VALUE(OffsetValueFalse, OffsetFalse);
267       OffsetValue = Builder->CreateSelect(SI->getCondition(), OffsetValueTrue,
268                                           OffsetValueFalse);
269     } else
270       RETURN(false);
271
272     if (!SizeValueTrue && !SizeValueFalse && SizeTrue == SizeFalse)
273       Size = SizeTrue;
274     else if (Penalty > 1) {
275       GET_VALUE(SizeValueTrue, SizeTrue);
276       GET_VALUE(SizeValueFalse, SizeFalse);
277       SizeValue = Builder->CreateSelect(SI->getCondition(), SizeValueTrue,
278                                         SizeValueFalse);
279     } else
280       RETURN(false);
281     RETURN(true);
282
283   // call allocation function
284   } else if (CallInst *CI = dyn_cast<CallInst>(Ptr)) {
285     SmallVector<unsigned, 4> Args;
286
287     if (MDNode *MD = CI->getMetadata("alloc_size")) {
288       for (unsigned i = 0, e = MD->getNumOperands(); i != e; ++i)
289         Args.push_back(cast<ConstantInt>(MD->getOperand(i))->getZExtValue());
290
291     } else if (Function *Callee = CI->getCalledFunction()) {
292       FunctionType *FTy = Callee->getFunctionType();
293
294       // alloc(size)
295       if (FTy->getNumParams() == 1 && FTy->getParamType(0)->isIntegerTy()) {
296         if ((Callee->getName() == "malloc" ||
297              Callee->getName() == "valloc" ||
298              Callee->getName() == "_Znwj"  || // operator new(unsigned int)
299              Callee->getName() == "_Znwm"  || // operator new(unsigned long)
300              Callee->getName() == "_Znaj"  || // operator new[](unsigned int)
301              Callee->getName() == "_Znam")) {
302           Args.push_back(0);
303         }
304       } else if (FTy->getNumParams() == 2) {
305         // alloc(_, x)
306         if (FTy->getParamType(1)->isIntegerTy() &&
307             ((Callee->getName() == "realloc" ||
308               Callee->getName() == "reallocf"))) {
309           Args.push_back(1);
310
311         // alloc(x, y)
312         } else if (FTy->getParamType(0)->isIntegerTy() &&
313                    FTy->getParamType(1)->isIntegerTy() &&
314                    Callee->getName() == "calloc") {
315           Args.push_back(0);
316           Args.push_back(1);
317         }
318       } else if (FTy->getNumParams() == 3) {
319         // alloc(_, _, x)
320         if (FTy->getParamType(2)->isIntegerTy() &&
321             Callee->getName() == "posix_memalign") {
322           Args.push_back(2);
323         }
324       }
325     }
326
327     if (Args.empty())
328       RETURN(false);
329
330     // check if all arguments are constant. if so, the object size is also const
331     bool AllConst = true;
332     for (SmallVectorImpl<unsigned>::iterator I = Args.begin(), E = Args.end();
333          I != E; ++I) {
334       if (!isa<ConstantInt>(CI->getArgOperand(*I))) {
335         AllConst = false;
336         break;
337       }
338     }
339
340     if (AllConst) {
341       Size = 1;
342       for (SmallVectorImpl<unsigned>::iterator I = Args.begin(), E = Args.end();
343            I != E; ++I) {
344         ConstantInt *Arg = cast<ConstantInt>(CI->getArgOperand(*I));
345         Size *= Arg->getValue().zextOrSelf(IntTyBits);
346       }
347       RETURN(true);
348     }
349
350     if (Penalty < 2)
351       RETURN(false);
352
353     // not all arguments are constant, so create a sequence of multiplications
354     for (SmallVectorImpl<unsigned>::iterator I = Args.begin(), E = Args.end();
355          I != E; ++I) {
356       Value *Arg = Builder->CreateZExt(CI->getArgOperand(*I), IntTy);
357       if (!SizeValue) {
358         SizeValue = Arg;
359         continue;
360       }
361       SizeValue = Builder->CreateMul(SizeValue, Arg);
362     }
363     RETURN(true);
364
365     // TODO: handle more standard functions:
366     // - strdup / strndup
367     // - strcpy / strncpy
368     // - memcpy / memmove
369     // - strcat / strncat
370
371   } else if (PHINode *PHI = dyn_cast<PHINode>(Ptr)) {
372     // create 2 PHIs: one for offset and another for size
373     PHINode *OffsetPHI = Builder->CreatePHI(IntTy, PHI->getNumIncomingValues());
374     PHINode *SizePHI   = Builder->CreatePHI(IntTy, PHI->getNumIncomingValues());
375
376     // insert right away in the cache to handle recursive PHIs
377     CacheData CacheEntry;
378     CacheEntry.Offset = CacheEntry.Size = 0;
379     CacheEntry.OffsetValue = OffsetPHI;
380     CacheEntry.SizeValue = SizePHI;
381     CacheEntry.ReturnVal = true;
382     CacheMap[Ptr] = CacheEntry;
383
384     // compute offset/size for each PHI incoming pointer
385     bool someOk = false;
386     for (unsigned i = 0, e = PHI->getNumIncomingValues(); i != e; ++i) {
387       Builder->SetInsertPoint(PHI->getIncomingBlock(i)->getFirstInsertionPt());
388
389       APInt PhiOffset(IntTyBits, 0), PhiSize(IntTyBits, 0);
390       Value *PhiOffsetValue = 0, *PhiSizeValue = 0;
391       someOk |= computeAllocSize(PHI->getIncomingValue(i), PhiOffset,
392                                  PhiOffsetValue, PhiSize, PhiSizeValue);
393
394       GET_VALUE(PhiOffsetValue, PhiOffset);
395       GET_VALUE(PhiSizeValue, PhiSize);
396
397       OffsetPHI->addIncoming(PhiOffsetValue, PHI->getIncomingBlock(i));
398       SizePHI->addIncoming(PhiSizeValue, PHI->getIncomingBlock(i));
399     }
400
401     // fail here if we couldn't compute the size/offset in any incoming edge
402     if (!someOk)
403       RETURN(false);
404
405     OffsetValue = OffsetPHI;
406     SizeValue = SizePHI;
407     RETURN(true);    
408
409   } else if (isa<UndefValue>(Ptr)) {
410     Size = 0;
411     RETURN(true);
412
413   } else if (isa<LoadInst>(Ptr)) {
414     ++ChecksUnableLoad;
415     RETURN(false);
416   }
417
418   RETURN(false);
419
420 cache_and_return:
421   // cache the result and return
422   CacheData CacheEntry;
423   CacheEntry.Offset = Offset;
424   CacheEntry.OffsetValue = OffsetValue;
425   CacheEntry.Size = Size;
426   CacheEntry.SizeValue = SizeValue;
427   CacheEntry.ReturnVal = ReturnVal;
428   CacheMap[Ptr] = CacheEntry;
429
430   Builder->SetInsertPoint(PrevInsertPoint);
431   return ReturnVal;
432 }
433
434
435 /// instrument - adds run-time bounds checks to memory accessing instructions.
436 /// Ptr is the pointer that will be read/written, and InstVal is either the
437 /// result from the load or the value being stored. It is used to determine the
438 /// size of memory block that is touched.
439 /// Returns true if any change was made to the IR, false otherwise.
440 bool BoundsChecking::instrument(Value *Ptr, Value *InstVal) {
441   uint64_t NeededSize = TD->getTypeStoreSize(InstVal->getType());
442   DEBUG(dbgs() << "Instrument " << *Ptr << " for " << Twine(NeededSize)
443               << " bytes\n");
444
445   IntegerType *IntTy = TD->getIntPtrType(Fn->getContext());
446   unsigned IntTyBits = IntTy->getBitWidth();
447
448   APInt Offset(IntTyBits, 0), Size(IntTyBits, 0);
449   Value *OffsetValue = 0, *SizeValue = 0;
450
451   if (!computeAllocSize(Ptr, Offset, OffsetValue, Size, SizeValue)) {
452     DEBUG(dbgs() << "computeAllocSize failed:\n" << *Ptr << "\n");
453     ++ChecksUnable;
454     return false;
455   }
456
457   // three checks are required to ensure safety:
458   // . Offset >= 0  (since the offset is given from the base ptr)
459   // . Size >= Offset  (unsigned)
460   // . Size - Offset >= NeededSize  (unsigned)
461   if (!OffsetValue && !SizeValue) {
462     if (Offset.slt(0) || Size.ult(Offset) || (Size - Offset).ult(NeededSize)) {
463       // Out of bounds
464       emitBranchToTrap();
465       ++ChecksAdded;
466       return true;
467     }
468     // in bounds
469     ++ChecksSkipped;
470     return false;
471   }
472
473   // emit check for offset < 0
474   Value *CmpOffset = 0;
475   if (OffsetValue)
476     CmpOffset = Builder->CreateICmpSLT(OffsetValue, ConstantInt::get(IntTy, 0));
477   else if (Offset.slt(0)) {
478     // offset proved to be negative
479     emitBranchToTrap();
480     ++ChecksAdded;
481     return true;
482   }
483
484   // we couldn't determine statically if the memory access is safe; emit a
485   // run-time check
486   GET_VALUE(OffsetValue, Offset);
487   GET_VALUE(SizeValue, Size);
488
489   Value *NeededSizeVal = ConstantInt::get(IntTy, NeededSize);
490   // FIXME: add NSW/NUW here?  -- we dont care if the subtraction overflows
491   Value *ObjSize = Builder->CreateSub(SizeValue, OffsetValue);
492   Value *Cmp1 = Builder->CreateICmpULT(SizeValue, OffsetValue);
493   Value *Cmp2 = Builder->CreateICmpULT(ObjSize, NeededSizeVal);
494   Value *Or = Builder->CreateOr(Cmp1, Cmp2);
495   if (CmpOffset)
496     Or = Builder->CreateOr(CmpOffset, Or);
497   emitBranchToTrap(Or);
498
499   ++ChecksAdded;
500   return true;
501 }
502
503 bool BoundsChecking::runOnFunction(Function &F) {
504   TD = &getAnalysis<TargetData>();
505   LI = &getAnalysis<LoopInfo>();
506   SE = &getAnalysis<ScalarEvolution>();
507
508   TrapBB = 0;
509   Fn = &F;
510   BuilderTy TheBuilder(F.getContext(), TargetFolder(TD));
511   Builder = &TheBuilder;
512
513   // check HANDLE_MEMORY_INST in include/llvm/Instruction.def for memory
514   // touching instructions
515   std::vector<Instruction*> WorkList;
516   for (inst_iterator i = inst_begin(F), e = inst_end(F); i != e; ++i) {
517     Instruction *I = &*i;
518     if (isa<LoadInst>(I) || isa<StoreInst>(I) || isa<AtomicCmpXchgInst>(I) ||
519         isa<AtomicRMWInst>(I))
520         WorkList.push_back(I);
521   }
522
523   bool MadeChange = false;
524   for (std::vector<Instruction*>::iterator i = WorkList.begin(),
525        e = WorkList.end(); i != e; ++i) {
526     Instruction *I = *i;
527
528     Builder->SetInsertPoint(I);
529     if (LoadInst *LI = dyn_cast<LoadInst>(I)) {
530       MadeChange |= instrument(LI->getPointerOperand(), LI);
531     } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) {
532       MadeChange |= instrument(SI->getPointerOperand(), SI->getValueOperand());
533     } else if (AtomicCmpXchgInst *AI = dyn_cast<AtomicCmpXchgInst>(I)) {
534       MadeChange |= instrument(AI->getPointerOperand(),AI->getCompareOperand());
535     } else if (AtomicRMWInst *AI = dyn_cast<AtomicRMWInst>(I)) {
536       MadeChange |= instrument(AI->getPointerOperand(), AI->getValOperand());
537     } else {
538       llvm_unreachable("unknown Instruction type");
539     }
540   }
541   return MadeChange;
542 }
543
544 FunctionPass *llvm::createBoundsCheckingPass(unsigned Penalty) {
545   return new BoundsChecking(Penalty);
546 }