Add addrspacecast instruction.
[oota-llvm.git] / lib / Analysis / ConstantFolding.cpp
index fa6e55859d80c3af01418ba0ca813c9e37cf3441..3d32232dacf9dc21ebf5e2f8a206ddb56bd07fb3 100644 (file)
@@ -224,7 +224,8 @@ static bool IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV,
                                        APInt &Offset, const DataLayout &TD) {
   // Trivial case, constant is the global.
   if ((GV = dyn_cast<GlobalValue>(C))) {
-    Offset.clearAllBits();
+    unsigned BitWidth = TD.getPointerTypeSizeInBits(GV->getType());
+    Offset = APInt(BitWidth, 0);
     return true;
   }
 
@@ -238,16 +239,23 @@ static bool IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV,
     return IsConstantOffsetFromGlobal(CE->getOperand(0), GV, Offset, TD);
 
   // i32* getelementptr ([5 x i32]* @a, i32 0, i32 5)
-  if (GEPOperator *GEP = dyn_cast<GEPOperator>(CE)) {
-    // If the base isn't a global+constant, we aren't either.
-    if (!IsConstantOffsetFromGlobal(CE->getOperand(0), GV, Offset, TD))
-      return false;
+  GEPOperator *GEP = dyn_cast<GEPOperator>(CE);
+  if (!GEP)
+    return false;
 
-    // Otherwise, add any offset that our operands provide.
-    return GEP->accumulateConstantOffset(TD, Offset);
-  }
+  unsigned BitWidth = TD.getPointerTypeSizeInBits(GEP->getType());
+  APInt TmpOffset(BitWidth, 0);
 
-  return false;
+  // If the base isn't a global+constant, we aren't either.
+  if (!IsConstantOffsetFromGlobal(CE->getOperand(0), GV, TmpOffset, TD))
+    return false;
+
+  // Otherwise, add any offset that our operands provide.
+  if (!GEP->accumulateConstantOffset(TD, TmpOffset))
+    return false;
+
+  Offset = TmpOffset;
+  return true;
 }
 
 /// ReadDataFromGlobal - Recursive helper to read bits out of global.  C is the
@@ -324,12 +332,12 @@ static bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset,
       // If we read all of the bytes we needed from this element we're done.
       uint64_t NextEltOffset = SL->getElementOffset(Index);
 
-      if (BytesLeft <= NextEltOffset-CurEltOffset-ByteOffset)
+      if (BytesLeft <= NextEltOffset - CurEltOffset - ByteOffset)
         return true;
 
       // Move to the next element of the struct.
-      CurPtr += NextEltOffset-CurEltOffset-ByteOffset;
-      BytesLeft -= NextEltOffset-CurEltOffset-ByteOffset;
+      CurPtr += NextEltOffset - CurEltOffset - ByteOffset;
+      BytesLeft -= NextEltOffset - CurEltOffset - ByteOffset;
       ByteOffset = 0;
       CurEltOffset = NextEltOffset;
     }
