1 //===- Pattern.h - SPIRV Common Conversion Patterns -----------------------===//
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //===----------------------------------------------------------------------===//
9 #ifndef MLIR_CONVERSION_SPIRVCOMMON_PATTERN_H
10 #define MLIR_CONVERSION_SPIRVCOMMON_PATTERN_H
12 #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
13 #include "mlir/IR/TypeUtilities.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 #include "llvm/Support/FormatVariadic.h"
20 /// Converts elementwise unary, binary and ternary standard operations to SPIR-V
22 template <typename Op
, typename SPIRVOp
>
23 struct ElementwiseOpPattern
: public OpConversionPattern
<Op
> {
24 using OpConversionPattern
<Op
>::OpConversionPattern
;
27 matchAndRewrite(Op op
, typename
Op::Adaptor adaptor
,
28 ConversionPatternRewriter
&rewriter
) const override
{
29 assert(adaptor
.getOperands().size() <= 3);
30 Type dstType
= this->getTypeConverter()->convertType(op
.getType());
32 return rewriter
.notifyMatchFailure(
34 llvm::formatv("failed to convert type {0} for SPIR-V", op
.getType()));
37 if (SPIRVOp::template hasTrait
<OpTrait::spirv::UnsignedOp
>() &&
38 !getElementTypeOrSelf(op
.getType()).isIndex() &&
39 dstType
!= op
.getType()) {
41 return op
.emitError("bitwidth emulation is not implemented yet on "
42 "unsigned op pattern version");
44 rewriter
.template replaceOpWithNewOp
<SPIRVOp
>(op
, dstType
,
45 adaptor
.getOperands());
53 #endif // MLIR_CONVERSION_SPIRVCOMMON_PATTERN_H