Remove attribution from file headers, per discussion on llvmdev.
[oota-llvm.git] / lib / Analysis / ConstantFolding.cpp
index 886dd9f4f762a7a60b8fa0762ed0b24027fe8dbd..21e19444cfe5414b5fb39d74427eec0879644be6 100644 (file)
@@ -2,8 +2,8 @@
 //
 //                     The LLVM Compiler Infrastructure
 //
-// This file was developed by the LLVM research group and is distributed under
-// the University of Illinois Open Source License. See LICENSE.TXT for details.
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
 //
 //===----------------------------------------------------------------------===//
 //
@@ -55,7 +55,8 @@ static bool IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV,
   if (CE->getOpcode() == Instruction::GetElementPtr) {
     // Cannot compute this if the element type of the pointer is missing size
     // info.
-    if (!cast<PointerType>(CE->getOperand(0)->getType())->getElementType()->isSized())
+    if (!cast<PointerType>(CE->getOperand(0)->getType())
+                 ->getElementType()->isSized())
       return false;
     
     // If the base isn't a global+constant, we aren't either.
@@ -117,7 +118,7 @@ static Constant *SymbolicallyEvaluateBinop(unsigned Opc, Constant *Op0,
 
 /// SymbolicallyEvaluateGEP - If we can symbolically evaluate the specified GEP
 /// constant expression, do so.
-static Constant *SymbolicallyEvaluateGEP(Constant** Ops, unsigned NumOps,
+static Constant *SymbolicallyEvaluateGEP(Constant* const* Ops, unsigned NumOps,
                                          const Type *ResultTy,
                                          const TargetData *TD) {
   Constant *Ptr = Ops[0];
@@ -144,6 +145,122 @@ static Constant *SymbolicallyEvaluateGEP(Constant** Ops, unsigned NumOps,
   return 0;
 }
 
+/// FoldBitCast - Constant fold bitcast, symbolically evaluating it with 
+/// targetdata.  Return 0 if unfoldable.
+static Constant *FoldBitCast(Constant *C, const Type *DestTy,
+                             const TargetData &TD) {
+  // If this is a bitcast from constant vector -> vector, fold it.
+  if (ConstantVector *CV = dyn_cast<ConstantVector>(C)) {
+    if (const VectorType *DestVTy = dyn_cast<VectorType>(DestTy)) {
+      // If the element types match, VMCore can fold it.
+      unsigned NumDstElt = DestVTy->getNumElements();
+      unsigned NumSrcElt = CV->getNumOperands();
+      if (NumDstElt == NumSrcElt)
+        return 0;
+      
+      const Type *SrcEltTy = CV->getType()->getElementType();
+      const Type *DstEltTy = DestVTy->getElementType();
+      
+      // Otherwise, we're changing the number of elements in a vector, which 
+      // requires endianness information to do the right thing.  For example,
+      //    bitcast (<2 x i64> <i64 0, i64 1> to <4 x i32>)
+      // folds to (little endian):
+      //    <4 x i32> <i32 0, i32 0, i32 1, i32 0>
+      // and to (big endian):
+      //    <4 x i32> <i32 0, i32 0, i32 0, i32 1>
+      
+      // First thing is first.  We only want to think about integer here, so if
+      // we have something in FP form, recast it as integer.
+      if (DstEltTy->isFloatingPoint()) {
+        // Fold to an vector of integers with same size as our FP type.
+        unsigned FPWidth = DstEltTy->getPrimitiveSizeInBits();
+        const Type *DestIVTy = VectorType::get(IntegerType::get(FPWidth),
+                                               NumDstElt);
+        // Recursively handle this integer conversion, if possible.
+        C = FoldBitCast(C, DestIVTy, TD);
+        if (!C) return 0;
+        
+        // Finally, VMCore can handle this now that #elts line up.
+        return ConstantExpr::getBitCast(C, DestTy);
+      }
+      
+      // Okay, we know the destination is integer, if the input is FP, convert
+      // it to integer first.
+      if (SrcEltTy->isFloatingPoint()) {
+        unsigned FPWidth = SrcEltTy->getPrimitiveSizeInBits();
+        const Type *SrcIVTy = VectorType::get(IntegerType::get(FPWidth),
+                                              NumSrcElt);
+        // Ask VMCore to do the conversion now that #elts line up.
+        C = ConstantExpr::getBitCast(C, SrcIVTy);
+        CV = dyn_cast<ConstantVector>(C);
+        if (!CV) return 0;  // If VMCore wasn't able to fold it, bail out.
+      }
+      
+      // Now we know that the input and output vectors are both integer vectors
+      // of the same size, and that their #elements is not the same.  Do the
+      // conversion here, which depends on whether the input or output has
+      // more elements.
+      bool isLittleEndian = TD.isLittleEndian();
+      
+      SmallVector<Constant*, 32> Result;
+      if (NumDstElt < NumSrcElt) {
+        // Handle: bitcast (<4 x i32> <i32 0, i32 1, i32 2, i32 3> to <2 x i64>)
+        Constant *Zero = Constant::getNullValue(DstEltTy);
+        unsigned Ratio = NumSrcElt/NumDstElt;
+        unsigned SrcBitSize = SrcEltTy->getPrimitiveSizeInBits();
+        unsigned SrcElt = 0;
+        for (unsigned i = 0; i != NumDstElt; ++i) {
+          // Build each element of the result.
+          Constant *Elt = Zero;
+          unsigned ShiftAmt = isLittleEndian ? 0 : SrcBitSize*(Ratio-1);
+          for (unsigned j = 0; j != Ratio; ++j) {
+            Constant *Src = dyn_cast<ConstantInt>(CV->getOperand(SrcElt++));
+            if (!Src) return 0;  // Reject constantexpr elements.
+            
+            // Zero extend the element to the right size.
+            Src = ConstantExpr::getZExt(Src, Elt->getType());
+            
+            // Shift it to the right place, depending on endianness.
+            Src = ConstantExpr::getShl(Src, 
+                                    ConstantInt::get(Src->getType(), ShiftAmt));
+            ShiftAmt += isLittleEndian ? SrcBitSize : -SrcBitSize;
+            
+            // Mix it in.
+            Elt = ConstantExpr::getOr(Elt, Src);
+          }
+          Result.push_back(Elt);
+        }
+      } else {
+        // Handle: bitcast (<2 x i64> <i64 0, i64 1> to <4 x i32>)
+        unsigned Ratio = NumDstElt/NumSrcElt;
+        unsigned DstBitSize = DstEltTy->getPrimitiveSizeInBits();
+        
+        // Loop over each source value, expanding into multiple results.
+        for (unsigned i = 0; i != NumSrcElt; ++i) {
+          Constant *Src = dyn_cast<ConstantInt>(CV->getOperand(i));
+          if (!Src) return 0;  // Reject constantexpr elements.
+
+          unsigned ShiftAmt = isLittleEndian ? 0 : DstBitSize*(Ratio-1);
+          for (unsigned j = 0; j != Ratio; ++j) {
+            // Shift the piece of the value into the right place, depending on
+            // endianness.
+            Constant *Elt = ConstantExpr::getLShr(Src, 
+                                ConstantInt::get(Src->getType(), ShiftAmt));
+            ShiftAmt += isLittleEndian ? DstBitSize : -DstBitSize;
+
+            // Truncate and remember this piece.
+            Result.push_back(ConstantExpr::getTrunc(Elt, DstEltTy));
+          }
+        }
+      }
+      
+      return ConstantVector::get(&Result[0], Result.size());
+    }
+  }
+  
+  return 0;
+}
+
 
 //===----------------------------------------------------------------------===//
 // Constant Folding public APIs
@@ -181,7 +298,12 @@ Constant *llvm::ConstantFoldInstruction(Instruction *I, const TargetData *TD) {
     else
       return 0;  // All operands not constant!
 
-  return ConstantFoldInstOperands(I, &Ops[0], Ops.size(), TD);
+  if (const CmpInst *CI = dyn_cast<CmpInst>(I))
+    return ConstantFoldCompareInstOperands(CI->getPredicate(),
+                                           &Ops[0], Ops.size(), TD);
+  else
+    return ConstantFoldInstOperands(I->getOpcode(), I->getType(),
+                                    &Ops[0], Ops.size(), TD);
 }
 
 /// ConstantFoldInstOperands - Attempt to constant fold an instruction with the
@@ -190,23 +312,19 @@ Constant *llvm::ConstantFoldInstruction(Instruction *I, const TargetData *TD) {
 /// attempting to fold instructions like loads and stores, which have no
 /// constant expression form.
 ///
-Constant *llvm::ConstantFoldInstOperands(const Instruction* I
-                                         Constant** Ops, unsigned NumOps,
+Constant *llvm::ConstantFoldInstOperands(unsigned Opcode, const Type *DestTy
+                                         Constant* const* Ops, unsigned NumOps,
                                          const TargetData *TD) {
-  unsigned Opc = I->getOpcode();
-  const Type *DestTy = I->getType();
-
   // Handle easy binops first.
-  if (isa<BinaryOperator>(I)) {
+  if (Instruction::isBinaryOp(Opcode)) {
     if (isa<ConstantExpr>(Ops[0]) || isa<ConstantExpr>(Ops[1]))
-      if (Constant *C = SymbolicallyEvaluateBinop(I->getOpcode(), Ops[0],
-                                                  Ops[1], TD))
+      if (Constant *C = SymbolicallyEvaluateBinop(Opcode, Ops[0], Ops[1], TD))
         return C;
     
-    return ConstantExpr::get(Opc, Ops[0], Ops[1]);
+    return ConstantExpr::get(Opcode, Ops[0], Ops[1]);
   }
   
-  switch (Opc) {
+  switch (Opcode) {
   default: return 0;
   case Instruction::Call:
     if (Function *F = dyn_cast<Function>(Ops[0]))
@@ -215,8 +333,7 @@ Constant *llvm::ConstantFoldInstOperands(const Instruction* I,
     return 0;
   case Instruction::ICmp:
   case Instruction::FCmp:
-    return ConstantExpr::getCompare(cast<CmpInst>(I)->getPredicate(), Ops[0], 
-                                    Ops[1]);
+    assert(0 &&"This function is invalid for compares: no predicate specified");
   case Instruction::PtrToInt:
     // If the input is a inttoptr, eliminate the pair.  This requires knowing
     // the width of a pointer, so it can't be done in ConstantExpr::getCast.
@@ -229,10 +346,10 @@ Constant *llvm::ConstantFoldInstOperands(const Instruction* I,
                                                 TD->getPointerSizeInBits()));
         Input = ConstantExpr::getAnd(Input, Mask);
         // Do a zext or trunc to get to the dest size.
-        return ConstantExpr::getIntegerCast(Input, I->getType(), false);
+        return ConstantExpr::getIntegerCast(Input, DestTy, false);
       }
     }
-    // FALL THROUGH.
+    return ConstantExpr::getCast(Opcode, Ops[0], DestTy);
   case Instruction::IntToPtr:
   case Instruction::Trunc:
   case Instruction::ZExt:
@@ -243,8 +360,12 @@ Constant *llvm::ConstantFoldInstOperands(const Instruction* I,
   case Instruction::SIToFP:
   case Instruction::FPToUI:
   case Instruction::FPToSI:
+      return ConstantExpr::getCast(Opcode, Ops[0], DestTy);
   case Instruction::BitCast:
-    return ConstantExpr::getCast(Opc, Ops[0], DestTy);
+    if (TD)
+      if (Constant *C = FoldBitCast(Ops[0], DestTy, *TD))
+        return C;
+    return ConstantExpr::getBitCast(Ops[0], DestTy);
   case Instruction::Select:
     return ConstantExpr::getSelect(Ops[0], Ops[1], Ops[2]);
   case Instruction::ExtractElement:
@@ -254,13 +375,73 @@ Constant *llvm::ConstantFoldInstOperands(const Instruction* I,
   case Instruction::ShuffleVector:
     return ConstantExpr::getShuffleVector(Ops[0], Ops[1], Ops[2]);
   case Instruction::GetElementPtr:
-    if (Constant *C = SymbolicallyEvaluateGEP(Ops, NumOps, I->getType(), TD))
+    if (Constant *C = SymbolicallyEvaluateGEP(Ops, NumOps, DestTy, TD))
       return C;
     
     return ConstantExpr::getGetElementPtr(Ops[0], Ops+1, NumOps-1);
   }
 }
 
+/// ConstantFoldCompareInstOperands - Attempt to constant fold a compare
+/// instruction (icmp/fcmp) with the specified operands.  If it fails, it
+/// returns a constant expression of the specified operands.
+///
+Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate,
+                                                Constant*const * Ops, 
+                                                unsigned NumOps,
+                                                const TargetData *TD) {
+  // fold: icmp (inttoptr x), null         -> icmp x, 0
+  // fold: icmp (ptrtoint x), 0            -> icmp x, null
+  // fold: icmp (inttoptr x), (inttoptr y) -> icmp x, y
+  // fold: icmp (ptrtoint x), (ptrtoint y) -> icmp x, y
+  //
+  // ConstantExpr::getCompare cannot do this, because it doesn't have TD
+  // around to know if bit truncation is happening.
+  if (ConstantExpr *CE0 = dyn_cast<ConstantExpr>(Ops[0])) {
+    if (TD && Ops[1]->isNullValue()) {
+      const Type *IntPtrTy = TD->getIntPtrType();
+      if (CE0->getOpcode() == Instruction::IntToPtr) {
+        // Convert the integer value to the right size to ensure we get the
+        // proper extension or truncation.
+        Constant *C = ConstantExpr::getIntegerCast(CE0->getOperand(0),
+                                                   IntPtrTy, false);
+        Constant *NewOps[] = { C, Constant::getNullValue(C->getType()) };
+        return ConstantFoldCompareInstOperands(Predicate, NewOps, 2, TD);
+      }
+      
+      // Only do this transformation if the int is intptrty in size, otherwise
+      // there is a truncation or extension that we aren't modeling.
+      if (CE0->getOpcode() == Instruction::PtrToInt && 
+          CE0->getType() == IntPtrTy) {
+        Constant *C = CE0->getOperand(0);
+        Constant *NewOps[] = { C, Constant::getNullValue(C->getType()) };
+        // FIXME!
+        return ConstantFoldCompareInstOperands(Predicate, NewOps, 2, TD);
+      }
+    }
+    
+    if (TD && isa<ConstantExpr>(Ops[1]) &&
+        cast<ConstantExpr>(Ops[1])->getOpcode() == CE0->getOpcode()) {
+      const Type *IntPtrTy = TD->getIntPtrType();
+      // Only do this transformation if the int is intptrty in size, otherwise
+      // there is a truncation or extension that we aren't modeling.
+      if ((CE0->getOpcode() == Instruction::IntToPtr &&
+           CE0->getOperand(0)->getType() == IntPtrTy &&
+           Ops[1]->getOperand(0)->getType() == IntPtrTy) ||
+          (CE0->getOpcode() == Instruction::PtrToInt &&
+           CE0->getType() == IntPtrTy &&
+           CE0->getOperand(0)->getType() == Ops[1]->getOperand(0)->getType())) {
+        Constant *NewOps[] = { 
+          CE0->getOperand(0), cast<ConstantExpr>(Ops[1])->getOperand(0) 
+        };
+        return ConstantFoldCompareInstOperands(Predicate, NewOps, 2, TD);
+      }
+    }
+  }
+  return ConstantExpr::getCompare(Predicate, Ops[0], Ops[1]); 
+}
+
+
 /// ConstantFoldLoadThroughGEPConstantExpr - Given a constant and a
 /// getelementptr constantexpr, return the constant value being addressed by the
 /// constant expression, or null if something is funny and we can't decide.
@@ -438,7 +619,8 @@ static Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double),
 /// with the specified arguments, returning null if unsuccessful.
 
 Constant *
-llvm::ConstantFoldCall(Function *F, Constant** Operands, unsigned NumOperands) {
+llvm::ConstantFoldCall(Function *F, 
+                       Constant* const* Operands, unsigned NumOperands) {
   const ValueName *NameVal = F->getValueName();
   if (NameVal == 0) return 0;
   const char *Str = NameVal->getKeyData();
@@ -516,18 +698,14 @@ llvm::ConstantFoldCall(Function *F, Constant** Operands, unsigned NumOperands) {
         break;
       }
     } else if (ConstantInt *Op = dyn_cast<ConstantInt>(Operands[0])) {
-      if (Len > 11 && !memcmp(Str, "llvm.bswap", 10)) {
+      if (Len > 11 && !memcmp(Str, "llvm.bswap", 10))
         return ConstantInt::get(Op->getValue().byteSwap());
-      } else if (Len > 11 && !memcmp(Str, "llvm.ctpop", 10)) {
-        uint64_t ctpop = Op->getValue().countPopulation();
-        return ConstantInt::get(Ty, ctpop);
-      } else if (Len > 10 && !memcmp(Str, "llvm.cttz", 9)) {
-        uint64_t cttz = Op->getValue().countTrailingZeros();
-        return ConstantInt::get(Ty, cttz);
-      } else if (Len > 10 && !memcmp(Str, "llvm.ctlz", 9)) {
-        uint64_t ctlz = Op->getValue().countLeadingZeros();
-        return ConstantInt::get(Ty, ctlz);
-      }
+      else if (Len > 11 && !memcmp(Str, "llvm.ctpop", 10))
+        return ConstantInt::get(Ty, Op->getValue().countPopulation());
+      else if (Len > 10 && !memcmp(Str, "llvm.cttz", 9))
+        return ConstantInt::get(Ty, Op->getValue().countTrailingZeros());
+      else if (Len > 10 && !memcmp(Str, "llvm.ctlz", 9))
+        return ConstantInt::get(Ty, Op->getValue().countLeadingZeros());
     }
   } else if (NumOperands == 2) {
     if (ConstantFP *Op1 = dyn_cast<ConstantFP>(Operands[0])) {