Roll external/abseil_cpp/ 8f739d18b..917bfee46 (2 commits) (#5887)
[KhronosGroup/SPIRV-Tools.git] / source / fuzz / transformation_add_function.cpp
blob1f61ede725dacab5544b94fb0733730f09b7298f
1 // Copyright (c) 2019 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
15 #include "source/fuzz/transformation_add_function.h"
17 #include "source/fuzz/fuzzer_util.h"
18 #include "source/fuzz/instruction_message.h"
20 namespace spvtools {
21 namespace fuzz {
23 TransformationAddFunction::TransformationAddFunction(
24 protobufs::TransformationAddFunction message)
25 : message_(std::move(message)) {}
27 TransformationAddFunction::TransformationAddFunction(
28 const std::vector<protobufs::Instruction>& instructions) {
29 for (auto& instruction : instructions) {
30 *message_.add_instruction() = instruction;
32 message_.set_is_livesafe(false);
35 TransformationAddFunction::TransformationAddFunction(
36 const std::vector<protobufs::Instruction>& instructions,
37 uint32_t loop_limiter_variable_id, uint32_t loop_limit_constant_id,
38 const std::vector<protobufs::LoopLimiterInfo>& loop_limiters,
39 uint32_t kill_unreachable_return_value_id,
40 const std::vector<protobufs::AccessChainClampingInfo>&
41 access_chain_clampers) {
42 for (auto& instruction : instructions) {
43 *message_.add_instruction() = instruction;
45 message_.set_is_livesafe(true);
46 message_.set_loop_limiter_variable_id(loop_limiter_variable_id);
47 message_.set_loop_limit_constant_id(loop_limit_constant_id);
48 for (auto& loop_limiter : loop_limiters) {
49 *message_.add_loop_limiter_info() = loop_limiter;
51 message_.set_kill_unreachable_return_value_id(
52 kill_unreachable_return_value_id);
53 for (auto& access_clamper : access_chain_clampers) {
54 *message_.add_access_chain_clamping_info() = access_clamper;
58 bool TransformationAddFunction::IsApplicable(
59 opt::IRContext* ir_context,
60 const TransformationContext& transformation_context) const {
61 // This transformation may use a lot of ids, all of which need to be fresh
62 // and distinct. This set tracks them.
63 std::set<uint32_t> ids_used_by_this_transformation;
65 // Ensure that all result ids in the new function are fresh and distinct.
66 for (auto& instruction : message_.instruction()) {
67 if (instruction.result_id()) {
68 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
69 instruction.result_id(), ir_context,
70 &ids_used_by_this_transformation)) {
71 return false;
76 if (message_.is_livesafe()) {
77 // Ensure that all ids provided for making the function livesafe are fresh
78 // and distinct.
79 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
80 message_.loop_limiter_variable_id(), ir_context,
81 &ids_used_by_this_transformation)) {
82 return false;
84 for (auto& loop_limiter_info : message_.loop_limiter_info()) {
85 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
86 loop_limiter_info.load_id(), ir_context,
87 &ids_used_by_this_transformation)) {
88 return false;
90 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
91 loop_limiter_info.increment_id(), ir_context,
92 &ids_used_by_this_transformation)) {
93 return false;
95 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
96 loop_limiter_info.compare_id(), ir_context,
97 &ids_used_by_this_transformation)) {
98 return false;
100 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
101 loop_limiter_info.logical_op_id(), ir_context,
102 &ids_used_by_this_transformation)) {
103 return false;
106 for (auto& access_chain_clamping_info :
107 message_.access_chain_clamping_info()) {
108 for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) {
109 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
110 pair.first(), ir_context, &ids_used_by_this_transformation)) {
111 return false;
113 if (!CheckIdIsFreshAndNotUsedByThisTransformation(
114 pair.second(), ir_context, &ids_used_by_this_transformation)) {
115 return false;
121 // Because checking all the conditions for a function to be valid is a big
122 // job that the SPIR-V validator can already do, a "try it and see" approach
123 // is taken here.
125 // We first clone the current module, so that we can try adding the new
126 // function without risking wrecking |ir_context|.
127 auto cloned_module = fuzzerutil::CloneIRContext(ir_context);
129 // We try to add a function to the cloned module, which may fail if
130 // |message_.instruction| is not sufficiently well-formed.
131 if (!TryToAddFunction(cloned_module.get())) {
132 return false;
135 // Check whether the cloned module is still valid after adding the function.
136 // If it is not, the transformation is not applicable.
137 if (!fuzzerutil::IsValid(cloned_module.get(),
138 transformation_context.GetValidatorOptions(),
139 fuzzerutil::kSilentMessageConsumer)) {
140 return false;
143 if (message_.is_livesafe()) {
144 if (!TryToMakeFunctionLivesafe(cloned_module.get(),
145 transformation_context)) {
146 return false;
148 // After making the function livesafe, we check validity of the module
149 // again. This is because the turning of OpKill, OpUnreachable and OpReturn
150 // instructions into branches changes control flow graph reachability, which
151 // has the potential to make the module invalid when it was otherwise valid.
152 // It is simpler to rely on the validator to guard against this than to
153 // consider all scenarios when making a function livesafe.
154 if (!fuzzerutil::IsValid(cloned_module.get(),
155 transformation_context.GetValidatorOptions(),
156 fuzzerutil::kSilentMessageConsumer)) {
157 return false;
160 return true;
163 void TransformationAddFunction::Apply(
164 opt::IRContext* ir_context,
165 TransformationContext* transformation_context) const {
166 // Add the function to the module. As the transformation is applicable, this
167 // should succeed.
168 bool success = TryToAddFunction(ir_context);
169 assert(success && "The function should be successfully added.");
170 (void)(success); // Keep release builds happy (otherwise they may complain
171 // that |success| is not used).
173 if (message_.is_livesafe()) {
174 // Make the function livesafe, which also should succeed.
175 success = TryToMakeFunctionLivesafe(ir_context, *transformation_context);
176 assert(success && "It should be possible to make the function livesafe.");
177 (void)(success); // Keep release builds happy.
179 ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
181 assert(spv::Op(message_.instruction(0).opcode()) == spv::Op::OpFunction &&
182 "The first instruction of an 'add function' transformation must be "
183 "OpFunction.");
185 if (message_.is_livesafe()) {
186 // Inform the fact manager that the function is livesafe.
187 transformation_context->GetFactManager()->AddFactFunctionIsLivesafe(
188 message_.instruction(0).result_id());
189 } else {
190 // Inform the fact manager that all blocks in the function are dead.
191 for (auto& inst : message_.instruction()) {
192 if (spv::Op(inst.opcode()) == spv::Op::OpLabel) {
193 transformation_context->GetFactManager()->AddFactBlockIsDead(
194 inst.result_id());
199 // Record the fact that all pointer parameters and variables declared in the
200 // function should be regarded as having irrelevant values. This allows other
201 // passes to store arbitrarily to such variables, and to pass them freely as
202 // parameters to other functions knowing that it is OK if they get
203 // over-written.
204 for (auto& instruction : message_.instruction()) {
205 switch (spv::Op(instruction.opcode())) {
206 case spv::Op::OpFunctionParameter:
207 if (ir_context->get_def_use_mgr()
208 ->GetDef(instruction.result_type_id())
209 ->opcode() == spv::Op::OpTypePointer) {
210 transformation_context->GetFactManager()
211 ->AddFactValueOfPointeeIsIrrelevant(instruction.result_id());
213 break;
214 case spv::Op::OpVariable:
215 transformation_context->GetFactManager()
216 ->AddFactValueOfPointeeIsIrrelevant(instruction.result_id());
217 break;
218 default:
219 break;
224 protobufs::Transformation TransformationAddFunction::ToMessage() const {
225 protobufs::Transformation result;
226 *result.mutable_add_function() = message_;
227 return result;
230 bool TransformationAddFunction::TryToAddFunction(
231 opt::IRContext* ir_context) const {
232 // This function returns false if |message_.instruction| was not well-formed
233 // enough to actually create a function and add it to |ir_context|.
235 // A function must have at least some instructions.
236 if (message_.instruction().empty()) {
237 return false;
240 // A function must start with OpFunction.
241 auto function_begin = message_.instruction(0);
242 if (spv::Op(function_begin.opcode()) != spv::Op::OpFunction) {
243 return false;
246 // Make a function, headed by the OpFunction instruction.
247 std::unique_ptr<opt::Function> new_function = MakeUnique<opt::Function>(
248 InstructionFromMessage(ir_context, function_begin));
250 // Keeps track of which instruction protobuf message we are currently
251 // considering.
252 uint32_t instruction_index = 1;
253 const auto num_instructions =
254 static_cast<uint32_t>(message_.instruction().size());
256 // Iterate through all function parameter instructions, adding parameters to
257 // the new function.
258 while (instruction_index < num_instructions &&
259 spv::Op(message_.instruction(instruction_index).opcode()) ==
260 spv::Op::OpFunctionParameter) {
261 new_function->AddParameter(InstructionFromMessage(
262 ir_context, message_.instruction(instruction_index)));
263 instruction_index++;
266 // After the parameters, there needs to be a label.
267 if (instruction_index == num_instructions ||
268 spv::Op(message_.instruction(instruction_index).opcode()) !=
269 spv::Op::OpLabel) {
270 return false;
273 // Iterate through the instructions block by block until the end of the
274 // function is reached.
275 while (instruction_index < num_instructions &&
276 spv::Op(message_.instruction(instruction_index).opcode()) !=
277 spv::Op::OpFunctionEnd) {
278 // Invariant: we should always be at a label instruction at this point.
279 assert(spv::Op(message_.instruction(instruction_index).opcode()) ==
280 spv::Op::OpLabel);
282 // Make a basic block using the label instruction.
283 std::unique_ptr<opt::BasicBlock> block =
284 MakeUnique<opt::BasicBlock>(InstructionFromMessage(
285 ir_context, message_.instruction(instruction_index)));
287 // Consider successive instructions until we hit another label or the end
288 // of the function, adding each such instruction to the block.
289 instruction_index++;
290 while (instruction_index < num_instructions &&
291 spv::Op(message_.instruction(instruction_index).opcode()) !=
292 spv::Op::OpFunctionEnd &&
293 spv::Op(message_.instruction(instruction_index).opcode()) !=
294 spv::Op::OpLabel) {
295 block->AddInstruction(InstructionFromMessage(
296 ir_context, message_.instruction(instruction_index)));
297 instruction_index++;
299 // Add the block to the new function.
300 new_function->AddBasicBlock(std::move(block));
302 // Having considered all the blocks, we should be at the last instruction and
303 // it needs to be OpFunctionEnd.
304 if (instruction_index != num_instructions - 1 ||
305 spv::Op(message_.instruction(instruction_index).opcode()) !=
306 spv::Op::OpFunctionEnd) {
307 return false;
309 // Set the function's final instruction, add the function to the module and
310 // report success.
311 new_function->SetFunctionEnd(InstructionFromMessage(
312 ir_context, message_.instruction(instruction_index)));
313 ir_context->AddFunction(std::move(new_function));
315 ir_context->InvalidateAnalysesExceptFor(opt::IRContext::kAnalysisNone);
317 return true;
320 bool TransformationAddFunction::TryToMakeFunctionLivesafe(
321 opt::IRContext* ir_context,
322 const TransformationContext& transformation_context) const {
323 assert(message_.is_livesafe() && "Precondition: is_livesafe must hold.");
325 // Get a pointer to the added function.
326 opt::Function* added_function = nullptr;
327 for (auto& function : *ir_context->module()) {
328 if (function.result_id() == message_.instruction(0).result_id()) {
329 added_function = &function;
330 break;
333 assert(added_function && "The added function should have been found.");
335 if (!TryToAddLoopLimiters(ir_context, added_function)) {
336 // Adding loop limiters did not work; bail out.
337 return false;
340 // Consider all the instructions in the function, and:
341 // - attempt to replace OpKill and OpUnreachable with return instructions
342 // - attempt to clamp access chains to be within bounds
343 // - check that OpFunctionCall instructions are only to livesafe functions
344 for (auto& block : *added_function) {
345 for (auto& inst : block) {
346 switch (inst.opcode()) {
347 case spv::Op::OpKill:
348 case spv::Op::OpUnreachable:
349 if (!TryToTurnKillOrUnreachableIntoReturn(ir_context, added_function,
350 &inst)) {
351 return false;
353 break;
354 case spv::Op::OpAccessChain:
355 case spv::Op::OpInBoundsAccessChain:
356 if (!TryToClampAccessChainIndices(ir_context, &inst)) {
357 return false;
359 break;
360 case spv::Op::OpFunctionCall:
361 // A livesafe function my only call other livesafe functions.
362 if (!transformation_context.GetFactManager()->FunctionIsLivesafe(
363 inst.GetSingleWordInOperand(0))) {
364 return false;
366 default:
367 break;
371 return true;
374 uint32_t TransformationAddFunction::GetBackEdgeBlockId(
375 opt::IRContext* ir_context, uint32_t loop_header_block_id) {
376 const auto* loop_header_block =
377 ir_context->cfg()->block(loop_header_block_id);
378 assert(loop_header_block && "|loop_header_block_id| is invalid");
380 for (auto pred : ir_context->cfg()->preds(loop_header_block_id)) {
381 if (ir_context->GetDominatorAnalysis(loop_header_block->GetParent())
382 ->Dominates(loop_header_block_id, pred)) {
383 return pred;
387 return 0;
390 bool TransformationAddFunction::TryToAddLoopLimiters(
391 opt::IRContext* ir_context, opt::Function* added_function) const {
392 // Collect up all the loop headers so that we can subsequently add loop
393 // limiting logic.
394 std::vector<opt::BasicBlock*> loop_headers;
395 for (auto& block : *added_function) {
396 if (block.IsLoopHeader()) {
397 loop_headers.push_back(&block);
401 if (loop_headers.empty()) {
402 // There are no loops, so no need to add any loop limiters.
403 return true;
406 // Check that the module contains appropriate ingredients for declaring and
407 // manipulating a loop limiter.
409 auto loop_limit_constant_id_instr =
410 ir_context->get_def_use_mgr()->GetDef(message_.loop_limit_constant_id());
411 if (!loop_limit_constant_id_instr ||
412 loop_limit_constant_id_instr->opcode() != spv::Op::OpConstant) {
413 // The loop limit constant id instruction must exist and have an
414 // appropriate opcode.
415 return false;
418 auto loop_limit_type = ir_context->get_def_use_mgr()->GetDef(
419 loop_limit_constant_id_instr->type_id());
420 if (loop_limit_type->opcode() != spv::Op::OpTypeInt ||
421 loop_limit_type->GetSingleWordInOperand(0) != 32) {
422 // The type of the loop limit constant must be 32-bit integer. It
423 // doesn't actually matter whether the integer is signed or not.
424 return false;
427 // Find the id of the "unsigned int" type.
428 opt::analysis::Integer unsigned_int_type(32, false);
429 uint32_t unsigned_int_type_id =
430 ir_context->get_type_mgr()->GetId(&unsigned_int_type);
431 if (!unsigned_int_type_id) {
432 // Unsigned int is not available; we need this type in order to add loop
433 // limiters.
434 return false;
436 auto registered_unsigned_int_type =
437 ir_context->get_type_mgr()->GetRegisteredType(&unsigned_int_type);
439 // Look for 0 of type unsigned int.
440 opt::analysis::IntConstant zero(registered_unsigned_int_type->AsInteger(),
441 {0});
442 auto registered_zero = ir_context->get_constant_mgr()->FindConstant(&zero);
443 if (!registered_zero) {
444 // We need 0 in order to be able to initialize loop limiters.
445 return false;
447 uint32_t zero_id = ir_context->get_constant_mgr()
448 ->GetDefiningInstruction(registered_zero)
449 ->result_id();
451 // Look for 1 of type unsigned int.
452 opt::analysis::IntConstant one(registered_unsigned_int_type->AsInteger(),
453 {1});
454 auto registered_one = ir_context->get_constant_mgr()->FindConstant(&one);
455 if (!registered_one) {
456 // We need 1 in order to be able to increment loop limiters.
457 return false;
459 uint32_t one_id = ir_context->get_constant_mgr()
460 ->GetDefiningInstruction(registered_one)
461 ->result_id();
463 // Look for pointer-to-unsigned int type.
464 opt::analysis::Pointer pointer_to_unsigned_int_type(
465 registered_unsigned_int_type, spv::StorageClass::Function);
466 uint32_t pointer_to_unsigned_int_type_id =
467 ir_context->get_type_mgr()->GetId(&pointer_to_unsigned_int_type);
468 if (!pointer_to_unsigned_int_type_id) {
469 // We need pointer-to-unsigned int in order to declare the loop limiter
470 // variable.
471 return false;
474 // Look for bool type.
475 opt::analysis::Bool bool_type;
476 uint32_t bool_type_id = ir_context->get_type_mgr()->GetId(&bool_type);
477 if (!bool_type_id) {
478 // We need bool in order to compare the loop limiter's value with the loop
479 // limit constant.
480 return false;
483 // Declare the loop limiter variable at the start of the function's entry
484 // block, via an instruction of the form:
485 // %loop_limiter_var = spv::Op::OpVariable %ptr_to_uint Function %zero
486 added_function->begin()->begin()->InsertBefore(MakeUnique<opt::Instruction>(
487 ir_context, spv::Op::OpVariable, pointer_to_unsigned_int_type_id,
488 message_.loop_limiter_variable_id(),
489 opt::Instruction::OperandList({{SPV_OPERAND_TYPE_STORAGE_CLASS,
490 {uint32_t(spv::StorageClass::Function)}},
491 {SPV_OPERAND_TYPE_ID, {zero_id}}})));
492 // Update the module's id bound since we have added the loop limiter
493 // variable id.
494 fuzzerutil::UpdateModuleIdBound(ir_context,
495 message_.loop_limiter_variable_id());
497 // Consider each loop in turn.
498 for (auto loop_header : loop_headers) {
499 // Look for the loop's back-edge block. This is a predecessor of the loop
500 // header that is dominated by the loop header.
501 const auto back_edge_block_id =
502 GetBackEdgeBlockId(ir_context, loop_header->id());
503 if (!back_edge_block_id) {
504 // The loop's back-edge block must be unreachable. This means that the
505 // loop cannot iterate, so there is no need to make it lifesafe; we can
506 // move on from this loop.
507 continue;
510 // If the loop's merge block is unreachable, then there are no constraints
511 // on where the merge block appears in relation to the blocks of the loop.
512 // This means we need to be careful when adding a branch from the back-edge
513 // block to the merge block: the branch might make the loop merge reachable,
514 // and it might then be dominated by the loop header and possibly by other
515 // blocks in the loop. Since a block needs to appear before those blocks it
516 // strictly dominates, this could make the module invalid. To avoid this
517 // problem we bail out in the case where the loop header does not dominate
518 // the loop merge.
519 if (!ir_context->GetDominatorAnalysis(added_function)
520 ->Dominates(loop_header->id(), loop_header->MergeBlockId())) {
521 return false;
524 // Go through the sequence of loop limiter infos and find the one
525 // corresponding to this loop.
526 bool found = false;
527 protobufs::LoopLimiterInfo loop_limiter_info;
528 for (auto& info : message_.loop_limiter_info()) {
529 if (info.loop_header_id() == loop_header->id()) {
530 loop_limiter_info = info;
531 found = true;
532 break;
535 if (!found) {
536 // We don't have loop limiter info for this loop header.
537 return false;
540 // The back-edge block either has the form:
542 // (1)
544 // %l = OpLabel
545 // ... instructions ...
546 // OpBranch %loop_header
548 // (2)
550 // %l = OpLabel
551 // ... instructions ...
552 // OpBranchConditional %c %loop_header %loop_merge
554 // (3)
556 // %l = OpLabel
557 // ... instructions ...
558 // OpBranchConditional %c %loop_merge %loop_header
560 // We turn these into the following:
562 // (1)
564 // %l = OpLabel
565 // ... instructions ...
566 // %t1 = OpLoad %uint32 %loop_limiter
567 // %t2 = OpIAdd %uint32 %t1 %one
568 // OpStore %loop_limiter %t2
569 // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
570 // OpBranchConditional %t3 %loop_merge %loop_header
572 // (2)
574 // %l = OpLabel
575 // ... instructions ...
576 // %t1 = OpLoad %uint32 %loop_limiter
577 // %t2 = OpIAdd %uint32 %t1 %one
578 // OpStore %loop_limiter %t2
579 // %t3 = OpULessThan %bool %t1 %loop_limit
580 // %t4 = OpLogicalAnd %bool %c %t3
581 // OpBranchConditional %t4 %loop_header %loop_merge
583 // (3)
585 // %l = OpLabel
586 // ... instructions ...
587 // %t1 = OpLoad %uint32 %loop_limiter
588 // %t2 = OpIAdd %uint32 %t1 %one
589 // OpStore %loop_limiter %t2
590 // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
591 // %t4 = OpLogicalOr %bool %c %t3
592 // OpBranchConditional %t4 %loop_merge %loop_header
594 auto back_edge_block = ir_context->cfg()->block(back_edge_block_id);
595 auto back_edge_block_terminator = back_edge_block->terminator();
596 bool compare_using_greater_than_equal;
597 if (back_edge_block_terminator->opcode() == spv::Op::OpBranch) {
598 compare_using_greater_than_equal = true;
599 } else {
600 assert(back_edge_block_terminator->opcode() ==
601 spv::Op::OpBranchConditional);
602 assert(((back_edge_block_terminator->GetSingleWordInOperand(1) ==
603 loop_header->id() &&
604 back_edge_block_terminator->GetSingleWordInOperand(2) ==
605 loop_header->MergeBlockId()) ||
606 (back_edge_block_terminator->GetSingleWordInOperand(2) ==
607 loop_header->id() &&
608 back_edge_block_terminator->GetSingleWordInOperand(1) ==
609 loop_header->MergeBlockId())) &&
610 "A back edge edge block must branch to"
611 " either the loop header or merge");
612 compare_using_greater_than_equal =
613 back_edge_block_terminator->GetSingleWordInOperand(1) ==
614 loop_header->MergeBlockId();
617 std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
619 // Add a load from the loop limiter variable, of the form:
620 // %t1 = OpLoad %uint32 %loop_limiter
621 new_instructions.push_back(MakeUnique<opt::Instruction>(
622 ir_context, spv::Op::OpLoad, unsigned_int_type_id,
623 loop_limiter_info.load_id(),
624 opt::Instruction::OperandList(
625 {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}}})));
627 // Increment the loaded value:
628 // %t2 = OpIAdd %uint32 %t1 %one
629 new_instructions.push_back(MakeUnique<opt::Instruction>(
630 ir_context, spv::Op::OpIAdd, unsigned_int_type_id,
631 loop_limiter_info.increment_id(),
632 opt::Instruction::OperandList(
633 {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
634 {SPV_OPERAND_TYPE_ID, {one_id}}})));
636 // Store the incremented value back to the loop limiter variable:
637 // OpStore %loop_limiter %t2
638 new_instructions.push_back(MakeUnique<opt::Instruction>(
639 ir_context, spv::Op::OpStore, 0, 0,
640 opt::Instruction::OperandList(
641 {{SPV_OPERAND_TYPE_ID, {message_.loop_limiter_variable_id()}},
642 {SPV_OPERAND_TYPE_ID, {loop_limiter_info.increment_id()}}})));
644 // Compare the loaded value with the loop limit; either:
645 // %t3 = OpUGreaterThanEqual %bool %t1 %loop_limit
646 // or
647 // %t3 = OpULessThan %bool %t1 %loop_limit
648 new_instructions.push_back(MakeUnique<opt::Instruction>(
649 ir_context,
650 compare_using_greater_than_equal ? spv::Op::OpUGreaterThanEqual
651 : spv::Op::OpULessThan,
652 bool_type_id, loop_limiter_info.compare_id(),
653 opt::Instruction::OperandList(
654 {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.load_id()}},
655 {SPV_OPERAND_TYPE_ID, {message_.loop_limit_constant_id()}}})));
657 if (back_edge_block_terminator->opcode() == spv::Op::OpBranchConditional) {
658 new_instructions.push_back(MakeUnique<opt::Instruction>(
659 ir_context,
660 compare_using_greater_than_equal ? spv::Op::OpLogicalOr
661 : spv::Op::OpLogicalAnd,
662 bool_type_id, loop_limiter_info.logical_op_id(),
663 opt::Instruction::OperandList(
664 {{SPV_OPERAND_TYPE_ID,
665 {back_edge_block_terminator->GetSingleWordInOperand(0)}},
666 {SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}}})));
669 // Add the new instructions at the end of the back edge block, before the
670 // terminator and any loop merge instruction (as the back edge block can
671 // be the loop header).
672 if (back_edge_block->GetLoopMergeInst()) {
673 back_edge_block->GetLoopMergeInst()->InsertBefore(
674 std::move(new_instructions));
675 } else {
676 back_edge_block_terminator->InsertBefore(std::move(new_instructions));
679 if (back_edge_block_terminator->opcode() == spv::Op::OpBranchConditional) {
680 back_edge_block_terminator->SetInOperand(
681 0, {loop_limiter_info.logical_op_id()});
682 } else {
683 assert(back_edge_block_terminator->opcode() == spv::Op::OpBranch &&
684 "Back-edge terminator must be OpBranch or OpBranchConditional");
686 // Check that, if the merge block starts with OpPhi instructions, suitable
687 // ids have been provided to give these instructions a value corresponding
688 // to the new incoming edge from the back edge block.
689 auto merge_block = ir_context->cfg()->block(loop_header->MergeBlockId());
690 if (!fuzzerutil::PhiIdsOkForNewEdge(ir_context, back_edge_block,
691 merge_block,
692 loop_limiter_info.phi_id())) {
693 return false;
696 // Augment OpPhi instructions at the loop merge with the given ids.
697 uint32_t phi_index = 0;
698 for (auto& inst : *merge_block) {
699 if (inst.opcode() != spv::Op::OpPhi) {
700 break;
702 assert(phi_index <
703 static_cast<uint32_t>(loop_limiter_info.phi_id().size()) &&
704 "There should be at least one phi id per OpPhi instruction.");
705 inst.AddOperand(
706 {SPV_OPERAND_TYPE_ID, {loop_limiter_info.phi_id(phi_index)}});
707 inst.AddOperand({SPV_OPERAND_TYPE_ID, {back_edge_block_id}});
708 phi_index++;
711 // Add the new edge, by changing OpBranch to OpBranchConditional.
712 back_edge_block_terminator->SetOpcode(spv::Op::OpBranchConditional);
713 back_edge_block_terminator->SetInOperands(opt::Instruction::OperandList(
714 {{SPV_OPERAND_TYPE_ID, {loop_limiter_info.compare_id()}},
715 {SPV_OPERAND_TYPE_ID, {loop_header->MergeBlockId()}},
716 {SPV_OPERAND_TYPE_ID, {loop_header->id()}}}));
719 // Update the module's id bound with respect to the various ids that
720 // have been used for loop limiter manipulation.
721 fuzzerutil::UpdateModuleIdBound(ir_context, loop_limiter_info.load_id());
722 fuzzerutil::UpdateModuleIdBound(ir_context,
723 loop_limiter_info.increment_id());
724 fuzzerutil::UpdateModuleIdBound(ir_context, loop_limiter_info.compare_id());
725 fuzzerutil::UpdateModuleIdBound(ir_context,
726 loop_limiter_info.logical_op_id());
728 return true;
731 bool TransformationAddFunction::TryToTurnKillOrUnreachableIntoReturn(
732 opt::IRContext* ir_context, opt::Function* added_function,
733 opt::Instruction* kill_or_unreachable_inst) const {
734 assert((kill_or_unreachable_inst->opcode() == spv::Op::OpKill ||
735 kill_or_unreachable_inst->opcode() == spv::Op::OpUnreachable) &&
736 "Precondition: instruction must be OpKill or OpUnreachable.");
738 // Get the function's return type.
739 auto function_return_type_inst =
740 ir_context->get_def_use_mgr()->GetDef(added_function->type_id());
742 if (function_return_type_inst->opcode() == spv::Op::OpTypeVoid) {
743 // The function has void return type, so change this instruction to
744 // OpReturn.
745 kill_or_unreachable_inst->SetOpcode(spv::Op::OpReturn);
746 } else {
747 // The function has non-void return type, so change this instruction
748 // to OpReturnValue, using the value id provided with the
749 // transformation.
751 // We first check that the id, %id, provided with the transformation
752 // specifically to turn OpKill and OpUnreachable instructions into
753 // OpReturnValue %id has the same type as the function's return type.
754 if (ir_context->get_def_use_mgr()
755 ->GetDef(message_.kill_unreachable_return_value_id())
756 ->type_id() != function_return_type_inst->result_id()) {
757 return false;
759 kill_or_unreachable_inst->SetOpcode(spv::Op::OpReturnValue);
760 kill_or_unreachable_inst->SetInOperands(
761 {{SPV_OPERAND_TYPE_ID, {message_.kill_unreachable_return_value_id()}}});
763 return true;
766 bool TransformationAddFunction::TryToClampAccessChainIndices(
767 opt::IRContext* ir_context, opt::Instruction* access_chain_inst) const {
768 assert((access_chain_inst->opcode() == spv::Op::OpAccessChain ||
769 access_chain_inst->opcode() == spv::Op::OpInBoundsAccessChain) &&
770 "Precondition: instruction must be OpAccessChain or "
771 "OpInBoundsAccessChain.");
773 // Find the AccessChainClampingInfo associated with this access chain.
774 const protobufs::AccessChainClampingInfo* access_chain_clamping_info =
775 nullptr;
776 for (auto& clamping_info : message_.access_chain_clamping_info()) {
777 if (clamping_info.access_chain_id() == access_chain_inst->result_id()) {
778 access_chain_clamping_info = &clamping_info;
779 break;
782 if (!access_chain_clamping_info) {
783 // No access chain clamping information was found; the function cannot be
784 // made livesafe.
785 return false;
788 // Check that there is a (compare_id, select_id) pair for every
789 // index associated with the instruction.
790 if (static_cast<uint32_t>(
791 access_chain_clamping_info->compare_and_select_ids().size()) !=
792 access_chain_inst->NumInOperands() - 1) {
793 return false;
796 // Walk the access chain, clamping each index to be within bounds if it is
797 // not a constant.
798 auto base_object = ir_context->get_def_use_mgr()->GetDef(
799 access_chain_inst->GetSingleWordInOperand(0));
800 assert(base_object && "The base object must exist.");
801 auto pointer_type =
802 ir_context->get_def_use_mgr()->GetDef(base_object->type_id());
803 assert(pointer_type && pointer_type->opcode() == spv::Op::OpTypePointer &&
804 "The base object must have pointer type.");
805 auto should_be_composite_type = ir_context->get_def_use_mgr()->GetDef(
806 pointer_type->GetSingleWordInOperand(1));
808 // Consider each index input operand in turn (operand 0 is the base object).
809 for (uint32_t index = 1; index < access_chain_inst->NumInOperands();
810 index++) {
811 // We are going to turn:
813 // %result = OpAccessChain %type %object ... %index ...
815 // into:
817 // %t1 = OpULessThanEqual %bool %index %bound_minus_one
818 // %t2 = OpSelect %int_type %t1 %index %bound_minus_one
819 // %result = OpAccessChain %type %object ... %t2 ...
821 // ... unless %index is already a constant.
823 // Get the bound for the composite being indexed into; e.g. the number of
824 // columns of matrix or the size of an array.
825 uint32_t bound = fuzzerutil::GetBoundForCompositeIndex(
826 *should_be_composite_type, ir_context);
828 // Get the instruction associated with the index and figure out its integer
829 // type.
830 const uint32_t index_id = access_chain_inst->GetSingleWordInOperand(index);
831 auto index_inst = ir_context->get_def_use_mgr()->GetDef(index_id);
832 auto index_type_inst =
833 ir_context->get_def_use_mgr()->GetDef(index_inst->type_id());
834 assert(index_type_inst->opcode() == spv::Op::OpTypeInt);
835 assert(index_type_inst->GetSingleWordInOperand(0) == 32);
836 opt::analysis::Integer* index_int_type =
837 ir_context->get_type_mgr()
838 ->GetType(index_type_inst->result_id())
839 ->AsInteger();
841 if (index_inst->opcode() != spv::Op::OpConstant ||
842 index_inst->GetSingleWordInOperand(0) >= bound) {
843 // The index is either non-constant or an out-of-bounds constant, so we
844 // need to clamp it.
845 assert(should_be_composite_type->opcode() != spv::Op::OpTypeStruct &&
846 "Access chain indices into structures are required to be "
847 "constants.");
848 opt::analysis::IntConstant bound_minus_one(index_int_type, {bound - 1});
849 if (!ir_context->get_constant_mgr()->FindConstant(&bound_minus_one)) {
850 // We do not have an integer constant whose value is |bound| -1.
851 return false;
854 opt::analysis::Bool bool_type;
855 uint32_t bool_type_id = ir_context->get_type_mgr()->GetId(&bool_type);
856 if (!bool_type_id) {
857 // Bool type is not declared; we cannot do a comparison.
858 return false;
861 uint32_t bound_minus_one_id =
862 ir_context->get_constant_mgr()
863 ->GetDefiningInstruction(&bound_minus_one)
864 ->result_id();
866 uint32_t compare_id =
867 access_chain_clamping_info->compare_and_select_ids(index - 1).first();
868 uint32_t select_id =
869 access_chain_clamping_info->compare_and_select_ids(index - 1)
870 .second();
871 std::vector<std::unique_ptr<opt::Instruction>> new_instructions;
873 // Compare the index with the bound via an instruction of the form:
874 // %t1 = OpULessThanEqual %bool %index %bound_minus_one
875 new_instructions.push_back(MakeUnique<opt::Instruction>(
876 ir_context, spv::Op::OpULessThanEqual, bool_type_id, compare_id,
877 opt::Instruction::OperandList(
878 {{SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
879 {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
881 // Select the index if in-bounds, otherwise one less than the bound:
882 // %t2 = OpSelect %int_type %t1 %index %bound_minus_one
883 new_instructions.push_back(MakeUnique<opt::Instruction>(
884 ir_context, spv::Op::OpSelect, index_type_inst->result_id(),
885 select_id,
886 opt::Instruction::OperandList(
887 {{SPV_OPERAND_TYPE_ID, {compare_id}},
888 {SPV_OPERAND_TYPE_ID, {index_inst->result_id()}},
889 {SPV_OPERAND_TYPE_ID, {bound_minus_one_id}}})));
891 // Add the new instructions before the access chain
892 access_chain_inst->InsertBefore(std::move(new_instructions));
894 // Replace %index with %t2.
895 access_chain_inst->SetInOperand(index, {select_id});
896 fuzzerutil::UpdateModuleIdBound(ir_context, compare_id);
897 fuzzerutil::UpdateModuleIdBound(ir_context, select_id);
899 should_be_composite_type =
900 FollowCompositeIndex(ir_context, *should_be_composite_type, index_id);
902 return true;
905 opt::Instruction* TransformationAddFunction::FollowCompositeIndex(
906 opt::IRContext* ir_context, const opt::Instruction& composite_type_inst,
907 uint32_t index_id) {
908 uint32_t sub_object_type_id;
909 switch (composite_type_inst.opcode()) {
910 case spv::Op::OpTypeArray:
911 case spv::Op::OpTypeRuntimeArray:
912 sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
913 break;
914 case spv::Op::OpTypeMatrix:
915 case spv::Op::OpTypeVector:
916 sub_object_type_id = composite_type_inst.GetSingleWordInOperand(0);
917 break;
918 case spv::Op::OpTypeStruct: {
919 auto index_inst = ir_context->get_def_use_mgr()->GetDef(index_id);
920 assert(index_inst->opcode() == spv::Op::OpConstant);
921 assert(ir_context->get_def_use_mgr()
922 ->GetDef(index_inst->type_id())
923 ->opcode() == spv::Op::OpTypeInt);
924 assert(ir_context->get_def_use_mgr()
925 ->GetDef(index_inst->type_id())
926 ->GetSingleWordInOperand(0) == 32);
927 uint32_t index_value = index_inst->GetSingleWordInOperand(0);
928 sub_object_type_id =
929 composite_type_inst.GetSingleWordInOperand(index_value);
930 break;
932 default:
933 assert(false && "Unknown composite type.");
934 sub_object_type_id = 0;
935 break;
937 assert(sub_object_type_id && "No sub-object found.");
938 return ir_context->get_def_use_mgr()->GetDef(sub_object_type_id);
941 std::unordered_set<uint32_t> TransformationAddFunction::GetFreshIds() const {
942 std::unordered_set<uint32_t> result;
943 for (auto& instruction : message_.instruction()) {
944 result.insert(instruction.result_id());
946 if (message_.is_livesafe()) {
947 result.insert(message_.loop_limiter_variable_id());
948 for (auto& loop_limiter_info : message_.loop_limiter_info()) {
949 result.insert(loop_limiter_info.load_id());
950 result.insert(loop_limiter_info.increment_id());
951 result.insert(loop_limiter_info.compare_id());
952 result.insert(loop_limiter_info.logical_op_id());
954 for (auto& access_chain_clamping_info :
955 message_.access_chain_clamping_info()) {
956 for (auto& pair : access_chain_clamping_info.compare_and_select_ids()) {
957 result.insert(pair.first());
958 result.insert(pair.second());
962 return result;
965 } // namespace fuzz
966 } // namespace spvtools