Encode a cost of zero as a cost of 1.
[oota-llvm.git] / utils / PerfectShuffle / PerfectShuffle.cpp
1 //===-- PerfectShuffle.cpp - Perfect Shuffle Generator --------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file was developed by Chris Lattner and is distributed under
6 // the University of Illinois Open Source License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file computes an optimal sequence of instructions for doing all shuffles
11 // of two 4-element vectors.  With a release build and when configured to emit
12 // an altivec instruction table, this takes about 30s to run on a 2.7Ghz
13 // PowerPC G5.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include <iostream>
18 #include <vector>
19
20 struct Operator;
21
22 // Masks are 4-nibble hex numbers.  Values 0-7 in any nibble means that it takes
23 // an element from that value of the input vectors.  A value of 8 means the 
24 // entry is undefined.
25
26 // Mask manipulation functions.
27 static inline unsigned short MakeMask(unsigned V0, unsigned V1, 
28                                       unsigned V2, unsigned V3) {
29   return (V0 << (3*4)) | (V1 << (2*4)) | (V2 << (1*4)) | (V3 << (0*4));
30 }
31
32 /// getMaskElt - Return element N of the specified mask.
33 static unsigned getMaskElt(unsigned Mask, unsigned Elt) {
34   return (Mask >> ((3-Elt)*4)) & 0xF;
35 }
36
37 static unsigned setMaskElt(unsigned Mask, unsigned Elt, unsigned NewVal) {
38   unsigned FieldShift = ((3-Elt)*4);
39   return (Mask & ~(0xF << FieldShift)) | (NewVal << FieldShift);
40 }
41
42 // Reject elements where the values are 9-15.
43 static bool isValidMask(unsigned short Mask) {
44   unsigned short UndefBits = Mask & 0x8888;
45   return (Mask & ((UndefBits >> 1)|(UndefBits>>2)|(UndefBits>>3))) == 0;
46 }
47
48 /// hasUndefElements - Return true if any of the elements in the mask are undefs
49 ///
50 static bool hasUndefElements(unsigned short Mask) {
51   return (Mask & 0x8888) != 0;
52 }
53
54 /// isOnlyLHSMask - Return true if this mask only refers to its LHS, not
55 /// including undef values..
56 static bool isOnlyLHSMask(unsigned short Mask) {
57   return (Mask & 0x4444) == 0;
58 }
59
60 /// getLHSOnlyMask - Given a mask that refers to its LHS and RHS, modify it to
61 /// refer to the LHS only (for when one argument value is passed into the same
62 /// function twice).
63 static unsigned short getLHSOnlyMask(unsigned short Mask) {
64   return Mask & 0xBBBB;  // Keep only LHS and Undefs.
65 }
66
67 /// getCompressedMask - Turn a 16-bit uncompressed mask (where each elt uses 4
68 /// bits) into a compressed 13-bit mask, where each elt is multiplied by 9.
69 static unsigned getCompressedMask(unsigned short Mask) {
70   return getMaskElt(Mask, 0)*9*9*9 + getMaskElt(Mask, 1)*9*9 + 
71          getMaskElt(Mask, 2)*9     + getMaskElt(Mask, 3);
72 }
73
74 static void PrintMask(unsigned i, std::ostream &OS) {
75   OS << "<" << (char)(getMaskElt(i, 0) == 8 ? 'u' : ('0'+getMaskElt(i, 0)))
76      << "," << (char)(getMaskElt(i, 1) == 8 ? 'u' : ('0'+getMaskElt(i, 1)))
77      << "," << (char)(getMaskElt(i, 2) == 8 ? 'u' : ('0'+getMaskElt(i, 2)))
78      << "," << (char)(getMaskElt(i, 3) == 8 ? 'u' : ('0'+getMaskElt(i, 3)))
79      << ">";
80 }
81
82 /// ShuffleVal - This represents a shufflevector operation.
83 struct ShuffleVal {
84   unsigned Cost;  // Number of instrs used to generate this value.
85   Operator *Op;   // The Operation used to generate this value.
86   unsigned short Arg0, Arg1;  // Input operands for this value.
87   
88   ShuffleVal() : Cost(1000000) {}
89 };
90
91
92 /// ShufTab - This is the actual shuffle table that we are trying to generate.
93 ///
94 static ShuffleVal ShufTab[65536];
95
96 /// TheOperators - All of the operators that this target supports.
97 static std::vector<Operator*> TheOperators;
98
99 /// Operator - This is a vector operation that is available for use.
100 struct Operator {
101   unsigned short ShuffleMask;
102   unsigned short OpNum;
103   const char *Name;
104   
105   Operator(unsigned short shufflemask, const char *name, unsigned opnum)
106     : ShuffleMask(shufflemask), OpNum(opnum), Name(name) {
107     TheOperators.push_back(this);
108   }
109   ~Operator() {
110     assert(TheOperators.back() == this);
111     TheOperators.pop_back();
112   }
113   
114   bool isOnlyLHSOperator() const {
115     return isOnlyLHSMask(ShuffleMask);
116   }
117   
118   const char *getName() const { return Name; }
119   
120   unsigned short getTransformedMask(unsigned short LHSMask, unsigned RHSMask) {
121     // Extract the elements from LHSMask and RHSMask, as appropriate.
122     unsigned Result = 0;
123     for (unsigned i = 0; i != 4; ++i) {
124       unsigned SrcElt = (ShuffleMask >> (4*i)) & 0xF;
125       unsigned ResElt;
126       if (SrcElt < 4)
127         ResElt = getMaskElt(LHSMask, SrcElt);
128       else if (SrcElt < 8)
129         ResElt = getMaskElt(RHSMask, SrcElt-4);
130       else {
131         assert(SrcElt == 8 && "Bad src elt!");
132         ResElt = 8;
133       }
134       Result |= ResElt << (4*i);
135     }
136     return Result;
137   }
138 };
139
140 static const char *getZeroCostOpName(unsigned short Op) {
141   if (ShufTab[Op].Arg0 == 0x0123)
142     return "LHS";
143   else if (ShufTab[Op].Arg0 == 0x4567)
144     return "RHS";
145   else {
146     assert(0 && "bad zero cost operation");
147     abort();
148   }
149 }
150
151 static void PrintOperation(unsigned ValNo, unsigned short Vals[]) {
152   unsigned short ThisOp = Vals[ValNo];
153   std::cerr << "t" << ValNo;
154   PrintMask(ThisOp, std::cerr);
155   std::cerr << " = " << ShufTab[ThisOp].Op->getName() << "(";
156     
157   if (ShufTab[ShufTab[ThisOp].Arg0].Cost == 0) {
158     std::cerr << getZeroCostOpName(ShufTab[ThisOp].Arg0);
159     PrintMask(ShufTab[ThisOp].Arg0, std::cerr);
160   } else {
161     // Figure out what tmp # it is.
162     for (unsigned i = 0; ; ++i)
163       if (Vals[i] == ShufTab[ThisOp].Arg0) {
164         std::cerr << "t" << i;
165         break;
166       }
167   }
168   
169   if (!ShufTab[Vals[ValNo]].Op->isOnlyLHSOperator()) {
170     std::cerr << ", ";
171     if (ShufTab[ShufTab[ThisOp].Arg1].Cost == 0) {
172       std::cerr << getZeroCostOpName(ShufTab[ThisOp].Arg1);
173       PrintMask(ShufTab[ThisOp].Arg1, std::cerr);
174     } else {
175       // Figure out what tmp # it is.
176       for (unsigned i = 0; ; ++i)
177         if (Vals[i] == ShufTab[ThisOp].Arg1) {
178           std::cerr << "t" << i;
179           break;
180         }
181     }
182   }
183   std::cerr << ")  ";
184 }
185
186 static unsigned getNumEntered() {
187   unsigned Count = 0;
188   for (unsigned i = 0; i != 65536; ++i)
189     Count += ShufTab[i].Cost < 100;
190   return Count;
191 }
192
193 static void EvaluateOps(unsigned short Elt, unsigned short Vals[], 
194                         unsigned &NumVals) {
195   if (ShufTab[Elt].Cost == 0) return;
196
197   // If this value has already been evaluated, it is free.  FIXME: match undefs.
198   for (unsigned i = 0, e = NumVals; i != e; ++i)
199     if (Vals[i] == Elt) return;
200   
201   // Otherwise, get the operands of the value, then add it.
202   unsigned Arg0 = ShufTab[Elt].Arg0, Arg1 = ShufTab[Elt].Arg1;
203   if (ShufTab[Arg0].Cost)
204     EvaluateOps(Arg0, Vals, NumVals);
205   if (Arg0 != Arg1 && ShufTab[Arg1].Cost)
206     EvaluateOps(Arg1, Vals, NumVals);
207   
208   Vals[NumVals++] = Elt;
209 }
210
211
212 int main() {
213   // Seed the table with accesses to the LHS and RHS.
214   ShufTab[0x0123].Cost = 0;
215   ShufTab[0x0123].Op = 0;
216   ShufTab[0x0123].Arg0 = 0x0123;
217   ShufTab[0x4567].Cost = 0;
218   ShufTab[0x4567].Op = 0;
219   ShufTab[0x4567].Arg0 = 0x4567;
220   
221   // Seed the first-level of shuffles, shuffles whose inputs are the input to
222   // the vectorshuffle operation.
223   bool MadeChange = true;
224   unsigned OpCount = 0;
225   while (MadeChange) {
226     MadeChange = false;
227     ++OpCount;
228     std::cerr << "Starting iteration #" << OpCount << " with "
229               << getNumEntered() << " entries established.\n";
230     
231     // Scan the table for two reasons: First, compute the maximum cost of any
232     // operation left in the table.  Second, make sure that values with undefs
233     // have the cheapest alternative that they match.
234     unsigned MaxCost = ShufTab[0].Cost;
235     for (unsigned i = 1; i != 0x8889; ++i) {
236       if (!isValidMask(i)) continue;
237       if (ShufTab[i].Cost > MaxCost)
238         MaxCost = ShufTab[i].Cost;
239       
240       // If this value has an undef, make it be computed the cheapest possible
241       // way of any of the things that it matches.
242       if (hasUndefElements(i)) {
243         // This code is a little bit tricky, so here's the idea: consider some
244         // permutation, like 7u4u.  To compute the lowest cost for 7u4u, we
245         // need to take the minimum cost of all of 7[0-8]4[0-8], 81 entries.  If
246         // there are 3 undefs, the number rises to 729 entries we have to scan,
247         // and for the 4 undef case, we have to scan the whole table.
248         //
249         // Instead of doing this huge amount of scanning, we process the table
250         // entries *in order*, and use the fact that 'u' is 8, larger than any
251         // valid index.  Given an entry like 7u4u then, we only need to scan
252         // 7[0-7]4u - 8 entries.  We can get away with this, because we already
253         // know that each of 704u, 714u, 724u, etc contain the minimum value of
254         // all of the 704[0-8], 714[0-8] and 724[0-8] entries respectively.
255         unsigned UndefIdx;
256         if (i & 0x8000)
257           UndefIdx = 0;
258         else if (i & 0x0800)
259           UndefIdx = 1;
260         else if (i & 0x0080)
261           UndefIdx = 2;
262         else if (i & 0x0008)
263           UndefIdx = 3;
264         else
265           abort();
266         
267         unsigned MinVal  = i;
268         unsigned MinCost = ShufTab[i].Cost;
269         
270         // Scan the 8 entries.
271         for (unsigned j = 0; j != 8; ++j) {
272           unsigned NewElt = setMaskElt(i, UndefIdx, j);
273           if (ShufTab[NewElt].Cost < MinCost) {
274             MinCost = ShufTab[NewElt].Cost;
275             MinVal = NewElt;
276           }
277         }
278         
279         // If we found something cheaper than what was here before, use it.
280         if (i != MinVal) {
281           MadeChange = true;
282           ShufTab[i] = ShufTab[MinVal];
283         }
284       } 
285     }
286     
287     for (unsigned LHS = 0; LHS != 0x8889; ++LHS) {
288       if (!isValidMask(LHS)) continue;
289       if (ShufTab[LHS].Cost > 1000) continue;
290
291       // If nothing involving this operand could possibly be cheaper than what
292       // we already have, don't consider it.
293       if (ShufTab[LHS].Cost + 1 >= MaxCost)
294         continue;
295         
296       for (unsigned opnum = 0, e = TheOperators.size(); opnum != e; ++opnum) {
297         Operator *Op = TheOperators[opnum];
298         unsigned short Mask = Op->ShuffleMask;
299
300         // Evaluate op(LHS,LHS)
301         unsigned ResultMask = Op->getTransformedMask(LHS, LHS);
302
303         unsigned Cost = ShufTab[LHS].Cost + 1;
304         if (Cost < ShufTab[ResultMask].Cost) {
305           ShufTab[ResultMask].Cost = Cost;
306           ShufTab[ResultMask].Op = Op;
307           ShufTab[ResultMask].Arg0 = LHS;
308           ShufTab[ResultMask].Arg1 = LHS;
309           MadeChange = true;
310         }
311         
312         // If this is a two input instruction, include the op(x,y) cases.  If
313         // this is a one input instruction, skip this.
314         if (Op->isOnlyLHSOperator()) continue;
315         
316         for (unsigned RHS = 0; RHS != 0x8889; ++RHS) {
317           if (!isValidMask(RHS)) continue;
318           if (ShufTab[RHS].Cost > 1000) continue;
319           
320           // If nothing involving this operand could possibly be cheaper than
321           // what we already have, don't consider it.
322           if (ShufTab[RHS].Cost + 1 >= MaxCost)
323             continue;
324           
325
326           // Evaluate op(LHS,RHS)
327           unsigned ResultMask = Op->getTransformedMask(LHS, RHS);
328
329           if (ShufTab[ResultMask].Cost <= OpCount ||
330               ShufTab[ResultMask].Cost <= ShufTab[LHS].Cost ||
331               ShufTab[ResultMask].Cost <= ShufTab[RHS].Cost)
332             continue;
333           
334           // Figure out the cost to evaluate this, knowing that CSE's only need
335           // to be evaluated once.
336           unsigned short Vals[30];
337           unsigned NumVals = 0;
338           EvaluateOps(LHS, Vals, NumVals);
339           EvaluateOps(RHS, Vals, NumVals);
340
341           unsigned Cost = NumVals + 1;
342           if (Cost < ShufTab[ResultMask].Cost) {
343             ShufTab[ResultMask].Cost = Cost;
344             ShufTab[ResultMask].Op = Op;
345             ShufTab[ResultMask].Arg0 = LHS;
346             ShufTab[ResultMask].Arg1 = RHS;
347             MadeChange = true;
348           }
349         }
350       }
351     }
352   }
353   
354   std::cerr << "Finished Table has " << getNumEntered()
355             << " entries established.\n";
356   
357   unsigned CostArray[10] = { 0 };
358
359   // Compute a cost histogram.
360   for (unsigned i = 0; i != 65536; ++i) {
361     if (!isValidMask(i)) continue;
362     if (ShufTab[i].Cost > 9)
363       ++CostArray[9];
364     else
365       ++CostArray[ShufTab[i].Cost];
366   }
367   
368   for (unsigned i = 0; i != 9; ++i)
369     if (CostArray[i])
370       std::cout << "// " << CostArray[i] << " entries have cost " << i << "\n";
371   if (CostArray[9])
372     std::cout << "// " << CostArray[9] << " entries have higher cost!\n";
373   
374   
375   // Build up the table to emit.
376   std::cout << "\n// This table is 6561*4 = 26244 bytes in size.\n";
377   std::cout << "static const unsigned PerfectShuffleTable[6561+1] = {\n";
378   
379   for (unsigned i = 0; i != 0x8889; ++i) {
380     if (!isValidMask(i)) continue;
381     
382     // CostSat - The cost of this operation saturated to two bits.
383     unsigned CostSat = ShufTab[i].Cost;
384     if (CostSat > 4) CostSat = 4;
385     if (CostSat == 0) CostSat = 1;
386     --CostSat;  // Cost is now between 0-3.
387     
388     unsigned OpNum = ShufTab[i].Op ? ShufTab[i].Op->OpNum : 0;
389     assert(OpNum < 16 && "Too few bits to encode operation!");
390     
391     unsigned LHS = getCompressedMask(ShufTab[i].Arg0);
392     unsigned RHS = getCompressedMask(ShufTab[i].Arg1);
393     
394     // Encode this as 2 bits of saturated cost, 4 bits of opcodes, 13 bits of
395     // LHS, and 13 bits of RHS = 32 bits.
396     unsigned Val = (CostSat << 30) | (OpNum << 26) | (LHS << 13) | RHS;
397
398     std::cout << "  " << Val << "U,\t// ";
399     PrintMask(i, std::cout);
400     std::cout << ": Cost " << ShufTab[i].Cost;
401     std::cout << " " << (ShufTab[i].Op ? ShufTab[i].Op->getName() : "copy");
402     std::cout << " ";
403     if (ShufTab[ShufTab[i].Arg0].Cost == 0) {
404       std::cout << getZeroCostOpName(ShufTab[i].Arg0);
405     } else {
406       PrintMask(ShufTab[i].Arg0, std::cout);
407     }
408
409     if (ShufTab[i].Op && !ShufTab[i].Op->isOnlyLHSOperator()) {
410       std::cout << ", ";
411       if (ShufTab[ShufTab[i].Arg1].Cost == 0) {
412         std::cout << getZeroCostOpName(ShufTab[i].Arg1);
413       } else {
414         PrintMask(ShufTab[i].Arg1, std::cout);
415       }
416     }
417     std::cout << "\n";
418   }  
419   std::cout << "  0\n};\n";
420
421   if (0) {
422     // Print out the table.
423     for (unsigned i = 0; i != 0x8889; ++i) {
424       if (!isValidMask(i)) continue;
425       if (ShufTab[i].Cost < 1000) {
426         PrintMask(i, std::cerr);
427         std::cerr << " - Cost " << ShufTab[i].Cost << " - ";
428         
429         unsigned short Vals[30];
430         unsigned NumVals = 0;
431         EvaluateOps(i, Vals, NumVals);
432
433         for (unsigned j = 0, e = NumVals; j != e; ++j)
434           PrintOperation(j, Vals);
435         std::cerr << "\n";
436       }
437     }
438   }
439 }
440
441
442 #define GENERATE_ALTIVEC
443
444 #ifdef GENERATE_ALTIVEC
445
446 ///===---------------------------------------------------------------------===//
447 /// The altivec instruction definitions.  This is the altivec-specific part of
448 /// this file.
449 ///===---------------------------------------------------------------------===//
450
451 // Note that the opcode numbers here must match those in the PPC backend.
452 enum {
453   OP_COPY = 0,   // Copy, used for things like <u,u,u,3> to say it is <0,1,2,3>
454   OP_VMRGHW,
455   OP_VMRGLW,
456   OP_VSPLTISW0,
457   OP_VSPLTISW1,
458   OP_VSPLTISW2,
459   OP_VSPLTISW3,
460   OP_VSLDOI4,
461   OP_VSLDOI8,
462   OP_VSLDOI12,
463 };
464
465 struct vmrghw : public Operator {
466   vmrghw() : Operator(0x0415, "vmrghw", OP_VMRGHW) {}
467 } the_vmrghw;
468
469 struct vmrglw : public Operator {
470   vmrglw() : Operator(0x2637, "vmrglw", OP_VMRGLW) {}
471 } the_vmrglw;
472
473 template<unsigned Elt>
474 struct vspltisw : public Operator {
475   vspltisw(const char *N, unsigned Opc)
476     : Operator(MakeMask(Elt, Elt, Elt, Elt), N, Opc) {}
477 };
478
479 vspltisw<0> the_vspltisw0("vspltisw0", OP_VSPLTISW0);
480 vspltisw<1> the_vspltisw1("vspltisw1", OP_VSPLTISW1);
481 vspltisw<2> the_vspltisw2("vspltisw2", OP_VSPLTISW2);
482 vspltisw<3> the_vspltisw3("vspltisw3", OP_VSPLTISW3);
483
484 template<unsigned N>
485 struct vsldoi : public Operator {
486   vsldoi(const char *Name, unsigned Opc)
487     : Operator(MakeMask(N&7, (N+1)&7, (N+2)&7, (N+3)&7), Name, Opc) {
488   }
489 };
490
491 vsldoi<1> the_vsldoi1("vsldoi4" , OP_VSLDOI4);
492 vsldoi<2> the_vsldoi2("vsldoi8" , OP_VSLDOI8);
493 vsldoi<3> the_vsldoi3("vsldoi12", OP_VSLDOI12);
494
495 #endif