1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
8 #include "base/callback.h"
9 #include "base/memory/scoped_ptr.h"
10 #include "base/metrics/field_trial.h"
11 #include "base/run_loop.h"
12 #include "base/strings/string_number_conversions.h"
13 #include "base/time/time.h"
14 #include "chrome/browser/safe_browsing/client_side_model_loader.h"
15 #include "chrome/common/safe_browsing/client_model.pb.h"
16 #include "chrome/common/safe_browsing/csd.pb.h"
17 #include "components/variations/variations_associated_data.h"
18 #include "content/public/test/test_browser_thread_bundle.h"
19 #include "net/http/http_status_code.h"
20 #include "net/url_request/test_url_fetcher_factory.h"
21 #include "net/url_request/url_request_status.h"
22 #include "testing/gmock/include/gmock/gmock.h"
23 #include "testing/gtest/include/gtest/gtest.h"
26 using ::testing::Invoke
;
27 using ::testing::Mock
;
28 using ::testing::StrictMock
;
31 namespace safe_browsing
{
34 class MockModelLoader
: public ModelLoader
{
36 explicit MockModelLoader(base::Closure update_renderers_callback
,
37 const std::string model_name
)
38 : ModelLoader(update_renderers_callback
, model_name
) {}
39 ~MockModelLoader() override
{}
41 MOCK_METHOD1(ScheduleFetch
, void(int64
));
42 MOCK_METHOD2(EndFetch
, void(ClientModelStatus
, base::TimeDelta
));
45 DISALLOW_COPY_AND_ASSIGN(MockModelLoader
);
50 class ModelLoaderTest
: public testing::Test
{
53 : factory_(new net::FakeURLFetcherFactory(NULL
)),
54 field_trials_(new base::FieldTrialList(NULL
)) {}
56 void SetUp() override
{
57 variations::testing::ClearAllVariationIDs();
58 variations::testing::ClearAllVariationParams();
61 // Set up the finch experiment to control the model number
62 // used in the model URL. This clears all existing state.
63 void SetFinchModelNumber(int model_number
) {
64 // Destroy the existing FieldTrialList before creating a new one to avoid
66 field_trials_
.reset();
67 field_trials_
.reset(new base::FieldTrialList(NULL
));
68 variations::testing::ClearAllVariationIDs();
69 variations::testing::ClearAllVariationParams();
71 const std::string group_name
= "ModelFoo"; // Not used in CSD code.
72 ASSERT_TRUE(base::FieldTrialList::CreateFieldTrial(
73 ModelLoader::kClientModelFinchExperiment
, group_name
));
75 std::map
<std::string
, std::string
> params
;
76 params
[ModelLoader::kClientModelFinchParam
] =
77 base::IntToString(model_number
);
79 ASSERT_TRUE(variations::AssociateVariationParams(
80 ModelLoader::kClientModelFinchExperiment
, group_name
, params
));
83 // Set the URL for future SetModelFetchResponse() calls.
84 void SetModelUrl(const ModelLoader
& loader
) { model_url_
= loader
.url_
; }
86 void SetModelFetchResponse(std::string response_data
,
87 net::HttpStatusCode response_code
,
88 net::URLRequestStatus::Status status
) {
89 CHECK(model_url_
.is_valid());
90 factory_
->SetFakeResponse(model_url_
, response_data
, response_code
, status
);
94 content::TestBrowserThreadBundle thread_bundle_
;
95 scoped_ptr
<net::FakeURLFetcherFactory
> factory_
;
96 scoped_ptr
<base::FieldTrialList
> field_trials_
;
100 ACTION_P(InvokeClosure
, closure
) {
104 // Test the reponse to many variations of model responses.
105 TEST_F(ModelLoaderTest
, FetchModelTest
) {
106 StrictMock
<MockModelLoader
> loader(base::Closure(), "top_model.pb");
109 // The model fetch failed.
112 SetModelFetchResponse("blamodel", net::HTTP_INTERNAL_SERVER_ERROR
,
113 net::URLRequestStatus::FAILED
);
114 EXPECT_CALL(loader
, EndFetch(ModelLoader::MODEL_FETCH_FAILED
, _
))
115 .WillOnce(InvokeClosure(loop
.QuitClosure()));
118 Mock::VerifyAndClearExpectations(&loader
);
124 SetModelFetchResponse(std::string(), net::HTTP_OK
,
125 net::URLRequestStatus::SUCCESS
);
126 EXPECT_CALL(loader
, EndFetch(ModelLoader::MODEL_EMPTY
, _
))
127 .WillOnce(InvokeClosure(loop
.QuitClosure()));
130 Mock::VerifyAndClearExpectations(&loader
);
133 // Model is too large.
136 SetModelFetchResponse(std::string(ModelLoader::kMaxModelSizeBytes
+ 1, 'x'),
137 net::HTTP_OK
, net::URLRequestStatus::SUCCESS
);
138 EXPECT_CALL(loader
, EndFetch(ModelLoader::MODEL_TOO_LARGE
, _
))
139 .WillOnce(InvokeClosure(loop
.QuitClosure()));
142 Mock::VerifyAndClearExpectations(&loader
);
145 // Unable to parse the model file.
148 SetModelFetchResponse("Invalid model file", net::HTTP_OK
,
149 net::URLRequestStatus::SUCCESS
);
150 EXPECT_CALL(loader
, EndFetch(ModelLoader::MODEL_PARSE_ERROR
, _
))
151 .WillOnce(InvokeClosure(loop
.QuitClosure()));
154 Mock::VerifyAndClearExpectations(&loader
);
157 // Model that is missing some required fields (missing the version field).
158 ClientSideModel model
;
159 model
.set_max_words_per_term(4);
162 SetModelFetchResponse(model
.SerializePartialAsString(), net::HTTP_OK
,
163 net::URLRequestStatus::SUCCESS
);
164 EXPECT_CALL(loader
, EndFetch(ModelLoader::MODEL_MISSING_FIELDS
, _
))
165 .WillOnce(InvokeClosure(loop
.QuitClosure()));
168 Mock::VerifyAndClearExpectations(&loader
);
171 // Model that points to hashes that don't exist.
172 model
.set_version(10);
173 model
.add_hashes("bla");
174 model
.add_page_term(1); // Should be 0 instead of 1.
177 SetModelFetchResponse(model
.SerializePartialAsString(), net::HTTP_OK
,
178 net::URLRequestStatus::SUCCESS
);
179 EXPECT_CALL(loader
, EndFetch(ModelLoader::MODEL_BAD_HASH_IDS
, _
))
180 .WillOnce(InvokeClosure(loop
.QuitClosure()));
183 Mock::VerifyAndClearExpectations(&loader
);
185 model
.set_page_term(0, 0);
187 // Model version number is wrong.
188 model
.set_version(-1);
191 SetModelFetchResponse(model
.SerializeAsString(), net::HTTP_OK
,
192 net::URLRequestStatus::SUCCESS
);
193 EXPECT_CALL(loader
, EndFetch(ModelLoader::MODEL_INVALID_VERSION_NUMBER
, _
))
194 .WillOnce(InvokeClosure(loop
.QuitClosure()));
197 Mock::VerifyAndClearExpectations(&loader
);
201 model
.set_version(10);
204 SetModelFetchResponse(model
.SerializeAsString(), net::HTTP_OK
,
205 net::URLRequestStatus::SUCCESS
);
206 EXPECT_CALL(loader
, EndFetch(ModelLoader::MODEL_SUCCESS
, _
))
207 .WillOnce(InvokeClosure(loop
.QuitClosure()));
210 Mock::VerifyAndClearExpectations(&loader
);
213 // Model version number is decreasing. Set the model version number of the
214 // model that is currently loaded in the loader object to 11.
215 loader
.model_
.reset(new ClientSideModel(model
));
216 loader
.model_
->set_version(11);
219 SetModelFetchResponse(model
.SerializeAsString(), net::HTTP_OK
,
220 net::URLRequestStatus::SUCCESS
);
221 EXPECT_CALL(loader
, EndFetch(ModelLoader::MODEL_INVALID_VERSION_NUMBER
, _
))
222 .WillOnce(InvokeClosure(loop
.QuitClosure()));
225 Mock::VerifyAndClearExpectations(&loader
);
228 // Model version hasn't changed since the last reload.
229 loader
.model_
->set_version(10);
232 SetModelFetchResponse(model
.SerializeAsString(), net::HTTP_OK
,
233 net::URLRequestStatus::SUCCESS
);
234 EXPECT_CALL(loader
, EndFetch(ModelLoader::MODEL_NOT_CHANGED
, _
))
235 .WillOnce(InvokeClosure(loop
.QuitClosure()));
238 Mock::VerifyAndClearExpectations(&loader
);
242 // Test that a successful reponse will update the renderers
243 TEST_F(ModelLoaderTest
, UpdateRenderersTest
) {
244 // Use runloop for convenient callback detection.
246 StrictMock
<MockModelLoader
> loader(loop
.QuitClosure(), "top_model.pb");
247 EXPECT_CALL(loader
, ScheduleFetch(_
));
248 loader
.ModelLoader::EndFetch(ModelLoader::MODEL_SUCCESS
, base::TimeDelta());
250 Mock::VerifyAndClearExpectations(&loader
);
253 // Test that a one fetch schedules another fetch.
254 TEST_F(ModelLoaderTest
, RescheduleFetchTest
) {
255 StrictMock
<MockModelLoader
> loader(base::Closure(), "top_model.pb");
257 // Zero max_age. Uses default.
258 base::TimeDelta max_age
;
259 EXPECT_CALL(loader
, ScheduleFetch(ModelLoader::kClientModelFetchIntervalMs
));
260 loader
.ModelLoader::EndFetch(ModelLoader::MODEL_NOT_CHANGED
, max_age
);
261 Mock::VerifyAndClearExpectations(&loader
);
263 // Non-zero max_age from header.
264 max_age
= base::TimeDelta::FromMinutes(42);
265 EXPECT_CALL(loader
, ScheduleFetch((max_age
+ base::TimeDelta::FromMinutes(1))
267 loader
.ModelLoader::EndFetch(ModelLoader::MODEL_NOT_CHANGED
, max_age
);
268 Mock::VerifyAndClearExpectations(&loader
);
270 // Non-zero max_age, but failed load should use default interval.
271 max_age
= base::TimeDelta::FromMinutes(42);
272 EXPECT_CALL(loader
, ScheduleFetch(ModelLoader::kClientModelFetchIntervalMs
));
273 loader
.ModelLoader::EndFetch(ModelLoader::MODEL_FETCH_FAILED
, max_age
);
274 Mock::VerifyAndClearExpectations(&loader
);
277 // Test that Finch params control the model names.
278 TEST_F(ModelLoaderTest
, ModelNamesTest
) {
279 // Test the name-templating.
280 EXPECT_EQ(ModelLoader::FillInModelName(true, 3),
281 "client_model_v5_ext_variation_3.pb");
282 EXPECT_EQ(ModelLoader::FillInModelName(false, 5),
283 "client_model_v5_variation_5.pb");
285 // No Finch setup. Should default to 0.
286 scoped_ptr
<ModelLoader
> loader
;
287 loader
.reset(new ModelLoader(base::Closure(), NULL
,
288 false /* is_extended_reporting */));
289 EXPECT_EQ(loader
->name(), "client_model_v5_variation_0.pb");
290 EXPECT_EQ(loader
->url_
.spec(),
291 "https://ssl.gstatic.com/safebrowsing/csd/"
292 "client_model_v5_variation_0.pb");
294 // Model 1, no extended reporting.
295 SetFinchModelNumber(1);
296 loader
.reset(new ModelLoader(base::Closure(), NULL
, false));
297 EXPECT_EQ(loader
->name(), "client_model_v5_variation_1.pb");
299 // Model 2, with extended reporting.
300 SetFinchModelNumber(2);
301 loader
.reset(new ModelLoader(base::Closure(), NULL
, true));
302 EXPECT_EQ(loader
->name(), "client_model_v5_ext_variation_2.pb");
305 TEST_F(ModelLoaderTest
, ModelHasValidHashIds
) {
306 ClientSideModel model
;
307 EXPECT_TRUE(ModelLoader::ModelHasValidHashIds(model
));
308 model
.add_hashes("bla");
309 EXPECT_TRUE(ModelLoader::ModelHasValidHashIds(model
));
310 model
.add_page_term(0);
311 EXPECT_TRUE(ModelLoader::ModelHasValidHashIds(model
));
313 model
.add_page_term(-1);
314 EXPECT_FALSE(ModelLoader::ModelHasValidHashIds(model
));
315 model
.set_page_term(1, 1);
316 EXPECT_FALSE(ModelLoader::ModelHasValidHashIds(model
));
317 model
.set_page_term(1, 0);
318 EXPECT_TRUE(ModelLoader::ModelHasValidHashIds(model
));
321 model
.add_hashes("blu");
322 ClientSideModel::Rule
* rule
= model
.add_rule();
323 rule
->add_feature(0);
324 rule
->add_feature(1);
325 rule
->set_weight(0.1f
);
326 EXPECT_TRUE(ModelLoader::ModelHasValidHashIds(model
));
328 rule
= model
.add_rule();
329 rule
->add_feature(0);
330 rule
->add_feature(1);
331 rule
->add_feature(-1);
332 rule
->set_weight(0.2f
);
333 EXPECT_FALSE(ModelLoader::ModelHasValidHashIds(model
));
335 rule
->set_feature(2, 2);
336 EXPECT_FALSE(ModelLoader::ModelHasValidHashIds(model
));
338 rule
->set_feature(2, 1);
339 EXPECT_TRUE(ModelLoader::ModelHasValidHashIds(model
));
342 } // namespace safe_browsing