move a bunch of constant folding code f rom Transforms/Utils/Local.cpp into
[oota-llvm.git] / lib / Analysis / ConstantFolding.cpp
1 //===-- ConstantFolding.cpp - Analyze constant folding possibilities ------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by the LLVM research group and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This family of functions determines the possibility of performing constant
11 // folding.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "llvm/Analysis/ConstantFolding.h"
16 #include "llvm/Constants.h"
17 #include "llvm/DerivedTypes.h"
18 #include "llvm/Function.h"
19 #include "llvm/Instructions.h"
20 #include "llvm/Intrinsics.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Support/GetElementPtrTypeIterator.h"
23 #include "llvm/Support/MathExtras.h"
24 #include <cerrno>
25 #include <cmath>
26 using namespace llvm;
27
28 /// ConstantFoldInstruction - Attempt to constant fold the specified
29 /// instruction.  If successful, the constant result is returned, if not, null
30 /// is returned.  Note that this function can only fail when attempting to fold
31 /// instructions like loads and stores, which have no constant expression form.
32 ///
33 Constant *llvm::ConstantFoldInstruction(Instruction *I, const TargetData *TD) {
34   if (PHINode *PN = dyn_cast<PHINode>(I)) {
35     if (PN->getNumIncomingValues() == 0)
36       return Constant::getNullValue(PN->getType());
37
38     Constant *Result = dyn_cast<Constant>(PN->getIncomingValue(0));
39     if (Result == 0) return 0;
40
41     // Handle PHI nodes specially here...
42     for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i)
43       if (PN->getIncomingValue(i) != Result && PN->getIncomingValue(i) != PN)
44         return 0;   // Not all the same incoming constants...
45
46     // If we reach here, all incoming values are the same constant.
47     return Result;
48   }
49
50   // Scan the operand list, checking to see if they are all constants, if so,
51   // hand off to ConstantFoldInstOperands.
52   SmallVector<Constant*, 8> Ops;
53   for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i)
54     if (Constant *Op = dyn_cast<Constant>(I->getOperand(i)))
55       Ops.push_back(Op);
56     else
57       return 0;  // All operands not constant!
58
59   return ConstantFoldInstOperands(I, &Ops[0], Ops.size());
60 }
61
62 /// ConstantFoldInstOperands - Attempt to constant fold an instruction with the
63 /// specified opcode and operands.  If successful, the constant result is
64 /// returned, if not, null is returned.  Note that this function can fail when
65 /// attempting to fold instructions like loads and stores, which have no
66 /// constant expression form.
67 ///
68 Constant *llvm::ConstantFoldInstOperands(const Instruction* I, 
69                                          Constant** Ops, unsigned NumOps,
70                                          const TargetData *TD) {
71   unsigned Opc = I->getOpcode();
72   const Type *DestTy = I->getType();
73
74   // Handle easy binops first
75   if (isa<BinaryOperator>(I))
76     return ConstantExpr::get(Opc, Ops[0], Ops[1]);
77   
78   switch (Opc) {
79   default: return 0;
80   case Instruction::Call:
81     if (Function *F = dyn_cast<Function>(Ops[0]))
82       if (canConstantFoldCallTo(F))
83         return ConstantFoldCall(F, Ops+1, NumOps);
84     return 0;
85   case Instruction::ICmp:
86   case Instruction::FCmp:
87     return ConstantExpr::getCompare(cast<CmpInst>(I)->getPredicate(), Ops[0], 
88                                     Ops[1]);
89   case Instruction::Shl:
90   case Instruction::LShr:
91   case Instruction::AShr:
92     return ConstantExpr::get(Opc, Ops[0], Ops[1]);
93   case Instruction::Trunc:
94   case Instruction::ZExt:
95   case Instruction::SExt:
96   case Instruction::FPTrunc:
97   case Instruction::FPExt:
98   case Instruction::UIToFP:
99   case Instruction::SIToFP:
100   case Instruction::FPToUI:
101   case Instruction::FPToSI:
102   case Instruction::PtrToInt:
103   case Instruction::IntToPtr:
104   case Instruction::BitCast:
105     return ConstantExpr::getCast(Opc, Ops[0], DestTy);
106   case Instruction::Select:
107     return ConstantExpr::getSelect(Ops[0], Ops[1], Ops[2]);
108   case Instruction::ExtractElement:
109     return ConstantExpr::getExtractElement(Ops[0], Ops[1]);
110   case Instruction::InsertElement:
111     return ConstantExpr::getInsertElement(Ops[0], Ops[1], Ops[2]);
112   case Instruction::ShuffleVector:
113     return ConstantExpr::getShuffleVector(Ops[0], Ops[1], Ops[2]);
114   case Instruction::GetElementPtr:
115     return ConstantExpr::getGetElementPtr(Ops[0],
116                                           std::vector<Constant*>(Ops+1, 
117                                                                  Ops+NumOps));
118   }
119 }
120
121 /// ConstantFoldLoadThroughGEPConstantExpr - Given a constant and a
122 /// getelementptr constantexpr, return the constant value being addressed by the
123 /// constant expression, or null if something is funny and we can't decide.
124 Constant *llvm::ConstantFoldLoadThroughGEPConstantExpr(Constant *C, 
125                                                        ConstantExpr *CE) {
126   if (CE->getOperand(1) != Constant::getNullValue(CE->getOperand(1)->getType()))
127     return 0;  // Do not allow stepping over the value!
128   
129   // Loop over all of the operands, tracking down which value we are
130   // addressing...
131   gep_type_iterator I = gep_type_begin(CE), E = gep_type_end(CE);
132   for (++I; I != E; ++I)
133     if (const StructType *STy = dyn_cast<StructType>(*I)) {
134       ConstantInt *CU = cast<ConstantInt>(I.getOperand());
135       assert(CU->getZExtValue() < STy->getNumElements() &&
136              "Struct index out of range!");
137       unsigned El = (unsigned)CU->getZExtValue();
138       if (ConstantStruct *CS = dyn_cast<ConstantStruct>(C)) {
139         C = CS->getOperand(El);
140       } else if (isa<ConstantAggregateZero>(C)) {
141         C = Constant::getNullValue(STy->getElementType(El));
142       } else if (isa<UndefValue>(C)) {
143         C = UndefValue::get(STy->getElementType(El));
144       } else {
145         return 0;
146       }
147     } else if (ConstantInt *CI = dyn_cast<ConstantInt>(I.getOperand())) {
148       if (const ArrayType *ATy = dyn_cast<ArrayType>(*I)) {
149         if (CI->getZExtValue() >= ATy->getNumElements())
150          return 0;
151         if (ConstantArray *CA = dyn_cast<ConstantArray>(C))
152           C = CA->getOperand(CI->getZExtValue());
153         else if (isa<ConstantAggregateZero>(C))
154           C = Constant::getNullValue(ATy->getElementType());
155         else if (isa<UndefValue>(C))
156           C = UndefValue::get(ATy->getElementType());
157         else
158           return 0;
159       } else if (const PackedType *PTy = dyn_cast<PackedType>(*I)) {
160         if (CI->getZExtValue() >= PTy->getNumElements())
161           return 0;
162         if (ConstantPacked *CP = dyn_cast<ConstantPacked>(C))
163           C = CP->getOperand(CI->getZExtValue());
164         else if (isa<ConstantAggregateZero>(C))
165           C = Constant::getNullValue(PTy->getElementType());
166         else if (isa<UndefValue>(C))
167           C = UndefValue::get(PTy->getElementType());
168         else
169           return 0;
170       } else {
171         return 0;
172       }
173     } else {
174       return 0;
175     }
176   return C;
177 }
178
179
180 //===----------------------------------------------------------------------===//
181 //  Constant Folding for Calls
182 //
183
184 /// canConstantFoldCallTo - Return true if its even possible to fold a call to
185 /// the specified function.
186 bool
187 llvm::canConstantFoldCallTo(Function *F) {
188   const std::string &Name = F->getName();
189
190   switch (F->getIntrinsicID()) {
191   case Intrinsic::sqrt_f32:
192   case Intrinsic::sqrt_f64:
193   case Intrinsic::bswap_i16:
194   case Intrinsic::bswap_i32:
195   case Intrinsic::bswap_i64:
196   case Intrinsic::powi_f32:
197   case Intrinsic::powi_f64:
198   // FIXME: these should be constant folded as well
199   //case Intrinsic::ctpop_i8:
200   //case Intrinsic::ctpop_i16:
201   //case Intrinsic::ctpop_i32:
202   //case Intrinsic::ctpop_i64:
203   //case Intrinsic::ctlz_i8:
204   //case Intrinsic::ctlz_i16:
205   //case Intrinsic::ctlz_i32:
206   //case Intrinsic::ctlz_i64:
207   //case Intrinsic::cttz_i8:
208   //case Intrinsic::cttz_i16:
209   //case Intrinsic::cttz_i32:
210   //case Intrinsic::cttz_i64:
211     return true;
212   default: break;
213   }
214
215   switch (Name[0])
216   {
217     case 'a':
218       return Name == "acos" || Name == "asin" || Name == "atan" ||
219              Name == "atan2";
220     case 'c':
221       return Name == "ceil" || Name == "cos" || Name == "cosf" ||
222              Name == "cosh";
223     case 'e':
224       return Name == "exp";
225     case 'f':
226       return Name == "fabs" || Name == "fmod" || Name == "floor";
227     case 'l':
228       return Name == "log" || Name == "log10";
229     case 'p':
230       return Name == "pow";
231     case 's':
232       return Name == "sin" || Name == "sinh" || 
233              Name == "sqrt" || Name == "sqrtf";
234     case 't':
235       return Name == "tan" || Name == "tanh";
236     default:
237       return false;
238   }
239 }
240
241 static Constant *ConstantFoldFP(double (*NativeFP)(double), double V, 
242                                 const Type *Ty) {
243   errno = 0;
244   V = NativeFP(V);
245   if (errno == 0)
246     return ConstantFP::get(Ty, V);
247   errno = 0;
248   return 0;
249 }
250
251 /// ConstantFoldCall - Attempt to constant fold a call to the specified function
252 /// with the specified arguments, returning null if unsuccessful.
253 Constant *
254 llvm::ConstantFoldCall(Function *F, Constant** Operands, unsigned NumOperands) {
255   const std::string &Name = F->getName();
256   const Type *Ty = F->getReturnType();
257
258   if (NumOperands == 1) {
259     if (ConstantFP *Op = dyn_cast<ConstantFP>(Operands[0])) {
260       double V = Op->getValue();
261       switch (Name[0])
262       {
263         case 'a':
264           if (Name == "acos")
265             return ConstantFoldFP(acos, V, Ty);
266           else if (Name == "asin")
267             return ConstantFoldFP(asin, V, Ty);
268           else if (Name == "atan")
269             return ConstantFP::get(Ty, atan(V));
270           break;
271         case 'c':
272           if (Name == "ceil")
273             return ConstantFoldFP(ceil, V, Ty);
274           else if (Name == "cos")
275             return ConstantFP::get(Ty, cos(V));
276           else if (Name == "cosh")
277             return ConstantFP::get(Ty, cosh(V));
278           break;
279         case 'e':
280           if (Name == "exp")
281             return ConstantFP::get(Ty, exp(V));
282           break;
283         case 'f':
284           if (Name == "fabs")
285             return ConstantFP::get(Ty, fabs(V));
286           else if (Name == "floor")
287             return ConstantFoldFP(floor, V, Ty);
288           break;
289         case 'l':
290           if (Name == "log" && V > 0)
291             return ConstantFP::get(Ty, log(V));
292           else if (Name == "log10" && V > 0)
293             return ConstantFoldFP(log10, V, Ty);
294           else if (Name == "llvm.sqrt.f32" || Name == "llvm.sqrt.f64") {
295             if (V >= -0.0)
296               return ConstantFP::get(Ty, sqrt(V));
297             else // Undefined
298               return ConstantFP::get(Ty, 0.0);
299           }
300           break;
301         case 's':
302           if (Name == "sin")
303             return ConstantFP::get(Ty, sin(V));
304           else if (Name == "sinh")
305             return ConstantFP::get(Ty, sinh(V));
306           else if (Name == "sqrt" && V >= 0)
307             return ConstantFP::get(Ty, sqrt(V));
308           else if (Name == "sqrtf" && V >= 0)
309             return ConstantFP::get(Ty, sqrt((float)V));
310           break;
311         case 't':
312           if (Name == "tan")
313             return ConstantFP::get(Ty, tan(V));
314           else if (Name == "tanh")
315             return ConstantFP::get(Ty, tanh(V));
316           break;
317         default:
318           break;
319       }
320     } else if (ConstantInt *Op = dyn_cast<ConstantInt>(Operands[0])) {
321       uint64_t V = Op->getZExtValue();
322       if (Name == "llvm.bswap.i16")
323         return ConstantInt::get(Ty, ByteSwap_16(V));
324       else if (Name == "llvm.bswap.i32")
325         return ConstantInt::get(Ty, ByteSwap_32(V));
326       else if (Name == "llvm.bswap.i64")
327         return ConstantInt::get(Ty, ByteSwap_64(V));
328     }
329   } else if (NumOperands == 2) {
330     if (ConstantFP *Op1 = dyn_cast<ConstantFP>(Operands[0])) {
331       double Op1V = Op1->getValue();
332       if (ConstantFP *Op2 = dyn_cast<ConstantFP>(Operands[1])) {
333         double Op2V = Op2->getValue();
334
335         if (Name == "pow") {
336           errno = 0;
337           double V = pow(Op1V, Op2V);
338           if (errno == 0)
339             return ConstantFP::get(Ty, V);
340         } else if (Name == "fmod") {
341           errno = 0;
342           double V = fmod(Op1V, Op2V);
343           if (errno == 0)
344             return ConstantFP::get(Ty, V);
345         } else if (Name == "atan2") {
346           return ConstantFP::get(Ty, atan2(Op1V,Op2V));
347         }
348       } else if (ConstantInt *Op2C = dyn_cast<ConstantInt>(Operands[1])) {
349         if (Name == "llvm.powi.f32") {
350           return ConstantFP::get(Ty, std::pow((float)Op1V,
351                                               (int)Op2C->getZExtValue()));
352         } else if (Name == "llvm.powi.f64") {
353           return ConstantFP::get(Ty, std::pow((double)Op1V,
354                                               (int)Op2C->getZExtValue()));
355         }
356       }
357     }
358   }
359   return 0;
360 }
361