1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/dns/dns_test_util.h"
9 #include "base/big_endian.h"
10 #include "base/bind.h"
11 #include "base/memory/weak_ptr.h"
12 #include "base/message_loop/message_loop.h"
13 #include "base/sys_byteorder.h"
14 #include "net/base/dns_util.h"
15 #include "net/base/io_buffer.h"
16 #include "net/base/net_errors.h"
17 #include "net/dns/address_sorter.h"
18 #include "net/dns/dns_query.h"
19 #include "net/dns/dns_response.h"
20 #include "net/dns/dns_transaction.h"
21 #include "testing/gtest/include/gtest/gtest.h"
26 class MockAddressSorter
: public AddressSorter
{
28 ~MockAddressSorter() override
{}
29 void Sort(const AddressList
& list
,
30 const CallbackType
& callback
) const override
{
32 callback
.Run(true, list
);
36 // A DnsTransaction which uses MockDnsClientRuleList to determine the response.
37 class MockTransaction
: public DnsTransaction
,
38 public base::SupportsWeakPtr
<MockTransaction
> {
40 MockTransaction(const MockDnsClientRuleList
& rules
,
41 const std::string
& hostname
,
43 const DnsTransactionFactory::CallbackType
& callback
)
44 : result_(MockDnsClientRule::FAIL
),
50 // Find the relevant rule which matches |qtype| and prefix of |hostname|.
51 for (size_t i
= 0; i
< rules
.size(); ++i
) {
52 const std::string
& prefix
= rules
[i
].prefix
;
53 if ((rules
[i
].qtype
== qtype
) &&
54 (hostname
.size() >= prefix
.size()) &&
55 (hostname
.compare(0, prefix
.size(), prefix
) == 0)) {
56 result_
= rules
[i
].result
;
57 delayed_
= rules
[i
].delay
;
63 const std::string
& GetHostname() const override
{ return hostname_
; }
65 uint16
GetType() const override
{ return qtype_
; }
67 void Start() override
{
68 EXPECT_FALSE(started_
);
72 // Using WeakPtr to cleanly cancel when transaction is destroyed.
73 base::MessageLoop::current()->PostTask(
74 FROM_HERE
, base::Bind(&MockTransaction::Finish
, AsWeakPtr()));
77 void FinishDelayedTransaction() {
78 EXPECT_TRUE(delayed_
);
83 bool delayed() const { return delayed_
; }
88 case MockDnsClientRule::EMPTY
:
89 case MockDnsClientRule::OK
: {
91 DNSDomainFromDot(hostname_
, &qname
);
92 DnsQuery
query(0, qname
, qtype_
);
95 char* buffer
= response
.io_buffer()->data();
96 int nbytes
= query
.io_buffer()->size();
97 memcpy(buffer
, query
.io_buffer()->data(), nbytes
);
98 dns_protocol::Header
* header
=
99 reinterpret_cast<dns_protocol::Header
*>(buffer
);
100 header
->flags
|= dns_protocol::kFlagResponse
;
102 if (MockDnsClientRule::OK
== result_
) {
103 const uint16 kPointerToQueryName
=
104 static_cast<uint16
>(0xc000 | sizeof(*header
));
106 const uint32 kTTL
= 86400; // One day.
108 // Size of RDATA which is a IPv4 or IPv6 address.
109 size_t rdata_size
= qtype_
== dns_protocol::kTypeA
? kIPv4AddressSize
112 // 12 is the sum of sizes of the compressed name reference, TYPE,
113 // CLASS, TTL and RDLENGTH.
114 size_t answer_size
= 12 + rdata_size
;
116 // Write answer with loopback IP address.
117 header
->ancount
= base::HostToNet16(1);
118 base::BigEndianWriter
writer(buffer
+ nbytes
, answer_size
);
119 writer
.WriteU16(kPointerToQueryName
);
120 writer
.WriteU16(qtype_
);
121 writer
.WriteU16(dns_protocol::kClassIN
);
122 writer
.WriteU32(kTTL
);
123 writer
.WriteU16(rdata_size
);
124 if (qtype_
== dns_protocol::kTypeA
) {
125 char kIPv4Loopback
[] = { 0x7f, 0, 0, 1 };
126 writer
.WriteBytes(kIPv4Loopback
, sizeof(kIPv4Loopback
));
128 char kIPv6Loopback
[] = { 0, 0, 0, 0, 0, 0, 0, 0,
129 0, 0, 0, 0, 0, 0, 0, 1 };
130 writer
.WriteBytes(kIPv6Loopback
, sizeof(kIPv6Loopback
));
132 nbytes
+= answer_size
;
134 EXPECT_TRUE(response
.InitParse(nbytes
, query
));
135 callback_
.Run(this, OK
, &response
);
137 case MockDnsClientRule::FAIL
:
138 callback_
.Run(this, ERR_NAME_NOT_RESOLVED
, NULL
);
140 case MockDnsClientRule::TIMEOUT
:
141 callback_
.Run(this, ERR_DNS_TIMED_OUT
, NULL
);
149 MockDnsClientRule::Result result_
;
150 const std::string hostname_
;
152 DnsTransactionFactory::CallbackType callback_
;
159 // A DnsTransactionFactory which creates MockTransaction.
160 class MockTransactionFactory
: public DnsTransactionFactory
{
162 explicit MockTransactionFactory(const MockDnsClientRuleList
& rules
)
165 ~MockTransactionFactory() override
{}
167 scoped_ptr
<DnsTransaction
> CreateTransaction(
168 const std::string
& hostname
,
170 const DnsTransactionFactory::CallbackType
& callback
,
171 const BoundNetLog
&) override
{
172 MockTransaction
* transaction
=
173 new MockTransaction(rules_
, hostname
, qtype
, callback
);
174 if (transaction
->delayed())
175 delayed_transactions_
.push_back(transaction
->AsWeakPtr());
176 return scoped_ptr
<DnsTransaction
>(transaction
);
179 void CompleteDelayedTransactions() {
180 DelayedTransactionList old_delayed_transactions
;
181 old_delayed_transactions
.swap(delayed_transactions_
);
182 for (DelayedTransactionList::iterator it
= old_delayed_transactions
.begin();
183 it
!= old_delayed_transactions
.end(); ++it
) {
185 (*it
)->FinishDelayedTransaction();
190 typedef std::vector
<base::WeakPtr
<MockTransaction
> > DelayedTransactionList
;
192 MockDnsClientRuleList rules_
;
193 DelayedTransactionList delayed_transactions_
;
196 MockDnsClient::MockDnsClient(const DnsConfig
& config
,
197 const MockDnsClientRuleList
& rules
)
199 factory_(new MockTransactionFactory(rules
)),
200 address_sorter_(new MockAddressSorter()) {
203 MockDnsClient::~MockDnsClient() {}
205 void MockDnsClient::SetConfig(const DnsConfig
& config
) {
209 const DnsConfig
* MockDnsClient::GetConfig() const {
210 return config_
.IsValid() ? &config_
: NULL
;
213 DnsTransactionFactory
* MockDnsClient::GetTransactionFactory() {
214 return config_
.IsValid() ? factory_
.get() : NULL
;
217 AddressSorter
* MockDnsClient::GetAddressSorter() {
218 return address_sorter_
.get();
221 void MockDnsClient::CompleteDelayedTransactions() {
222 factory_
->CompleteDelayedTransactions();