[mlir][linalg] Add support for masked vectorization of `tensor.insert_slice` (1/N...
[llvm-project.git] / llvm / tools / llvm-mca / CodeRegionGenerator.h
blob12261e7656a4237cf58a237c1dc632a194051bbc
1 //===----------------------- CodeRegionGenerator.h --------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 ///
10 /// This file declares classes responsible for generating llvm-mca
11 /// CodeRegions from various types of input. llvm-mca only analyzes CodeRegions,
12 /// so the classes here provide the input-to-CodeRegions translation.
14 //===----------------------------------------------------------------------===//
16 #ifndef LLVM_TOOLS_LLVM_MCA_CODEREGION_GENERATOR_H
17 #define LLVM_TOOLS_LLVM_MCA_CODEREGION_GENERATOR_H
19 #include "CodeRegion.h"
20 #include "llvm/MC/MCAsmInfo.h"
21 #include "llvm/MC/MCContext.h"
22 #include "llvm/MC/MCParser/MCAsmLexer.h"
23 #include "llvm/MC/MCStreamer.h"
24 #include "llvm/MC/MCSubtargetInfo.h"
25 #include "llvm/MC/TargetRegistry.h"
26 #include "llvm/MCA/CustomBehaviour.h"
27 #include "llvm/Support/Error.h"
28 #include "llvm/Support/SourceMgr.h"
29 #include <memory>
31 namespace llvm {
32 namespace mca {
34 class MCACommentConsumer : public AsmCommentConsumer {
35 protected:
36 bool FoundError = false;
38 public:
39 MCACommentConsumer() = default;
41 bool hadErr() const { return FoundError; }
44 /// A comment consumer that parses strings. The only valid tokens are strings.
45 class AnalysisRegionCommentConsumer : public MCACommentConsumer {
46 AnalysisRegions &Regions;
48 public:
49 AnalysisRegionCommentConsumer(AnalysisRegions &R) : Regions(R) {}
51 /// Parses a comment. It begins a new region if it is of the form
52 /// LLVM-MCA-BEGIN. It ends a region if it is of the form LLVM-MCA-END.
53 /// Regions can be optionally named if they are of the form
54 /// LLVM-MCA-BEGIN <name> or LLVM-MCA-END <name>. Subregions are
55 /// permitted, but a region that begins while another region is active
56 /// must be ended before the outer region is ended. If thre is only one
57 /// active region, LLVM-MCA-END does not need to provide a name.
58 void HandleComment(SMLoc Loc, StringRef CommentText) override;
61 /// A comment consumer that parses strings to create InstrumentRegions.
62 /// The only valid tokens are strings.
63 class InstrumentRegionCommentConsumer : public MCACommentConsumer {
64 llvm::SourceMgr &SM;
66 InstrumentRegions &Regions;
68 InstrumentManager &IM;
70 public:
71 InstrumentRegionCommentConsumer(llvm::SourceMgr &SM, InstrumentRegions &R,
72 InstrumentManager &IM)
73 : SM(SM), Regions(R), IM(IM) {}
75 /// Parses a comment. It begins a new region if it is of the form
76 /// LLVM-MCA-<INSTRUMENTATION_TYPE> <data> where INSTRUMENTATION_TYPE
77 /// is a valid InstrumentKind. If there is already an active
78 /// region of type INSTRUMENATION_TYPE, then it will end the active
79 /// one and begin a new one using the new data.
80 void HandleComment(SMLoc Loc, StringRef CommentText) override;
82 InstrumentManager &getInstrumentManager() { return IM; }
85 // This class provides the callbacks that occur when parsing input assembly.
86 class MCStreamerWrapper : public MCStreamer {
87 protected:
88 CodeRegions &Regions;
90 public:
91 MCStreamerWrapper(MCContext &Context, mca::CodeRegions &R)
92 : MCStreamer(Context), Regions(R) {}
94 // We only want to intercept the emission of new instructions.
95 void emitInstruction(const MCInst &Inst,
96 const MCSubtargetInfo & /* unused */) override {
97 Regions.addInstruction(Inst);
100 bool emitSymbolAttribute(MCSymbol *Symbol, MCSymbolAttr Attribute) override {
101 return true;
104 void emitCommonSymbol(MCSymbol *Symbol, uint64_t Size,
105 Align ByteAlignment) override {}
106 void emitZerofill(MCSection *Section, MCSymbol *Symbol = nullptr,
107 uint64_t Size = 0, Align ByteAlignment = Align(1),
108 SMLoc Loc = SMLoc()) override {}
109 void emitGPRel32Value(const MCExpr *Value) override {}
110 void beginCOFFSymbolDef(const MCSymbol *Symbol) override {}
111 void emitCOFFSymbolStorageClass(int StorageClass) override {}
112 void emitCOFFSymbolType(int Type) override {}
113 void endCOFFSymbolDef() override {}
115 ArrayRef<MCInst> GetInstructionSequence(unsigned Index) const {
116 return Regions.getInstructionSequence(Index);
120 class InstrumentMCStreamer : public MCStreamerWrapper {
121 InstrumentManager &IM;
123 public:
124 InstrumentMCStreamer(MCContext &Context, mca::InstrumentRegions &R,
125 InstrumentManager &IM)
126 : MCStreamerWrapper(Context, R), IM(IM) {}
128 void emitInstruction(const MCInst &Inst,
129 const MCSubtargetInfo &MCSI) override {
130 MCStreamerWrapper::emitInstruction(Inst, MCSI);
132 // We know that Regions is an InstrumentRegions by the constructor.
133 for (UniqueInstrument &I : IM.createInstruments(Inst)) {
134 StringRef InstrumentKind = I.get()->getDesc();
135 // End InstrumentType region if one is open
136 if (Regions.isRegionActive(InstrumentKind))
137 Regions.endRegion(InstrumentKind, Inst.getLoc());
138 // Start new instrumentation region
139 Regions.beginRegion(InstrumentKind, Inst.getLoc(), std::move(I));
144 /// This abstract class is responsible for parsing the input given to
145 /// the llvm-mca driver, and converting that into a CodeRegions instance.
146 class CodeRegionGenerator {
147 protected:
148 CodeRegionGenerator(const CodeRegionGenerator &) = delete;
149 CodeRegionGenerator &operator=(const CodeRegionGenerator &) = delete;
150 virtual Expected<const CodeRegions &>
151 parseCodeRegions(const std::unique_ptr<MCInstPrinter> &IP,
152 bool SkipFailures) = 0;
154 public:
155 CodeRegionGenerator() {}
156 virtual ~CodeRegionGenerator();
159 /// Abastract CodeRegionGenerator with AnalysisRegions member
160 class AnalysisRegionGenerator : public virtual CodeRegionGenerator {
161 protected:
162 AnalysisRegions Regions;
164 public:
165 AnalysisRegionGenerator(llvm::SourceMgr &SM) : Regions(SM) {}
167 virtual Expected<const AnalysisRegions &>
168 parseAnalysisRegions(const std::unique_ptr<MCInstPrinter> &IP,
169 bool SkipFailures) = 0;
172 /// Abstract CodeRegionGenerator with InstrumentRegionsRegions member
173 class InstrumentRegionGenerator : public virtual CodeRegionGenerator {
174 protected:
175 InstrumentRegions Regions;
177 public:
178 InstrumentRegionGenerator(llvm::SourceMgr &SM) : Regions(SM) {}
180 virtual Expected<const InstrumentRegions &>
181 parseInstrumentRegions(const std::unique_ptr<MCInstPrinter> &IP,
182 bool SkipFailures) = 0;
185 /// This abstract class is responsible for parsing input ASM and
186 /// generating a CodeRegions instance.
187 class AsmCodeRegionGenerator : public virtual CodeRegionGenerator {
188 const Target &TheTarget;
189 const MCAsmInfo &MAI;
190 const MCSubtargetInfo &STI;
191 const MCInstrInfo &MCII;
192 unsigned AssemblerDialect; // This is set during parsing.
194 protected:
195 MCContext &Ctx;
197 public:
198 AsmCodeRegionGenerator(const Target &T, MCContext &C, const MCAsmInfo &A,
199 const MCSubtargetInfo &S, const MCInstrInfo &I)
200 : TheTarget(T), MAI(A), STI(S), MCII(I), AssemblerDialect(0), Ctx(C) {}
202 virtual MCACommentConsumer *getCommentConsumer() = 0;
203 virtual CodeRegions &getRegions() = 0;
204 virtual MCStreamerWrapper *getMCStreamer() = 0;
206 unsigned getAssemblerDialect() const { return AssemblerDialect; }
207 Expected<const CodeRegions &>
208 parseCodeRegions(const std::unique_ptr<MCInstPrinter> &IP,
209 bool SkipFailures) override;
212 class AsmAnalysisRegionGenerator final : public AnalysisRegionGenerator,
213 public AsmCodeRegionGenerator {
214 AnalysisRegionCommentConsumer CC;
215 MCStreamerWrapper Streamer;
217 public:
218 AsmAnalysisRegionGenerator(const Target &T, llvm::SourceMgr &SM, MCContext &C,
219 const MCAsmInfo &A, const MCSubtargetInfo &S,
220 const MCInstrInfo &I)
221 : AnalysisRegionGenerator(SM), AsmCodeRegionGenerator(T, C, A, S, I),
222 CC(Regions), Streamer(Ctx, Regions) {}
224 MCACommentConsumer *getCommentConsumer() override { return &CC; };
225 CodeRegions &getRegions() override { return Regions; };
226 MCStreamerWrapper *getMCStreamer() override { return &Streamer; }
228 Expected<const AnalysisRegions &>
229 parseAnalysisRegions(const std::unique_ptr<MCInstPrinter> &IP,
230 bool SkipFailures) override {
231 Expected<const CodeRegions &> RegionsOrErr =
232 parseCodeRegions(IP, SkipFailures);
233 if (!RegionsOrErr)
234 return RegionsOrErr.takeError();
235 else
236 return static_cast<const AnalysisRegions &>(*RegionsOrErr);
239 Expected<const CodeRegions &>
240 parseCodeRegions(const std::unique_ptr<MCInstPrinter> &IP,
241 bool SkipFailures) override {
242 return AsmCodeRegionGenerator::parseCodeRegions(IP, SkipFailures);
246 class AsmInstrumentRegionGenerator final : public InstrumentRegionGenerator,
247 public AsmCodeRegionGenerator {
248 InstrumentRegionCommentConsumer CC;
249 InstrumentMCStreamer Streamer;
251 public:
252 AsmInstrumentRegionGenerator(const Target &T, llvm::SourceMgr &SM,
253 MCContext &C, const MCAsmInfo &A,
254 const MCSubtargetInfo &S, const MCInstrInfo &I,
255 InstrumentManager &IM)
256 : InstrumentRegionGenerator(SM), AsmCodeRegionGenerator(T, C, A, S, I),
257 CC(SM, Regions, IM), Streamer(Ctx, Regions, IM) {}
259 MCACommentConsumer *getCommentConsumer() override { return &CC; };
260 CodeRegions &getRegions() override { return Regions; };
261 MCStreamerWrapper *getMCStreamer() override { return &Streamer; }
263 Expected<const InstrumentRegions &>
264 parseInstrumentRegions(const std::unique_ptr<MCInstPrinter> &IP,
265 bool SkipFailures) override {
266 Expected<const CodeRegions &> RegionsOrErr =
267 parseCodeRegions(IP, SkipFailures);
268 if (!RegionsOrErr)
269 return RegionsOrErr.takeError();
270 else
271 return static_cast<const InstrumentRegions &>(*RegionsOrErr);
274 Expected<const CodeRegions &>
275 parseCodeRegions(const std::unique_ptr<MCInstPrinter> &IP,
276 bool SkipFailures) override {
277 return AsmCodeRegionGenerator::parseCodeRegions(IP, SkipFailures);
281 } // namespace mca
282 } // namespace llvm
284 #endif // LLVM_TOOLS_LLVM_MCA_CODEREGION_GENERATOR_H