1 //===- TestShapeFunctions.cpp - Passes to test shape function ------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/Dialect/Shape/IR/Shape.h"
13 #include "mlir/IR/BuiltinDialect.h"
14 #include "mlir/Interfaces/InferTypeOpInterface.h"
15 #include "mlir/Pass/Pass.h"
20 /// This is a pass that reports shape functions associated with ops.
21 struct ReportShapeFnPass
22 : public PassWrapper
<ReportShapeFnPass
, OperationPass
<ModuleOp
>> {
23 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReportShapeFnPass
)
25 void runOnOperation() override
;
26 StringRef
getArgument() const final
{ return "test-shape-function-report"; }
27 StringRef
getDescription() const final
{
28 return "Test pass to report associated shape functions";
33 void ReportShapeFnPass::runOnOperation() {
34 auto module
= getOperation();
36 // Report the shape function available to refine the op.
37 auto shapeFnId
= StringAttr::get(&getContext(), "shape.function");
38 auto remarkShapeFn
= [&](shape::FunctionLibraryOp shapeFnLib
, Operation
*op
) {
39 if (op
->hasTrait
<OpTrait::IsTerminator
>())
41 if (auto typeInterface
= dyn_cast
<InferTypeOpInterface
>(op
)) {
42 op
->emitRemark() << "implements InferType op interface";
45 if (auto fn
= shapeFnLib
.getShapeFunction(op
)) {
46 op
->emitRemark() << "associated shape function: " << fn
.getName();
49 if (auto symbol
= op
->getAttrOfType
<SymbolRefAttr
>(shapeFnId
)) {
51 cast
<shape::FuncOp
>(SymbolTable::lookupSymbolIn(module
, symbol
));
52 op
->emitRemark() << "associated shape function: " << fn
.getName();
58 // Lookup shape function library.
59 SmallVector
<shape::FunctionLibraryOp
, 4> libraries
;
60 auto attr
= module
->getDiscardableAttr("shape.lib");
62 auto lookup
= [&](Attribute attr
) {
63 return cast
<shape::FunctionLibraryOp
>(
64 SymbolTable::lookupSymbolIn(module
, cast
<SymbolRefAttr
>(attr
)));
66 if (auto arrayAttr
= dyn_cast
<ArrayAttr
>(attr
)) {
67 libraries
.reserve(arrayAttr
.size());
68 for (auto attr
: arrayAttr
)
69 libraries
.push_back(lookup(attr
));
72 libraries
.push_back(lookup(attr
));
76 module
.getBodyRegion().walk([&](func::FuncOp func
) {
77 // Skip ops in the shape function library.
78 if (isa
<shape::FunctionLibraryOp
>(func
->getParentOp()))
81 func
.walk([&](Operation
*op
) {
82 bool found
= llvm::any_of(libraries
, [&](shape::FunctionLibraryOp lib
) {
83 return remarkShapeFn(lib
, op
);
86 op
->emitRemark() << "no associated way to refine shape";
92 void registerShapeFunctionTestPasses() {
93 PassRegistration
<ReportShapeFnPass
>();