Add support for new intrinsic
[oota-llvm.git] / lib / VMCore / ConstantFold.cpp
index aca7df2589357b2699afc41d19375f4d185acb60..04ec28bfb83a9bad84fb13c68dde8ddd49cd03d0 100644 (file)
@@ -15,6 +15,7 @@
 #include "llvm/iPHINode.h"
 #include "llvm/InstrTypes.h"
 #include "llvm/DerivedTypes.h"
+#include "llvm/Support/GetElementPtrTypeIterator.h"
 #include <cmath>
 using namespace llvm;
 
@@ -159,22 +160,29 @@ Constant *llvm::ConstantFoldGetElementPtr(const Constant *C,
   // TODO If C is null and all idx's are null, return null of the right type.
 
 
-  if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) {
+  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(const_cast<Constant*>(C))) {
     // Combine Indices - If the source pointer to this getelementptr instruction
     // is a getelementptr instruction, combine the indices of the two
     // getelementptr instructions into a single instruction.
     //
     if (CE->getOpcode() == Instruction::GetElementPtr) {
-      if (CE->getOperand(CE->getNumOperands()-1)->getType() == Type::LongTy) {
+      const Type *LastTy = 0;
+      for (gep_type_iterator I = gep_type_begin(CE), E = gep_type_end(CE);
+           I != E; ++I)
+        LastTy = *I;
+
+      if (LastTy && isa<ArrayType>(LastTy)) {
         std::vector<Constant*> NewIndices;
         NewIndices.reserve(IdxList.size() + CE->getNumOperands());
         for (unsigned i = 1, e = CE->getNumOperands()-1; i != e; ++i)
           NewIndices.push_back(cast<Constant>(CE->getOperand(i)));
 
         // Add the last index of the source with the first index of the new GEP.
+        // Make sure to handle the case when they are actually different types.
         Constant *Combined =
-          ConstantExpr::get(Instruction::Add, IdxList[0],
-                            CE->getOperand(CE->getNumOperands()-1));
+          ConstantExpr::get(Instruction::Add,
+                            ConstantExpr::getCast(IdxList[0], Type::LongTy),
+   ConstantExpr::getCast(CE->getOperand(CE->getNumOperands()-1), Type::LongTy));
                             
         NewIndices.push_back(Combined);
         NewIndices.insert(NewIndices.end(), IdxList.begin()+1, IdxList.end());
@@ -256,6 +264,10 @@ class TemplateRules : public ConstRules {
                                  const Constant *V2) const { 
     return SubClassName::LessThan((const ArgType *)V1, (const ArgType *)V2);
   }
+  virtual ConstantBool *equalto(const Constant *V1, 
+                                const Constant *V2) const { 
+    return SubClassName::EqualTo((const ArgType *)V1, (const ArgType *)V2);
+  }
 
   // Casting operators.  ick
   virtual ConstantBool *castToBool(const Constant *V) const {
@@ -313,6 +325,9 @@ class TemplateRules : public ConstRules {
   static ConstantBool *LessThan(const ArgType *V1, const ArgType *V2) {
     return 0;
   }
+  static ConstantBool *EqualTo(const ArgType *V1, const ArgType *V2) {
+    return 0;
+  }
 
   // Casting operators.  ick
   static ConstantBool *CastToBool  (const Constant *V) { return 0; }
@@ -339,6 +354,10 @@ class TemplateRules : public ConstRules {
 // EmptyRules provides a concrete base class of ConstRules that does nothing
 //
 struct EmptyRules : public TemplateRules<Constant, EmptyRules> {
+  static ConstantBool *EqualTo(const Constant *V1, const Constant *V2) {
+    if (V1 == V2) return ConstantBool::True;
+    return 0;
+  }
 };
 
 
@@ -355,6 +374,10 @@ struct BoolRules : public TemplateRules<ConstantBool, BoolRules> {
     return ConstantBool::get(V1->getValue() < V2->getValue());
   }
 
+  static ConstantBool *EqualTo(const Constant *V1, const Constant *V2) {
+    return ConstantBool::get(V1 == V2);
+  }
+
   static Constant *And(const ConstantBool *V1, const ConstantBool *V2) {
     return ConstantBool::get(V1->getValue() & V2->getValue());
   }
@@ -389,64 +412,54 @@ struct BoolRules : public TemplateRules<ConstantBool, BoolRules> {
 
 
 //===----------------------------------------------------------------------===//
-//                            PointerRules Class
+//                            NullPointerRules Class
 //===----------------------------------------------------------------------===//
 //
-// PointerRules provides a concrete base class of ConstRules for pointer types
+// NullPointerRules provides a concrete base class of ConstRules for null
+// pointers.
 //
-struct PointerRules : public TemplateRules<ConstantPointer, PointerRules> {
+struct NullPointerRules : public TemplateRules<ConstantPointerNull,
+                                               NullPointerRules> {
+  static ConstantBool *EqualTo(const Constant *V1, const Constant *V2) {
+    return ConstantBool::True;  // Null pointers are always equal
+  }
   static ConstantBool *CastToBool  (const Constant *V) {
-    if (V->isNullValue()) return ConstantBool::False;
-    return 0;  // Can't const prop other types of pointers
+    return ConstantBool::False;
   }
   static ConstantSInt *CastToSByte (const Constant *V) {
-    if (V->isNullValue()) return ConstantSInt::get(Type::SByteTy, 0);
-    return 0;  // Can't const prop other types of pointers
+    return ConstantSInt::get(Type::SByteTy, 0);
   }
   static ConstantUInt *CastToUByte (const Constant *V) {
-    if (V->isNullValue()) return ConstantUInt::get(Type::UByteTy, 0);
-    return 0;  // Can't const prop other types of pointers
+    return ConstantUInt::get(Type::UByteTy, 0);
   }
   static ConstantSInt *CastToShort (const Constant *V) {
-    if (V->isNullValue()) return ConstantSInt::get(Type::ShortTy, 0);
-    return 0;  // Can't const prop other types of pointers
+    return ConstantSInt::get(Type::ShortTy, 0);
   }
   static ConstantUInt *CastToUShort(const Constant *V) {
-    if (V->isNullValue()) return ConstantUInt::get(Type::UShortTy, 0);
-    return 0;  // Can't const prop other types of pointers
+    return ConstantUInt::get(Type::UShortTy, 0);
   }
   static ConstantSInt *CastToInt   (const Constant *V) {
-    if (V->isNullValue()) return ConstantSInt::get(Type::IntTy, 0);
-    return 0;  // Can't const prop other types of pointers
+    return ConstantSInt::get(Type::IntTy, 0);
   }
   static ConstantUInt *CastToUInt  (const Constant *V) {
-    if (V->isNullValue()) return ConstantUInt::get(Type::UIntTy, 0);
-    return 0;  // Can't const prop other types of pointers
+    return ConstantUInt::get(Type::UIntTy, 0);
   }
   static ConstantSInt *CastToLong  (const Constant *V) {
-    if (V->isNullValue()) return ConstantSInt::get(Type::LongTy, 0);
-    return 0;  // Can't const prop other types of pointers
+    return ConstantSInt::get(Type::LongTy, 0);
   }
   static ConstantUInt *CastToULong (const Constant *V) {
-    if (V->isNullValue()) return ConstantUInt::get(Type::ULongTy, 0);
-    return 0;  // Can't const prop other types of pointers
+    return ConstantUInt::get(Type::ULongTy, 0);
   }
   static ConstantFP   *CastToFloat (const Constant *V) {
-    if (V->isNullValue()) return ConstantFP::get(Type::FloatTy, 0);
-    return 0;  // Can't const prop other types of pointers
+    return ConstantFP::get(Type::FloatTy, 0);
   }
   static ConstantFP   *CastToDouble(const Constant *V) {
-    if (V->isNullValue()) return ConstantFP::get(Type::DoubleTy, 0);
-    return 0;  // Can't const prop other types of pointers
+    return ConstantFP::get(Type::DoubleTy, 0);
   }
 
-  static Constant *CastToPointer(const ConstantPointer *V,
+  static Constant *CastToPointer(const ConstantPointerNull *V,
                                  const PointerType *PTy) {
-    if (V->getType() == PTy)
-      return const_cast<ConstantPointer*>(V);  // Allow cast %PTy %ptr to %PTy
-    if (V->isNullValue())
-      return ConstantPointerNull::get(PTy);
-    return 0;  // Can't const prop other types of pointers
+    return ConstantPointerNull::get(PTy);
   }
 };
 
@@ -488,6 +501,12 @@ struct DirectRules : public TemplateRules<ConstantClass, SuperClass> {
     return ConstantBool::get(R);
   } 
 
+  static ConstantBool *EqualTo(const ConstantClass *V1,
+                               const ConstantClass *V2) {
+    bool R = (BuiltinType)V1->getValue() == (BuiltinType)V2->getValue();
+    return ConstantBool::get(R);
+  }
+
   static Constant *CastToPointer(const ConstantClass *V,
                                  const PointerType *PTy) {
     if (V->isNullValue())    // Is it a FP or Integral null value?
@@ -592,9 +611,9 @@ struct DirectFPRules
 };
 
 ConstRules &ConstRules::get(const Constant &V1, const Constant &V2) {
-  static EmptyRules   EmptyR;
-  static BoolRules    BoolR;
-  static PointerRules PointerR;
+  static EmptyRules       EmptyR;
+  static BoolRules        BoolR;
+  static NullPointerRules NullPointerR;
   static DirectIntRules<ConstantSInt,   signed char , &Type::SByteTy>  SByteR;
   static DirectIntRules<ConstantUInt, unsigned char , &Type::UByteTy>  UByteR;
   static DirectIntRules<ConstantSInt,   signed short, &Type::ShortTy>  ShortR;
@@ -606,7 +625,8 @@ ConstRules &ConstRules::get(const Constant &V1, const Constant &V2) {
   static DirectFPRules <ConstantFP  , float         , &Type::FloatTy>  FloatR;
   static DirectFPRules <ConstantFP  , double        , &Type::DoubleTy> DoubleR;
 
-  if (isa<ConstantExpr>(V1) || isa<ConstantExpr>(V2))
+  if (isa<ConstantExpr>(V1) || isa<ConstantExpr>(V2) ||
+      isa<ConstantPointerRef>(V1) || isa<ConstantPointerRef>(V2))
     return EmptyR;
 
   // FIXME: This assert doesn't work because shifts pass both operands in to
@@ -616,7 +636,7 @@ ConstRules &ConstRules::get(const Constant &V1, const Constant &V2) {
   switch (V1.getType()->getPrimitiveID()) {
   default: assert(0 && "Unknown value type for constant folding!");
   case Type::BoolTyID:    return BoolR;
-  case Type::PointerTyID: return PointerR;
+  case Type::PointerTyID: return NullPointerR;
   case Type::SByteTyID:   return SByteR;
   case Type::UByteTyID:   return UByteR;
   case Type::ShortTyID:   return ShortR;