1 // Copyright 2015 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "components/dom_distiller/core/page_features.h"
10 #include "base/files/file_util.h"
11 #include "base/json/json_reader.h"
12 #include "base/json/json_writer.h"
13 #include "base/memory/scoped_ptr.h"
14 #include "base/path_service.h"
15 #include "testing/gtest/include/gtest/gtest.h"
17 namespace dom_distiller
{
19 // This test uses input data of core features and the output of the training
20 // pipeline's derived feature extraction to ensure that the extraction that is
21 // done in Chromium matches that in the training pipeline.
22 TEST(DomDistillerPageFeaturesTest
, TestCalculateDerivedFeatures
) {
23 base::FilePath dir_source_root
;
24 EXPECT_TRUE(PathService::Get(base::DIR_SOURCE_ROOT
, &dir_source_root
));
25 std::string input_data
;
26 ASSERT_TRUE(base::ReadFileToString(
27 dir_source_root
.AppendASCII(
28 "components/test/data/dom_distiller/core_features.json"),
30 std::string expected_output_data
;
31 // This file contains the output from the calculation of derived features in
32 // the training pipeline.
33 ASSERT_TRUE(base::ReadFileToString(
34 dir_source_root
.AppendASCII(
35 "components/test/data/dom_distiller/derived_features.json"),
36 &expected_output_data
));
38 scoped_ptr
<base::Value
> input_json
= base::JSONReader::Read(input_data
);
39 ASSERT_TRUE(input_json
);
41 scoped_ptr
<base::Value
> expected_output_json(
42 base::JSONReader::DeprecatedRead(expected_output_data
));
43 ASSERT_TRUE(expected_output_json
);
45 base::ListValue
* input_entries
;
46 ASSERT_TRUE(input_json
->GetAsList(&input_entries
));
47 ASSERT_GT(input_entries
->GetSize(), 0u);
49 base::ListValue
* expected_output_entries
;
50 ASSERT_TRUE(expected_output_json
->GetAsList(&expected_output_entries
));
51 ASSERT_EQ(expected_output_entries
->GetSize(), input_entries
->GetSize());
53 // In the output, the features list is a sequence of labels followed by values
54 // (so labels at even indices, values at odd indices).
55 base::DictionaryValue
* entry
;
56 base::ListValue
* derived_features
;
57 ASSERT_TRUE(expected_output_entries
->GetDictionary(0, &entry
));
58 ASSERT_TRUE(entry
->GetList("features", &derived_features
));
59 std::vector
<std::string
> labels
;
60 for (size_t i
= 0; i
< derived_features
->GetSize(); i
+= 2) {
62 ASSERT_TRUE(derived_features
->GetString(i
, &label
));
63 labels
.push_back(label
);
66 for (size_t i
= 0; i
< input_entries
->GetSize(); ++i
) {
67 base::DictionaryValue
* core_features
;
68 ASSERT_TRUE(input_entries
->GetDictionary(i
, &entry
));
69 ASSERT_TRUE(entry
->GetDictionary("features", &core_features
));
70 // CalculateDerivedFeaturesFromJSON expects a base::Value of the stringified
71 // JSON (and not a base::Value of the JSON itself)
72 std::string stringified_json
;
73 ASSERT_TRUE(base::JSONWriter::Write(*core_features
, &stringified_json
));
74 scoped_ptr
<base::Value
> stringified_value(
75 new base::StringValue(stringified_json
));
76 std::vector
<double> derived(
77 CalculateDerivedFeaturesFromJSON(stringified_value
.get()));
79 ASSERT_EQ(labels
.size(), derived
.size());
80 ASSERT_TRUE(expected_output_entries
->GetDictionary(i
, &entry
));
81 ASSERT_TRUE(entry
->GetList("features", &derived_features
));
82 std::string entry_url
;
83 ASSERT_TRUE(entry
->GetString("url", &entry_url
));
84 for (size_t j
= 0, value_index
= 1; j
< derived
.size();
85 ++j
, value_index
+= 2) {
86 double expected_value
;
87 if (!derived_features
->GetDouble(value_index
, &expected_value
)) {
89 ASSERT_TRUE(derived_features
->GetBoolean(value_index
, &bool_value
));
90 expected_value
= bool_value
? 1.0 : 0.0;
92 EXPECT_DOUBLE_EQ(derived
[j
], expected_value
)
93 << "incorrect value for entry with url " << entry_url
94 << " for derived feature " << labels
[j
];