1 //===- OpGenHelpers.cpp - MLIR operation generator helpers ----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file defines helpers used in the op generators.
11 //===----------------------------------------------------------------------===//
13 #include "OpGenHelpers.h"
14 #include "llvm/ADT/StringSet.h"
15 #include "llvm/Support/CommandLine.h"
16 #include "llvm/Support/FormatVariadic.h"
17 #include "llvm/Support/Regex.h"
18 #include "llvm/TableGen/Error.h"
22 using namespace mlir::tblgen
;
24 cl::OptionCategory
opDefGenCat("Options for op definition generators");
26 static cl::opt
<std::string
> opIncFilter(
28 cl::desc("Regex of name of op's to include (no filter if empty)"),
29 cl::cat(opDefGenCat
));
30 static cl::opt
<std::string
> opExcFilter(
32 cl::desc("Regex of name of op's to exclude (no filter if empty)"),
33 cl::cat(opDefGenCat
));
34 static cl::opt
<unsigned> opShardCount(
36 cl::desc("The number of shards into which the op classes will be divided"),
37 cl::cat(opDefGenCat
), cl::init(1));
39 static std::string
getOperationName(const Record
&def
) {
40 auto prefix
= def
.getValueAsDef("opDialect")->getValueAsString("name");
41 auto opName
= def
.getValueAsString("opName");
43 return std::string(opName
);
44 return std::string(llvm::formatv("{0}.{1}", prefix
, opName
));
48 mlir::tblgen::getRequestedOpDefinitions(const RecordKeeper
&recordKeeper
) {
49 Record
*classDef
= recordKeeper
.getClass("Op");
51 PrintFatalError("ERROR: Couldn't find the 'Op' class!\n");
53 llvm::Regex
includeRegex(opIncFilter
), excludeRegex(opExcFilter
);
54 std::vector
<Record
*> defs
;
55 for (const auto &def
: recordKeeper
.getDefs()) {
56 if (!def
.second
->isSubClassOf(classDef
))
58 // Include if no include filter or include filter matches.
59 if (!opIncFilter
.empty() &&
60 !includeRegex
.match(getOperationName(*def
.second
)))
62 // Unless there is an exclude filter and it matches.
63 if (!opExcFilter
.empty() &&
64 excludeRegex
.match(getOperationName(*def
.second
)))
66 defs
.push_back(def
.second
.get());
72 bool mlir::tblgen::isPythonReserved(StringRef str
) {
73 static llvm::StringSet
<> reserved({
74 "False", "None", "True", "and", "as", "assert", "async",
75 "await", "break", "class", "continue", "def", "del", "elif",
76 "else", "except", "finally", "for", "from", "global", "if",
77 "import", "in", "is", "lambda", "nonlocal", "not", "or",
78 "pass", "raise", "return", "try", "while", "with", "yield",
80 // These aren't Python keywords but builtin functions that shouldn't/can't be
82 reserved
.insert("callable");
83 reserved
.insert("issubclass");
84 reserved
.insert("type");
85 return reserved
.contains(str
);
88 void mlir::tblgen::shardOpDefinitions(
89 ArrayRef
<llvm::Record
*> defs
,
90 SmallVectorImpl
<ArrayRef
<llvm::Record
*>> &shardedDefs
) {
91 assert(opShardCount
> 0 && "expected a positive shard count");
92 if (opShardCount
== 1) {
93 shardedDefs
.push_back(defs
);
97 unsigned minShardSize
= defs
.size() / opShardCount
;
98 unsigned numMissing
= defs
.size() - minShardSize
* opShardCount
;
99 shardedDefs
.reserve(opShardCount
);
100 for (unsigned i
= 0, start
= 0; i
< opShardCount
; ++i
) {
101 unsigned size
= minShardSize
+ (i
< numMissing
);
102 shardedDefs
.push_back(defs
.slice(start
, size
));