Fix GCC build problem with 288f05f related to SmallVector. (#116958)
[llvm-project.git] / mlir / test / lib / Dialect / Arith / TestEmulateWideInt.cpp
blob738d4ee59cbdeae2927197330bdaa5632163e198
1 //===- TestWideIntEmulation.cpp - Test Wide Int Emulation ------*- c++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass for integration testing of wide integer
10 // emulation patterns. Applies conversion patterns only to functions whose
11 // names start with a specified prefix.
13 //===----------------------------------------------------------------------===//
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Arith/Transforms/Passes.h"
17 #include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/DialectConversion.h"
24 using namespace mlir;
26 namespace {
27 struct TestEmulateWideIntPass
28 : public PassWrapper<TestEmulateWideIntPass, OperationPass<func::FuncOp>> {
29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateWideIntPass)
31 TestEmulateWideIntPass() = default;
32 TestEmulateWideIntPass(const TestEmulateWideIntPass &pass)
33 : PassWrapper(pass) {}
35 void getDependentDialects(DialectRegistry &registry) const override {
36 registry.insert<arith::ArithDialect, func::FuncDialect, LLVM::LLVMDialect,
37 vector::VectorDialect>();
39 StringRef getArgument() const final { return "test-arith-emulate-wide-int"; }
40 StringRef getDescription() const final {
41 return "Function pass to test Wide Integer Emulation";
44 void runOnOperation() override {
45 if (!llvm::isPowerOf2_32(widestIntSupported) || widestIntSupported < 2) {
46 signalPassFailure();
47 return;
50 func::FuncOp op = getOperation();
51 if (!op.getSymName().starts_with(testFunctionPrefix))
52 return;
54 MLIRContext *ctx = op.getContext();
55 arith::WideIntEmulationConverter typeConverter(widestIntSupported);
57 // Use `llvm.bitcast` as the bridge so that we can use preserve the
58 // function argument and return types of the processed function.
59 // TODO: Consider extending `arith.bitcast` to support scalar-to-1D-vector
60 // casts (and vice versa) and using it insted of `llvm.bitcast`.
61 auto addBitcast = [](OpBuilder &builder, Type type, ValueRange inputs,
62 Location loc) -> Value {
63 auto cast = builder.create<LLVM::BitcastOp>(loc, type, inputs);
64 return cast->getResult(0);
66 typeConverter.addSourceMaterialization(addBitcast);
67 typeConverter.addTargetMaterialization(addBitcast);
69 ConversionTarget target(*ctx);
70 target
71 .addDynamicallyLegalDialect<arith::ArithDialect, vector::VectorDialect>(
72 [&typeConverter](Operation *op) {
73 return typeConverter.isLegal(op);
74 });
76 RewritePatternSet patterns(ctx);
77 arith::populateArithWideIntEmulationPatterns(typeConverter, patterns);
78 if (failed(applyPartialConversion(op, target, std::move(patterns))))
79 signalPassFailure();
82 Option<std::string> testFunctionPrefix{
83 *this, "function-prefix",
84 llvm::cl::desc("Prefix of functions to run the emulation pass on"),
85 llvm::cl::init("emulate_")};
86 Option<unsigned> widestIntSupported{
87 *this, "widest-int-supported",
88 llvm::cl::desc("Maximum integer bit width supported by the target"),
89 llvm::cl::init(32)};
91 } // namespace
93 namespace mlir::test {
94 void registerTestArithEmulateWideIntPass() {
95 PassRegistration<TestEmulateWideIntPass>();
97 } // namespace mlir::test