1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/dns/dns_test_util.h"
9 #include "base/big_endian.h"
10 #include "base/bind.h"
11 #include "base/memory/weak_ptr.h"
12 #include "base/message_loop/message_loop.h"
13 #include "base/sys_byteorder.h"
14 #include "net/base/dns_util.h"
15 #include "net/base/io_buffer.h"
16 #include "net/base/net_errors.h"
17 #include "net/dns/address_sorter.h"
18 #include "net/dns/dns_query.h"
19 #include "net/dns/dns_response.h"
20 #include "net/dns/dns_transaction.h"
21 #include "testing/gtest/include/gtest/gtest.h"
26 class MockAddressSorter
: public AddressSorter
{
28 virtual ~MockAddressSorter() {}
29 virtual void Sort(const AddressList
& list
,
30 const CallbackType
& callback
) const OVERRIDE
{
32 callback
.Run(true, list
);
36 // A DnsTransaction which uses MockDnsClientRuleList to determine the response.
37 class MockTransaction
: public DnsTransaction
,
38 public base::SupportsWeakPtr
<MockTransaction
> {
40 MockTransaction(const MockDnsClientRuleList
& rules
,
41 const std::string
& hostname
,
43 const DnsTransactionFactory::CallbackType
& callback
)
44 : result_(MockDnsClientRule::FAIL
),
50 // Find the relevant rule which matches |qtype| and prefix of |hostname|.
51 for (size_t i
= 0; i
< rules
.size(); ++i
) {
52 const std::string
& prefix
= rules
[i
].prefix
;
53 if ((rules
[i
].qtype
== qtype
) &&
54 (hostname
.size() >= prefix
.size()) &&
55 (hostname
.compare(0, prefix
.size(), prefix
) == 0)) {
56 result_
= rules
[i
].result
;
57 delayed_
= rules
[i
].delay
;
63 virtual const std::string
& GetHostname() const OVERRIDE
{
67 virtual uint16
GetType() const OVERRIDE
{
71 virtual void Start() OVERRIDE
{
72 EXPECT_FALSE(started_
);
76 // Using WeakPtr to cleanly cancel when transaction is destroyed.
77 base::MessageLoop::current()->PostTask(
78 FROM_HERE
, base::Bind(&MockTransaction::Finish
, AsWeakPtr()));
81 void FinishDelayedTransaction() {
82 EXPECT_TRUE(delayed_
);
87 bool delayed() const { return delayed_
; }
92 case MockDnsClientRule::EMPTY
:
93 case MockDnsClientRule::OK
: {
95 DNSDomainFromDot(hostname_
, &qname
);
96 DnsQuery
query(0, qname
, qtype_
);
99 char* buffer
= response
.io_buffer()->data();
100 int nbytes
= query
.io_buffer()->size();
101 memcpy(buffer
, query
.io_buffer()->data(), nbytes
);
102 dns_protocol::Header
* header
=
103 reinterpret_cast<dns_protocol::Header
*>(buffer
);
104 header
->flags
|= dns_protocol::kFlagResponse
;
106 if (MockDnsClientRule::OK
== result_
) {
107 const uint16 kPointerToQueryName
=
108 static_cast<uint16
>(0xc000 | sizeof(*header
));
110 const uint32 kTTL
= 86400; // One day.
112 // Size of RDATA which is a IPv4 or IPv6 address.
113 size_t rdata_size
= qtype_
== net::dns_protocol::kTypeA
?
114 net::kIPv4AddressSize
: net::kIPv6AddressSize
;
116 // 12 is the sum of sizes of the compressed name reference, TYPE,
117 // CLASS, TTL and RDLENGTH.
118 size_t answer_size
= 12 + rdata_size
;
120 // Write answer with loopback IP address.
121 header
->ancount
= base::HostToNet16(1);
122 base::BigEndianWriter
writer(buffer
+ nbytes
, answer_size
);
123 writer
.WriteU16(kPointerToQueryName
);
124 writer
.WriteU16(qtype_
);
125 writer
.WriteU16(net::dns_protocol::kClassIN
);
126 writer
.WriteU32(kTTL
);
127 writer
.WriteU16(rdata_size
);
128 if (qtype_
== net::dns_protocol::kTypeA
) {
129 char kIPv4Loopback
[] = { 0x7f, 0, 0, 1 };
130 writer
.WriteBytes(kIPv4Loopback
, sizeof(kIPv4Loopback
));
132 char kIPv6Loopback
[] = { 0, 0, 0, 0, 0, 0, 0, 0,
133 0, 0, 0, 0, 0, 0, 0, 1 };
134 writer
.WriteBytes(kIPv6Loopback
, sizeof(kIPv6Loopback
));
136 nbytes
+= answer_size
;
138 EXPECT_TRUE(response
.InitParse(nbytes
, query
));
139 callback_
.Run(this, OK
, &response
);
141 case MockDnsClientRule::FAIL
:
142 callback_
.Run(this, ERR_NAME_NOT_RESOLVED
, NULL
);
144 case MockDnsClientRule::TIMEOUT
:
145 callback_
.Run(this, ERR_DNS_TIMED_OUT
, NULL
);
153 MockDnsClientRule::Result result_
;
154 const std::string hostname_
;
156 DnsTransactionFactory::CallbackType callback_
;
163 // A DnsTransactionFactory which creates MockTransaction.
164 class MockTransactionFactory
: public DnsTransactionFactory
{
166 explicit MockTransactionFactory(const MockDnsClientRuleList
& rules
)
169 virtual ~MockTransactionFactory() {}
171 virtual scoped_ptr
<DnsTransaction
> CreateTransaction(
172 const std::string
& hostname
,
174 const DnsTransactionFactory::CallbackType
& callback
,
175 const BoundNetLog
&) OVERRIDE
{
176 MockTransaction
* transaction
=
177 new MockTransaction(rules_
, hostname
, qtype
, callback
);
178 if (transaction
->delayed())
179 delayed_transactions_
.push_back(transaction
->AsWeakPtr());
180 return scoped_ptr
<DnsTransaction
>(transaction
);
183 void CompleteDelayedTransactions() {
184 DelayedTransactionList old_delayed_transactions
;
185 old_delayed_transactions
.swap(delayed_transactions_
);
186 for (DelayedTransactionList::iterator it
= old_delayed_transactions
.begin();
187 it
!= old_delayed_transactions
.end(); ++it
) {
189 (*it
)->FinishDelayedTransaction();
194 typedef std::vector
<base::WeakPtr
<MockTransaction
> > DelayedTransactionList
;
196 MockDnsClientRuleList rules_
;
197 DelayedTransactionList delayed_transactions_
;
200 MockDnsClient::MockDnsClient(const DnsConfig
& config
,
201 const MockDnsClientRuleList
& rules
)
203 factory_(new MockTransactionFactory(rules
)),
204 address_sorter_(new MockAddressSorter()) {
207 MockDnsClient::~MockDnsClient() {}
209 void MockDnsClient::SetConfig(const DnsConfig
& config
) {
213 const DnsConfig
* MockDnsClient::GetConfig() const {
214 return config_
.IsValid() ? &config_
: NULL
;
217 DnsTransactionFactory
* MockDnsClient::GetTransactionFactory() {
218 return config_
.IsValid() ? factory_
.get() : NULL
;
221 AddressSorter
* MockDnsClient::GetAddressSorter() {
222 return address_sorter_
.get();
225 void MockDnsClient::CompleteDelayedTransactions() {
226 factory_
->CompleteDelayedTransactions();