1 // Copyright (c) 2020 André Perez Maselco
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
15 #include "source/fuzz/transformation_adjust_branch_weights.h"
17 #include "source/fuzz/fuzzer_util.h"
18 #include "source/fuzz/instruction_descriptor.h"
25 const uint32_t kBranchWeightForTrueLabelIndex
= 3;
26 const uint32_t kBranchWeightForFalseLabelIndex
= 4;
30 TransformationAdjustBranchWeights::TransformationAdjustBranchWeights(
31 protobufs::TransformationAdjustBranchWeights message
)
32 : message_(std::move(message
)) {}
34 TransformationAdjustBranchWeights::TransformationAdjustBranchWeights(
35 const protobufs::InstructionDescriptor
& instruction_descriptor
,
36 const std::pair
<uint32_t, uint32_t>& branch_weights
) {
37 *message_
.mutable_instruction_descriptor() = instruction_descriptor
;
38 message_
.mutable_branch_weights()->set_first(branch_weights
.first
);
39 message_
.mutable_branch_weights()->set_second(branch_weights
.second
);
42 bool TransformationAdjustBranchWeights::IsApplicable(
43 opt::IRContext
* ir_context
, const TransformationContext
& /*unused*/) const {
45 FindInstruction(message_
.instruction_descriptor(), ir_context
);
46 if (instruction
== nullptr) {
50 spv::Op opcode
= static_cast<spv::Op
>(
51 message_
.instruction_descriptor().target_instruction_opcode());
53 assert(instruction
->opcode() == opcode
&&
54 "The located instruction must have the same opcode as in the "
57 // Must be an OpBranchConditional instruction.
58 if (opcode
!= spv::Op::OpBranchConditional
) {
62 assert((message_
.branch_weights().first() != 0 ||
63 message_
.branch_weights().second() != 0) &&
64 "At least one weight must be non-zero.");
66 assert(message_
.branch_weights().first() <=
67 UINT32_MAX
- message_
.branch_weights().second() &&
68 "The sum of the two weights must not be greater than UINT32_MAX.");
73 void TransformationAdjustBranchWeights::Apply(
74 opt::IRContext
* ir_context
, TransformationContext
* /*unused*/) const {
76 FindInstruction(message_
.instruction_descriptor(), ir_context
);
77 if (instruction
->HasBranchWeights()) {
78 instruction
->SetOperand(kBranchWeightForTrueLabelIndex
,
79 {message_
.branch_weights().first()});
80 instruction
->SetOperand(kBranchWeightForFalseLabelIndex
,
81 {message_
.branch_weights().second()});
83 instruction
->AddOperand({SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER
,
84 {message_
.branch_weights().first()}});
85 instruction
->AddOperand({SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER
,
86 {message_
.branch_weights().second()}});
90 protobufs::Transformation
TransformationAdjustBranchWeights::ToMessage() const {
91 protobufs::Transformation result
;
92 *result
.mutable_adjust_branch_weights() = message_
;
96 std::unordered_set
<uint32_t> TransformationAdjustBranchWeights::GetFreshIds()
98 return std::unordered_set
<uint32_t>();
102 } // namespace spvtools