LLVM 20.0.0git
AMDGPURegBankLegalizeHelper.cpp
Go to the documentation of this file.
1//===-- AMDGPURegBankLegalizeHelper.cpp -----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// Implements actual lowering algorithms for each ID that can be used in
10/// Rule.OperandMapping. Similar to legalizer helper but with register banks.
11//
12//===----------------------------------------------------------------------===//
13
16#include "AMDGPUInstrInfo.h"
21
22#define DEBUG_TYPE "amdgpu-regbanklegalize"
23
24using namespace llvm;
25using namespace AMDGPU;
26
29 const RegisterBankInfo &RBI, const RegBankLegalizeRules &RBLRules)
30 : B(B), MRI(*B.getMRI()), MUI(MUI), RBI(RBI), RBLRules(RBLRules),
31 SgprRB(&RBI.getRegBank(AMDGPU::SGPRRegBankID)),
32 VgprRB(&RBI.getRegBank(AMDGPU::VGPRRegBankID)),
33 VccRB(&RBI.getRegBank(AMDGPU::VCCRegBankID)) {}
34
36 const SetOfRulesForOpcode &RuleSet = RBLRules.getRulesForOpc(MI);
37 const RegBankLLTMapping &Mapping = RuleSet.findMappingForMI(MI, MRI, MUI);
38
39 SmallSet<Register, 4> WaterfallSgprs;
40 unsigned OpIdx = 0;
41 if (Mapping.DstOpMapping.size() > 0) {
42 B.setInsertPt(*MI.getParent(), std::next(MI.getIterator()));
43 applyMappingDst(MI, OpIdx, Mapping.DstOpMapping);
44 }
45 if (Mapping.SrcOpMapping.size() > 0) {
46 B.setInstr(MI);
47 applyMappingSrc(MI, OpIdx, Mapping.SrcOpMapping, WaterfallSgprs);
48 }
49
50 lower(MI, Mapping, WaterfallSgprs);
51}
52
53void RegBankLegalizeHelper::splitLoad(MachineInstr &MI,
54 ArrayRef<LLT> LLTBreakdown, LLT MergeTy) {
55 MachineFunction &MF = B.getMF();
56 assert(MI.getNumMemOperands() == 1);
57 MachineMemOperand &BaseMMO = **MI.memoperands_begin();
58 Register Dst = MI.getOperand(0).getReg();
59 const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
60 Register Base = MI.getOperand(1).getReg();
61 LLT PtrTy = MRI.getType(Base);
62 const RegisterBank *PtrRB = MRI.getRegBankOrNull(Base);
63 LLT OffsetTy = LLT::scalar(PtrTy.getSizeInBits());
64 SmallVector<Register, 4> LoadPartRegs;
65
66 unsigned ByteOffset = 0;
67 for (LLT PartTy : LLTBreakdown) {
68 Register BasePlusOffset;
69 if (ByteOffset == 0) {
70 BasePlusOffset = Base;
71 } else {
72 auto Offset = B.buildConstant({PtrRB, OffsetTy}, ByteOffset);
73 BasePlusOffset = B.buildPtrAdd({PtrRB, PtrTy}, Base, Offset).getReg(0);
74 }
75 auto *OffsetMMO = MF.getMachineMemOperand(&BaseMMO, ByteOffset, PartTy);
76 auto LoadPart = B.buildLoad({DstRB, PartTy}, BasePlusOffset, *OffsetMMO);
77 LoadPartRegs.push_back(LoadPart.getReg(0));
78 ByteOffset += PartTy.getSizeInBytes();
79 }
80
81 if (!MergeTy.isValid()) {
82 // Loads are of same size, concat or merge them together.
83 B.buildMergeLikeInstr(Dst, LoadPartRegs);
84 } else {
85 // Loads are not all of same size, need to unmerge them to smaller pieces
86 // of MergeTy type, then merge pieces to Dst.
87 SmallVector<Register, 4> MergeTyParts;
88 for (Register Reg : LoadPartRegs) {
89 if (MRI.getType(Reg) == MergeTy) {
90 MergeTyParts.push_back(Reg);
91 } else {
92 auto Unmerge = B.buildUnmerge({DstRB, MergeTy}, Reg);
93 for (unsigned i = 0; i < Unmerge->getNumOperands() - 1; ++i)
94 MergeTyParts.push_back(Unmerge.getReg(i));
95 }
96 }
97 B.buildMergeLikeInstr(Dst, MergeTyParts);
98 }
99 MI.eraseFromParent();
100}
101
102void RegBankLegalizeHelper::widenLoad(MachineInstr &MI, LLT WideTy,
103 LLT MergeTy) {
104 MachineFunction &MF = B.getMF();
105 assert(MI.getNumMemOperands() == 1);
106 MachineMemOperand &BaseMMO = **MI.memoperands_begin();
107 Register Dst = MI.getOperand(0).getReg();
108 const RegisterBank *DstRB = MRI.getRegBankOrNull(Dst);
109 Register Base = MI.getOperand(1).getReg();
110
111 MachineMemOperand *WideMMO = MF.getMachineMemOperand(&BaseMMO, 0, WideTy);
112 auto WideLoad = B.buildLoad({DstRB, WideTy}, Base, *WideMMO);
113
114 if (WideTy.isScalar()) {
115 B.buildTrunc(Dst, WideLoad);
116 } else {
117 SmallVector<Register, 4> MergeTyParts;
118 auto Unmerge = B.buildUnmerge({DstRB, MergeTy}, WideLoad);
119
120 LLT DstTy = MRI.getType(Dst);
121 unsigned NumElts = DstTy.getSizeInBits() / MergeTy.getSizeInBits();
122 for (unsigned i = 0; i < NumElts; ++i) {
123 MergeTyParts.push_back(Unmerge.getReg(i));
124 }
125 B.buildMergeLikeInstr(Dst, MergeTyParts);
126 }
127 MI.eraseFromParent();
128}
129
130void RegBankLegalizeHelper::lower(MachineInstr &MI,
131 const RegBankLLTMapping &Mapping,
132 SmallSet<Register, 4> &WaterfallSgprs) {
133
134 switch (Mapping.LoweringMethod) {
135 case DoNotLower:
136 return;
137 case UniExtToSel: {
138 LLT Ty = MRI.getType(MI.getOperand(0).getReg());
139 auto True = B.buildConstant({SgprRB, Ty},
140 MI.getOpcode() == AMDGPU::G_SEXT ? -1 : 1);
141 auto False = B.buildConstant({SgprRB, Ty}, 0);
142 // Input to G_{Z|S}EXT is 'Legalizer legal' S1. Most common case is compare.
143 // We are making select here. S1 cond was already 'any-extended to S32' +
144 // 'AND with 1 to clean high bits' by Sgpr32AExtBoolInReg.
145 B.buildSelect(MI.getOperand(0).getReg(), MI.getOperand(1).getReg(), True,
146 False);
147 MI.eraseFromParent();
148 return;
149 }
150 case Ext32To64: {
151 const RegisterBank *RB = MRI.getRegBank(MI.getOperand(0).getReg());
153
154 if (MI.getOpcode() == AMDGPU::G_ZEXT) {
155 Hi = B.buildConstant({RB, S32}, 0);
156 } else {
157 // Replicate sign bit from 32-bit extended part.
158 auto ShiftAmt = B.buildConstant({RB, S32}, 31);
159 Hi = B.buildAShr({RB, S32}, MI.getOperand(1).getReg(), ShiftAmt);
160 }
161
162 B.buildMergeLikeInstr(MI.getOperand(0).getReg(),
163 {MI.getOperand(1).getReg(), Hi});
164 MI.eraseFromParent();
165 return;
166 }
167 case UniCstExt: {
168 uint64_t ConstVal = MI.getOperand(1).getCImm()->getZExtValue();
169 B.buildConstant(MI.getOperand(0).getReg(), ConstVal);
170
171 MI.eraseFromParent();
172 return;
173 }
174 case VgprToVccCopy: {
175 Register Src = MI.getOperand(1).getReg();
176 LLT Ty = MRI.getType(Src);
177 // Take lowest bit from each lane and put it in lane mask.
178 // Lowering via compare, but we need to clean high bits first as compare
179 // compares all bits in register.
180 Register BoolSrc = MRI.createVirtualRegister({VgprRB, Ty});
181 if (Ty == S64) {
182 auto Src64 = B.buildUnmerge({VgprRB, Ty}, Src);
183 auto One = B.buildConstant(VgprRB_S32, 1);
184 auto AndLo = B.buildAnd(VgprRB_S32, Src64.getReg(0), One);
185 auto Zero = B.buildConstant(VgprRB_S32, 0);
186 auto AndHi = B.buildAnd(VgprRB_S32, Src64.getReg(1), Zero);
187 B.buildMergeLikeInstr(BoolSrc, {AndLo, AndHi});
188 } else {
189 assert(Ty == S32 || Ty == S16);
190 auto One = B.buildConstant({VgprRB, Ty}, 1);
191 B.buildAnd(BoolSrc, Src, One);
192 }
193 auto Zero = B.buildConstant({VgprRB, Ty}, 0);
194 B.buildICmp(CmpInst::ICMP_NE, MI.getOperand(0).getReg(), BoolSrc, Zero);
195 MI.eraseFromParent();
196 return;
197 }
198 case SplitTo32: {
199 auto Op1 = B.buildUnmerge(VgprRB_S32, MI.getOperand(1).getReg());
200 auto Op2 = B.buildUnmerge(VgprRB_S32, MI.getOperand(2).getReg());
201 unsigned Opc = MI.getOpcode();
202 auto Lo = B.buildInstr(Opc, {VgprRB_S32}, {Op1.getReg(0), Op2.getReg(0)});
203 auto Hi = B.buildInstr(Opc, {VgprRB_S32}, {Op1.getReg(1), Op2.getReg(1)});
204 B.buildMergeLikeInstr(MI.getOperand(0).getReg(), {Lo, Hi});
205 MI.eraseFromParent();
206 break;
207 }
208 case SplitLoad: {
209 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
210 unsigned Size = DstTy.getSizeInBits();
211 // Even split to 128-bit loads
212 if (Size > 128) {
213 LLT B128;
214 if (DstTy.isVector()) {
215 LLT EltTy = DstTy.getElementType();
216 B128 = LLT::fixed_vector(128 / EltTy.getSizeInBits(), EltTy);
217 } else {
218 B128 = LLT::scalar(128);
219 }
220 if (Size / 128 == 2)
221 splitLoad(MI, {B128, B128});
222 else if (Size / 128 == 4)
223 splitLoad(MI, {B128, B128, B128, B128});
224 else {
225 LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
226 llvm_unreachable("SplitLoad type not supported for MI");
227 }
228 }
229 // 64 and 32 bit load
230 else if (DstTy == S96)
231 splitLoad(MI, {S64, S32}, S32);
232 else if (DstTy == V3S32)
233 splitLoad(MI, {V2S32, S32}, S32);
234 else if (DstTy == V6S16)
235 splitLoad(MI, {V4S16, V2S16}, V2S16);
236 else {
237 LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
238 llvm_unreachable("SplitLoad type not supported for MI");
239 }
240 break;
241 }
242 case WidenLoad: {
243 LLT DstTy = MRI.getType(MI.getOperand(0).getReg());
244 if (DstTy == S96)
245 widenLoad(MI, S128);
246 else if (DstTy == V3S32)
247 widenLoad(MI, V4S32, S32);
248 else if (DstTy == V6S16)
249 widenLoad(MI, V8S16, V2S16);
250 else {
251 LLVM_DEBUG(dbgs() << "MI: "; MI.dump(););
252 llvm_unreachable("WidenLoad type not supported for MI");
253 }
254 break;
255 }
256 }
257
258 // TODO: executeInWaterfallLoop(... WaterfallSgprs)
259}
260
261LLT RegBankLegalizeHelper::getTyFromID(RegBankLLTMappingApplyID ID) {
262 switch (ID) {
263 case Vcc:
264 case UniInVcc:
265 return LLT::scalar(1);
266 case Sgpr16:
267 return LLT::scalar(16);
268 case Sgpr32:
269 case Sgpr32Trunc:
270 case Sgpr32AExt:
272 case Sgpr32SExt:
273 case UniInVgprS32:
274 case Vgpr32:
275 return LLT::scalar(32);
276 case Sgpr64:
277 case Vgpr64:
278 return LLT::scalar(64);
279 case SgprP1:
280 case VgprP1:
281 return LLT::pointer(1, 64);
282 case SgprP3:
283 case VgprP3:
284 return LLT::pointer(3, 32);
285 case SgprP4:
286 case VgprP4:
287 return LLT::pointer(4, 64);
288 case SgprP5:
289 case VgprP5:
290 return LLT::pointer(5, 32);
291 case SgprV4S32:
292 case VgprV4S32:
293 case UniInVgprV4S32:
294 return LLT::fixed_vector(4, 32);
295 default:
296 return LLT();
297 }
298}
299
300LLT RegBankLegalizeHelper::getBTyFromID(RegBankLLTMappingApplyID ID, LLT Ty) {
301 switch (ID) {
302 case SgprB32:
303 case VgprB32:
304 case UniInVgprB32:
305 if (Ty == LLT::scalar(32) || Ty == LLT::fixed_vector(2, 16) ||
306 Ty == LLT::pointer(3, 32) || Ty == LLT::pointer(5, 32) ||
307 Ty == LLT::pointer(6, 32))
308 return Ty;
309 return LLT();
310 case SgprB64:
311 case VgprB64:
312 case UniInVgprB64:
313 if (Ty == LLT::scalar(64) || Ty == LLT::fixed_vector(2, 32) ||
314 Ty == LLT::fixed_vector(4, 16) || Ty == LLT::pointer(0, 64) ||
315 Ty == LLT::pointer(1, 64) || Ty == LLT::pointer(4, 64))
316 return Ty;
317 return LLT();
318 case SgprB96:
319 case VgprB96:
320 case UniInVgprB96:
321 if (Ty == LLT::scalar(96) || Ty == LLT::fixed_vector(3, 32) ||
322 Ty == LLT::fixed_vector(6, 16))
323 return Ty;
324 return LLT();
325 case SgprB128:
326 case VgprB128:
327 case UniInVgprB128:
328 if (Ty == LLT::scalar(128) || Ty == LLT::fixed_vector(4, 32) ||
329 Ty == LLT::fixed_vector(2, 64))
330 return Ty;
331 return LLT();
332 case SgprB256:
333 case VgprB256:
334 case UniInVgprB256:
335 if (Ty == LLT::scalar(256) || Ty == LLT::fixed_vector(8, 32) ||
336 Ty == LLT::fixed_vector(4, 64) || Ty == LLT::fixed_vector(16, 16))
337 return Ty;
338 return LLT();
339 case SgprB512:
340 case VgprB512:
341 case UniInVgprB512:
342 if (Ty == LLT::scalar(512) || Ty == LLT::fixed_vector(16, 32) ||
343 Ty == LLT::fixed_vector(8, 64))
344 return Ty;
345 return LLT();
346 default:
347 return LLT();
348 }
349}
350
351const RegisterBank *
352RegBankLegalizeHelper::getRegBankFromID(RegBankLLTMappingApplyID ID) {
353 switch (ID) {
354 case Vcc:
355 return VccRB;
356 case Sgpr16:
357 case Sgpr32:
358 case Sgpr64:
359 case SgprP1:
360 case SgprP3:
361 case SgprP4:
362 case SgprP5:
363 case SgprV4S32:
364 case SgprB32:
365 case SgprB64:
366 case SgprB96:
367 case SgprB128:
368 case SgprB256:
369 case SgprB512:
370 case UniInVcc:
371 case UniInVgprS32:
372 case UniInVgprV4S32:
373 case UniInVgprB32:
374 case UniInVgprB64:
375 case UniInVgprB96:
376 case UniInVgprB128:
377 case UniInVgprB256:
378 case UniInVgprB512:
379 case Sgpr32Trunc:
380 case Sgpr32AExt:
382 case Sgpr32SExt:
383 return SgprRB;
384 case Vgpr32:
385 case Vgpr64:
386 case VgprP1:
387 case VgprP3:
388 case VgprP4:
389 case VgprP5:
390 case VgprV4S32:
391 case VgprB32:
392 case VgprB64:
393 case VgprB96:
394 case VgprB128:
395 case VgprB256:
396 case VgprB512:
397 return VgprRB;
398 default:
399 return nullptr;
400 }
401}
402
403void RegBankLegalizeHelper::applyMappingDst(
404 MachineInstr &MI, unsigned &OpIdx,
406 // Defs start from operand 0
407 for (; OpIdx < MethodIDs.size(); ++OpIdx) {
408 if (MethodIDs[OpIdx] == None)
409 continue;
410 MachineOperand &Op = MI.getOperand(OpIdx);
411 Register Reg = Op.getReg();
412 LLT Ty = MRI.getType(Reg);
413 [[maybe_unused]] const RegisterBank *RB = MRI.getRegBank(Reg);
414
415 switch (MethodIDs[OpIdx]) {
416 // vcc, sgpr and vgpr scalars, pointers and vectors
417 case Vcc:
418 case Sgpr16:
419 case Sgpr32:
420 case Sgpr64:
421 case SgprP1:
422 case SgprP3:
423 case SgprP4:
424 case SgprP5:
425 case SgprV4S32:
426 case Vgpr32:
427 case Vgpr64:
428 case VgprP1:
429 case VgprP3:
430 case VgprP4:
431 case VgprP5:
432 case VgprV4S32: {
433 assert(Ty == getTyFromID(MethodIDs[OpIdx]));
434 assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
435 break;
436 }
437 // sgpr and vgpr B-types
438 case SgprB32:
439 case SgprB64:
440 case SgprB96:
441 case SgprB128:
442 case SgprB256:
443 case SgprB512:
444 case VgprB32:
445 case VgprB64:
446 case VgprB96:
447 case VgprB128:
448 case VgprB256:
449 case VgprB512: {
450 assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
451 assert(RB == getRegBankFromID(MethodIDs[OpIdx]));
452 break;
453 }
454 // uniform in vcc/vgpr: scalars, vectors and B-types
455 case UniInVcc: {
456 assert(Ty == S1);
457 assert(RB == SgprRB);
458 Register NewDst = MRI.createVirtualRegister(VccRB_S1);
459 Op.setReg(NewDst);
460 auto CopyS32_Vcc =
461 B.buildInstr(AMDGPU::G_AMDGPU_COPY_SCC_VCC, {SgprRB_S32}, {NewDst});
462 B.buildTrunc(Reg, CopyS32_Vcc);
463 break;
464 }
465 case UniInVgprS32:
466 case UniInVgprV4S32: {
467 assert(Ty == getTyFromID(MethodIDs[OpIdx]));
468 assert(RB == SgprRB);
469 Register NewVgprDst = MRI.createVirtualRegister({VgprRB, Ty});
470 Op.setReg(NewVgprDst);
471 buildReadAnyLane(B, Reg, NewVgprDst, RBI);
472 break;
473 }
474 case UniInVgprB32:
475 case UniInVgprB64:
476 case UniInVgprB96:
477 case UniInVgprB128:
478 case UniInVgprB256:
479 case UniInVgprB512: {
480 assert(Ty == getBTyFromID(MethodIDs[OpIdx], Ty));
481 assert(RB == SgprRB);
482 Register NewVgprDst = MRI.createVirtualRegister({VgprRB, Ty});
483 Op.setReg(NewVgprDst);
484 AMDGPU::buildReadAnyLane(B, Reg, NewVgprDst, RBI);
485 break;
486 }
487 // sgpr trunc
488 case Sgpr32Trunc: {
489 assert(Ty.getSizeInBits() < 32);
490 assert(RB == SgprRB);
491 Register NewDst = MRI.createVirtualRegister(SgprRB_S32);
492 Op.setReg(NewDst);
493 B.buildTrunc(Reg, NewDst);
494 break;
495 }
496 case InvalidMapping: {
497 LLVM_DEBUG(dbgs() << "Instruction with Invalid mapping: "; MI.dump(););
498 llvm_unreachable("missing fast rule for MI");
499 }
500 default:
501 llvm_unreachable("ID not supported");
502 }
503 }
504}
505
506void RegBankLegalizeHelper::applyMappingSrc(
507 MachineInstr &MI, unsigned &OpIdx,
509 SmallSet<Register, 4> &SgprWaterfallOperandRegs) {
510 for (unsigned i = 0; i < MethodIDs.size(); ++OpIdx, ++i) {
511 if (MethodIDs[i] == None || MethodIDs[i] == IntrId || MethodIDs[i] == Imm)
512 continue;
513
514 MachineOperand &Op = MI.getOperand(OpIdx);
515 Register Reg = Op.getReg();
516 LLT Ty = MRI.getType(Reg);
517 const RegisterBank *RB = MRI.getRegBank(Reg);
518
519 switch (MethodIDs[i]) {
520 case Vcc: {
521 assert(Ty == S1);
522 assert(RB == VccRB || RB == SgprRB);
523 if (RB == SgprRB) {
524 auto Aext = B.buildAnyExt(SgprRB_S32, Reg);
525 auto CopyVcc_Scc =
526 B.buildInstr(AMDGPU::G_AMDGPU_COPY_VCC_SCC, {VccRB_S1}, {Aext});
527 Op.setReg(CopyVcc_Scc.getReg(0));
528 }
529 break;
530 }
531 // sgpr scalars, pointers and vectors
532 case Sgpr16:
533 case Sgpr32:
534 case Sgpr64:
535 case SgprP1:
536 case SgprP3:
537 case SgprP4:
538 case SgprP5:
539 case SgprV4S32: {
540 assert(Ty == getTyFromID(MethodIDs[i]));
541 assert(RB == getRegBankFromID(MethodIDs[i]));
542 break;
543 }
544 // sgpr B-types
545 case SgprB32:
546 case SgprB64:
547 case SgprB96:
548 case SgprB128:
549 case SgprB256:
550 case SgprB512: {
551 assert(Ty == getBTyFromID(MethodIDs[i], Ty));
552 assert(RB == getRegBankFromID(MethodIDs[i]));
553 break;
554 }
555 // vgpr scalars, pointers and vectors
556 case Vgpr32:
557 case Vgpr64:
558 case VgprP1:
559 case VgprP3:
560 case VgprP4:
561 case VgprP5:
562 case VgprV4S32: {
563 assert(Ty == getTyFromID(MethodIDs[i]));
564 if (RB != VgprRB) {
565 auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg);
566 Op.setReg(CopyToVgpr.getReg(0));
567 }
568 break;
569 }
570 // vgpr B-types
571 case VgprB32:
572 case VgprB64:
573 case VgprB96:
574 case VgprB128:
575 case VgprB256:
576 case VgprB512: {
577 assert(Ty == getBTyFromID(MethodIDs[i], Ty));
578 if (RB != VgprRB) {
579 auto CopyToVgpr = B.buildCopy({VgprRB, Ty}, Reg);
580 Op.setReg(CopyToVgpr.getReg(0));
581 }
582 break;
583 }
584 // sgpr and vgpr scalars with extend
585 case Sgpr32AExt: {
586 // Note: this ext allows S1, and it is meant to be combined away.
587 assert(Ty.getSizeInBits() < 32);
588 assert(RB == SgprRB);
589 auto Aext = B.buildAnyExt(SgprRB_S32, Reg);
590 Op.setReg(Aext.getReg(0));
591 break;
592 }
593 case Sgpr32AExtBoolInReg: {
594 // Note: this ext allows S1, and it is meant to be combined away.
595 assert(Ty.getSizeInBits() == 1);
596 assert(RB == SgprRB);
597 auto Aext = B.buildAnyExt(SgprRB_S32, Reg);
598 // Zext SgprS1 is not legal, this instruction is most of times meant to be
599 // combined away in RB combiner, so do not make AND with 1.
600 auto Cst1 = B.buildConstant(SgprRB_S32, 1);
601 auto BoolInReg = B.buildAnd(SgprRB_S32, Aext, Cst1);
602 Op.setReg(BoolInReg.getReg(0));
603 break;
604 }
605 case Sgpr32SExt: {
606 assert(1 < Ty.getSizeInBits() && Ty.getSizeInBits() < 32);
607 assert(RB == SgprRB);
608 auto Sext = B.buildSExt(SgprRB_S32, Reg);
609 Op.setReg(Sext.getReg(0));
610 break;
611 }
612 default:
613 llvm_unreachable("ID not supported");
614 }
615 }
616}
617
619 Register Dst = MI.getOperand(0).getReg();
620 LLT Ty = MRI.getType(Dst);
621
622 if (Ty == LLT::scalar(1) && MUI.isUniform(Dst)) {
623 B.setInsertPt(*MI.getParent(), MI.getParent()->getFirstNonPHI());
624
625 Register NewDst = MRI.createVirtualRegister(SgprRB_S32);
626 MI.getOperand(0).setReg(NewDst);
627 B.buildTrunc(Dst, NewDst);
628
629 for (unsigned i = 1; i < MI.getNumOperands(); i += 2) {
630 Register UseReg = MI.getOperand(i).getReg();
631
632 auto DefMI = MRI.getVRegDef(UseReg)->getIterator();
633 MachineBasicBlock *DefMBB = DefMI->getParent();
634
635 B.setInsertPt(*DefMBB, DefMBB->SkipPHIsAndLabels(std::next(DefMI)));
636
637 auto NewUse = B.buildAnyExt(SgprRB_S32, UseReg);
638 MI.getOperand(i).setReg(NewUse.getReg(0));
639 }
640
641 return;
642 }
643
644 // ALL divergent i1 phis should be already lowered and inst-selected into PHI
645 // with sgpr reg class and S1 LLT.
646 // Note: this includes divergent phis that don't require lowering.
647 if (Ty == LLT::scalar(1) && MUI.isDivergent(Dst)) {
648 LLVM_DEBUG(dbgs() << "Divergent S1 G_PHI: "; MI.dump(););
649 llvm_unreachable("Make sure to run AMDGPUGlobalISelDivergenceLowering "
650 "before RegBankLegalize to lower lane mask(vcc) phis");
651 }
652
653 // We accept all types that can fit in some register class.
654 // Uniform G_PHIs have all sgpr registers.
655 // Divergent G_PHIs have vgpr dst but inputs can be sgpr or vgpr.
656 if (Ty == LLT::scalar(32) || Ty == LLT::pointer(4, 64)) {
657 return;
658 }
659
660 LLVM_DEBUG(dbgs() << "G_PHI not handled: "; MI.dump(););
661 llvm_unreachable("type not supported");
662}
663
664[[maybe_unused]] static bool verifyRegBankOnOperands(MachineInstr &MI,
665 const RegisterBank *RB,
667 unsigned StartOpIdx,
668 unsigned EndOpIdx) {
669 for (unsigned i = StartOpIdx; i <= EndOpIdx; ++i) {
670 if (MRI.getRegBankOrNull(MI.getOperand(i).getReg()) != RB)
671 return false;
672 }
673 return true;
674}
675
677 const RegisterBank *RB = MRI.getRegBank(MI.getOperand(0).getReg());
678 // Put RB on all registers
679 unsigned NumDefs = MI.getNumDefs();
680 unsigned NumOperands = MI.getNumOperands();
681
682 assert(verifyRegBankOnOperands(MI, RB, MRI, 0, NumDefs - 1));
683 if (RB == SgprRB)
684 assert(verifyRegBankOnOperands(MI, RB, MRI, NumDefs, NumOperands - 1));
685
686 if (RB == VgprRB) {
687 B.setInstr(MI);
688 for (unsigned i = NumDefs; i < NumOperands; ++i) {
689 Register Reg = MI.getOperand(i).getReg();
690 if (MRI.getRegBank(Reg) != RB) {
691 auto Copy = B.buildCopy({VgprRB, MRI.getType(Reg)}, Reg);
692 MI.getOperand(i).setReg(Copy.getReg(0));
693 }
694 }
695 }
696}
unsigned const MachineRegisterInfo * MRI
MachineInstrBuilder MachineInstrBuilder & DefMI
Contains the definition of a TargetInstrInfo class that is common to all AMD GPUs.
Provides AMDGPU specific target descriptions.
static bool verifyRegBankOnOperands(MachineInstr &MI, const RegisterBank *RB, MachineRegisterInfo &MRI, unsigned StartOpIdx, unsigned EndOpIdx)
This file declares the targeting of the RegisterBankInfo class for AMDGPU.
static GCRegistry::Add< OcamlGC > B("ocaml", "ocaml 3.10-compatible GC")
#define LLVM_DEBUG(...)
Definition: Debug.h:106
uint64_t Size
static Register UseReg(const MachineOperand &MO)
IRTranslator LLVM IR MI
This file declares the MachineIRBuilder class.
Machine IR instance of the generic uniformity analysis.
static unsigned getReg(const MCDisassembler *D, unsigned RC, unsigned RegNo)
assert(ImpDefSCC.getReg()==AMDGPU::SCC &&ImpDefSCC.isDef())
RegBankLegalizeHelper(MachineIRBuilder &B, const MachineUniformityInfo &MUI, const RegisterBankInfo &RBI, const RegBankLegalizeRules &RBLRules)
const SetOfRulesForOpcode & getRulesForOpc(MachineInstr &MI) const
const RegBankLLTMapping & findMappingForMI(const MachineInstr &MI, const MachineRegisterInfo &MRI, const MachineUniformityInfo &MUI) const
ArrayRef - Represent a constant reference to an array (0 or more elements consecutively in memory),...
Definition: ArrayRef.h:41
@ ICMP_NE
not equal
Definition: InstrTypes.h:695
This class represents an Operation in the Expression.
bool isDivergent(ConstValueRefT V) const
Whether V is divergent at its definition.
bool isUniform(ConstValueRefT V) const
Whether V is uniform/non-divergent.
constexpr bool isScalar() const
Definition: LowLevelType.h:146
static constexpr LLT scalar(unsigned SizeInBits)
Get a low-level scalar or aggregate "bag of bits".
Definition: LowLevelType.h:42
constexpr bool isValid() const
Definition: LowLevelType.h:145
constexpr bool isVector() const
Definition: LowLevelType.h:148
static constexpr LLT pointer(unsigned AddressSpace, unsigned SizeInBits)
Get a low-level pointer in the given address space.
Definition: LowLevelType.h:57
constexpr TypeSize getSizeInBits() const
Returns the total size of the type. Must only be called on sized types.
Definition: LowLevelType.h:190
constexpr LLT getElementType() const
Returns the vector's element type. Only valid for vector types.
Definition: LowLevelType.h:277
static constexpr LLT fixed_vector(unsigned NumElements, unsigned ScalarSizeInBits)
Get a low-level fixed-width vector of some number of elements and element width.
Definition: LowLevelType.h:100
iterator SkipPHIsAndLabels(iterator I)
Return the first instruction in MBB after I that is not a PHI or a label.
MachineMemOperand * getMachineMemOperand(MachinePointerInfo PtrInfo, MachineMemOperand::Flags f, LLT MemTy, Align base_alignment, const AAMDNodes &AAInfo=AAMDNodes(), const MDNode *Ranges=nullptr, SyncScope::ID SSID=SyncScope::System, AtomicOrdering Ordering=AtomicOrdering::NotAtomic, AtomicOrdering FailureOrdering=AtomicOrdering::NotAtomic)
getMachineMemOperand - Allocate a new MachineMemOperand.
Helper class to build MachineInstr.
void setInsertPt(MachineBasicBlock &MBB, MachineBasicBlock::iterator II)
Set the insertion point before the specified position.
MachineInstrBuilder buildAShr(const DstOp &Dst, const SrcOp &Src0, const SrcOp &Src1, std::optional< unsigned > Flags=std::nullopt)
MachineInstrBuilder buildUnmerge(ArrayRef< LLT > Res, const SrcOp &Op)
Build and insert Res0, ... = G_UNMERGE_VALUES Op.
MachineInstrBuilder buildSelect(const DstOp &Res, const SrcOp &Tst, const SrcOp &Op0, const SrcOp &Op1, std::optional< unsigned > Flags=std::nullopt)
Build and insert a Res = G_SELECT Tst, Op0, Op1.
MachineInstrBuilder buildAnd(const DstOp &Dst, const SrcOp &Src0, const SrcOp &Src1)
Build and insert Res = G_AND Op0, Op1.
MachineInstrBuilder buildICmp(CmpInst::Predicate Pred, const DstOp &Res, const SrcOp &Op0, const SrcOp &Op1, std::optional< unsigned > Flags=std::nullopt)
Build and insert a Res = G_ICMP Pred, Op0, Op1.
MachineInstrBuilder buildSExt(const DstOp &Res, const SrcOp &Op)
Build and insert Res = G_SEXT Op.
void setInstr(MachineInstr &MI)
Set the insertion point to before MI.
MachineInstrBuilder buildMergeLikeInstr(const DstOp &Res, ArrayRef< Register > Ops)
Build and insert Res = G_MERGE_VALUES Op0, ... or Res = G_BUILD_VECTOR Op0, ... or Res = G_CONCAT_VEC...
MachineInstrBuilder buildLoad(const DstOp &Res, const SrcOp &Addr, MachineMemOperand &MMO)
Build and insert Res = G_LOAD Addr, MMO.
MachineInstrBuilder buildPtrAdd(const DstOp &Res, const SrcOp &Op0, const SrcOp &Op1, std::optional< unsigned > Flags=std::nullopt)
Build and insert Res = G_PTR_ADD Op0, Op1.
MachineInstrBuilder buildInstr(unsigned Opcode)
Build and insert <empty> = Opcode <empty>.
MachineFunction & getMF()
Getter for the function we currently build.
MachineInstrBuilder buildTrunc(const DstOp &Res, const SrcOp &Op, std::optional< unsigned > Flags=std::nullopt)
Build and insert Res = G_TRUNC Op.
MachineInstrBuilder buildAnyExt(const DstOp &Res, const SrcOp &Op)
Build and insert Res = G_ANYEXT Op0.
MachineInstrBuilder buildCopy(const DstOp &Res, const SrcOp &Op)
Build and insert Res = COPY Op.
virtual MachineInstrBuilder buildConstant(const DstOp &Res, const ConstantInt &Val)
Build and insert Res = G_CONSTANT Val.
Representation of each machine instruction.
Definition: MachineInstr.h:71
const MachineBasicBlock * getParent() const
Definition: MachineInstr.h:349
A description of a memory reference used in the backend.
MachineOperand class - Representation of each machine instruction operand.
MachineRegisterInfo - Keep track of information for virtual and physical registers,...
MachineInstr * getVRegDef(Register Reg) const
getVRegDef - Return the machine instr that defines the specified virtual register or null if none is ...
const RegisterBank * getRegBank(Register Reg) const
Return the register bank of Reg.
Register createVirtualRegister(const TargetRegisterClass *RegClass, StringRef Name="")
createVirtualRegister - Create and return a new virtual register in the function with the specified r...
LLT getType(Register Reg) const
Get the low-level type of Reg or LLT{} if Reg is not a generic (target independent) virtual register.
const RegisterBank * getRegBankOrNull(Register Reg) const
Return the register bank of Reg, or null if Reg has not been assigned a register bank or has been ass...
void dump() const
Definition: Pass.cpp:136
Holds all the information related to register banks.
This class implements the register bank concept.
Definition: RegisterBank.h:28
Wrapper class representing virtual and physical registers.
Definition: Register.h:19
SmallSet - This maintains a set of unique values, optimizing for the case when the set is small (less...
Definition: SmallSet.h:132
size_t size() const
Definition: SmallVector.h:78
This class consists of common code factored out of the SmallVector class to reduce code duplication b...
Definition: SmallVector.h:573
void push_back(const T &Elt)
Definition: SmallVector.h:413
This is a 'vector' (really, a variable-sized array), optimized for the case when the array is small.
Definition: SmallVector.h:1196
self_iterator getIterator()
Definition: ilist_node.h:132
#define llvm_unreachable(msg)
Marks that the current location is not supposed to be reachable.
void buildReadAnyLane(MachineIRBuilder &B, Register SgprDst, Register VgprSrc, const RegisterBankInfo &RBI)
Reg
All possible values of the reg field in the ModR/M byte.
This is an optimization pass for GlobalISel generic memory operations.
Definition: AddressRanges.h:18
@ Offset
Definition: DWP.cpp:480
raw_ostream & dbgs()
dbgs() - This returns a reference to a raw_ostream for debugging messages.
Definition: Debug.cpp:163
SmallVector< RegBankLLTMappingApplyID, 2 > DstOpMapping
SmallVector< RegBankLLTMappingApplyID, 4 > SrcOpMapping