@@ -338,7 +346,7 @@ static bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset,
 
   if (isa<ConstantArray>(C) || isa<ConstantVector>(C) ||
       isa<ConstantDataSequential>(C)) {
-    Type *EltTy = cast<SequentialType>(C->getType())->getElementType();
+    Type *EltTy = C->getType()->getSequentialElementType();
     uint64_t EltSize = TD.getTypeAllocSize(EltTy);
     uint64_t Index = ByteOffset / EltSize;
     uint64_t Offset = ByteOffset - Index * EltSize;
@@ -346,7 +354,7 @@ static bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset,
     if (ArrayType *AT = dyn_cast<ArrayType>(C->getType()))
       NumElts = AT->getNumElements();
     else
-      NumElts = cast<VectorType>(C->getType())->getNumElements();
+      NumElts = C->getType()->getVectorNumElements();
 
     for (; Index != NumElts; ++Index) {
       if (!ReadDataFromGlobal(C->getAggregateElement(Index), Offset, CurPtr,
@@ -367,7 +375,7 @@ static bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset,
 
   if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) {
     if (CE->getOpcode() == Instruction::IntToPtr &&
-        CE->getOperand(0)->getType() == TD.getIntPtrType(CE->getContext())) {
+        CE->getOperand(0)->getType() == TD.getIntPtrType(CE->getType())) {
       return ReadDataFromGlobal(CE->getOperand(0), ByteOffset, CurPtr,
                                 BytesLeft, TD);
     }
@@ -379,26 +387,29 @@ static bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset,
 
 static Constant *FoldReinterpretLoadFromConstPtr(Constant *C,
                                                  const DataLayout &TD) {
-  Type *LoadTy = cast<PointerType>(C->getType())->getElementType();
+  PointerType *PTy = cast<PointerType>(C->getType());
+  Type *LoadTy = PTy->getElementType();
   IntegerType *IntType = dyn_cast<IntegerType>(LoadTy);
 
   // If this isn't an integer load we can't fold it directly.
   if (!IntType) {
+    unsigned AS = PTy->getAddressSpace();
+
     // If this is a float/double load, we can try folding it as an int32/64 load
     // and then bitcast the result.  This can be useful for union cases.  Note
     // that address spaces don't matter here since we're not going to result in
     // an actual new load.
     Type *MapTy;
     if (LoadTy->isHalfTy())
-      MapTy = Type::getInt16PtrTy(C->getContext());
+      MapTy = Type::getInt16PtrTy(C->getContext(), AS);
     else if (LoadTy->isFloatTy())
-      MapTy = Type::getInt32PtrTy(C->getContext());
+      MapTy = Type::getInt32PtrTy(C->getContext(), AS);
     else if (LoadTy->isDoubleTy())
-      MapTy = Type::getInt64PtrTy(C->getContext());
+      MapTy = Type::getInt64PtrTy(C->getContext(), AS);
     else if (LoadTy->isVectorTy()) {
-      MapTy = IntegerType::get(C->getContext(),
-                               TD.getTypeAllocSizeInBits(LoadTy));
-      MapTy = PointerType::getUnqual(MapTy);
+      MapTy = PointerType::getIntNPtrTy(C->getContext(),
+                                        TD.getTypeAllocSizeInBits(LoadTy),
+                                        AS);
     } else
       return 0;
 
@@ -409,10 +420,11 @@ static Constant *FoldReinterpretLoadFromConstPtr(Constant *C,
   }
 
   unsigned BytesLoaded = (IntType->getBitWidth() + 7) / 8;
-  if (BytesLoaded > 32 || BytesLoaded == 0) return 0;
+  if (BytesLoaded > 32 || BytesLoaded == 0)
+    return 0;
 
   GlobalValue *GVal;
-  APInt Offset(TD.getPointerSizeInBits(), 0);
+  APInt Offset;
   if (!IsConstantOffsetFromGlobal(C, GVal, Offset, TD))
     return 0;
 
@@ -423,7 +435,8 @@ static Constant *FoldReinterpretLoadFromConstPtr(Constant *C,
 
   // If we're loading off the beginning of the global, some bytes may be valid,
   // but we don't try to handle this.
-  if (Offset.isNegative()) return 0;
+  if (Offset.isNegative())
+    return 0;
 
   // If we're not accessing anything in this constant, the result is undefined.
   if (Offset.getZExtValue() >=
@@ -580,13 +593,13 @@ static Constant *SymbolicallyEvaluateBinop(unsigned Opc, Constant *Op0,
   // constant.  This happens frequently when iterating over a global array.
   if (Opc == Instruction::Sub && DL) {
     GlobalValue *GV1, *GV2;
-    unsigned PtrSize = DL->getPointerSizeInBits();
-    unsigned OpSize = DL->getTypeSizeInBits(Op0->getType());
-    APInt Offs1(PtrSize, 0), Offs2(PtrSize, 0);
+    APInt Offs1, Offs2;
 
     if (IsConstantOffsetFromGlobal(Op0, GV1, Offs1, *DL))
       if (IsConstantOffsetFromGlobal(Op1, GV2, Offs2, *DL) &&
           GV1 == GV2) {
+        unsigned OpSize = DL->getTypeSizeInBits(Op0->getType());
+
         // (&GV+C1) - (&GV+C2) -> C1-C2, pointer arithmetic cannot overflow.
         // PtrToInt may change the bitwidth so we have convert to the right size
         // first.
@@ -604,8 +617,10 @@ static Constant *SymbolicallyEvaluateBinop(unsigned Opc, Constant *Op0,
 static Constant *CastGEPIndices(ArrayRef<Constant *> Ops,
                                 Type *ResultTy, const DataLayout *TD,
                                 const TargetLibraryInfo *TLI) {
-  if (!TD) return 0;
-  Type *IntPtrTy = TD->getIntPtrType(ResultTy->getContext());
+  if (!TD)
+    return 0;
+
+  Type *IntPtrTy = TD->getIntPtrType(ResultTy);
 
   bool Any = false;
   SmallVector<Constant*, 32> NewIdxs;
@@ -648,7 +663,7 @@ static Constant* StripPtrCastKeepAS(Constant* Ptr) {
   if (NewPtrTy->getAddressSpace() != OldPtrTy->getAddressSpace()) {
     NewPtrTy = NewPtrTy->getElementType()->getPointerTo(
       OldPtrTy->getAddressSpace());
-    Ptr = ConstantExpr::getBitCast(Ptr, NewPtrTy);
+    Ptr = ConstantExpr::getPointerCast(Ptr, NewPtrTy);
   }
   return Ptr;
 }
@@ -659,11 +674,12 @@ static Constant *SymbolicallyEvaluateGEP(ArrayRef<Constant *> Ops,
                                          Type *ResultTy, const DataLayout *TD,
                                          const TargetLibraryInfo *TLI) {
   Constant *Ptr = Ops[0];
-  if (!TD || !cast<PointerType>(Ptr->getType())->getElementType()->isSized() ||
+  if (!TD || !Ptr->getType()->getPointerElementType()->isSized() ||
       !Ptr->getType()->isPointerTy())
     return 0;
 
-  Type *IntPtrTy = TD->getIntPtrType(Ptr->getContext());
+  Type *IntPtrTy = TD->getIntPtrType(Ptr->getType());
+  Type *ResultElementTy = ResultTy->getPointerElementType();
 
   // If this is a constant expr gep that is effectively computing an
   // "offsetof", fold it into 'cast int Size to T*' instead of 'gep 0, 0, 12'
@@ -672,8 +688,7 @@ static Constant *SymbolicallyEvaluateGEP(ArrayRef<Constant *> Ops,
 
       // If this is "gep i8* Ptr, (sub 0, V)", fold this as:
       // "inttoptr (sub (ptrtoint Ptr), V)"
-      if (Ops.size() == 2 &&
-          cast<PointerType>(ResultTy)->getElementType()->isIntegerTy(8)) {
+      if (Ops.size() == 2 && ResultElementTy->isIntegerTy(8)) {
         ConstantExpr *CE = dyn_cast<ConstantExpr>(Ops[1]);
         assert((CE == 0 || CE->getType() == IntPtrTy) &&
                "CastGEPIndices didn't canonicalize index types!");
@@ -739,7 +754,8 @@ static Constant *SymbolicallyEvaluateGEP(ArrayRef<Constant *> Ops,
   // Also, this helps GlobalOpt do SROA on GlobalVariables.
   Type *Ty = Ptr->getType();
   assert(Ty->isPointerTy() && "Forming regular GEP of non-pointer type");
-  SmallVector<Constant*, 32> NewIdxs;
+  SmallVector<Constant *, 32> NewIdxs;
+
   do {
     if (SequentialType *ATy = dyn_cast<SequentialType>(Ty)) {
       if (ATy->isPointerTy()) {
@@ -754,7 +770,6 @@ static Constant *SymbolicallyEvaluateGEP(ArrayRef<Constant *> Ops,
 
       // Determine which element of the array the offset points into.
       APInt ElemSize(BitWidth, TD->getTypeAllocSize(ATy->getElementType()));
-      IntegerType *IntPtrTy = TD->getIntPtrType(Ty->getContext());
       if (ElemSize == 0)
         // The element size is 0. This may be [0 x Ty]*, so just use a zero
         // index for this level and proceed to the next level to see if it can
@@ -789,7 +804,7 @@ static Constant *SymbolicallyEvaluateGEP(ArrayRef<Constant *> Ops,
       // We've reached some non-indexable type.
       break;
     }
-  } while (Ty != cast<PointerType>(ResultTy)->getElementType());
+  } while (Ty != ResultElementTy);
 
   // If we haven't used up the entire offset by descending the static
   // type, then the offset is pointing into the middle of an indivisible
@@ -799,12 +814,12 @@ static Constant *SymbolicallyEvaluateGEP(ArrayRef<Constant *> Ops,
 
   // Create a GEP.
   Constant *C = ConstantExpr::getGetElementPtr(Ptr, NewIdxs);
-  assert(cast<PointerType>(C->getType())->getElementType() == Ty &&
+  assert(C->getType()->getPointerElementType() == Ty &&
          "Computed GetElementPtr has unexpected type!");
 
   // If we ended up indexing a member with a type that doesn't match
   // the type of what the original indices indexed, add a cast.
-  if (Ty != cast<PointerType>(ResultTy)->getElementType())
+  if (Ty != ResultElementTy)
     C = FoldBitCast(C, ResultTy, *TD);
 
   return C;
@@ -966,10 +981,11 @@ Constant *llvm::ConstantFoldInstOperands(unsigned Opcode, Type *DestTy,
       if (TD && CE->getOpcode() == Instruction::IntToPtr) {
         Constant *Input = CE->getOperand(0);
         unsigned InWidth = Input->getType()->getScalarSizeInBits();
-        if (TD->getPointerSizeInBits() < InWidth) {
+        unsigned PtrWidth = TD->getPointerTypeSizeInBits(CE->getType());
+        if (PtrWidth < InWidth) {
           Constant *Mask =
-            ConstantInt::get(CE->getContext(), APInt::getLowBitsSet(InWidth,
-                                                  TD->getPointerSizeInBits()));
+            ConstantInt::get(CE->getContext(),
+                             APInt::getLowBitsSet(InWidth, PtrWidth));
           Input = ConstantExpr::getAnd(Input, Mask);
         }
         // Do a zext or trunc to get to the dest size.
@@ -979,13 +995,22 @@ Constant *llvm::ConstantFoldInstOperands(unsigned Opcode, Type *DestTy,
     return ConstantExpr::getCast(Opcode, Ops[0], DestTy);
   case Instruction::IntToPtr:
     // If the input is a ptrtoint, turn the pair into a ptr to ptr bitcast if
-    // the int size is >= the ptr size.  This requires knowing the width of a
-    // pointer, so it can't be done in ConstantExpr::getCast.
-    if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ops[0]))
-      if (TD &&
-          TD->getPointerSizeInBits() <= CE->getType()->getScalarSizeInBits() &&
-          CE->getOpcode() == Instruction::PtrToInt)
-        return FoldBitCast(CE->getOperand(0), DestTy, *TD);
+    // the int size is >= the ptr size and the address spaces are the same.
+    // This requires knowing the width of a pointer, so it can't be done in
+    // ConstantExpr::getCast.
+    if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ops[0])) {
+      if (TD && CE->getOpcode() == Instruction::PtrToInt) {
+        Constant *SrcPtr = CE->getOperand(0);
+        unsigned SrcPtrSize = TD->getPointerTypeSizeInBits(SrcPtr->getType());
+        unsigned MidIntSize = CE->getType()->getScalarSizeInBits();
+
+        if (MidIntSize >= SrcPtrSize) {
+          unsigned SrcAS = SrcPtr->getType()->getPointerAddressSpace();
+          if (SrcAS == DestTy->getPointerAddressSpace())
+            return FoldBitCast(CE->getOperand(0), DestTy, *TD);
+        }
+      }
+    }
 
     return ConstantExpr::getCast(Opcode, Ops[0], DestTy);
   case Instruction::Trunc:
@@ -997,6 +1022,7 @@ Constant *llvm::ConstantFoldInstOperands(unsigned Opcode, Type *DestTy,
   case Instruction::SIToFP:
   case Instruction::FPToUI:
   case Instruction::FPToSI:
+  case Instruction::AddrSpaceCast:
       return ConstantExpr::getCast(Opcode, Ops[0], DestTy);
   case Instruction::BitCast:
     if (TD)
@@ -1037,8 +1063,8 @@ Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate,
   // around to know if bit truncation is happening.
   if (ConstantExpr *CE0 = dyn_cast<ConstantExpr>(Ops0)) {
     if (TD && Ops1->isNullValue()) {
-      Type *IntPtrTy = TD->getIntPtrType(CE0->getContext());
       if (CE0->getOpcode() == Instruction::IntToPtr) {
+        Type *IntPtrTy = TD->getIntPtrType(CE0->getType());
         // Convert the integer value to the right size to ensure we get the
         // proper extension or truncation.
         Constant *C = ConstantExpr::getIntegerCast(CE0->getOperand(0),
@@ -1049,19 +1075,21 @@ Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate,
 
       // Only do this transformation if the int is intptrty in size, otherwise
       // there is a truncation or extension that we aren't modeling.
-      if (CE0->getOpcode() == Instruction::PtrToInt &&
-          CE0->getType() == IntPtrTy) {
-        Constant *C = CE0->getOperand(0);
-        Constant *Null = Constant::getNullValue(C->getType());
-        return ConstantFoldCompareInstOperands(Predicate, C, Null, TD, TLI);
+      if (CE0->getOpcode() == Instruction::PtrToInt) {
+        Type *IntPtrTy = TD->getIntPtrType(CE0->getOperand(0)->getType());
+        if (CE0->getType() == IntPtrTy) {
+          Constant *C = CE0->getOperand(0);
+          Constant *Null = Constant::getNullValue(C->getType());
+          return ConstantFoldCompareInstOperands(Predicate, C, Null, TD, TLI);
+        }
       }
     }
 
     if (ConstantExpr *CE1 = dyn_cast<ConstantExpr>(Ops1)) {
       if (TD && CE0->getOpcode() == CE1->getOpcode()) {
-        Type *IntPtrTy = TD->getIntPtrType(CE0->getContext());
-
         if (CE0->getOpcode() == Instruction::IntToPtr) {
+          Type *IntPtrTy = TD->getIntPtrType(CE0->getType());
+
           // Convert the integer value to the right size to ensure we get the
           // proper extension or truncation.
           Constant *C0 = ConstantExpr::getIntegerCast(CE0->getOperand(0),
@@ -1073,11 +1101,17 @@ Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate,
 
         // Only do this transformation if the int is intptrty in size, otherwise
         // there is a truncation or extension that we aren't modeling.
-        if ((CE0->getOpcode() == Instruction::PtrToInt &&
-             CE0->getType() == IntPtrTy &&
-             CE0->getOperand(0)->getType() == CE1->getOperand(0)->getType()))
-          return ConstantFoldCompareInstOperands(Predicate, CE0->getOperand(0),
-                                                 CE1->getOperand(0), TD, TLI);
+        if (CE0->getOpcode() == Instruction::PtrToInt) {
+          Type *IntPtrTy = TD->getIntPtrType(CE0->getOperand(0)->getType());
+          if (CE0->getType() == IntPtrTy &&
+              CE0->getOperand(0)->getType() == CE1->getOperand(0)->getType()) {
+            return ConstantFoldCompareInstOperands(Predicate,
+                                                   CE0->getOperand(0),
+                                                   CE1->getOperand(0),
+                                                   TD,
+                                                   TLI);
+          }
+        }
       }
     }
 
@@ -1265,7 +1299,7 @@ static Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double),
 static Constant *ConstantFoldConvertToInt(const APFloat &Val,
                                           bool roundTowardZero, Type *Ty) {
   // All of these conversion intrinsics form an integer of at most 64bits.
-  unsigned ResultWidth = cast<IntegerType>(Ty)->getBitWidth();
+  unsigned ResultWidth = Ty->getIntegerBitWidth();
   assert(ResultWidth <= 64 &&
          "Can only constant fold conversions to 64 and 32 bit ints");