1 //===- ConvertShapeConstraints.cpp - Conversion of shape constraints ------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
11 #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
12 #include "mlir/Dialect/SCF/IR/SCF.h"
13 #include "mlir/Dialect/Shape/IR/Shape.h"
14 #include "mlir/Dialect/Tensor/IR/Tensor.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Pass/PassRegistry.h"
18 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 #define GEN_PASS_DEF_CONVERTSHAPECONSTRAINTS
22 #include "mlir/Conversion/Passes.h.inc"
28 #include "ShapeToStandard.cpp.inc"
32 class ConvertCstrRequireOp
: public OpRewritePattern
<shape::CstrRequireOp
> {
34 using OpRewritePattern::OpRewritePattern
;
35 LogicalResult
matchAndRewrite(shape::CstrRequireOp op
,
36 PatternRewriter
&rewriter
) const override
{
37 rewriter
.create
<cf::AssertOp
>(op
.getLoc(), op
.getPred(), op
.getMsgAttr());
38 rewriter
.replaceOpWithNewOp
<shape::ConstWitnessOp
>(op
, true);
44 void mlir::populateConvertShapeConstraintsConversionPatterns(
45 RewritePatternSet
&patterns
) {
46 patterns
.add
<CstrBroadcastableToRequire
>(patterns
.getContext());
47 patterns
.add
<CstrEqToRequire
>(patterns
.getContext());
48 patterns
.add
<ConvertCstrRequireOp
>(patterns
.getContext());
52 // This pass eliminates shape constraints from the program, converting them to
53 // eager (side-effecting) error handling code. After eager error handling code
54 // is emitted, witnesses are satisfied, so they are replace with
55 // `shape.const_witness true`.
56 class ConvertShapeConstraints
57 : public impl::ConvertShapeConstraintsBase
<ConvertShapeConstraints
> {
58 void runOnOperation() override
{
59 auto *func
= getOperation();
60 auto *context
= &getContext();
62 RewritePatternSet
patterns(context
);
63 populateConvertShapeConstraintsConversionPatterns(patterns
);
65 if (failed(applyPatternsAndFoldGreedily(func
, std::move(patterns
))))
66 return signalPassFailure();
71 std::unique_ptr
<Pass
> mlir::createConvertShapeConstraintsPass() {
72 return std::make_unique
<ConvertShapeConstraints
>();