Don't use PassInfo* as a type identifier for passes. Instead, use the address of...
[oota-llvm.git] / lib / Analysis / PointerTracking.cpp
index ce7ac899cd2869dfa5f98eb631c7f2eea240bb6b..07f46824700a8e0afb02717466fd067d6eb0c128 100644 (file)
@@ -28,7 +28,7 @@
 using namespace llvm;
 
 char PointerTracking::ID = 0;
-PointerTracking::PointerTracking() : FunctionPass(&ID) {}
+PointerTracking::PointerTracking() : FunctionPass(ID) {}
 
 bool PointerTracking::runOnFunction(Function &F) {
   predCache.clear();
@@ -144,6 +144,55 @@ const SCEV *PointerTracking::computeAllocationCount(Value *P,
   return SE->getCouldNotCompute();
 }
 
+Value *PointerTracking::computeAllocationCountValue(Value *P, const Type *&Ty) const 
+{
+  Value *V = P->stripPointerCasts();
+  if (AllocaInst *AI = dyn_cast<AllocaInst>(V)) {
+    Ty = AI->getAllocatedType();
+    // arraySize elements of type Ty.
+    return AI->getArraySize();
+  }
+
+  if (CallInst *CI = extractMallocCall(V)) {
+    Ty = getMallocAllocatedType(CI);
+    if (!Ty)
+      return 0;
+    Value *arraySize = getMallocArraySize(CI, TD);
+    if (!arraySize) {
+      Ty = Type::getInt8Ty(P->getContext());
+      return CI->getArgOperand(0);
+    }
+    // arraySize elements of type Ty.
+    return arraySize;
+  }
+
+  if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) {
+    if (GV->hasDefinitiveInitializer()) {
+      Constant *C = GV->getInitializer();
+      if (const ArrayType *ATy = dyn_cast<ArrayType>(C->getType())) {
+        Ty = ATy->getElementType();
+        return ConstantInt::get(Type::getInt32Ty(P->getContext()),
+                               ATy->getNumElements());
+      }
+    }
+    Ty = cast<PointerType>(GV->getType())->getElementType();
+    return ConstantInt::get(Type::getInt32Ty(P->getContext()), 1);
+    //TODO: implement more tracking for globals
+  }
+
+  if (CallInst *CI = dyn_cast<CallInst>(V)) {
+    CallSite CS(CI);
+    Function *F = dyn_cast<Function>(CS.getCalledValue()->stripPointerCasts());
+    if (F == reallocFunc) {
+      Ty = Type::getInt8Ty(P->getContext());
+      // realloc allocates arg1 bytes.
+      return CS.getArgument(1);
+    }
+  }
+
+  return 0;
+}
+
 // Calculates the number of elements of type Ty allocated for P.
 const SCEV *PointerTracking::computeAllocationCountForType(Value *P,
                                                            const Type *Ty)
@@ -183,17 +232,17 @@ enum SolverResult PointerTracking::isLoopGuardedBy(const Loop *L,
                                                    Predicate Pred,
                                                    const SCEV *A,
                                                    const SCEV *B) const {
-  if (SE->isLoopGuardedByCond(L, Pred, A, B))
+  if (SE->isLoopEntryGuardedByCond(L, Pred, A, B))
     return AlwaysTrue;
   Pred = ICmpInst::getSwappedPredicate(Pred);
-  if (SE->isLoopGuardedByCond(L, Pred, B, A))
+  if (SE->isLoopEntryGuardedByCond(L, Pred, B, A))
     return AlwaysTrue;
 
   Pred = ICmpInst::getInversePredicate(Pred);
-  if (SE->isLoopGuardedByCond(L, Pred, B, A))
+  if (SE->isLoopEntryGuardedByCond(L, Pred, B, A))
     return AlwaysFalse;
   Pred = ICmpInst::getSwappedPredicate(Pred);
-  if (SE->isLoopGuardedByCond(L, Pred, A, B))
+  if (SE->isLoopEntryGuardedByCond(L, Pred, A, B))
     return AlwaysTrue;
   return Unknown;
 }
@@ -263,5 +312,5 @@ void PointerTracking::print(raw_ostream &OS, const Module* M) const {
   }
 }
 
-static RegisterPass<PointerTracking> X("pointertracking",
-                                       "Track pointer bounds", false, true);
+INITIALIZE_PASS(PointerTracking, "pointertracking",
+                "Track pointer bounds", false, true);