1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/dns/dns_test_util.h"
9 #include "base/big_endian.h"
10 #include "base/bind.h"
11 #include "base/location.h"
12 #include "base/memory/weak_ptr.h"
13 #include "base/single_thread_task_runner.h"
14 #include "base/sys_byteorder.h"
15 #include "base/thread_task_runner_handle.h"
16 #include "net/base/dns_util.h"
17 #include "net/base/io_buffer.h"
18 #include "net/base/net_errors.h"
19 #include "net/dns/address_sorter.h"
20 #include "net/dns/dns_query.h"
21 #include "net/dns/dns_response.h"
22 #include "net/dns/dns_transaction.h"
23 #include "testing/gtest/include/gtest/gtest.h"
28 class MockAddressSorter
: public AddressSorter
{
30 ~MockAddressSorter() override
{}
31 void Sort(const AddressList
& list
,
32 const CallbackType
& callback
) const override
{
34 callback
.Run(true, list
);
38 // A DnsTransaction which uses MockDnsClientRuleList to determine the response.
39 class MockTransaction
: public DnsTransaction
,
40 public base::SupportsWeakPtr
<MockTransaction
> {
42 MockTransaction(const MockDnsClientRuleList
& rules
,
43 const std::string
& hostname
,
45 const DnsTransactionFactory::CallbackType
& callback
)
46 : result_(MockDnsClientRule::FAIL
),
52 // Find the relevant rule which matches |qtype| and prefix of |hostname|.
53 for (size_t i
= 0; i
< rules
.size(); ++i
) {
54 const std::string
& prefix
= rules
[i
].prefix
;
55 if ((rules
[i
].qtype
== qtype
) &&
56 (hostname
.size() >= prefix
.size()) &&
57 (hostname
.compare(0, prefix
.size(), prefix
) == 0)) {
58 result_
= rules
[i
].result
;
59 delayed_
= rules
[i
].delay
;
65 const std::string
& GetHostname() const override
{ return hostname_
; }
67 uint16
GetType() const override
{ return qtype_
; }
69 void Start() override
{
70 EXPECT_FALSE(started_
);
74 // Using WeakPtr to cleanly cancel when transaction is destroyed.
75 base::ThreadTaskRunnerHandle::Get()->PostTask(
76 FROM_HERE
, base::Bind(&MockTransaction::Finish
, AsWeakPtr()));
79 void FinishDelayedTransaction() {
80 EXPECT_TRUE(delayed_
);
85 bool delayed() const { return delayed_
; }
90 case MockDnsClientRule::EMPTY
:
91 case MockDnsClientRule::OK
: {
93 DNSDomainFromDot(hostname_
, &qname
);
94 DnsQuery
query(0, qname
, qtype_
);
97 char* buffer
= response
.io_buffer()->data();
98 int nbytes
= query
.io_buffer()->size();
99 memcpy(buffer
, query
.io_buffer()->data(), nbytes
);
100 dns_protocol::Header
* header
=
101 reinterpret_cast<dns_protocol::Header
*>(buffer
);
102 header
->flags
|= dns_protocol::kFlagResponse
;
104 if (MockDnsClientRule::OK
== result_
) {
105 const uint16 kPointerToQueryName
=
106 static_cast<uint16
>(0xc000 | sizeof(*header
));
108 const uint32 kTTL
= 86400; // One day.
110 // Size of RDATA which is a IPv4 or IPv6 address.
111 size_t rdata_size
= qtype_
== dns_protocol::kTypeA
? kIPv4AddressSize
114 // 12 is the sum of sizes of the compressed name reference, TYPE,
115 // CLASS, TTL and RDLENGTH.
116 size_t answer_size
= 12 + rdata_size
;
118 // Write answer with loopback IP address.
119 header
->ancount
= base::HostToNet16(1);
120 base::BigEndianWriter
writer(buffer
+ nbytes
, answer_size
);
121 writer
.WriteU16(kPointerToQueryName
);
122 writer
.WriteU16(qtype_
);
123 writer
.WriteU16(dns_protocol::kClassIN
);
124 writer
.WriteU32(kTTL
);
125 writer
.WriteU16(rdata_size
);
126 if (qtype_
== dns_protocol::kTypeA
) {
127 char kIPv4Loopback
[] = { 0x7f, 0, 0, 1 };
128 writer
.WriteBytes(kIPv4Loopback
, sizeof(kIPv4Loopback
));
130 char kIPv6Loopback
[] = { 0, 0, 0, 0, 0, 0, 0, 0,
131 0, 0, 0, 0, 0, 0, 0, 1 };
132 writer
.WriteBytes(kIPv6Loopback
, sizeof(kIPv6Loopback
));
134 nbytes
+= answer_size
;
136 EXPECT_TRUE(response
.InitParse(nbytes
, query
));
137 callback_
.Run(this, OK
, &response
);
139 case MockDnsClientRule::FAIL
:
140 callback_
.Run(this, ERR_NAME_NOT_RESOLVED
, NULL
);
142 case MockDnsClientRule::TIMEOUT
:
143 callback_
.Run(this, ERR_DNS_TIMED_OUT
, NULL
);
151 MockDnsClientRule::Result result_
;
152 const std::string hostname_
;
154 DnsTransactionFactory::CallbackType callback_
;
161 // A DnsTransactionFactory which creates MockTransaction.
162 class MockTransactionFactory
: public DnsTransactionFactory
{
164 explicit MockTransactionFactory(const MockDnsClientRuleList
& rules
)
167 ~MockTransactionFactory() override
{}
169 scoped_ptr
<DnsTransaction
> CreateTransaction(
170 const std::string
& hostname
,
172 const DnsTransactionFactory::CallbackType
& callback
,
173 const BoundNetLog
&) override
{
174 MockTransaction
* transaction
=
175 new MockTransaction(rules_
, hostname
, qtype
, callback
);
176 if (transaction
->delayed())
177 delayed_transactions_
.push_back(transaction
->AsWeakPtr());
178 return scoped_ptr
<DnsTransaction
>(transaction
);
181 void CompleteDelayedTransactions() {
182 DelayedTransactionList old_delayed_transactions
;
183 old_delayed_transactions
.swap(delayed_transactions_
);
184 for (DelayedTransactionList::iterator it
= old_delayed_transactions
.begin();
185 it
!= old_delayed_transactions
.end(); ++it
) {
187 (*it
)->FinishDelayedTransaction();
192 typedef std::vector
<base::WeakPtr
<MockTransaction
> > DelayedTransactionList
;
194 MockDnsClientRuleList rules_
;
195 DelayedTransactionList delayed_transactions_
;
198 MockDnsClient::MockDnsClient(const DnsConfig
& config
,
199 const MockDnsClientRuleList
& rules
)
201 factory_(new MockTransactionFactory(rules
)),
202 address_sorter_(new MockAddressSorter()) {
205 MockDnsClient::~MockDnsClient() {}
207 void MockDnsClient::SetConfig(const DnsConfig
& config
) {
211 const DnsConfig
* MockDnsClient::GetConfig() const {
212 return config_
.IsValid() ? &config_
: NULL
;
215 DnsTransactionFactory
* MockDnsClient::GetTransactionFactory() {
216 return config_
.IsValid() ? factory_
.get() : NULL
;
219 AddressSorter
* MockDnsClient::GetAddressSorter() {
220 return address_sorter_
.get();
223 void MockDnsClient::CompleteDelayedTransactions() {
224 factory_
->CompleteDelayedTransactions();