1 //===- OpStats.cpp - Prints stats of operations in module -----------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Transforms/Passes.h"
11 #include "mlir/IR/BuiltinOps.h"
12 #include "mlir/IR/Operation.h"
13 #include "mlir/IR/OperationSupport.h"
14 #include "llvm/ADT/DenseMap.h"
15 #include "llvm/Support/Format.h"
16 #include "llvm/Support/raw_ostream.h"
19 #define GEN_PASS_DEF_PRINTOPSTATS
20 #include "mlir/Transforms/Passes.h.inc"
26 struct PrintOpStatsPass
: public impl::PrintOpStatsBase
<PrintOpStatsPass
> {
27 explicit PrintOpStatsPass(raw_ostream
&os
) : os(os
) {}
29 explicit PrintOpStatsPass(raw_ostream
&os
, bool printAsJSON
) : os(os
) {
30 this->printAsJSON
= printAsJSON
;
33 // Prints the resultant operation statistics post iterating over the module.
34 void runOnOperation() override
;
36 // Print summary of op stats.
39 // Print symmary of op stats in JSON.
40 void printSummaryInJSON();
43 llvm::StringMap
<int64_t> opCount
;
48 void PrintOpStatsPass::runOnOperation() {
51 // Compute the operation statistics for the currently visited operation.
53 [&](Operation
*op
) { ++opCount
[op
->getName().getStringRef()]; });
58 markAllAnalysesPreserved();
61 void PrintOpStatsPass::printSummary() {
62 os
<< "Operations encountered:\n";
63 os
<< "-----------------------\n";
64 SmallVector
<StringRef
, 64> sorted(opCount
.keys());
67 // Split an operation name from its dialect prefix.
68 auto splitOperationName
= [](StringRef opName
) {
69 auto splitName
= opName
.split('.');
70 return splitName
.second
.empty() ? std::make_pair("", splitName
.first
)
74 // Compute the largest dialect and operation name.
75 size_t maxLenOpName
= 0, maxLenDialect
= 0;
76 for (const auto &key
: sorted
) {
77 auto [dialectName
, opName
] = splitOperationName(key
);
78 maxLenDialect
= std::max(maxLenDialect
, dialectName
.size());
79 maxLenOpName
= std::max(maxLenOpName
, opName
.size());
82 for (const auto &key
: sorted
) {
83 auto [dialectName
, opName
] = splitOperationName(key
);
85 // Left-align the names (aligning on the dialect) and right-align the count
86 // below. The alignment is for readability and does not affect CSV/FileCheck
88 if (dialectName
.empty())
89 os
.indent(maxLenDialect
+ 3);
91 os
<< llvm::right_justify(dialectName
, maxLenDialect
+ 2) << '.';
93 // Left justify the operation name.
94 os
<< llvm::left_justify(opName
, maxLenOpName
) << " , " << opCount
[key
]
99 void PrintOpStatsPass::printSummaryInJSON() {
100 SmallVector
<StringRef
, 64> sorted(opCount
.keys());
105 for (unsigned i
= 0, e
= sorted
.size(); i
!= e
; ++i
) {
106 const auto &key
= sorted
[i
];
107 os
<< " \"" << key
<< "\" : " << opCount
[key
];
116 std::unique_ptr
<Pass
> mlir::createPrintOpStatsPass(raw_ostream
&os
) {
117 return std::make_unique
<PrintOpStatsPass
>(os
);
120 std::unique_ptr
<Pass
> mlir::createPrintOpStatsPass(raw_ostream
&os
,
122 return std::make_unique
<PrintOpStatsPass
>(os
, printAsJSON
);