Add lowering of ARM 4-element shuffles to multiple instructios via perfectshuffle...
[oota-llvm.git] / utils / PerfectShuffle / PerfectShuffle.cpp
1 //===-- PerfectShuffle.cpp - Perfect Shuffle Generator --------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file computes an optimal sequence of instructions for doing all shuffles
11 // of two 4-element vectors.  With a release build and when configured to emit
12 // an altivec instruction table, this takes about 30s to run on a 2.7Ghz
13 // PowerPC G5.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include <iostream>
18 #include <vector>
19 #include <cassert>
20 #include <cstdlib>
21 struct Operator;
22
23 // Masks are 4-nibble hex numbers.  Values 0-7 in any nibble means that it takes
24 // an element from that value of the input vectors.  A value of 8 means the
25 // entry is undefined.
26
27 // Mask manipulation functions.
28 static inline unsigned short MakeMask(unsigned V0, unsigned V1,
29                                       unsigned V2, unsigned V3) {
30   return (V0 << (3*4)) | (V1 << (2*4)) | (V2 << (1*4)) | (V3 << (0*4));
31 }
32
33 /// getMaskElt - Return element N of the specified mask.
34 static unsigned getMaskElt(unsigned Mask, unsigned Elt) {
35   return (Mask >> ((3-Elt)*4)) & 0xF;
36 }
37
38 static unsigned setMaskElt(unsigned Mask, unsigned Elt, unsigned NewVal) {
39   unsigned FieldShift = ((3-Elt)*4);
40   return (Mask & ~(0xF << FieldShift)) | (NewVal << FieldShift);
41 }
42
43 // Reject elements where the values are 9-15.
44 static bool isValidMask(unsigned short Mask) {
45   unsigned short UndefBits = Mask & 0x8888;
46   return (Mask & ((UndefBits >> 1)|(UndefBits>>2)|(UndefBits>>3))) == 0;
47 }
48
49 /// hasUndefElements - Return true if any of the elements in the mask are undefs
50 ///
51 static bool hasUndefElements(unsigned short Mask) {
52   return (Mask & 0x8888) != 0;
53 }
54
55 /// isOnlyLHSMask - Return true if this mask only refers to its LHS, not
56 /// including undef values..
57 static bool isOnlyLHSMask(unsigned short Mask) {
58   return (Mask & 0x4444) == 0;
59 }
60
61 /// getLHSOnlyMask - Given a mask that refers to its LHS and RHS, modify it to
62 /// refer to the LHS only (for when one argument value is passed into the same
63 /// function twice).
64 #if 0
65 static unsigned short getLHSOnlyMask(unsigned short Mask) {
66   return Mask & 0xBBBB;  // Keep only LHS and Undefs.
67 }
68 #endif
69
70 /// getCompressedMask - Turn a 16-bit uncompressed mask (where each elt uses 4
71 /// bits) into a compressed 13-bit mask, where each elt is multiplied by 9.
72 static unsigned getCompressedMask(unsigned short Mask) {
73   return getMaskElt(Mask, 0)*9*9*9 + getMaskElt(Mask, 1)*9*9 +
74          getMaskElt(Mask, 2)*9     + getMaskElt(Mask, 3);
75 }
76
77 static void PrintMask(unsigned i, std::ostream &OS) {
78   OS << "<" << (char)(getMaskElt(i, 0) == 8 ? 'u' : ('0'+getMaskElt(i, 0)))
79      << "," << (char)(getMaskElt(i, 1) == 8 ? 'u' : ('0'+getMaskElt(i, 1)))
80      << "," << (char)(getMaskElt(i, 2) == 8 ? 'u' : ('0'+getMaskElt(i, 2)))
81      << "," << (char)(getMaskElt(i, 3) == 8 ? 'u' : ('0'+getMaskElt(i, 3)))
82      << ">";
83 }
84
85 /// ShuffleVal - This represents a shufflevector operation.
86 struct ShuffleVal {
87   unsigned Cost;  // Number of instrs used to generate this value.
88   Operator *Op;   // The Operation used to generate this value.
89   unsigned short Arg0, Arg1;  // Input operands for this value.
90
91   ShuffleVal() : Cost(1000000) {}
92 };
93
94
95 /// ShufTab - This is the actual shuffle table that we are trying to generate.
96 ///
97 static ShuffleVal ShufTab[65536];
98
99 /// TheOperators - All of the operators that this target supports.
100 static std::vector<Operator*> TheOperators;
101
102 /// Operator - This is a vector operation that is available for use.
103 struct Operator {
104   unsigned short ShuffleMask;
105   unsigned short OpNum;
106   const char *Name;
107   unsigned Cost;
108
109   Operator(unsigned short shufflemask, const char *name, unsigned opnum,
110            unsigned cost = 1)
111     : ShuffleMask(shufflemask), OpNum(opnum), Name(name), Cost(cost) {
112     TheOperators.push_back(this);
113   }
114   ~Operator() {
115     assert(TheOperators.back() == this);
116     TheOperators.pop_back();
117   }
118
119   bool isOnlyLHSOperator() const {
120     return isOnlyLHSMask(ShuffleMask);
121   }
122
123   const char *getName() const { return Name; }
124   unsigned getCost() const { return Cost; }
125
126   unsigned short getTransformedMask(unsigned short LHSMask, unsigned RHSMask) {
127     // Extract the elements from LHSMask and RHSMask, as appropriate.
128     unsigned Result = 0;
129     for (unsigned i = 0; i != 4; ++i) {
130       unsigned SrcElt = (ShuffleMask >> (4*i)) & 0xF;
131       unsigned ResElt;
132       if (SrcElt < 4)
133         ResElt = getMaskElt(LHSMask, SrcElt);
134       else if (SrcElt < 8)
135         ResElt = getMaskElt(RHSMask, SrcElt-4);
136       else {
137         assert(SrcElt == 8 && "Bad src elt!");
138         ResElt = 8;
139       }
140       Result |= ResElt << (4*i);
141     }
142     return Result;
143   }
144 };
145
146 static const char *getZeroCostOpName(unsigned short Op) {
147   if (ShufTab[Op].Arg0 == 0x0123)
148     return "LHS";
149   else if (ShufTab[Op].Arg0 == 0x4567)
150     return "RHS";
151   else {
152     assert(0 && "bad zero cost operation");
153     abort();
154   }
155 }
156
157 static void PrintOperation(unsigned ValNo, unsigned short Vals[]) {
158   unsigned short ThisOp = Vals[ValNo];
159   std::cerr << "t" << ValNo;
160   PrintMask(ThisOp, std::cerr);
161   std::cerr << " = " << ShufTab[ThisOp].Op->getName() << "(";
162
163   if (ShufTab[ShufTab[ThisOp].Arg0].Cost == 0) {
164     std::cerr << getZeroCostOpName(ShufTab[ThisOp].Arg0);
165     PrintMask(ShufTab[ThisOp].Arg0, std::cerr);
166   } else {
167     // Figure out what tmp # it is.
168     for (unsigned i = 0; ; ++i)
169       if (Vals[i] == ShufTab[ThisOp].Arg0) {
170         std::cerr << "t" << i;
171         break;
172       }
173   }
174
175   if (!ShufTab[Vals[ValNo]].Op->isOnlyLHSOperator()) {
176     std::cerr << ", ";
177     if (ShufTab[ShufTab[ThisOp].Arg1].Cost == 0) {
178       std::cerr << getZeroCostOpName(ShufTab[ThisOp].Arg1);
179       PrintMask(ShufTab[ThisOp].Arg1, std::cerr);
180     } else {
181       // Figure out what tmp # it is.
182       for (unsigned i = 0; ; ++i)
183         if (Vals[i] == ShufTab[ThisOp].Arg1) {
184           std::cerr << "t" << i;
185           break;
186         }
187     }
188   }
189   std::cerr << ")  ";
190 }
191
192 static unsigned getNumEntered() {
193   unsigned Count = 0;
194   for (unsigned i = 0; i != 65536; ++i)
195     Count += ShufTab[i].Cost < 100;
196   return Count;
197 }
198
199 static void EvaluateOps(unsigned short Elt, unsigned short Vals[],
200                         unsigned &NumVals) {
201   if (ShufTab[Elt].Cost == 0) return;
202
203   // If this value has already been evaluated, it is free.  FIXME: match undefs.
204   for (unsigned i = 0, e = NumVals; i != e; ++i)
205     if (Vals[i] == Elt) return;
206
207   // Otherwise, get the operands of the value, then add it.
208   unsigned Arg0 = ShufTab[Elt].Arg0, Arg1 = ShufTab[Elt].Arg1;
209   if (ShufTab[Arg0].Cost)
210     EvaluateOps(Arg0, Vals, NumVals);
211   if (Arg0 != Arg1 && ShufTab[Arg1].Cost)
212     EvaluateOps(Arg1, Vals, NumVals);
213
214   Vals[NumVals++] = Elt;
215 }
216
217
218 int main() {
219   // Seed the table with accesses to the LHS and RHS.
220   ShufTab[0x0123].Cost = 0;
221   ShufTab[0x0123].Op = 0;
222   ShufTab[0x0123].Arg0 = 0x0123;
223   ShufTab[0x4567].Cost = 0;
224   ShufTab[0x4567].Op = 0;
225   ShufTab[0x4567].Arg0 = 0x4567;
226
227   // Seed the first-level of shuffles, shuffles whose inputs are the input to
228   // the vectorshuffle operation.
229   bool MadeChange = true;
230   unsigned OpCount = 0;
231   while (MadeChange) {
232     MadeChange = false;
233     ++OpCount;
234     std::cerr << "Starting iteration #" << OpCount << " with "
235               << getNumEntered() << " entries established.\n";
236
237     // Scan the table for two reasons: First, compute the maximum cost of any
238     // operation left in the table.  Second, make sure that values with undefs
239     // have the cheapest alternative that they match.
240     unsigned MaxCost = ShufTab[0].Cost;
241     for (unsigned i = 1; i != 0x8889; ++i) {
242       if (!isValidMask(i)) continue;
243       if (ShufTab[i].Cost > MaxCost)
244         MaxCost = ShufTab[i].Cost;
245
246       // If this value has an undef, make it be computed the cheapest possible
247       // way of any of the things that it matches.
248       if (hasUndefElements(i)) {
249         // This code is a little bit tricky, so here's the idea: consider some
250         // permutation, like 7u4u.  To compute the lowest cost for 7u4u, we
251         // need to take the minimum cost of all of 7[0-8]4[0-8], 81 entries.  If
252         // there are 3 undefs, the number rises to 729 entries we have to scan,
253         // and for the 4 undef case, we have to scan the whole table.
254         //
255         // Instead of doing this huge amount of scanning, we process the table
256         // entries *in order*, and use the fact that 'u' is 8, larger than any
257         // valid index.  Given an entry like 7u4u then, we only need to scan
258         // 7[0-7]4u - 8 entries.  We can get away with this, because we already
259         // know that each of 704u, 714u, 724u, etc contain the minimum value of
260         // all of the 704[0-8], 714[0-8] and 724[0-8] entries respectively.
261         unsigned UndefIdx;
262         if (i & 0x8000)
263           UndefIdx = 0;
264         else if (i & 0x0800)
265           UndefIdx = 1;
266         else if (i & 0x0080)
267           UndefIdx = 2;
268         else if (i & 0x0008)
269           UndefIdx = 3;
270         else
271           abort();
272
273         unsigned MinVal  = i;
274         unsigned MinCost = ShufTab[i].Cost;
275
276         // Scan the 8 entries.
277         for (unsigned j = 0; j != 8; ++j) {
278           unsigned NewElt = setMaskElt(i, UndefIdx, j);
279           if (ShufTab[NewElt].Cost < MinCost) {
280             MinCost = ShufTab[NewElt].Cost;
281             MinVal = NewElt;
282           }
283         }
284
285         // If we found something cheaper than what was here before, use it.
286         if (i != MinVal) {
287           MadeChange = true;
288           ShufTab[i] = ShufTab[MinVal];
289         }
290       }
291     }
292
293     for (unsigned LHS = 0; LHS != 0x8889; ++LHS) {
294       if (!isValidMask(LHS)) continue;
295       if (ShufTab[LHS].Cost > 1000) continue;
296
297       // If nothing involving this operand could possibly be cheaper than what
298       // we already have, don't consider it.
299       if (ShufTab[LHS].Cost + 1 >= MaxCost)
300         continue;
301
302       for (unsigned opnum = 0, e = TheOperators.size(); opnum != e; ++opnum) {
303         Operator *Op = TheOperators[opnum];
304
305         // Evaluate op(LHS,LHS)
306         unsigned ResultMask = Op->getTransformedMask(LHS, LHS);
307
308         unsigned Cost = ShufTab[LHS].Cost + Op->getCost();
309         if (Cost < ShufTab[ResultMask].Cost) {
310           ShufTab[ResultMask].Cost = Cost;
311           ShufTab[ResultMask].Op = Op;
312           ShufTab[ResultMask].Arg0 = LHS;
313           ShufTab[ResultMask].Arg1 = LHS;
314           MadeChange = true;
315         }
316
317         // If this is a two input instruction, include the op(x,y) cases.  If
318         // this is a one input instruction, skip this.
319         if (Op->isOnlyLHSOperator()) continue;
320
321         for (unsigned RHS = 0; RHS != 0x8889; ++RHS) {
322           if (!isValidMask(RHS)) continue;
323           if (ShufTab[RHS].Cost > 1000) continue;
324
325           // If nothing involving this operand could possibly be cheaper than
326           // what we already have, don't consider it.
327           if (ShufTab[RHS].Cost + 1 >= MaxCost)
328             continue;
329
330
331           // Evaluate op(LHS,RHS)
332           unsigned ResultMask = Op->getTransformedMask(LHS, RHS);
333
334           if (ShufTab[ResultMask].Cost <= OpCount ||
335               ShufTab[ResultMask].Cost <= ShufTab[LHS].Cost ||
336               ShufTab[ResultMask].Cost <= ShufTab[RHS].Cost)
337             continue;
338
339           // Figure out the cost to evaluate this, knowing that CSE's only need
340           // to be evaluated once.
341           unsigned short Vals[30];
342           unsigned NumVals = 0;
343           EvaluateOps(LHS, Vals, NumVals);
344           EvaluateOps(RHS, Vals, NumVals);
345
346           unsigned Cost = NumVals + Op->getCost();
347           if (Cost < ShufTab[ResultMask].Cost) {
348             ShufTab[ResultMask].Cost = Cost;
349             ShufTab[ResultMask].Op = Op;
350             ShufTab[ResultMask].Arg0 = LHS;
351             ShufTab[ResultMask].Arg1 = RHS;
352             MadeChange = true;
353           }
354         }
355       }
356     }
357   }
358
359   std::cerr << "Finished Table has " << getNumEntered()
360             << " entries established.\n";
361
362   unsigned CostArray[10] = { 0 };
363
364   // Compute a cost histogram.
365   for (unsigned i = 0; i != 65536; ++i) {
366     if (!isValidMask(i)) continue;
367     if (ShufTab[i].Cost > 9)
368       ++CostArray[9];
369     else
370       ++CostArray[ShufTab[i].Cost];
371   }
372
373   for (unsigned i = 0; i != 9; ++i)
374     if (CostArray[i])
375       std::cout << "// " << CostArray[i] << " entries have cost " << i << "\n";
376   if (CostArray[9])
377     std::cout << "// " << CostArray[9] << " entries have higher cost!\n";
378
379
380   // Build up the table to emit.
381   std::cout << "\n// This table is 6561*4 = 26244 bytes in size.\n";
382   std::cout << "static const unsigned PerfectShuffleTable[6561+1] = {\n";
383
384   for (unsigned i = 0; i != 0x8889; ++i) {
385     if (!isValidMask(i)) continue;
386
387     // CostSat - The cost of this operation saturated to two bits.
388     unsigned CostSat = ShufTab[i].Cost;
389     if (CostSat > 4) CostSat = 4;
390     if (CostSat == 0) CostSat = 1;
391     --CostSat;  // Cost is now between 0-3.
392
393     unsigned OpNum = ShufTab[i].Op ? ShufTab[i].Op->OpNum : 0;
394     assert(OpNum < 16 && "Too few bits to encode operation!");
395
396     unsigned LHS = getCompressedMask(ShufTab[i].Arg0);
397     unsigned RHS = getCompressedMask(ShufTab[i].Arg1);
398
399     // Encode this as 2 bits of saturated cost, 4 bits of opcodes, 13 bits of
400     // LHS, and 13 bits of RHS = 32 bits.
401     unsigned Val = (CostSat << 30) | (OpNum << 26) | (LHS << 13) | RHS;
402
403     std::cout << "  " << Val << "U,\t// ";
404     PrintMask(i, std::cout);
405     std::cout << ": Cost " << ShufTab[i].Cost;
406     std::cout << " " << (ShufTab[i].Op ? ShufTab[i].Op->getName() : "copy");
407     std::cout << " ";
408     if (ShufTab[ShufTab[i].Arg0].Cost == 0) {
409       std::cout << getZeroCostOpName(ShufTab[i].Arg0);
410     } else {
411       PrintMask(ShufTab[i].Arg0, std::cout);
412     }
413
414     if (ShufTab[i].Op && !ShufTab[i].Op->isOnlyLHSOperator()) {
415       std::cout << ", ";
416       if (ShufTab[ShufTab[i].Arg1].Cost == 0) {
417         std::cout << getZeroCostOpName(ShufTab[i].Arg1);
418       } else {
419         PrintMask(ShufTab[i].Arg1, std::cout);
420       }
421     }
422     std::cout << "\n";
423   }
424   std::cout << "  0\n};\n";
425
426   if (0) {
427     // Print out the table.
428     for (unsigned i = 0; i != 0x8889; ++i) {
429       if (!isValidMask(i)) continue;
430       if (ShufTab[i].Cost < 1000) {
431         PrintMask(i, std::cerr);
432         std::cerr << " - Cost " << ShufTab[i].Cost << " - ";
433
434         unsigned short Vals[30];
435         unsigned NumVals = 0;
436         EvaluateOps(i, Vals, NumVals);
437
438         for (unsigned j = 0, e = NumVals; j != e; ++j)
439           PrintOperation(j, Vals);
440         std::cerr << "\n";
441       }
442     }
443   }
444 }
445
446
447 #ifdef GENERATE_ALTIVEC
448
449 ///===---------------------------------------------------------------------===//
450 /// The altivec instruction definitions.  This is the altivec-specific part of
451 /// this file.
452 ///===---------------------------------------------------------------------===//
453
454 // Note that the opcode numbers here must match those in the PPC backend.
455 enum {
456   OP_COPY = 0,   // Copy, used for things like <u,u,u,3> to say it is <0,1,2,3>
457   OP_VMRGHW,
458   OP_VMRGLW,
459   OP_VSPLTISW0,
460   OP_VSPLTISW1,
461   OP_VSPLTISW2,
462   OP_VSPLTISW3,
463   OP_VSLDOI4,
464   OP_VSLDOI8,
465   OP_VSLDOI12
466 };
467
468 struct vmrghw : public Operator {
469   vmrghw() : Operator(0x0415, "vmrghw", OP_VMRGHW) {}
470 } the_vmrghw;
471
472 struct vmrglw : public Operator {
473   vmrglw() : Operator(0x2637, "vmrglw", OP_VMRGLW) {}
474 } the_vmrglw;
475
476 template<unsigned Elt>
477 struct vspltisw : public Operator {
478   vspltisw(const char *N, unsigned Opc)
479     : Operator(MakeMask(Elt, Elt, Elt, Elt), N, Opc) {}
480 };
481
482 vspltisw<0> the_vspltisw0("vspltisw0", OP_VSPLTISW0);
483 vspltisw<1> the_vspltisw1("vspltisw1", OP_VSPLTISW1);
484 vspltisw<2> the_vspltisw2("vspltisw2", OP_VSPLTISW2);
485 vspltisw<3> the_vspltisw3("vspltisw3", OP_VSPLTISW3);
486
487 template<unsigned N>
488 struct vsldoi : public Operator {
489   vsldoi(const char *Name, unsigned Opc)
490     : Operator(MakeMask(N&7, (N+1)&7, (N+2)&7, (N+3)&7), Name, Opc) {
491   }
492 };
493
494 vsldoi<1> the_vsldoi1("vsldoi4" , OP_VSLDOI4);
495 vsldoi<2> the_vsldoi2("vsldoi8" , OP_VSLDOI8);
496 vsldoi<3> the_vsldoi3("vsldoi12", OP_VSLDOI12);
497
498 #endif
499
500 #define GENERATE_NEON
501
502 #ifdef GENERATE_NEON
503 enum {
504   OP_COPY = 0,   // Copy, used for things like <u,u,u,3> to say it is <0,1,2,3>
505   OP_VREV,
506   OP_VDUP0,
507   OP_VDUP1,
508   OP_VDUP2,
509   OP_VDUP3,
510   OP_VEXT1,
511   OP_VEXT2,
512   OP_VEXT3,
513   OP_VUZPL, // VUZP, left result
514   OP_VUZPR, // VUZP, right result
515   OP_VZIPL, // VZIP, left result
516   OP_VZIPR, // VZIP, right result
517   OP_VTRNL, // VTRN, left result
518   OP_VTRNR  // VTRN, right result
519 };
520
521 struct vrev : public Operator {
522   vrev() : Operator(0x1032, "vrev", OP_VREV) {}
523 } the_vrev;
524
525 template<unsigned Elt>
526 struct vdup : public Operator {
527   vdup(const char *N, unsigned Opc)
528     : Operator(MakeMask(Elt, Elt, Elt, Elt), N, Opc) {}
529 };
530
531 vdup<0> the_vdup0("vdup0", OP_VDUP0);
532 vdup<1> the_vdup1("vdup1", OP_VDUP1);
533 vdup<2> the_vdup2("vdup2", OP_VDUP2);
534 vdup<3> the_vdup3("vdup3", OP_VDUP3);
535
536 template<unsigned N>
537 struct vext : public Operator {
538   vext(const char *Name, unsigned Opc)
539     : Operator(MakeMask(N&7, (N+1)&7, (N+2)&7, (N+3)&7), Name, Opc) {
540   }
541 };
542
543 vext<1> the_vext1("vext1", OP_VEXT1);
544 vext<2> the_vext2("vext2", OP_VEXT2);
545 vext<3> the_vext3("vext3", OP_VEXT3);
546
547 struct vuzpl : public Operator {
548   vuzpl() : Operator(0x1032, "vuzpl", OP_VUZPL, 2) {}
549 } the_vuzpl;
550
551 struct vuzpr : public Operator {
552   vuzpr() : Operator(0x4602, "vuzpr", OP_VUZPR, 2) {}
553 } the_vuzpr;
554
555 struct vzipl : public Operator {
556   vzipl() : Operator(0x6273, "vzipl", OP_VZIPL, 2) {}
557 } the_vzipl;
558
559 struct vzipr : public Operator {
560   vzipr() : Operator(0x4051, "vzipr", OP_VZIPR, 2) {}
561 } the_vzipr;
562
563 struct vtrnl : public Operator {
564   vtrnl() : Operator(0x5173, "vtrnl", OP_VTRNL, 2) {}
565 } the_vtrnl;
566
567 struct vtrnr : public Operator {
568   vtrnr() : Operator(0x4062, "vtrnr", OP_VTRNR, 2) {}
569 } the_vtrnr;
570
571 #endif