Fix memory leak in the stackifier, due to the machinebasicblocks not holding
[oota-llvm.git] / lib / VMCore / ConstantFold.cpp
1 //===- ConstantHandling.cpp - Implement ConstantHandling.h ----------------===//
2 // 
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by the LLVM research group and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 // 
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements the various intrinsic operations, on constant values.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/ConstantHandling.h"
15 #include "llvm/iPHINode.h"
16 #include "llvm/InstrTypes.h"
17 #include "llvm/DerivedTypes.h"
18 #include "llvm/Support/GetElementPtrTypeIterator.h"
19 #include <cmath>
20 using namespace llvm;
21
22 // ConstantFoldInstruction - Attempt to constant fold the specified instruction.
23 // If successful, the constant result is returned, if not, null is returned.
24 //
25 Constant *llvm::ConstantFoldInstruction(Instruction *I) {
26   if (PHINode *PN = dyn_cast<PHINode>(I)) {
27     if (PN->getNumIncomingValues() == 0)
28       return Constant::getNullValue(PN->getType());
29     
30     Constant *Result = dyn_cast<Constant>(PN->getIncomingValue(0));
31     if (Result == 0) return 0;
32
33     // Handle PHI nodes specially here...
34     for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i)
35       if (PN->getIncomingValue(i) != Result)
36         return 0;   // Not all the same incoming constants...
37
38     // If we reach here, all incoming values are the same constant.
39     return Result;
40   }
41
42   Constant *Op0 = 0;
43   Constant *Op1 = 0;
44
45   if (I->getNumOperands() != 0) {    // Get first operand if it's a constant...
46     Op0 = dyn_cast<Constant>(I->getOperand(0));
47     if (Op0 == 0) return 0;          // Not a constant?, can't fold
48
49     if (I->getNumOperands() != 1) {  // Get second operand if it's a constant...
50       Op1 = dyn_cast<Constant>(I->getOperand(1));
51       if (Op1 == 0) return 0;        // Not a constant?, can't fold
52     }
53   }
54
55   if (isa<BinaryOperator>(I))
56     return ConstantExpr::get(I->getOpcode(), Op0, Op1);    
57
58   switch (I->getOpcode()) {
59   case Instruction::Cast:
60     return ConstantExpr::getCast(Op0, I->getType());
61   case Instruction::Shl:
62   case Instruction::Shr:
63     return ConstantExpr::getShift(I->getOpcode(), Op0, Op1);
64   case Instruction::GetElementPtr: {
65     std::vector<Constant*> IdxList;
66     IdxList.reserve(I->getNumOperands()-1);
67     if (Op1) IdxList.push_back(Op1);
68     for (unsigned i = 2, e = I->getNumOperands(); i != e; ++i)
69       if (Constant *C = dyn_cast<Constant>(I->getOperand(i)))
70         IdxList.push_back(C);
71       else
72         return 0;  // Non-constant operand
73     return ConstantExpr::getGetElementPtr(Op0, IdxList);
74   }
75   default:
76     return 0;
77   }
78 }
79
80 static unsigned getSize(const Type *Ty) {
81   unsigned S = Ty->getPrimitiveSize();
82   return S ? S : 8;  // Treat pointers at 8 bytes
83 }
84
85 Constant *llvm::ConstantFoldCastInstruction(const Constant *V,
86                                             const Type *DestTy) {
87   if (V->getType() == DestTy) return (Constant*)V;
88
89   if (const ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
90     if (CE->getOpcode() == Instruction::Cast) {
91       Constant *Op = const_cast<Constant*>(CE->getOperand(0));
92       // Try to not produce a cast of a cast, which is almost always redundant.
93       if (!Op->getType()->isFloatingPoint() &&
94           !CE->getType()->isFloatingPoint() &&
95           !DestTy->getType()->isFloatingPoint()) {
96         unsigned S1 = getSize(Op->getType()), S2 = getSize(CE->getType());
97         unsigned S3 = getSize(DestTy);
98         if (Op->getType() == DestTy && S3 >= S2)
99           return Op;
100         if (S1 >= S2 && S2 >= S3)
101           return ConstantExpr::getCast(Op, DestTy);
102         if (S1 <= S2 && S2 >= S3 && S1 <= S3)
103           return ConstantExpr::getCast(Op, DestTy);
104       }
105     } else if (CE->getOpcode() == Instruction::GetElementPtr) {
106       // If all of the indexes in the GEP are null values, there is no pointer
107       // adjustment going on.  We might as well cast the source pointer.
108       bool isAllNull = true;
109       for (unsigned i = 1, e = CE->getNumOperands(); i != e; ++i)
110         if (!CE->getOperand(i)->isNullValue()) {
111           isAllNull = false;
112           break;
113         }
114       if (isAllNull)
115         return ConstantExpr::getCast(CE->getOperand(0), DestTy);
116     }
117
118   return ConstRules::get(*V, *V).castTo(V, DestTy);
119 }
120
121 Constant *llvm::ConstantFoldBinaryInstruction(unsigned Opcode,
122                                               const Constant *V1,
123                                               const Constant *V2) {
124   switch (Opcode) {
125   case Instruction::Add:     return *V1 + *V2;
126   case Instruction::Sub:     return *V1 - *V2;
127   case Instruction::Mul:     return *V1 * *V2;
128   case Instruction::Div:     return *V1 / *V2;
129   case Instruction::Rem:     return *V1 % *V2;
130   case Instruction::And:     return *V1 & *V2;
131   case Instruction::Or:      return *V1 | *V2;
132   case Instruction::Xor:     return *V1 ^ *V2;
133
134   case Instruction::SetEQ:   return *V1 == *V2;
135   case Instruction::SetNE:   return *V1 != *V2;
136   case Instruction::SetLE:   return *V1 <= *V2;
137   case Instruction::SetGE:   return *V1 >= *V2;
138   case Instruction::SetLT:   return *V1 <  *V2;
139   case Instruction::SetGT:   return *V1 >  *V2;
140   }
141   return 0;
142 }
143
144 Constant *llvm::ConstantFoldShiftInstruction(unsigned Opcode,
145                                              const Constant *V1, 
146                                              const Constant *V2) {
147   switch (Opcode) {
148   case Instruction::Shl:     return *V1 << *V2;
149   case Instruction::Shr:     return *V1 >> *V2;
150   default:                   return 0;
151   }
152 }
153
154 Constant *llvm::ConstantFoldGetElementPtr(const Constant *C,
155                                         const std::vector<Constant*> &IdxList) {
156   if (IdxList.size() == 0 ||
157       (IdxList.size() == 1 && IdxList[0]->isNullValue()))
158     return const_cast<Constant*>(C);
159
160   // TODO If C is null and all idx's are null, return null of the right type.
161
162
163   if (ConstantExpr *CE = dyn_cast<ConstantExpr>(const_cast<Constant*>(C))) {
164     // Combine Indices - If the source pointer to this getelementptr instruction
165     // is a getelementptr instruction, combine the indices of the two
166     // getelementptr instructions into a single instruction.
167     //
168     if (CE->getOpcode() == Instruction::GetElementPtr) {
169       const Type *LastTy = 0;
170       for (gep_type_iterator I = gep_type_begin(CE), E = gep_type_end(CE);
171            I != E; ++I)
172         LastTy = *I;
173
174       if (LastTy && isa<ArrayType>(LastTy)) {
175         std::vector<Constant*> NewIndices;
176         NewIndices.reserve(IdxList.size() + CE->getNumOperands());
177         for (unsigned i = 1, e = CE->getNumOperands()-1; i != e; ++i)
178           NewIndices.push_back(cast<Constant>(CE->getOperand(i)));
179
180         // Add the last index of the source with the first index of the new GEP.
181         // Make sure to handle the case when they are actually different types.
182         Constant *Combined =
183           ConstantExpr::get(Instruction::Add,
184                             ConstantExpr::getCast(IdxList[0], Type::LongTy),
185    ConstantExpr::getCast(CE->getOperand(CE->getNumOperands()-1), Type::LongTy));
186                             
187         NewIndices.push_back(Combined);
188         NewIndices.insert(NewIndices.end(), IdxList.begin()+1, IdxList.end());
189         return ConstantExpr::getGetElementPtr(CE->getOperand(0), NewIndices);
190       }
191     }
192
193     // Implement folding of:
194     //    int* getelementptr ([2 x int]* cast ([3 x int]* %X to [2 x int]*),
195     //                        long 0, long 0)
196     // To: int* getelementptr ([3 x int]* %X, long 0, long 0)
197     //
198     if (CE->getOpcode() == Instruction::Cast && IdxList.size() > 1 &&
199         IdxList[0]->isNullValue())
200       if (const PointerType *SPT = 
201           dyn_cast<PointerType>(CE->getOperand(0)->getType()))
202         if (const ArrayType *SAT = dyn_cast<ArrayType>(SPT->getElementType()))
203           if (const ArrayType *CAT =
204               dyn_cast<ArrayType>(cast<PointerType>(C->getType())->getElementType()))
205             if (CAT->getElementType() == SAT->getElementType())
206               return ConstantExpr::getGetElementPtr(
207                       (Constant*)CE->getOperand(0), IdxList);
208   }
209   return 0;
210 }
211
212
213 //===----------------------------------------------------------------------===//
214 //                             TemplateRules Class
215 //===----------------------------------------------------------------------===//
216 //
217 // TemplateRules - Implement a subclass of ConstRules that provides all 
218 // operations as noops.  All other rules classes inherit from this class so 
219 // that if functionality is needed in the future, it can simply be added here 
220 // and to ConstRules without changing anything else...
221 // 
222 // This class also provides subclasses with typesafe implementations of methods
223 // so that don't have to do type casting.
224 //
225 template<class ArgType, class SubClassName>
226 class TemplateRules : public ConstRules {
227
228   //===--------------------------------------------------------------------===//
229   // Redirecting functions that cast to the appropriate types
230   //===--------------------------------------------------------------------===//
231
232   virtual Constant *add(const Constant *V1, const Constant *V2) const { 
233     return SubClassName::Add((const ArgType *)V1, (const ArgType *)V2);  
234   }
235   virtual Constant *sub(const Constant *V1, const Constant *V2) const { 
236     return SubClassName::Sub((const ArgType *)V1, (const ArgType *)V2);  
237   }
238   virtual Constant *mul(const Constant *V1, const Constant *V2) const { 
239     return SubClassName::Mul((const ArgType *)V1, (const ArgType *)V2);  
240   }
241   virtual Constant *div(const Constant *V1, const Constant *V2) const { 
242     return SubClassName::Div((const ArgType *)V1, (const ArgType *)V2);  
243   }
244   virtual Constant *rem(const Constant *V1, const Constant *V2) const { 
245     return SubClassName::Rem((const ArgType *)V1, (const ArgType *)V2);  
246   }
247   virtual Constant *op_and(const Constant *V1, const Constant *V2) const { 
248     return SubClassName::And((const ArgType *)V1, (const ArgType *)V2);  
249   }
250   virtual Constant *op_or(const Constant *V1, const Constant *V2) const { 
251     return SubClassName::Or((const ArgType *)V1, (const ArgType *)V2);  
252   }
253   virtual Constant *op_xor(const Constant *V1, const Constant *V2) const { 
254     return SubClassName::Xor((const ArgType *)V1, (const ArgType *)V2);  
255   }
256   virtual Constant *shl(const Constant *V1, const Constant *V2) const { 
257     return SubClassName::Shl((const ArgType *)V1, (const ArgType *)V2);  
258   }
259   virtual Constant *shr(const Constant *V1, const Constant *V2) const { 
260     return SubClassName::Shr((const ArgType *)V1, (const ArgType *)V2);  
261   }
262
263   virtual ConstantBool *lessthan(const Constant *V1, 
264                                  const Constant *V2) const { 
265     return SubClassName::LessThan((const ArgType *)V1, (const ArgType *)V2);
266   }
267   virtual ConstantBool *equalto(const Constant *V1, 
268                                 const Constant *V2) const { 
269     return SubClassName::EqualTo((const ArgType *)V1, (const ArgType *)V2);
270   }
271
272   // Casting operators.  ick
273   virtual ConstantBool *castToBool(const Constant *V) const {
274     return SubClassName::CastToBool((const ArgType*)V);
275   }
276   virtual ConstantSInt *castToSByte(const Constant *V) const {
277     return SubClassName::CastToSByte((const ArgType*)V);
278   }
279   virtual ConstantUInt *castToUByte(const Constant *V) const {
280     return SubClassName::CastToUByte((const ArgType*)V);
281   }
282   virtual ConstantSInt *castToShort(const Constant *V) const {
283     return SubClassName::CastToShort((const ArgType*)V);
284   }
285   virtual ConstantUInt *castToUShort(const Constant *V) const {
286     return SubClassName::CastToUShort((const ArgType*)V);
287   }
288   virtual ConstantSInt *castToInt(const Constant *V) const {
289     return SubClassName::CastToInt((const ArgType*)V);
290   }
291   virtual ConstantUInt *castToUInt(const Constant *V) const {
292     return SubClassName::CastToUInt((const ArgType*)V);
293   }
294   virtual ConstantSInt *castToLong(const Constant *V) const {
295     return SubClassName::CastToLong((const ArgType*)V);
296   }
297   virtual ConstantUInt *castToULong(const Constant *V) const {
298     return SubClassName::CastToULong((const ArgType*)V);
299   }
300   virtual ConstantFP   *castToFloat(const Constant *V) const {
301     return SubClassName::CastToFloat((const ArgType*)V);
302   }
303   virtual ConstantFP   *castToDouble(const Constant *V) const {
304     return SubClassName::CastToDouble((const ArgType*)V);
305   }
306   virtual Constant *castToPointer(const Constant *V, 
307                                   const PointerType *Ty) const {
308     return SubClassName::CastToPointer((const ArgType*)V, Ty);
309   }
310
311   //===--------------------------------------------------------------------===//
312   // Default "noop" implementations
313   //===--------------------------------------------------------------------===//
314
315   static Constant *Add(const ArgType *V1, const ArgType *V2) { return 0; }
316   static Constant *Sub(const ArgType *V1, const ArgType *V2) { return 0; }
317   static Constant *Mul(const ArgType *V1, const ArgType *V2) { return 0; }
318   static Constant *Div(const ArgType *V1, const ArgType *V2) { return 0; }
319   static Constant *Rem(const ArgType *V1, const ArgType *V2) { return 0; }
320   static Constant *And(const ArgType *V1, const ArgType *V2) { return 0; }
321   static Constant *Or (const ArgType *V1, const ArgType *V2) { return 0; }
322   static Constant *Xor(const ArgType *V1, const ArgType *V2) { return 0; }
323   static Constant *Shl(const ArgType *V1, const ArgType *V2) { return 0; }
324   static Constant *Shr(const ArgType *V1, const ArgType *V2) { return 0; }
325   static ConstantBool *LessThan(const ArgType *V1, const ArgType *V2) {
326     return 0;
327   }
328   static ConstantBool *EqualTo(const ArgType *V1, const ArgType *V2) {
329     return 0;
330   }
331
332   // Casting operators.  ick
333   static ConstantBool *CastToBool  (const Constant *V) { return 0; }
334   static ConstantSInt *CastToSByte (const Constant *V) { return 0; }
335   static ConstantUInt *CastToUByte (const Constant *V) { return 0; }
336   static ConstantSInt *CastToShort (const Constant *V) { return 0; }
337   static ConstantUInt *CastToUShort(const Constant *V) { return 0; }
338   static ConstantSInt *CastToInt   (const Constant *V) { return 0; }
339   static ConstantUInt *CastToUInt  (const Constant *V) { return 0; }
340   static ConstantSInt *CastToLong  (const Constant *V) { return 0; }
341   static ConstantUInt *CastToULong (const Constant *V) { return 0; }
342   static ConstantFP   *CastToFloat (const Constant *V) { return 0; }
343   static ConstantFP   *CastToDouble(const Constant *V) { return 0; }
344   static Constant     *CastToPointer(const Constant *,
345                                      const PointerType *) {return 0;}
346 };
347
348
349
350 //===----------------------------------------------------------------------===//
351 //                             EmptyRules Class
352 //===----------------------------------------------------------------------===//
353 //
354 // EmptyRules provides a concrete base class of ConstRules that does nothing
355 //
356 struct EmptyRules : public TemplateRules<Constant, EmptyRules> {
357   static ConstantBool *EqualTo(const Constant *V1, const Constant *V2) {
358     if (V1 == V2) return ConstantBool::True;
359     return 0;
360   }
361 };
362
363
364
365 //===----------------------------------------------------------------------===//
366 //                              BoolRules Class
367 //===----------------------------------------------------------------------===//
368 //
369 // BoolRules provides a concrete base class of ConstRules for the 'bool' type.
370 //
371 struct BoolRules : public TemplateRules<ConstantBool, BoolRules> {
372
373   static ConstantBool *LessThan(const ConstantBool *V1, const ConstantBool *V2){
374     return ConstantBool::get(V1->getValue() < V2->getValue());
375   }
376
377   static ConstantBool *EqualTo(const Constant *V1, const Constant *V2) {
378     return ConstantBool::get(V1 == V2);
379   }
380
381   static Constant *And(const ConstantBool *V1, const ConstantBool *V2) {
382     return ConstantBool::get(V1->getValue() & V2->getValue());
383   }
384
385   static Constant *Or(const ConstantBool *V1, const ConstantBool *V2) {
386     return ConstantBool::get(V1->getValue() | V2->getValue());
387   }
388
389   static Constant *Xor(const ConstantBool *V1, const ConstantBool *V2) {
390     return ConstantBool::get(V1->getValue() ^ V2->getValue());
391   }
392
393   // Casting operators.  ick
394 #define DEF_CAST(TYPE, CLASS, CTYPE) \
395   static CLASS *CastTo##TYPE  (const ConstantBool *V) {    \
396     return CLASS::get(Type::TYPE##Ty, (CTYPE)(bool)V->getValue()); \
397   }
398
399   DEF_CAST(Bool  , ConstantBool, bool)
400   DEF_CAST(SByte , ConstantSInt, signed char)
401   DEF_CAST(UByte , ConstantUInt, unsigned char)
402   DEF_CAST(Short , ConstantSInt, signed short)
403   DEF_CAST(UShort, ConstantUInt, unsigned short)
404   DEF_CAST(Int   , ConstantSInt, signed int)
405   DEF_CAST(UInt  , ConstantUInt, unsigned int)
406   DEF_CAST(Long  , ConstantSInt, int64_t)
407   DEF_CAST(ULong , ConstantUInt, uint64_t)
408   DEF_CAST(Float , ConstantFP  , float)
409   DEF_CAST(Double, ConstantFP  , double)
410 #undef DEF_CAST
411 };
412
413
414 //===----------------------------------------------------------------------===//
415 //                            NullPointerRules Class
416 //===----------------------------------------------------------------------===//
417 //
418 // NullPointerRules provides a concrete base class of ConstRules for null
419 // pointers.
420 //
421 struct NullPointerRules : public TemplateRules<ConstantPointerNull,
422                                                NullPointerRules> {
423   static ConstantBool *EqualTo(const Constant *V1, const Constant *V2) {
424     return ConstantBool::True;  // Null pointers are always equal
425   }
426   static ConstantBool *CastToBool  (const Constant *V) {
427     return ConstantBool::False;
428   }
429   static ConstantSInt *CastToSByte (const Constant *V) {
430     return ConstantSInt::get(Type::SByteTy, 0);
431   }
432   static ConstantUInt *CastToUByte (const Constant *V) {
433     return ConstantUInt::get(Type::UByteTy, 0);
434   }
435   static ConstantSInt *CastToShort (const Constant *V) {
436     return ConstantSInt::get(Type::ShortTy, 0);
437   }
438   static ConstantUInt *CastToUShort(const Constant *V) {
439     return ConstantUInt::get(Type::UShortTy, 0);
440   }
441   static ConstantSInt *CastToInt   (const Constant *V) {
442     return ConstantSInt::get(Type::IntTy, 0);
443   }
444   static ConstantUInt *CastToUInt  (const Constant *V) {
445     return ConstantUInt::get(Type::UIntTy, 0);
446   }
447   static ConstantSInt *CastToLong  (const Constant *V) {
448     return ConstantSInt::get(Type::LongTy, 0);
449   }
450   static ConstantUInt *CastToULong (const Constant *V) {
451     return ConstantUInt::get(Type::ULongTy, 0);
452   }
453   static ConstantFP   *CastToFloat (const Constant *V) {
454     return ConstantFP::get(Type::FloatTy, 0);
455   }
456   static ConstantFP   *CastToDouble(const Constant *V) {
457     return ConstantFP::get(Type::DoubleTy, 0);
458   }
459
460   static Constant *CastToPointer(const ConstantPointerNull *V,
461                                  const PointerType *PTy) {
462     return ConstantPointerNull::get(PTy);
463   }
464 };
465
466
467 //===----------------------------------------------------------------------===//
468 //                             DirectRules Class
469 //===----------------------------------------------------------------------===//
470 //
471 // DirectRules provides a concrete base classes of ConstRules for a variety of
472 // different types.  This allows the C++ compiler to automatically generate our
473 // constant handling operations in a typesafe and accurate manner.
474 //
475 template<class ConstantClass, class BuiltinType, Type **Ty, class SuperClass>
476 struct DirectRules : public TemplateRules<ConstantClass, SuperClass> {
477   static Constant *Add(const ConstantClass *V1, const ConstantClass *V2) {
478     BuiltinType R = (BuiltinType)V1->getValue() + (BuiltinType)V2->getValue();
479     return ConstantClass::get(*Ty, R);
480   }
481
482   static Constant *Sub(const ConstantClass *V1, const ConstantClass *V2) {
483     BuiltinType R = (BuiltinType)V1->getValue() - (BuiltinType)V2->getValue();
484     return ConstantClass::get(*Ty, R);
485   }
486
487   static Constant *Mul(const ConstantClass *V1, const ConstantClass *V2) {
488     BuiltinType R = (BuiltinType)V1->getValue() * (BuiltinType)V2->getValue();
489     return ConstantClass::get(*Ty, R);
490   }
491
492   static Constant *Div(const ConstantClass *V1, const ConstantClass *V2) {
493     if (V2->isNullValue()) return 0;
494     BuiltinType R = (BuiltinType)V1->getValue() / (BuiltinType)V2->getValue();
495     return ConstantClass::get(*Ty, R);
496   }
497
498   static ConstantBool *LessThan(const ConstantClass *V1,
499                                 const ConstantClass *V2) {
500     bool R = (BuiltinType)V1->getValue() < (BuiltinType)V2->getValue();
501     return ConstantBool::get(R);
502   } 
503
504   static ConstantBool *EqualTo(const ConstantClass *V1,
505                                const ConstantClass *V2) {
506     bool R = (BuiltinType)V1->getValue() == (BuiltinType)V2->getValue();
507     return ConstantBool::get(R);
508   }
509
510   static Constant *CastToPointer(const ConstantClass *V,
511                                  const PointerType *PTy) {
512     if (V->isNullValue())    // Is it a FP or Integral null value?
513       return ConstantPointerNull::get(PTy);
514     return 0;  // Can't const prop other types of pointers
515   }
516
517   // Casting operators.  ick
518 #define DEF_CAST(TYPE, CLASS, CTYPE) \
519   static CLASS *CastTo##TYPE  (const ConstantClass *V) {    \
520     return CLASS::get(Type::TYPE##Ty, (CTYPE)(BuiltinType)V->getValue()); \
521   }
522
523   DEF_CAST(Bool  , ConstantBool, bool)
524   DEF_CAST(SByte , ConstantSInt, signed char)
525   DEF_CAST(UByte , ConstantUInt, unsigned char)
526   DEF_CAST(Short , ConstantSInt, signed short)
527   DEF_CAST(UShort, ConstantUInt, unsigned short)
528   DEF_CAST(Int   , ConstantSInt, signed int)
529   DEF_CAST(UInt  , ConstantUInt, unsigned int)
530   DEF_CAST(Long  , ConstantSInt, int64_t)
531   DEF_CAST(ULong , ConstantUInt, uint64_t)
532   DEF_CAST(Float , ConstantFP  , float)
533   DEF_CAST(Double, ConstantFP  , double)
534 #undef DEF_CAST
535 };
536
537
538 //===----------------------------------------------------------------------===//
539 //                           DirectIntRules Class
540 //===----------------------------------------------------------------------===//
541 //
542 // DirectIntRules provides implementations of functions that are valid on
543 // integer types, but not all types in general.
544 //
545 template <class ConstantClass, class BuiltinType, Type **Ty>
546 struct DirectIntRules
547   : public DirectRules<ConstantClass, BuiltinType, Ty,
548                        DirectIntRules<ConstantClass, BuiltinType, Ty> > {
549
550   static Constant *Div(const ConstantClass *V1, const ConstantClass *V2) {
551     if (V2->isNullValue()) return 0;
552     if (V2->isAllOnesValue() &&              // MIN_INT / -1
553         (BuiltinType)V1->getValue() == -(BuiltinType)V1->getValue())
554       return 0;
555     BuiltinType R = (BuiltinType)V1->getValue() / (BuiltinType)V2->getValue();
556     return ConstantClass::get(*Ty, R);
557   }
558
559   static Constant *Rem(const ConstantClass *V1,
560                        const ConstantClass *V2) {
561     if (V2->isNullValue()) return 0;         // X / 0
562     if (V2->isAllOnesValue() &&              // MIN_INT / -1
563         (BuiltinType)V1->getValue() == -(BuiltinType)V1->getValue())
564       return 0;
565     BuiltinType R = (BuiltinType)V1->getValue() % (BuiltinType)V2->getValue();
566     return ConstantClass::get(*Ty, R);
567   }
568
569   static Constant *And(const ConstantClass *V1, const ConstantClass *V2) {
570     BuiltinType R = (BuiltinType)V1->getValue() & (BuiltinType)V2->getValue();
571     return ConstantClass::get(*Ty, R);
572   }
573   static Constant *Or(const ConstantClass *V1, const ConstantClass *V2) {
574     BuiltinType R = (BuiltinType)V1->getValue() | (BuiltinType)V2->getValue();
575     return ConstantClass::get(*Ty, R);
576   }
577   static Constant *Xor(const ConstantClass *V1, const ConstantClass *V2) {
578     BuiltinType R = (BuiltinType)V1->getValue() ^ (BuiltinType)V2->getValue();
579     return ConstantClass::get(*Ty, R);
580   }
581
582   static Constant *Shl(const ConstantClass *V1, const ConstantClass *V2) {
583     BuiltinType R = (BuiltinType)V1->getValue() << (BuiltinType)V2->getValue();
584     return ConstantClass::get(*Ty, R);
585   }
586
587   static Constant *Shr(const ConstantClass *V1, const ConstantClass *V2) {
588     BuiltinType R = (BuiltinType)V1->getValue() >> (BuiltinType)V2->getValue();
589     return ConstantClass::get(*Ty, R);
590   }
591 };
592
593
594 //===----------------------------------------------------------------------===//
595 //                           DirectFPRules Class
596 //===----------------------------------------------------------------------===//
597 //
598 // DirectFPRules provides implementations of functions that are valid on
599 // floating point types, but not all types in general.
600 //
601 template <class ConstantClass, class BuiltinType, Type **Ty>
602 struct DirectFPRules
603   : public DirectRules<ConstantClass, BuiltinType, Ty,
604                        DirectFPRules<ConstantClass, BuiltinType, Ty> > {
605   static Constant *Rem(const ConstantClass *V1, const ConstantClass *V2) {
606     if (V2->isNullValue()) return 0;
607     BuiltinType Result = std::fmod((BuiltinType)V1->getValue(),
608                                    (BuiltinType)V2->getValue());
609     return ConstantClass::get(*Ty, Result);
610   }
611 };
612
613 ConstRules &ConstRules::get(const Constant &V1, const Constant &V2) {
614   static EmptyRules       EmptyR;
615   static BoolRules        BoolR;
616   static NullPointerRules NullPointerR;
617   static DirectIntRules<ConstantSInt,   signed char , &Type::SByteTy>  SByteR;
618   static DirectIntRules<ConstantUInt, unsigned char , &Type::UByteTy>  UByteR;
619   static DirectIntRules<ConstantSInt,   signed short, &Type::ShortTy>  ShortR;
620   static DirectIntRules<ConstantUInt, unsigned short, &Type::UShortTy> UShortR;
621   static DirectIntRules<ConstantSInt,   signed int  , &Type::IntTy>    IntR;
622   static DirectIntRules<ConstantUInt, unsigned int  , &Type::UIntTy>   UIntR;
623   static DirectIntRules<ConstantSInt,  int64_t      , &Type::LongTy>   LongR;
624   static DirectIntRules<ConstantUInt, uint64_t      , &Type::ULongTy>  ULongR;
625   static DirectFPRules <ConstantFP  , float         , &Type::FloatTy>  FloatR;
626   static DirectFPRules <ConstantFP  , double        , &Type::DoubleTy> DoubleR;
627
628   if (isa<ConstantExpr>(V1) || isa<ConstantExpr>(V2) ||
629       isa<ConstantPointerRef>(V1) || isa<ConstantPointerRef>(V2))
630     return EmptyR;
631
632   // FIXME: This assert doesn't work because shifts pass both operands in to
633   // check for constant exprs.  :(
634   //assert(V1.getType() == V2.getType() &&"Nonequal types to constant folder?");
635
636   switch (V1.getType()->getPrimitiveID()) {
637   default: assert(0 && "Unknown value type for constant folding!");
638   case Type::BoolTyID:    return BoolR;
639   case Type::PointerTyID: return NullPointerR;
640   case Type::SByteTyID:   return SByteR;
641   case Type::UByteTyID:   return UByteR;
642   case Type::ShortTyID:   return ShortR;
643   case Type::UShortTyID:  return UShortR;
644   case Type::IntTyID:     return IntR;
645   case Type::UIntTyID:    return UIntR;
646   case Type::LongTyID:    return LongR;
647   case Type::ULongTyID:   return ULongR;
648   case Type::FloatTyID:   return FloatR;
649   case Type::DoubleTyID:  return DoubleR;
650   }
651 }