1 //===- ArmNeon2dToIntr.cpp - convert Arm Neon 2d ops to intrinsics --------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.h"
11 #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
12 #include "mlir/Dialect/Vector/IR/VectorOps.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Pass/PassRegistry.h"
16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19 #define GEN_PASS_DEF_CONVERTARMNEON2DTOINTR
20 #include "mlir/Conversion/Passes.h.inc"
24 using namespace mlir::arm_neon
;
28 class Sdot2dLoweringPattern
: public OpRewritePattern
<Sdot2dOp
> {
30 using OpRewritePattern::OpRewritePattern
;
32 /// Convert to 1-dimensional vector type to match the requirements of
33 /// arm.neon.intr.sdot
34 LogicalResult
matchAndRewrite(Sdot2dOp op
,
35 PatternRewriter
&rewriter
) const override
{
36 Type elemType
= cast
<VectorType
>(op
.getB().getType()).getElementType();
37 int length
= cast
<VectorType
>(op
.getB().getType()).getShape()[0] *
38 Sdot2dOp::kReductionSize
;
39 VectorType flattenedVectorType
= VectorType::get({length
}, elemType
);
40 Value b2d
= op
.getB();
41 Value c2d
= op
.getC();
42 Location loc
= op
.getLoc();
44 rewriter
.create
<vector::ShapeCastOp
>(loc
, flattenedVectorType
, b2d
);
46 rewriter
.create
<vector::ShapeCastOp
>(loc
, flattenedVectorType
, c2d
);
47 Value newOp
= rewriter
.create
<SdotOp
>(loc
, op
.getRes().getType(), op
.getA(),
49 rewriter
.replaceOp(op
, {newOp
});
54 class ConvertArmNeon2dToIntr
55 : public impl::ConvertArmNeon2dToIntrBase
<ConvertArmNeon2dToIntr
> {
56 void runOnOperation() override
{
57 auto *context
= &getContext();
59 RewritePatternSet
patterns(context
);
60 populateConvertArmNeon2dToIntrPatterns(patterns
);
63 applyPatternsAndFoldGreedily(getOperation(), std::move(patterns
))))
64 return signalPassFailure();
70 void mlir::populateConvertArmNeon2dToIntrPatterns(RewritePatternSet
&patterns
) {
71 patterns
.add
<Sdot2dLoweringPattern
>(patterns
.getContext());
74 std::unique_ptr
<Pass
> mlir::createConvertArmNeon2dToIntrPass() {
75 return std::make_unique
<ConvertArmNeon2dToIntr
>();