1 //=--------------------- AArch64InterleavedAccess.cpp ----------------------==//
3 // The LLVM Compiler Infrastructure
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
8 //===----------------------------------------------------------------------===//
10 // This file implements the AArch64InterleavedAccess pass, which identifies
11 // interleaved memory accesses and Transforms them into an AArch64 ldN/stN
12 // intrinsics (N = 2, 3, 4).
14 // An interleaved load reads data from memory into several vectors, with
15 // DE-interleaving the data on factor. An interleaved store writes several
16 // vectors to memory with RE-interleaving the data on factor. The interleave
17 // factor is equal to the number of vectors. AArch64 backend supports interleave
18 // factor of 2, 3 and 4.
20 // E.g. Transform an interleaved load (Factor = 2):
21 // %wide.vec = load <8 x i32>, <8 x i32>* %ptr
22 // %v0 = shuffle %wide.vec, undef, <0, 2, 4, 6> ; Extract even elements
23 // %v1 = shuffle %wide.vec, undef, <1, 3, 5, 7> ; Extract odd elements
25 // %ld2 = { <4 x i32>, <4 x i32> } call aarch64.neon.ld2(%ptr)
26 // %v0 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 0
27 // %v1 = extractelement { <4 x i32>, <4 x i32> } %ld2, i32 1
29 // E.g. Transform an interleaved store (Factor = 2):
30 // %i.vec = shuffle %v0, %v1, <0, 4, 1, 5, 2, 6, 3, 7> ; Interleaved vec
31 // store <8 x i32> %i.vec, <8 x i32>* %ptr
33 // %v0 = shuffle %i.vec, undef, <0, 1, 2, 3>
34 // %v1 = shuffle %i.vec, undef, <4, 5, 6, 7>
35 // call void aarch64.neon.st2(%v0, %v1, %ptr)
37 //===----------------------------------------------------------------------===//
40 #include "llvm/ADT/SetVector.h"
41 #include "llvm/Analysis/TargetTransformInfo.h"
42 #include "llvm/IR/InstIterator.h"
43 #include "llvm/IR/IRBuilder.h"
44 #include "llvm/IR/Module.h"
45 #include "llvm/Support/Debug.h"
46 #include "llvm/Support/MathExtras.h"
47 #include "llvm/Support/raw_ostream.h"
51 #define DEBUG_TYPE "aarch64-interleaved-access"
53 static const unsigned MIN_FACTOR = 2;
54 static const unsigned MAX_FACTOR = 4;
57 static void initializeAArch64InterleavedAccessPass(PassRegistry &);
62 class AArch64InterleavedAccess : public FunctionPass {
66 AArch64InterleavedAccess() : FunctionPass(ID) {
67 initializeAArch64InterleavedAccessPass(*PassRegistry::getPassRegistry());
70 const char *getPassName() const override {
71 return "AArch64 Interleaved Access Pass";
74 bool runOnFunction(Function &F) override;
80 /// \brief Transform an interleaved load into ldN intrinsic.
81 bool matchInterleavedLoad(ShuffleVectorInst *SVI,
82 SmallSetVector<Instruction *, 32> &DeadInsts);
84 /// \brief Transform an interleaved store into stN intrinsic.
85 bool matchInterleavedStore(ShuffleVectorInst *SVI,
86 SmallSetVector<Instruction *, 32> &DeadInsts);
88 } // end anonymous namespace.
90 char AArch64InterleavedAccess::ID = 0;
92 INITIALIZE_PASS_BEGIN(AArch64InterleavedAccess, DEBUG_TYPE,
93 "AArch64 interleaved access Pass", false, false)
94 INITIALIZE_PASS_END(AArch64InterleavedAccess, DEBUG_TYPE,
95 "AArch64 interleaved access Pass", false, false)
97 FunctionPass *llvm::createAArch64InterleavedAccessPass() {
98 return new AArch64InterleavedAccess();
101 /// \brief Get a ldN/stN intrinsic according to the Factor (2, 3, or 4).
102 static Intrinsic::ID getLdNStNIntrinsic(unsigned Factor, bool IsLoad) {
103 static const Intrinsic::ID LoadInt[3] = {Intrinsic::aarch64_neon_ld2,
104 Intrinsic::aarch64_neon_ld3,
105 Intrinsic::aarch64_neon_ld4};
106 static const Intrinsic::ID StoreInt[3] = {Intrinsic::aarch64_neon_st2,
107 Intrinsic::aarch64_neon_st3,
108 Intrinsic::aarch64_neon_st4};
110 assert(Factor >= MIN_FACTOR && Factor <= MAX_FACTOR &&
111 "Invalid interleave factor");
114 return LoadInt[Factor - 2];
116 return StoreInt[Factor - 2];
119 /// \brief Check if the mask is a DE-interleave mask of the given factor
121 /// <Index, Index+Factor, ..., Index+(NumElts-1)*Factor>
122 static bool isDeInterleaveMaskOfFactor(ArrayRef<int> Mask, unsigned Factor,
124 // Check all potential start indices from 0 to (Factor - 1).
125 for (Index = 0; Index < Factor; Index++) {
128 // Check that elements are in ascending order by Factor.
129 for (; i < Mask.size(); i++)
130 if (Mask[i] >= 0 && static_cast<unsigned>(Mask[i]) != Index + i * Factor)
133 if (i == Mask.size())
140 /// \brief Check if the mask is a DE-interleave mask for an interleaved load.
142 /// E.g. DE-interleave masks (Factor = 2) could be:
143 /// <0, 2, 4, 6> (mask of index 0 to extract even elements)
144 /// <1, 3, 5, 7> (mask of index 1 to extract odd elements)
145 static bool isDeInterleaveMask(ArrayRef<int> Mask, unsigned &Factor,
147 unsigned NumElts = Mask.size();
151 for (Factor = MIN_FACTOR; Factor <= MAX_FACTOR; Factor++)
152 if (isDeInterleaveMaskOfFactor(Mask, Factor, Index))
158 /// \brief Check if the given mask \p Mask is RE-interleaved mask of the given
159 /// factor \p Factor.
161 /// I.e. <0, NumSubElts, ... , NumSubElts*(Factor - 1), 1, NumSubElts + 1, ...>
162 static bool isReInterleaveMaskOfFactor(ArrayRef<int> Mask, unsigned Factor) {
163 unsigned NumElts = Mask.size();
164 if (NumElts % Factor)
167 unsigned NumSubElts = NumElts / Factor;
168 if (!isPowerOf2_32(NumSubElts))
171 for (unsigned i = 0; i < NumSubElts; i++)
172 for (unsigned j = 0; j < Factor; j++)
173 if (Mask[i * Factor + j] >= 0 &&
174 static_cast<unsigned>(Mask[i * Factor + j]) != j * NumSubElts + i)
180 /// \brief Check if the mask is RE-interleave mask for an interleaved store.
182 /// E.g. The RE-interleave mask (Factor = 2) could be:
183 /// <0, 4, 1, 5, 2, 6, 3, 7>
184 static bool isReInterleaveMask(ArrayRef<int> Mask, unsigned &Factor) {
188 // Check potential Factors and return true if find a factor for the mask.
189 for (Factor = MIN_FACTOR; Factor <= MAX_FACTOR; Factor++)
190 if (isReInterleaveMaskOfFactor(Mask, Factor))
196 /// \brief Get a mask consisting of sequential integers starting from \p Start.
198 /// I.e. <Start, Start + 1, ..., Start + NumElts - 1>
199 static Constant *getSequentialMask(IRBuilder<> &Builder, unsigned Start,
201 SmallVector<Constant *, 16> Mask;
202 for (unsigned i = 0; i < NumElts; i++)
203 Mask.push_back(Builder.getInt32(Start + i));
205 return ConstantVector::get(Mask);
208 bool AArch64InterleavedAccess::matchInterleavedLoad(
209 ShuffleVectorInst *SVI, SmallSetVector<Instruction *, 32> &DeadInsts) {
210 if (DeadInsts.count(SVI))
213 LoadInst *LI = dyn_cast<LoadInst>(SVI->getOperand(0));
214 if (!LI || !LI->isSimple() || !isa<UndefValue>(SVI->getOperand(1)))
217 SmallVector<ShuffleVectorInst *, 4> Shuffles;
219 // Check if all users of this load are shufflevectors.
220 for (auto UI = LI->user_begin(), E = LI->user_end(); UI != E; UI++) {
221 ShuffleVectorInst *SV = dyn_cast<ShuffleVectorInst>(*UI);
225 Shuffles.push_back(SV);
228 // Check if the type of the first shuffle is legal.
229 VectorType *VecTy = Shuffles[0]->getType();
230 unsigned TypeSize = DL->getTypeAllocSizeInBits(VecTy);
231 if (TypeSize != 64 && TypeSize != 128)
234 // Check if the mask of the first shuffle is strided and get the start index.
235 unsigned Factor, Index;
236 if (!isDeInterleaveMask(Shuffles[0]->getShuffleMask(), Factor, Index))
239 // Holds the corresponding index for each strided shuffle.
240 SmallVector<unsigned, 4> Indices;
241 Indices.push_back(Index);
243 // Check if other shufflevectors are of the same type and factor
244 for (unsigned i = 1; i < Shuffles.size(); i++) {
245 if (Shuffles[i]->getType() != VecTy)
249 if (!isDeInterleaveMaskOfFactor(Shuffles[i]->getShuffleMask(), Factor,
253 Indices.push_back(Index);
256 DEBUG(dbgs() << "Found an interleaved load:" << *LI << "\n");
258 // A pointer vector can not be the return type of the ldN intrinsics. Need to
259 // load integer vectors first and then convert to pointer vectors.
260 Type *EltTy = VecTy->getVectorElementType();
261 if (EltTy->isPointerTy())
262 VecTy = VectorType::get(DL->getIntPtrType(EltTy),
263 VecTy->getVectorNumElements());
265 Type *PtrTy = VecTy->getPointerTo(LI->getPointerAddressSpace());
266 Type *Tys[2] = {VecTy, PtrTy};
268 Intrinsic::getDeclaration(M, getLdNStNIntrinsic(Factor, true), Tys);
270 IRBuilder<> Builder(LI);
271 Value *Ptr = Builder.CreateBitCast(LI->getPointerOperand(), PtrTy);
273 CallInst *LdN = Builder.CreateCall(LdNFunc, Ptr, "ldN");
274 DEBUG(dbgs() << " Created:" << *LdN << "\n");
276 // Replace each strided shufflevector with the corresponding vector loaded
278 for (unsigned i = 0; i < Shuffles.size(); i++) {
279 ShuffleVectorInst *SV = Shuffles[i];
280 unsigned Index = Indices[i];
282 Value *SubVec = Builder.CreateExtractValue(LdN, Index);
284 // Convert the integer vector to pointer vector if the element is pointer.
285 if (EltTy->isPointerTy())
286 SubVec = Builder.CreateIntToPtr(SubVec, SV->getType());
288 SV->replaceAllUsesWith(SubVec);
290 DEBUG(dbgs() << " Replaced:" << *SV << "\n"
291 << " With:" << *SubVec << "\n");
293 // Avoid analyzing it twice.
294 DeadInsts.insert(SV);
297 // Mark this load as dead.
298 DeadInsts.insert(LI);
302 bool AArch64InterleavedAccess::matchInterleavedStore(
303 ShuffleVectorInst *SVI, SmallSetVector<Instruction *, 32> &DeadInsts) {
304 if (DeadInsts.count(SVI) || !SVI->hasOneUse())
307 StoreInst *SI = dyn_cast<StoreInst>(SVI->user_back());
308 if (!SI || !SI->isSimple())
311 // Check if the mask is interleaved and get the interleave factor.
313 if (!isReInterleaveMask(SVI->getShuffleMask(), Factor))
316 VectorType *VecTy = SVI->getType();
317 unsigned NumSubElts = VecTy->getVectorNumElements() / Factor;
318 Type *EltTy = VecTy->getVectorElementType();
319 VectorType *SubVecTy = VectorType::get(EltTy, NumSubElts);
321 // Skip illegal vector types.
322 unsigned TypeSize = DL->getTypeAllocSizeInBits(SubVecTy);
323 if (TypeSize != 64 && TypeSize != 128)
326 DEBUG(dbgs() << "Found an interleaved store:" << *SI << "\n");
328 Value *Op0 = SVI->getOperand(0);
329 Value *Op1 = SVI->getOperand(1);
330 IRBuilder<> Builder(SI);
332 // StN intrinsics don't support pointer vectors as arguments. Convert pointer
333 // vectors to integer vectors.
334 if (EltTy->isPointerTy()) {
335 Type *IntTy = DL->getIntPtrType(EltTy);
337 dyn_cast<VectorType>(Op0->getType())->getVectorNumElements();
339 // The corresponding integer vector type of the same element size.
340 Type *IntVecTy = VectorType::get(IntTy, NumOpElts);
342 Op0 = Builder.CreatePtrToInt(Op0, IntVecTy);
343 Op1 = Builder.CreatePtrToInt(Op1, IntVecTy);
344 SubVecTy = VectorType::get(IntTy, NumSubElts);
347 Type *PtrTy = SubVecTy->getPointerTo(SI->getPointerAddressSpace());
348 Type *Tys[2] = {SubVecTy, PtrTy};
350 Intrinsic::getDeclaration(M, getLdNStNIntrinsic(Factor, false), Tys);
352 SmallVector<Value *, 5> Ops;
354 // Split the shufflevector operands into sub vectors for the new stN call.
355 for (unsigned i = 0; i < Factor; i++)
356 Ops.push_back(Builder.CreateShuffleVector(
357 Op0, Op1, getSequentialMask(Builder, NumSubElts * i, NumSubElts)));
359 Ops.push_back(Builder.CreateBitCast(SI->getPointerOperand(), PtrTy));
360 CallInst *StN = Builder.CreateCall(StNFunc, Ops);
362 (void)StN; // silence warning.
363 DEBUG(dbgs() << " Replaced:" << *SI << "'\n");
364 DEBUG(dbgs() << " with:" << *StN << "\n");
366 // Mark this shufflevector and store as dead.
367 DeadInsts.insert(SI);
368 DeadInsts.insert(SVI);
372 bool AArch64InterleavedAccess::runOnFunction(Function &F) {
373 DEBUG(dbgs() << "*** " << getPassName() << ": " << F.getName() << "\n");
376 DL = &M->getDataLayout();
378 // Holds dead instructions that will be erased later.
379 SmallSetVector<Instruction *, 32> DeadInsts;
380 bool Changed = false;
381 for (auto &I : inst_range(F)) {
382 if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(&I)) {
383 Changed |= matchInterleavedLoad(SVI, DeadInsts);
384 Changed |= matchInterleavedStore(SVI, DeadInsts);
388 for (auto I : DeadInsts)
389 I->eraseFromParent();