Update broken references to image assets
[chromium-blink-merge.git] / chrome / browser / safe_browsing / client_side_model_loader_unittest.cc
blob4679b82f317e2d10cd1f14c7fee3077909301f84
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include <map>
6 #include <string>
8 #include "base/callback.h"
9 #include "base/memory/scoped_ptr.h"
10 #include "base/metrics/field_trial.h"
11 #include "base/run_loop.h"
12 #include "base/strings/string_number_conversions.h"
13 #include "base/time/time.h"
14 #include "chrome/browser/safe_browsing/client_side_model_loader.h"
15 #include "chrome/common/safe_browsing/client_model.pb.h"
16 #include "chrome/common/safe_browsing/csd.pb.h"
17 #include "components/variations/variations_associated_data.h"
18 #include "content/public/test/test_browser_thread_bundle.h"
19 #include "net/http/http_status_code.h"
20 #include "net/url_request/test_url_fetcher_factory.h"
21 #include "net/url_request/url_request_status.h"
22 #include "testing/gmock/include/gmock/gmock.h"
23 #include "testing/gtest/include/gtest/gtest.h"
24 #include "url/gurl.h"
26 using ::testing::Invoke;
27 using ::testing::Mock;
28 using ::testing::StrictMock;
29 using ::testing::_;
31 namespace safe_browsing {
32 namespace {
34 class MockModelLoader : public ModelLoader {
35 public:
36 explicit MockModelLoader(base::Closure update_renderers_callback,
37 const std::string model_name)
38 : ModelLoader(update_renderers_callback, model_name) {}
39 ~MockModelLoader() override {}
41 MOCK_METHOD1(ScheduleFetch, void(int64));
42 MOCK_METHOD2(EndFetch, void(ClientModelStatus, base::TimeDelta));
44 private:
45 DISALLOW_COPY_AND_ASSIGN(MockModelLoader);
48 } // namespace
50 class ModelLoaderTest : public testing::Test {
51 protected:
52 ModelLoaderTest()
53 : factory_(new net::FakeURLFetcherFactory(NULL)),
54 field_trials_(new base::FieldTrialList(NULL)) {}
56 void SetUp() override {
57 variations::testing::ClearAllVariationIDs();
58 variations::testing::ClearAllVariationParams();
61 // Set up the finch experiment to control the model number
62 // used in the model URL. This clears all existing state.
63 void SetFinchModelNumber(int model_number) {
64 // Destroy the existing FieldTrialList before creating a new one to avoid
65 // a DCHECK.
66 field_trials_.reset();
67 field_trials_.reset(new base::FieldTrialList(NULL));
68 variations::testing::ClearAllVariationIDs();
69 variations::testing::ClearAllVariationParams();
71 const std::string group_name = "ModelFoo"; // Not used in CSD code.
72 ASSERT_TRUE(base::FieldTrialList::CreateFieldTrial(
73 ModelLoader::kClientModelFinchExperiment, group_name));
75 std::map<std::string, std::string> params;
76 params[ModelLoader::kClientModelFinchParam] =
77 base::IntToString(model_number);
79 ASSERT_TRUE(variations::AssociateVariationParams(
80 ModelLoader::kClientModelFinchExperiment, group_name, params));
83 // Set the URL for future SetModelFetchResponse() calls.
84 void SetModelUrl(const ModelLoader& loader) { model_url_ = loader.url_; }
86 void SetModelFetchResponse(std::string response_data,
87 net::HttpStatusCode response_code,
88 net::URLRequestStatus::Status status) {
89 CHECK(model_url_.is_valid());
90 factory_->SetFakeResponse(model_url_, response_data, response_code, status);
93 private:
94 content::TestBrowserThreadBundle thread_bundle_;
95 scoped_ptr<net::FakeURLFetcherFactory> factory_;
96 scoped_ptr<base::FieldTrialList> field_trials_;
97 GURL model_url_;
100 ACTION_P(InvokeClosure, closure) {
101 closure.Run();
104 // Test the reponse to many variations of model responses.
105 TEST_F(ModelLoaderTest, FetchModelTest) {
106 StrictMock<MockModelLoader> loader(base::Closure(), "top_model.pb");
107 SetModelUrl(loader);
109 // The model fetch failed.
111 base::RunLoop loop;
112 SetModelFetchResponse("blamodel", net::HTTP_INTERNAL_SERVER_ERROR,
113 net::URLRequestStatus::FAILED);
114 EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_FETCH_FAILED, _))
115 .WillOnce(InvokeClosure(loop.QuitClosure()));
116 loader.StartFetch();
117 loop.Run();
118 Mock::VerifyAndClearExpectations(&loader);
121 // Empty model file.
123 base::RunLoop loop;
124 SetModelFetchResponse(std::string(), net::HTTP_OK,
125 net::URLRequestStatus::SUCCESS);
126 EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_EMPTY, _))
127 .WillOnce(InvokeClosure(loop.QuitClosure()));
128 loader.StartFetch();
129 loop.Run();
130 Mock::VerifyAndClearExpectations(&loader);
133 // Model is too large.
135 base::RunLoop loop;
136 SetModelFetchResponse(std::string(ModelLoader::kMaxModelSizeBytes + 1, 'x'),
137 net::HTTP_OK, net::URLRequestStatus::SUCCESS);
138 EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_TOO_LARGE, _))
139 .WillOnce(InvokeClosure(loop.QuitClosure()));
140 loader.StartFetch();
141 loop.Run();
142 Mock::VerifyAndClearExpectations(&loader);
145 // Unable to parse the model file.
147 base::RunLoop loop;
148 SetModelFetchResponse("Invalid model file", net::HTTP_OK,
149 net::URLRequestStatus::SUCCESS);
150 EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_PARSE_ERROR, _))
151 .WillOnce(InvokeClosure(loop.QuitClosure()));
152 loader.StartFetch();
153 loop.Run();
154 Mock::VerifyAndClearExpectations(&loader);
157 // Model that is missing some required fields (missing the version field).
158 ClientSideModel model;
159 model.set_max_words_per_term(4);
161 base::RunLoop loop;
162 SetModelFetchResponse(model.SerializePartialAsString(), net::HTTP_OK,
163 net::URLRequestStatus::SUCCESS);
164 EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_MISSING_FIELDS, _))
165 .WillOnce(InvokeClosure(loop.QuitClosure()));
166 loader.StartFetch();
167 loop.Run();
168 Mock::VerifyAndClearExpectations(&loader);
171 // Model that points to hashes that don't exist.
172 model.set_version(10);
173 model.add_hashes("bla");
174 model.add_page_term(1); // Should be 0 instead of 1.
176 base::RunLoop loop;
177 SetModelFetchResponse(model.SerializePartialAsString(), net::HTTP_OK,
178 net::URLRequestStatus::SUCCESS);
179 EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_BAD_HASH_IDS, _))
180 .WillOnce(InvokeClosure(loop.QuitClosure()));
181 loader.StartFetch();
182 loop.Run();
183 Mock::VerifyAndClearExpectations(&loader);
185 model.set_page_term(0, 0);
187 // Model version number is wrong.
188 model.set_version(-1);
190 base::RunLoop loop;
191 SetModelFetchResponse(model.SerializeAsString(), net::HTTP_OK,
192 net::URLRequestStatus::SUCCESS);
193 EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_INVALID_VERSION_NUMBER, _))
194 .WillOnce(InvokeClosure(loop.QuitClosure()));
195 loader.StartFetch();
196 loop.Run();
197 Mock::VerifyAndClearExpectations(&loader);
200 // Normal model.
201 model.set_version(10);
203 base::RunLoop loop;
204 SetModelFetchResponse(model.SerializeAsString(), net::HTTP_OK,
205 net::URLRequestStatus::SUCCESS);
206 EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_SUCCESS, _))
207 .WillOnce(InvokeClosure(loop.QuitClosure()));
208 loader.StartFetch();
209 loop.Run();
210 Mock::VerifyAndClearExpectations(&loader);
213 // Model version number is decreasing. Set the model version number of the
214 // model that is currently loaded in the loader object to 11.
215 loader.model_.reset(new ClientSideModel(model));
216 loader.model_->set_version(11);
218 base::RunLoop loop;
219 SetModelFetchResponse(model.SerializeAsString(), net::HTTP_OK,
220 net::URLRequestStatus::SUCCESS);
221 EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_INVALID_VERSION_NUMBER, _))
222 .WillOnce(InvokeClosure(loop.QuitClosure()));
223 loader.StartFetch();
224 loop.Run();
225 Mock::VerifyAndClearExpectations(&loader);
228 // Model version hasn't changed since the last reload.
229 loader.model_->set_version(10);
231 base::RunLoop loop;
232 SetModelFetchResponse(model.SerializeAsString(), net::HTTP_OK,
233 net::URLRequestStatus::SUCCESS);
234 EXPECT_CALL(loader, EndFetch(ModelLoader::MODEL_NOT_CHANGED, _))
235 .WillOnce(InvokeClosure(loop.QuitClosure()));
236 loader.StartFetch();
237 loop.Run();
238 Mock::VerifyAndClearExpectations(&loader);
242 // Test that a successful reponse will update the renderers
243 TEST_F(ModelLoaderTest, UpdateRenderersTest) {
244 // Use runloop for convenient callback detection.
245 base::RunLoop loop;
246 StrictMock<MockModelLoader> loader(loop.QuitClosure(), "top_model.pb");
247 EXPECT_CALL(loader, ScheduleFetch(_));
248 loader.ModelLoader::EndFetch(ModelLoader::MODEL_SUCCESS, base::TimeDelta());
249 loop.Run();
250 Mock::VerifyAndClearExpectations(&loader);
253 // Test that a one fetch schedules another fetch.
254 TEST_F(ModelLoaderTest, RescheduleFetchTest) {
255 StrictMock<MockModelLoader> loader(base::Closure(), "top_model.pb");
257 // Zero max_age. Uses default.
258 base::TimeDelta max_age;
259 EXPECT_CALL(loader, ScheduleFetch(ModelLoader::kClientModelFetchIntervalMs));
260 loader.ModelLoader::EndFetch(ModelLoader::MODEL_NOT_CHANGED, max_age);
261 Mock::VerifyAndClearExpectations(&loader);
263 // Non-zero max_age from header.
264 max_age = base::TimeDelta::FromMinutes(42);
265 EXPECT_CALL(loader, ScheduleFetch((max_age + base::TimeDelta::FromMinutes(1))
266 .InMilliseconds()));
267 loader.ModelLoader::EndFetch(ModelLoader::MODEL_NOT_CHANGED, max_age);
268 Mock::VerifyAndClearExpectations(&loader);
270 // Non-zero max_age, but failed load should use default interval.
271 max_age = base::TimeDelta::FromMinutes(42);
272 EXPECT_CALL(loader, ScheduleFetch(ModelLoader::kClientModelFetchIntervalMs));
273 loader.ModelLoader::EndFetch(ModelLoader::MODEL_FETCH_FAILED, max_age);
274 Mock::VerifyAndClearExpectations(&loader);
277 // Test that Finch params control the model names.
278 TEST_F(ModelLoaderTest, ModelNamesTest) {
279 // Test the name-templating.
280 EXPECT_EQ(ModelLoader::FillInModelName(true, 3),
281 "client_model_v5_ext_variation_3.pb");
282 EXPECT_EQ(ModelLoader::FillInModelName(false, 5),
283 "client_model_v5_variation_5.pb");
285 // No Finch setup. Should default to 0.
286 scoped_ptr<ModelLoader> loader;
287 loader.reset(new ModelLoader(base::Closure(), NULL,
288 false /* is_extended_reporting */));
289 EXPECT_EQ(loader->name(), "client_model_v5_variation_0.pb");
290 EXPECT_EQ(loader->url_.spec(),
291 "https://ssl.gstatic.com/safebrowsing/csd/"
292 "client_model_v5_variation_0.pb");
294 // Model 1, no extended reporting.
295 SetFinchModelNumber(1);
296 loader.reset(new ModelLoader(base::Closure(), NULL, false));
297 EXPECT_EQ(loader->name(), "client_model_v5_variation_1.pb");
299 // Model 2, with extended reporting.
300 SetFinchModelNumber(2);
301 loader.reset(new ModelLoader(base::Closure(), NULL, true));
302 EXPECT_EQ(loader->name(), "client_model_v5_ext_variation_2.pb");
305 TEST_F(ModelLoaderTest, ModelHasValidHashIds) {
306 ClientSideModel model;
307 EXPECT_TRUE(ModelLoader::ModelHasValidHashIds(model));
308 model.add_hashes("bla");
309 EXPECT_TRUE(ModelLoader::ModelHasValidHashIds(model));
310 model.add_page_term(0);
311 EXPECT_TRUE(ModelLoader::ModelHasValidHashIds(model));
313 model.add_page_term(-1);
314 EXPECT_FALSE(ModelLoader::ModelHasValidHashIds(model));
315 model.set_page_term(1, 1);
316 EXPECT_FALSE(ModelLoader::ModelHasValidHashIds(model));
317 model.set_page_term(1, 0);
318 EXPECT_TRUE(ModelLoader::ModelHasValidHashIds(model));
320 // Test bad rules.
321 model.add_hashes("blu");
322 ClientSideModel::Rule* rule = model.add_rule();
323 rule->add_feature(0);
324 rule->add_feature(1);
325 rule->set_weight(0.1f);
326 EXPECT_TRUE(ModelLoader::ModelHasValidHashIds(model));
328 rule = model.add_rule();
329 rule->add_feature(0);
330 rule->add_feature(1);
331 rule->add_feature(-1);
332 rule->set_weight(0.2f);
333 EXPECT_FALSE(ModelLoader::ModelHasValidHashIds(model));
335 rule->set_feature(2, 2);
336 EXPECT_FALSE(ModelLoader::ModelHasValidHashIds(model));
338 rule->set_feature(2, 1);
339 EXPECT_TRUE(ModelLoader::ModelHasValidHashIds(model));
342 } // namespace safe_browsing