1 //===- TestWideIntEmulation.cpp - Test Wide Int Emulation ------*- c++ -*-===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 // This file implements a pass for integration testing of wide integer
10 // emulation patterns. Applies conversion patterns only to functions whose
11 // names start with a specified prefix.
13 //===----------------------------------------------------------------------===//
15 #include "mlir/Dialect/Arith/IR/Arith.h"
16 #include "mlir/Dialect/Arith/Transforms/Passes.h"
17 #include "mlir/Dialect/Arith/Transforms/WideIntEmulationConverter.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/Transforms/DialectConversion.h"
27 struct TestEmulateWideIntPass
28 : public PassWrapper
<TestEmulateWideIntPass
, OperationPass
<func::FuncOp
>> {
29 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEmulateWideIntPass
)
31 TestEmulateWideIntPass() = default;
32 TestEmulateWideIntPass(const TestEmulateWideIntPass
&pass
)
33 : PassWrapper(pass
) {}
35 void getDependentDialects(DialectRegistry
®istry
) const override
{
36 registry
.insert
<arith::ArithDialect
, func::FuncDialect
, LLVM::LLVMDialect
,
37 vector::VectorDialect
>();
39 StringRef
getArgument() const final
{ return "test-arith-emulate-wide-int"; }
40 StringRef
getDescription() const final
{
41 return "Function pass to test Wide Integer Emulation";
44 void runOnOperation() override
{
45 if (!llvm::isPowerOf2_32(widestIntSupported
) || widestIntSupported
< 2) {
50 func::FuncOp op
= getOperation();
51 if (!op
.getSymName().starts_with(testFunctionPrefix
))
54 MLIRContext
*ctx
= op
.getContext();
55 arith::WideIntEmulationConverter
typeConverter(widestIntSupported
);
57 // Use `llvm.bitcast` as the bridge so that we can use preserve the
58 // function argument and return types of the processed function.
59 // TODO: Consider extending `arith.bitcast` to support scalar-to-1D-vector
60 // casts (and vice versa) and using it insted of `llvm.bitcast`.
61 auto addBitcast
= [](OpBuilder
&builder
, Type type
, ValueRange inputs
,
62 Location loc
) -> Value
{
63 auto cast
= builder
.create
<LLVM::BitcastOp
>(loc
, type
, inputs
);
64 return cast
->getResult(0);
66 typeConverter
.addSourceMaterialization(addBitcast
);
67 typeConverter
.addTargetMaterialization(addBitcast
);
69 ConversionTarget
target(*ctx
);
71 .addDynamicallyLegalDialect
<arith::ArithDialect
, vector::VectorDialect
>(
72 [&typeConverter
](Operation
*op
) {
73 return typeConverter
.isLegal(op
);
76 RewritePatternSet
patterns(ctx
);
77 arith::populateArithWideIntEmulationPatterns(typeConverter
, patterns
);
78 if (failed(applyPartialConversion(op
, target
, std::move(patterns
))))
82 Option
<std::string
> testFunctionPrefix
{
83 *this, "function-prefix",
84 llvm::cl::desc("Prefix of functions to run the emulation pass on"),
85 llvm::cl::init("emulate_")};
86 Option
<unsigned> widestIntSupported
{
87 *this, "widest-int-supported",
88 llvm::cl::desc("Maximum integer bit width supported by the target"),
93 namespace mlir::test
{
94 void registerTestArithEmulateWideIntPass() {
95 PassRegistration
<TestEmulateWideIntPass
>();
97 } // namespace mlir::test