Add gfx950 mfma instructions to ROCDL dialect (#123361)
[llvm-project.git] / llvm / lib / CodeGen / GlobalISel / CombinerHelperCasts.cpp
blob7b4c427a9c50415911a944cf1034b124f68e8326
1 //===- CombinerHelperCasts.cpp---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements CombinerHelper for G_ANYEXT, G_SEXT, G_TRUNC, and
10 // G_ZEXT
12 //===----------------------------------------------------------------------===//
13 #include "llvm/CodeGen/GlobalISel/CombinerHelper.h"
14 #include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
15 #include "llvm/CodeGen/GlobalISel/LegalizerInfo.h"
16 #include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
17 #include "llvm/CodeGen/GlobalISel/Utils.h"
18 #include "llvm/CodeGen/LowLevelTypeUtils.h"
19 #include "llvm/CodeGen/MachineOperand.h"
20 #include "llvm/CodeGen/MachineRegisterInfo.h"
21 #include "llvm/CodeGen/TargetOpcodes.h"
22 #include "llvm/Support/Casting.h"
24 #define DEBUG_TYPE "gi-combiner"
26 using namespace llvm;
28 bool CombinerHelper::matchSextOfTrunc(const MachineOperand &MO,
29 BuildFnTy &MatchInfo) const {
30 GSext *Sext = cast<GSext>(getDefIgnoringCopies(MO.getReg(), MRI));
31 GTrunc *Trunc = cast<GTrunc>(getDefIgnoringCopies(Sext->getSrcReg(), MRI));
33 Register Dst = Sext->getReg(0);
34 Register Src = Trunc->getSrcReg();
36 LLT DstTy = MRI.getType(Dst);
37 LLT SrcTy = MRI.getType(Src);
39 if (DstTy == SrcTy) {
40 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); };
41 return true;
44 if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() &&
45 isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) {
46 MatchInfo = [=](MachineIRBuilder &B) {
47 B.buildTrunc(Dst, Src, MachineInstr::MIFlag::NoSWrap);
49 return true;
52 if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() &&
53 isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT, {DstTy, SrcTy}})) {
54 MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); };
55 return true;
58 return false;
61 bool CombinerHelper::matchZextOfTrunc(const MachineOperand &MO,
62 BuildFnTy &MatchInfo) const {
63 GZext *Zext = cast<GZext>(getDefIgnoringCopies(MO.getReg(), MRI));
64 GTrunc *Trunc = cast<GTrunc>(getDefIgnoringCopies(Zext->getSrcReg(), MRI));
66 Register Dst = Zext->getReg(0);
67 Register Src = Trunc->getSrcReg();
69 LLT DstTy = MRI.getType(Dst);
70 LLT SrcTy = MRI.getType(Src);
72 if (DstTy == SrcTy) {
73 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); };
74 return true;
77 if (DstTy.getScalarSizeInBits() < SrcTy.getScalarSizeInBits() &&
78 isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}})) {
79 MatchInfo = [=](MachineIRBuilder &B) {
80 B.buildTrunc(Dst, Src, MachineInstr::MIFlag::NoUWrap);
82 return true;
85 if (DstTy.getScalarSizeInBits() > SrcTy.getScalarSizeInBits() &&
86 isLegalOrBeforeLegalizer({TargetOpcode::G_ZEXT, {DstTy, SrcTy}})) {
87 MatchInfo = [=](MachineIRBuilder &B) {
88 B.buildZExt(Dst, Src, MachineInstr::MIFlag::NonNeg);
90 return true;
93 return false;
96 bool CombinerHelper::matchNonNegZext(const MachineOperand &MO,
97 BuildFnTy &MatchInfo) const {
98 GZext *Zext = cast<GZext>(MRI.getVRegDef(MO.getReg()));
100 Register Dst = Zext->getReg(0);
101 Register Src = Zext->getSrcReg();
103 LLT DstTy = MRI.getType(Dst);
104 LLT SrcTy = MRI.getType(Src);
105 const auto &TLI = getTargetLowering();
107 // Convert zext nneg to sext if sext is the preferred form for the target.
108 if (isLegalOrBeforeLegalizer({TargetOpcode::G_SEXT, {DstTy, SrcTy}}) &&
109 TLI.isSExtCheaperThanZExt(getMVTForLLT(SrcTy), getMVTForLLT(DstTy))) {
110 MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); };
111 return true;
114 return false;
117 bool CombinerHelper::matchTruncateOfExt(const MachineInstr &Root,
118 const MachineInstr &ExtMI,
119 BuildFnTy &MatchInfo) const {
120 const GTrunc *Trunc = cast<GTrunc>(&Root);
121 const GExtOp *Ext = cast<GExtOp>(&ExtMI);
123 if (!MRI.hasOneNonDBGUse(Ext->getReg(0)))
124 return false;
126 Register Dst = Trunc->getReg(0);
127 Register Src = Ext->getSrcReg();
128 LLT DstTy = MRI.getType(Dst);
129 LLT SrcTy = MRI.getType(Src);
131 if (SrcTy == DstTy) {
132 // The source and the destination are equally sized. We need to copy.
133 MatchInfo = [=](MachineIRBuilder &B) { B.buildCopy(Dst, Src); };
135 return true;
138 if (SrcTy.getScalarSizeInBits() < DstTy.getScalarSizeInBits()) {
139 // If the source is smaller than the destination, we need to extend.
141 if (!isLegalOrBeforeLegalizer({Ext->getOpcode(), {DstTy, SrcTy}}))
142 return false;
144 MatchInfo = [=](MachineIRBuilder &B) {
145 B.buildInstr(Ext->getOpcode(), {Dst}, {Src});
148 return true;
151 if (SrcTy.getScalarSizeInBits() > DstTy.getScalarSizeInBits()) {
152 // If the source is larger than the destination, then we need to truncate.
154 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_TRUNC, {DstTy, SrcTy}}))
155 return false;
157 MatchInfo = [=](MachineIRBuilder &B) { B.buildTrunc(Dst, Src); };
159 return true;
162 return false;
165 bool CombinerHelper::isCastFree(unsigned Opcode, LLT ToTy, LLT FromTy) const {
166 const TargetLowering &TLI = getTargetLowering();
167 LLVMContext &Ctx = getContext();
169 switch (Opcode) {
170 case TargetOpcode::G_ANYEXT:
171 case TargetOpcode::G_ZEXT:
172 return TLI.isZExtFree(FromTy, ToTy, Ctx);
173 case TargetOpcode::G_TRUNC:
174 return TLI.isTruncateFree(FromTy, ToTy, Ctx);
175 default:
176 return false;
180 bool CombinerHelper::matchCastOfSelect(const MachineInstr &CastMI,
181 const MachineInstr &SelectMI,
182 BuildFnTy &MatchInfo) const {
183 const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI);
184 const GSelect *Select = cast<GSelect>(&SelectMI);
186 if (!MRI.hasOneNonDBGUse(Select->getReg(0)))
187 return false;
189 Register Dst = Cast->getReg(0);
190 LLT DstTy = MRI.getType(Dst);
191 LLT CondTy = MRI.getType(Select->getCondReg());
192 Register TrueReg = Select->getTrueReg();
193 Register FalseReg = Select->getFalseReg();
194 LLT SrcTy = MRI.getType(TrueReg);
195 Register Cond = Select->getCondReg();
197 if (!isLegalOrBeforeLegalizer({TargetOpcode::G_SELECT, {DstTy, CondTy}}))
198 return false;
200 if (!isCastFree(Cast->getOpcode(), DstTy, SrcTy))
201 return false;
203 MatchInfo = [=](MachineIRBuilder &B) {
204 auto True = B.buildInstr(Cast->getOpcode(), {DstTy}, {TrueReg});
205 auto False = B.buildInstr(Cast->getOpcode(), {DstTy}, {FalseReg});
206 B.buildSelect(Dst, Cond, True, False);
209 return true;
212 bool CombinerHelper::matchExtOfExt(const MachineInstr &FirstMI,
213 const MachineInstr &SecondMI,
214 BuildFnTy &MatchInfo) const {
215 const GExtOp *First = cast<GExtOp>(&FirstMI);
216 const GExtOp *Second = cast<GExtOp>(&SecondMI);
218 Register Dst = First->getReg(0);
219 Register Src = Second->getSrcReg();
220 LLT DstTy = MRI.getType(Dst);
221 LLT SrcTy = MRI.getType(Src);
223 if (!MRI.hasOneNonDBGUse(Second->getReg(0)))
224 return false;
226 // ext of ext -> later ext
227 if (First->getOpcode() == Second->getOpcode() &&
228 isLegalOrBeforeLegalizer({Second->getOpcode(), {DstTy, SrcTy}})) {
229 if (Second->getOpcode() == TargetOpcode::G_ZEXT) {
230 MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags;
231 if (Second->getFlag(MachineInstr::MIFlag::NonNeg))
232 Flag = MachineInstr::MIFlag::NonNeg;
233 MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); };
234 return true;
236 // not zext -> no flags
237 MatchInfo = [=](MachineIRBuilder &B) {
238 B.buildInstr(Second->getOpcode(), {Dst}, {Src});
240 return true;
243 // anyext of sext/zext -> sext/zext
244 // -> pick anyext as second ext, then ext of ext
245 if (First->getOpcode() == TargetOpcode::G_ANYEXT &&
246 isLegalOrBeforeLegalizer({Second->getOpcode(), {DstTy, SrcTy}})) {
247 if (Second->getOpcode() == TargetOpcode::G_ZEXT) {
248 MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags;
249 if (Second->getFlag(MachineInstr::MIFlag::NonNeg))
250 Flag = MachineInstr::MIFlag::NonNeg;
251 MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); };
252 return true;
254 MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); };
255 return true;
258 // sext/zext of anyext -> sext/zext
259 // -> pick anyext as first ext, then ext of ext
260 if (Second->getOpcode() == TargetOpcode::G_ANYEXT &&
261 isLegalOrBeforeLegalizer({First->getOpcode(), {DstTy, SrcTy}})) {
262 if (First->getOpcode() == TargetOpcode::G_ZEXT) {
263 MachineInstr::MIFlag Flag = MachineInstr::MIFlag::NoFlags;
264 if (First->getFlag(MachineInstr::MIFlag::NonNeg))
265 Flag = MachineInstr::MIFlag::NonNeg;
266 MatchInfo = [=](MachineIRBuilder &B) { B.buildZExt(Dst, Src, Flag); };
267 return true;
269 MatchInfo = [=](MachineIRBuilder &B) { B.buildSExt(Dst, Src); };
270 return true;
273 return false;
276 bool CombinerHelper::matchCastOfBuildVector(const MachineInstr &CastMI,
277 const MachineInstr &BVMI,
278 BuildFnTy &MatchInfo) const {
279 const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI);
280 const GBuildVector *BV = cast<GBuildVector>(&BVMI);
282 if (!MRI.hasOneNonDBGUse(BV->getReg(0)))
283 return false;
285 Register Dst = Cast->getReg(0);
286 // The type of the new build vector.
287 LLT DstTy = MRI.getType(Dst);
288 // The scalar or element type of the new build vector.
289 LLT ElemTy = DstTy.getScalarType();
290 // The scalar or element type of the old build vector.
291 LLT InputElemTy = MRI.getType(BV->getReg(0)).getElementType();
293 // Check legality of new build vector, the scalar casts, and profitability of
294 // the many casts.
295 if (!isLegalOrBeforeLegalizer(
296 {TargetOpcode::G_BUILD_VECTOR, {DstTy, ElemTy}}) ||
297 !isLegalOrBeforeLegalizer({Cast->getOpcode(), {ElemTy, InputElemTy}}) ||
298 !isCastFree(Cast->getOpcode(), ElemTy, InputElemTy))
299 return false;
301 MatchInfo = [=](MachineIRBuilder &B) {
302 SmallVector<Register> Casts;
303 unsigned Elements = BV->getNumSources();
304 for (unsigned I = 0; I < Elements; ++I) {
305 auto CastI =
306 B.buildInstr(Cast->getOpcode(), {ElemTy}, {BV->getSourceReg(I)});
307 Casts.push_back(CastI.getReg(0));
310 B.buildBuildVector(Dst, Casts);
313 return true;
316 bool CombinerHelper::matchNarrowBinop(const MachineInstr &TruncMI,
317 const MachineInstr &BinopMI,
318 BuildFnTy &MatchInfo) const {
319 const GTrunc *Trunc = cast<GTrunc>(&TruncMI);
320 const GBinOp *BinOp = cast<GBinOp>(&BinopMI);
322 if (!MRI.hasOneNonDBGUse(BinOp->getReg(0)))
323 return false;
325 Register Dst = Trunc->getReg(0);
326 LLT DstTy = MRI.getType(Dst);
328 // Is narrow binop legal?
329 if (!isLegalOrBeforeLegalizer({BinOp->getOpcode(), {DstTy}}))
330 return false;
332 MatchInfo = [=](MachineIRBuilder &B) {
333 auto LHS = B.buildTrunc(DstTy, BinOp->getLHSReg());
334 auto RHS = B.buildTrunc(DstTy, BinOp->getRHSReg());
335 B.buildInstr(BinOp->getOpcode(), {Dst}, {LHS, RHS});
338 return true;
341 bool CombinerHelper::matchCastOfInteger(const MachineInstr &CastMI,
342 APInt &MatchInfo) const {
343 const GExtOrTruncOp *Cast = cast<GExtOrTruncOp>(&CastMI);
345 APInt Input = getIConstantFromReg(Cast->getSrcReg(), MRI);
347 LLT DstTy = MRI.getType(Cast->getReg(0));
349 if (!isConstantLegalOrBeforeLegalizer(DstTy))
350 return false;
352 switch (Cast->getOpcode()) {
353 case TargetOpcode::G_TRUNC: {
354 MatchInfo = Input.trunc(DstTy.getScalarSizeInBits());
355 return true;
357 default:
358 return false;