1 // Copyright (c) 2019 Google LLC
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
15 #include "source/fuzz/shrinker.h"
19 #include "source/fuzz/added_function_reducer.h"
20 #include "source/fuzz/pseudo_random_generator.h"
21 #include "source/fuzz/replayer.h"
22 #include "source/opt/build_module.h"
23 #include "source/opt/ir_context.h"
24 #include "source/spirv_fuzzer_options.h"
25 #include "source/util/make_unique.h"
32 // A helper to get the size of a protobuf transformation sequence in a less
34 uint32_t NumRemainingTransformations(
35 const protobufs::TransformationSequence
& transformation_sequence
) {
36 return static_cast<uint32_t>(transformation_sequence
.transformation_size());
39 // A helper to return a transformation sequence identical to |transformations|,
40 // except that a chunk of size |chunk_size| starting from |chunk_index| x
41 // |chunk_size| is removed (or as many transformations as available if the whole
43 protobufs::TransformationSequence
RemoveChunk(
44 const protobufs::TransformationSequence
& transformations
,
45 uint32_t chunk_index
, uint32_t chunk_size
) {
46 uint32_t lower
= chunk_index
* chunk_size
;
47 uint32_t upper
= std::min((chunk_index
+ 1) * chunk_size
,
48 NumRemainingTransformations(transformations
));
49 assert(lower
< upper
);
50 assert(upper
<= NumRemainingTransformations(transformations
));
51 protobufs::TransformationSequence result
;
52 for (uint32_t j
= 0; j
< NumRemainingTransformations(transformations
); j
++) {
53 if (j
>= lower
&& j
< upper
) {
56 protobufs::Transformation transformation
=
57 transformations
.transformation()[j
];
58 *result
.mutable_transformation()->Add() = transformation
;
66 spv_target_env target_env
, MessageConsumer consumer
,
67 const std::vector
<uint32_t>& binary_in
,
68 const protobufs::FactSequence
& initial_facts
,
69 const protobufs::TransformationSequence
& transformation_sequence_in
,
70 const InterestingnessFunction
& interestingness_function
,
71 uint32_t step_limit
, bool validate_during_replay
,
72 spv_validator_options validator_options
)
73 : target_env_(target_env
),
74 consumer_(std::move(consumer
)),
75 binary_in_(binary_in
),
76 initial_facts_(initial_facts
),
77 transformation_sequence_in_(transformation_sequence_in
),
78 interestingness_function_(interestingness_function
),
79 step_limit_(step_limit
),
80 validate_during_replay_(validate_during_replay
),
81 validator_options_(validator_options
) {}
83 Shrinker::~Shrinker() = default;
85 Shrinker::ShrinkerResult
Shrinker::Run() {
86 // Check compatibility between the library version being linked with and the
87 // header files being used.
88 GOOGLE_PROTOBUF_VERIFY_VERSION
;
90 SpirvTools
tools(target_env_
);
91 if (!tools
.IsValid()) {
92 consumer_(SPV_MSG_ERROR
, nullptr, {},
93 "Failed to create SPIRV-Tools interface; stopping.");
94 return {Shrinker::ShrinkerResultStatus::kFailedToCreateSpirvToolsInterface
,
95 std::vector
<uint32_t>(), protobufs::TransformationSequence()};
98 // Initial binary should be valid.
99 if (!tools
.Validate(&binary_in_
[0], binary_in_
.size(), validator_options_
)) {
100 consumer_(SPV_MSG_INFO
, nullptr, {},
101 "Initial binary is invalid; stopping.");
102 return {Shrinker::ShrinkerResultStatus::kInitialBinaryInvalid
,
103 std::vector
<uint32_t>(), protobufs::TransformationSequence()};
106 // Run a replay of the initial transformation sequence to check that it
108 auto initial_replay_result
=
109 Replayer(target_env_
, consumer_
, binary_in_
, initial_facts_
,
110 transformation_sequence_in_
,
111 static_cast<uint32_t>(
112 transformation_sequence_in_
.transformation_size()),
113 validate_during_replay_
, validator_options_
)
115 if (initial_replay_result
.status
!=
116 Replayer::ReplayerResultStatus::kComplete
) {
117 return {ShrinkerResultStatus::kReplayFailed
, std::vector
<uint32_t>(),
118 protobufs::TransformationSequence()};
120 // Get the binary that results from running these transformations, and the
121 // subsequence of the initial transformations that actually apply (in
122 // principle this could be a strict subsequence).
123 std::vector
<uint32_t> current_best_binary
;
124 initial_replay_result
.transformed_module
->module()->ToBinary(
125 ¤t_best_binary
, false);
126 protobufs::TransformationSequence current_best_transformations
=
127 std::move(initial_replay_result
.applied_transformations
);
129 // Check that the binary produced by applying the initial transformations is
130 // indeed interesting.
131 if (!interestingness_function_(current_best_binary
, 0)) {
132 consumer_(SPV_MSG_INFO
, nullptr, {},
133 "Initial binary is not interesting; stopping.");
134 return {ShrinkerResultStatus::kInitialBinaryNotInteresting
,
135 std::vector
<uint32_t>(), protobufs::TransformationSequence()};
138 uint32_t attempt
= 0; // Keeps track of the number of shrink attempts that
139 // have been tried, whether successful or not.
141 uint32_t chunk_size
=
142 std::max(1u, NumRemainingTransformations(current_best_transformations
) /
143 2); // The number of contiguous transformations that the
144 // shrinker will try to remove in one go; starts
145 // high and decreases during the shrinking process.
147 // Keep shrinking until we:
148 // - reach the step limit,
149 // - run out of transformations to remove, or
150 // - cannot make the chunk size any smaller.
151 while (attempt
< step_limit_
&&
152 !current_best_transformations
.transformation().empty() &&
154 bool progress_this_round
=
155 false; // Used to decide whether to make the chunk size with which we
156 // remove transformations smaller. If we managed to remove at
157 // least one chunk of transformations at a particular chunk
158 // size, we set this flag so that we do not yet decrease the
162 NumRemainingTransformations(current_best_transformations
) &&
163 "Chunk size should never exceed the number of transformations that "
166 // The number of chunks is the ceiling of (#remaining_transformations /
168 const uint32_t num_chunks
=
169 (NumRemainingTransformations(current_best_transformations
) +
172 assert(num_chunks
>= 1 && "There should be at least one chunk.");
173 assert(num_chunks
* chunk_size
>=
174 NumRemainingTransformations(current_best_transformations
) &&
175 "All transformations should be in some chunk.");
177 // We go through the transformations in reverse, in chunks of size
178 // |chunk_size|, using |chunk_index| to track which chunk to try removing
179 // next. The loop exits early if we reach the shrinking step limit.
180 for (int chunk_index
= num_chunks
- 1;
181 attempt
< step_limit_
&& chunk_index
>= 0; chunk_index
--) {
182 // Remove a chunk of transformations according to the current index and
184 auto transformations_with_chunk_removed
=
185 RemoveChunk(current_best_transformations
,
186 static_cast<uint32_t>(chunk_index
), chunk_size
);
188 // Replay the smaller sequence of transformations to get a next binary and
189 // transformation sequence. Note that the transformations arising from
190 // replay might be even smaller than the transformations with the chunk
191 // removed, because removing those transformations might make further
192 // transformations inapplicable.
195 target_env_
, consumer_
, binary_in_
, initial_facts_
,
196 transformations_with_chunk_removed
,
197 static_cast<uint32_t>(
198 transformations_with_chunk_removed
.transformation_size()),
199 validate_during_replay_
, validator_options_
)
201 if (replay_result
.status
!= Replayer::ReplayerResultStatus::kComplete
) {
202 // Replay should not fail; if it does, we need to abort shrinking.
203 return {ShrinkerResultStatus::kReplayFailed
, std::vector
<uint32_t>(),
204 protobufs::TransformationSequence()};
208 NumRemainingTransformations(replay_result
.applied_transformations
) >=
209 chunk_index
* chunk_size
&&
210 "Removing this chunk of transformations should not have an effect "
211 "on earlier chunks.");
213 std::vector
<uint32_t> transformed_binary
;
214 replay_result
.transformed_module
->module()->ToBinary(&transformed_binary
,
216 if (interestingness_function_(transformed_binary
, attempt
)) {
217 // If the binary arising from the smaller transformation sequence is
218 // interesting, this becomes our current best binary and transformation
220 current_best_binary
= std::move(transformed_binary
);
221 current_best_transformations
=
222 std::move(replay_result
.applied_transformations
);
223 progress_this_round
= true;
225 // Either way, this was a shrink attempt, so increment our count of shrink
229 if (!progress_this_round
) {
230 // If we didn't manage to remove any chunks at this chunk size, try a
231 // smaller chunk size.
234 // Decrease the chunk size until it becomes no larger than the number of
235 // remaining transformations.
237 NumRemainingTransformations(current_best_transformations
)) {
242 // We now use spirv-reduce to minimise the functions associated with any
243 // AddFunction transformations that remain.
245 // Consider every remaining transformation.
246 for (uint32_t transformation_index
= 0;
247 attempt
< step_limit_
&&
248 transformation_index
<
249 static_cast<uint32_t>(
250 current_best_transformations
.transformation_size());
251 transformation_index
++) {
252 // Skip all transformations apart from TransformationAddFunction.
253 if (!current_best_transformations
.transformation(transformation_index
)
254 .has_add_function()) {
257 // Invoke spirv-reduce on the function encoded in this AddFunction
258 // transformation. The details of this are rather involved, and so are
259 // encapsulated in a separate class.
260 auto added_function_reducer_result
=
261 AddedFunctionReducer(target_env_
, consumer_
, binary_in_
, initial_facts_
,
262 current_best_transformations
, transformation_index
,
263 interestingness_function_
, validate_during_replay_
,
264 validator_options_
, step_limit_
, attempt
)
266 // Reducing the added function should succeed. If it doesn't, we report
267 // a shrinking error.
268 if (added_function_reducer_result
.status
!=
269 AddedFunctionReducer::AddedFunctionReducerResultStatus::kComplete
) {
270 return {ShrinkerResultStatus::kAddedFunctionReductionFailed
,
271 std::vector
<uint32_t>(), protobufs::TransformationSequence()};
273 assert(current_best_transformations
.transformation_size() ==
274 added_function_reducer_result
.applied_transformations
275 .transformation_size() &&
276 "The number of transformations should not have changed.");
277 current_best_binary
=
278 std::move(added_function_reducer_result
.transformed_binary
);
279 current_best_transformations
=
280 std::move(added_function_reducer_result
.applied_transformations
);
281 // The added function reducer reports how many reduction attempts
282 // spirv-reduce took when reducing the function. We regard each of these
283 // as a shrinker attempt.
284 attempt
+= added_function_reducer_result
.num_reduction_attempts
;
287 // Indicate whether shrinking completed or was truncated due to reaching the
290 // Either way, the output from the shrinker is the best binary we saw, and the
291 // transformations that led to it.
292 assert(attempt
<= step_limit_
);
293 if (attempt
== step_limit_
) {
294 std::stringstream strstream
;
295 strstream
<< "Shrinking did not complete; step limit " << step_limit_
297 consumer_(SPV_MSG_WARNING
, nullptr, {}, strstream
.str().c_str());
298 return {Shrinker::ShrinkerResultStatus::kStepLimitReached
,
299 std::move(current_best_binary
),
300 std::move(current_best_transformations
)};
302 return {Shrinker::ShrinkerResultStatus::kComplete
,
303 std::move(current_best_binary
),
304 std::move(current_best_transformations
)};
307 uint32_t Shrinker::GetIdBound(const std::vector
<uint32_t>& binary
) const {
308 // Build the module from the input binary.
309 std::unique_ptr
<opt::IRContext
> ir_context
=
310 BuildModule(target_env_
, consumer_
, binary
.data(), binary
.size());
311 assert(ir_context
&& "Error building module.");
312 return ir_context
->module()->id_bound();
316 } // namespace spvtools