[LIR] Make the LoopIdiomRecognize pass get analyses essentially the same
[oota-llvm.git] / lib / Transforms / Scalar / Float2Int.cpp
1 //===- Float2Int.cpp - Demote floating point ops to work on integers ------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file implements the Float2Int pass, which aims to demote floating
11 // point operations to work on integers, where that is losslessly possible.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #define DEBUG_TYPE "float2int"
16 #include "llvm/ADT/APInt.h"
17 #include "llvm/ADT/APSInt.h"
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/EquivalenceClasses.h"
20 #include "llvm/ADT/MapVector.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/Analysis/AliasAnalysis.h"
23 #include "llvm/IR/ConstantRange.h"
24 #include "llvm/IR/Constants.h"
25 #include "llvm/IR/IRBuilder.h"
26 #include "llvm/IR/InstIterator.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/Module.h"
29 #include "llvm/Pass.h"
30 #include "llvm/Support/Debug.h"
31 #include "llvm/Support/raw_ostream.h"
32 #include "llvm/Transforms/Scalar.h"
33 #include <deque>
34 #include <functional> // For std::function
35 using namespace llvm;
36
37 // The algorithm is simple. Start at instructions that convert from the
38 // float to the int domain: fptoui, fptosi and fcmp. Walk up the def-use
39 // graph, using an equivalence datastructure to unify graphs that interfere.
40 //
41 // Mappable instructions are those with an integer corrollary that, given
42 // integer domain inputs, produce an integer output; fadd, for example.
43 //
44 // If a non-mappable instruction is seen, this entire def-use graph is marked
45 // as non-transformable. If we see an instruction that converts from the 
46 // integer domain to FP domain (uitofp,sitofp), we terminate our walk.
47
48 /// The largest integer type worth dealing with.
49 static cl::opt<unsigned>
50 MaxIntegerBW("float2int-max-integer-bw", cl::init(64), cl::Hidden,
51              cl::desc("Max integer bitwidth to consider in float2int"
52                       "(default=64)"));
53
54 namespace {
55   struct Float2Int : public FunctionPass {
56     static char ID; // Pass identification, replacement for typeid
57     Float2Int() : FunctionPass(ID) {
58       initializeFloat2IntPass(*PassRegistry::getPassRegistry());
59     }
60
61     bool runOnFunction(Function &F) override;
62     void getAnalysisUsage(AnalysisUsage &AU) const override {
63       AU.setPreservesCFG();
64       AU.addPreserved<AliasAnalysis>();
65     }
66
67     void findRoots(Function &F, SmallPtrSet<Instruction*,8> &Roots);
68     ConstantRange seen(Instruction *I, ConstantRange R);
69     ConstantRange badRange();
70     ConstantRange unknownRange();
71     ConstantRange validateRange(ConstantRange R);
72     void walkBackwards(const SmallPtrSetImpl<Instruction*> &Roots);
73     void walkForwards();
74     bool validateAndTransform();
75     Value *convert(Instruction *I, Type *ToTy);
76     void cleanup();
77
78     MapVector<Instruction*, ConstantRange > SeenInsts;
79     SmallPtrSet<Instruction*,8> Roots;
80     EquivalenceClasses<Instruction*> ECs;
81     MapVector<Instruction*, Value*> ConvertedInsts;
82     LLVMContext *Ctx;
83   };
84 }
85
86 char Float2Int::ID = 0;
87 INITIALIZE_PASS(Float2Int, "float2int", "Float to int", false, false)
88
89 // Given a FCmp predicate, return a matching ICmp predicate if one
90 // exists, otherwise return BAD_ICMP_PREDICATE.
91 static CmpInst::Predicate mapFCmpPred(CmpInst::Predicate P) {
92   switch (P) {
93   case CmpInst::FCMP_OEQ:
94   case CmpInst::FCMP_UEQ:
95     return CmpInst::ICMP_EQ;
96   case CmpInst::FCMP_OGT:
97   case CmpInst::FCMP_UGT:
98     return CmpInst::ICMP_SGT;
99   case CmpInst::FCMP_OGE:
100   case CmpInst::FCMP_UGE:
101     return CmpInst::ICMP_SGE;
102   case CmpInst::FCMP_OLT:
103   case CmpInst::FCMP_ULT:
104     return CmpInst::ICMP_SLT;
105   case CmpInst::FCMP_OLE:
106   case CmpInst::FCMP_ULE:
107     return CmpInst::ICMP_SLE;
108   case CmpInst::FCMP_ONE:
109   case CmpInst::FCMP_UNE:
110     return CmpInst::ICMP_NE;
111   default:
112     return CmpInst::BAD_ICMP_PREDICATE;
113   }
114 }
115
116 // Given a floating point binary operator, return the matching
117 // integer version.
118 static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) {
119   switch (Opcode) {
120   default: llvm_unreachable("Unhandled opcode!");
121   case Instruction::FAdd: return Instruction::Add;
122   case Instruction::FSub: return Instruction::Sub;
123   case Instruction::FMul: return Instruction::Mul;
124   }
125 }
126
127 // Find the roots - instructions that convert from the FP domain to
128 // integer domain.
129 void Float2Int::findRoots(Function &F, SmallPtrSet<Instruction*,8> &Roots) {
130   for (auto &I : instructions(F)) {
131     switch (I.getOpcode()) {
132     default: break;
133     case Instruction::FPToUI:
134     case Instruction::FPToSI:
135       Roots.insert(&I);
136       break;
137     case Instruction::FCmp:
138       if (mapFCmpPred(cast<CmpInst>(&I)->getPredicate()) != 
139           CmpInst::BAD_ICMP_PREDICATE)
140         Roots.insert(&I);
141       break;
142     }
143   }
144 }
145
146 // Helper - mark I as having been traversed, having range R.
147 ConstantRange Float2Int::seen(Instruction *I, ConstantRange R) {
148   DEBUG(dbgs() << "F2I: " << *I << ":" << R << "\n");
149   if (SeenInsts.find(I) != SeenInsts.end())
150     SeenInsts.find(I)->second = R;
151   else
152     SeenInsts.insert(std::make_pair(I, R));
153   return R;
154 }
155
156 // Helper - get a range representing a poison value.
157 ConstantRange Float2Int::badRange() {
158   return ConstantRange(MaxIntegerBW + 1, true);
159 }
160 ConstantRange Float2Int::unknownRange() {
161   return ConstantRange(MaxIntegerBW + 1, false);
162 }
163 ConstantRange Float2Int::validateRange(ConstantRange R) {
164   if (R.getBitWidth() > MaxIntegerBW + 1)
165     return badRange();
166   return R;
167 }
168
169 // The most obvious way to structure the search is a depth-first, eager
170 // search from each root. However, that require direct recursion and so
171 // can only handle small instruction sequences. Instead, we split the search
172 // up into two phases:
173 //   - walkBackwards:  A breadth-first walk of the use-def graph starting from
174 //                     the roots. Populate "SeenInsts" with interesting
175 //                     instructions and poison values if they're obvious and
176 //                     cheap to compute. Calculate the equivalance set structure
177 //                     while we're here too.
178 //   - walkForwards:  Iterate over SeenInsts in reverse order, so we visit
179 //                     defs before their uses. Calculate the real range info.
180
181 // Breadth-first walk of the use-def graph; determine the set of nodes 
182 // we care about and eagerly determine if some of them are poisonous.
183 void Float2Int::walkBackwards(const SmallPtrSetImpl<Instruction*> &Roots) {
184   std::deque<Instruction*> Worklist(Roots.begin(), Roots.end());
185   while (!Worklist.empty()) {
186     Instruction *I = Worklist.back();
187     Worklist.pop_back();
188
189     if (SeenInsts.find(I) != SeenInsts.end())
190       // Seen already.
191       continue;
192
193     switch (I->getOpcode()) {
194       // FIXME: Handle select and phi nodes.
195     default:
196       // Path terminated uncleanly.
197       seen(I, badRange());
198       break;
199
200     case Instruction::UIToFP: {
201       // Path terminated cleanly.
202       unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits();
203       APInt Min = APInt::getMinValue(BW).zextOrSelf(MaxIntegerBW+1);
204       APInt Max = APInt::getMaxValue(BW).zextOrSelf(MaxIntegerBW+1);
205       seen(I, validateRange(ConstantRange(Min, Max)));
206       continue;
207     }
208
209     case Instruction::SIToFP: {
210       // Path terminated cleanly.
211       unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits();
212       APInt SMin = APInt::getSignedMinValue(BW).sextOrSelf(MaxIntegerBW+1);
213       APInt SMax = APInt::getSignedMaxValue(BW).sextOrSelf(MaxIntegerBW+1);
214       seen(I, validateRange(ConstantRange(SMin, SMax)));
215       continue;
216     }
217
218     case Instruction::FAdd:
219     case Instruction::FSub:
220     case Instruction::FMul:
221     case Instruction::FPToUI:
222     case Instruction::FPToSI:
223     case Instruction::FCmp:
224       seen(I, unknownRange());
225       break;
226     }
227   
228     for (Value *O : I->operands()) {
229       if (Instruction *OI = dyn_cast<Instruction>(O)) {
230         // Unify def-use chains if they interfere.
231         ECs.unionSets(I, OI);
232         if (SeenInsts.find(I)->second != badRange())
233           Worklist.push_back(OI);
234       } else if (!isa<ConstantFP>(O)) {      
235         // Not an instruction or ConstantFP? we can't do anything.
236         seen(I, badRange());
237       }
238     }
239   }
240 }
241
242 // Walk forwards down the list of seen instructions, so we visit defs before
243 // uses.
244 void Float2Int::walkForwards() {
245   for (auto &It : make_range(SeenInsts.rbegin(), SeenInsts.rend())) {
246     if (It.second != unknownRange())
247       continue;
248
249     Instruction *I = It.first;
250     std::function<ConstantRange(ArrayRef<ConstantRange>)> Op;
251     switch (I->getOpcode()) {
252       // FIXME: Handle select and phi nodes.
253     default:
254     case Instruction::UIToFP:
255     case Instruction::SIToFP:
256       llvm_unreachable("Should have been handled in walkForwards!");
257
258     case Instruction::FAdd:
259       Op = [](ArrayRef<ConstantRange> Ops) {
260         assert(Ops.size() == 2 && "FAdd is a binary operator!");
261         return Ops[0].add(Ops[1]);
262       };
263       break;
264
265     case Instruction::FSub:
266       Op = [](ArrayRef<ConstantRange> Ops) {
267         assert(Ops.size() == 2 && "FSub is a binary operator!");
268         return Ops[0].sub(Ops[1]);
269       };
270       break;
271
272     case Instruction::FMul:
273       Op = [](ArrayRef<ConstantRange> Ops) {
274         assert(Ops.size() == 2 && "FMul is a binary operator!");
275         return Ops[0].multiply(Ops[1]);
276       };
277       break;
278
279     //
280     // Root-only instructions - we'll only see these if they're the
281     //                          first node in a walk.
282     //
283     case Instruction::FPToUI:
284     case Instruction::FPToSI:
285       Op = [](ArrayRef<ConstantRange> Ops) {
286         assert(Ops.size() == 1 && "FPTo[US]I is a unary operator!");
287         return Ops[0];
288       };
289       break;
290
291     case Instruction::FCmp:
292       Op = [](ArrayRef<ConstantRange> Ops) {
293         assert(Ops.size() == 2 && "FCmp is a binary operator!");
294         return Ops[0].unionWith(Ops[1]);
295       };
296       break;
297     }
298
299     bool Abort = false;
300     SmallVector<ConstantRange,4> OpRanges;
301     for (Value *O : I->operands()) {
302       if (Instruction *OI = dyn_cast<Instruction>(O)) {
303         assert(SeenInsts.find(OI) != SeenInsts.end() &&
304                "def not seen before use!");
305         OpRanges.push_back(SeenInsts.find(OI)->second);
306       } else if (ConstantFP *CF = dyn_cast<ConstantFP>(O)) {
307         // Work out if the floating point number can be losslessly represented
308         // as an integer.
309         // APFloat::convertToInteger(&Exact) purports to do what we want, but
310         // the exactness can be too precise. For example, negative zero can
311         // never be exactly converted to an integer.
312         //
313         // Instead, we ask APFloat to round itself to an integral value - this
314         // preserves sign-of-zero - then compare the result with the original.
315         //
316         APFloat F = CF->getValueAPF();
317
318         // First, weed out obviously incorrect values. Non-finite numbers
319         // can't be represented and neither can negative zero, unless 
320         // we're in fast math mode.
321         if (!F.isFinite() ||
322             (F.isZero() && F.isNegative() && isa<FPMathOperator>(I) &&
323              !I->hasNoSignedZeros())) {
324           seen(I, badRange());
325           Abort = true;
326           break;
327         }
328
329         APFloat NewF = F;
330         auto Res = NewF.roundToIntegral(APFloat::rmNearestTiesToEven);
331         if (Res != APFloat::opOK || NewF.compare(F) != APFloat::cmpEqual) {
332           seen(I, badRange());
333           Abort = true;
334           break;
335         }
336         // OK, it's representable. Now get it.
337         APSInt Int(MaxIntegerBW+1, false);
338         bool Exact;
339         CF->getValueAPF().convertToInteger(Int,
340                                            APFloat::rmNearestTiesToEven,
341                                            &Exact);
342         OpRanges.push_back(ConstantRange(Int));
343       } else {
344         llvm_unreachable("Should have already marked this as badRange!");
345       }
346     }
347
348     // Reduce the operands' ranges to a single range and return.
349     if (!Abort)
350       seen(I, Op(OpRanges));    
351   }
352 }
353
354 // If there is a valid transform to be done, do it.
355 bool Float2Int::validateAndTransform() {
356   bool MadeChange = false;
357
358   // Iterate over every disjoint partition of the def-use graph.
359   for (auto It = ECs.begin(), E = ECs.end(); It != E; ++It) {
360     ConstantRange R(MaxIntegerBW + 1, false);
361     bool Fail = false;
362     Type *ConvertedToTy = nullptr;
363
364     // For every member of the partition, union all the ranges together.
365     for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
366          MI != ME; ++MI) {
367       Instruction *I = *MI;
368       auto SeenI = SeenInsts.find(I);
369       if (SeenI == SeenInsts.end())
370         continue;
371
372       R = R.unionWith(SeenI->second);
373       // We need to ensure I has no users that have not been seen.
374       // If it does, transformation would be illegal.
375       //
376       // Don't count the roots, as they terminate the graphs.
377       if (Roots.count(I) == 0) {
378         // Set the type of the conversion while we're here.
379         if (!ConvertedToTy)
380           ConvertedToTy = I->getType();
381         for (User *U : I->users()) {
382           Instruction *UI = dyn_cast<Instruction>(U);
383           if (!UI || SeenInsts.find(UI) == SeenInsts.end()) {
384             DEBUG(dbgs() << "F2I: Failing because of " << *U << "\n");
385             Fail = true;
386             break;
387           }
388         }
389       }
390       if (Fail)
391         break;
392     }
393
394     // If the set was empty, or we failed, or the range is poisonous,
395     // bail out.
396     if (ECs.member_begin(It) == ECs.member_end() || Fail ||
397         R.isFullSet() || R.isSignWrappedSet())
398       continue;
399     assert(ConvertedToTy && "Must have set the convertedtoty by this point!");
400     
401     // The number of bits required is the maximum of the upper and
402     // lower limits, plus one so it can be signed.
403     unsigned MinBW = std::max(R.getLower().getMinSignedBits(),
404                               R.getUpper().getMinSignedBits()) + 1;
405     DEBUG(dbgs() << "F2I: MinBitwidth=" << MinBW << ", R: " << R << "\n");
406
407     // If we've run off the realms of the exactly representable integers,
408     // the floating point result will differ from an integer approximation.
409
410     // Do we need more bits than are in the mantissa of the type we converted
411     // to? semanticsPrecision returns the number of mantissa bits plus one
412     // for the sign bit.
413     unsigned MaxRepresentableBits
414       = APFloat::semanticsPrecision(ConvertedToTy->getFltSemantics()) - 1;
415     if (MinBW > MaxRepresentableBits) {
416       DEBUG(dbgs() << "F2I: Value not guaranteed to be representable!\n");
417       continue;
418     }
419     if (MinBW > 64) {
420       DEBUG(dbgs() << "F2I: Value requires more than 64 bits to represent!\n");
421       continue;
422     }
423
424     // OK, R is known to be representable. Now pick a type for it.
425     // FIXME: Pick the smallest legal type that will fit.
426     Type *Ty = (MinBW > 32) ? Type::getInt64Ty(*Ctx) : Type::getInt32Ty(*Ctx);
427
428     for (auto MI = ECs.member_begin(It), ME = ECs.member_end();
429          MI != ME; ++MI)
430       convert(*MI, Ty);
431     MadeChange = true;
432   }
433
434   return MadeChange;
435 }
436
437 Value *Float2Int::convert(Instruction *I, Type *ToTy) {
438   if (ConvertedInsts.find(I) != ConvertedInsts.end())
439     // Already converted this instruction.
440     return ConvertedInsts[I];
441
442   SmallVector<Value*,4> NewOperands;
443   for (Value *V : I->operands()) {
444     // Don't recurse if we're an instruction that terminates the path.
445     if (I->getOpcode() == Instruction::UIToFP ||
446         I->getOpcode() == Instruction::SIToFP) {
447       NewOperands.push_back(V);
448     } else if (Instruction *VI = dyn_cast<Instruction>(V)) {
449       NewOperands.push_back(convert(VI, ToTy));
450     } else if (ConstantFP *CF = dyn_cast<ConstantFP>(V)) {
451       APSInt Val(ToTy->getPrimitiveSizeInBits(), /*IsUnsigned=*/false);
452       bool Exact;
453       CF->getValueAPF().convertToInteger(Val,
454                                          APFloat::rmNearestTiesToEven,
455                                          &Exact);
456       NewOperands.push_back(ConstantInt::get(ToTy, Val));
457     } else {
458       llvm_unreachable("Unhandled operand type?");
459     }
460   }
461
462   // Now create a new instruction.
463   IRBuilder<> IRB(I);
464   Value *NewV = nullptr;
465   switch (I->getOpcode()) {
466   default: llvm_unreachable("Unhandled instruction!");
467
468   case Instruction::FPToUI:
469     NewV = IRB.CreateZExtOrTrunc(NewOperands[0], I->getType());
470     break;
471
472   case Instruction::FPToSI:
473     NewV = IRB.CreateSExtOrTrunc(NewOperands[0], I->getType());
474     break;
475
476   case Instruction::FCmp: {
477     CmpInst::Predicate P = mapFCmpPred(cast<CmpInst>(I)->getPredicate());
478     assert(P != CmpInst::BAD_ICMP_PREDICATE && "Unhandled predicate!");
479     NewV = IRB.CreateICmp(P, NewOperands[0], NewOperands[1], I->getName());
480     break;
481   }
482
483   case Instruction::UIToFP:
484     NewV = IRB.CreateZExtOrTrunc(NewOperands[0], ToTy);
485     break;
486
487   case Instruction::SIToFP:
488     NewV = IRB.CreateSExtOrTrunc(NewOperands[0], ToTy);
489     break;
490
491   case Instruction::FAdd:
492   case Instruction::FSub:
493   case Instruction::FMul:
494     NewV = IRB.CreateBinOp(mapBinOpcode(I->getOpcode()),
495                            NewOperands[0], NewOperands[1],
496                            I->getName());
497     break;
498   }
499
500   // If we're a root instruction, RAUW.
501   if (Roots.count(I))
502     I->replaceAllUsesWith(NewV);
503
504   ConvertedInsts[I] = NewV;
505   return NewV;
506 }
507
508 // Perform dead code elimination on the instructions we just modified.
509 void Float2Int::cleanup() {
510   for (auto &I : make_range(ConvertedInsts.rbegin(), ConvertedInsts.rend()))
511     I.first->eraseFromParent();
512 }
513
514 bool Float2Int::runOnFunction(Function &F) {
515   if (skipOptnoneFunction(F))
516     return false;
517
518   DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n");
519   // Clear out all state.
520   ECs = EquivalenceClasses<Instruction*>();
521   SeenInsts.clear();
522   ConvertedInsts.clear();
523   Roots.clear();
524
525   Ctx = &F.getParent()->getContext();
526
527   findRoots(F, Roots);
528
529   walkBackwards(Roots);
530   walkForwards();
531
532   bool Modified = validateAndTransform();
533   if (Modified)
534     cleanup();
535   return Modified;
536 }
537
538 FunctionPass *llvm::createFloat2IntPass() {
539   return new Float2Int();
540 }
541