[mlir][ArithToAMDGPU] Add option for saturating truncation to fp8 (#74153)
commit750e90e4403df23d6b271afb90e6b4d463739965
authorKrzysztof Drewniak <Krzysztof.Drewniak@amd.com>
Tue, 23 Jan 2024 22:52:21 +0000 (23 16:52 -0600)
committerGitHub <noreply@github.com>
Tue, 23 Jan 2024 22:52:21 +0000 (23 16:52 -0600)
tree9c89cd12c1c4a803eb02b08c747b7e73cad41687
parent575568de4166bf69e0a5bc68978580afbe936878
[mlir][ArithToAMDGPU] Add option for saturating truncation to fp8 (#74153)

Many machine-learning applications (and most software written at AMD)
expect the operation that truncates floats to 8-bit floats to be
saturatinng. That is, they expect `truncf 256.0 : f32 to f8E4M3FNUZ` to
yield `240.0`, not `NaN`, and similarly for negative numbers. However,
the underlying hardware instruction that can be used for this truncation
implements overflow-to-NaN semantics.

To enable handling this usecase, we add the saturate-fp8-truncf option
to ArithToAMDGPU (off by default), which causes the requisite clamping
code to be emitted. Said clamping code ensures that Inf and NaN are
passed through exactly (and thus trancate to NaN).

Per review feedback, this commit efactors
createScalarOrSplatConstant() to the Arith dialect utilities and uses
it in this code. It also fixes naming of existing patterns and
switches from vector.extractelement/insertelement to
vector.extract/insert.
mlir/include/mlir/Conversion/ArithToAMDGPU/ArithToAMDGPU.h
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Dialect/Arith/Utils/Utils.h
mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
mlir/lib/Conversion/ArithToAMDGPU/CMakeLists.txt
mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
mlir/lib/Dialect/Arith/Utils/Utils.cpp
mlir/test/Conversion/ArithToAMDGPU/8-bit-float-saturation.mlir [new file with mode: 0644]
mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir