1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/dns/dns_test_util.h"
10 #include "base/memory/weak_ptr.h"
11 #include "base/message_loop.h"
12 #include "base/sys_byteorder.h"
13 #include "net/base/big_endian.h"
14 #include "net/base/dns_util.h"
15 #include "net/base/io_buffer.h"
16 #include "net/base/net_errors.h"
17 #include "net/dns/address_sorter.h"
18 #include "net/dns/dns_client.h"
19 #include "net/dns/dns_config_service.h"
20 #include "net/dns/dns_protocol.h"
21 #include "net/dns/dns_query.h"
22 #include "net/dns/dns_response.h"
23 #include "net/dns/dns_transaction.h"
24 #include "testing/gtest/include/gtest/gtest.h"
29 // A DnsTransaction which uses MockDnsClientRuleList to determine the response.
30 class MockTransaction
: public DnsTransaction
,
31 public base::SupportsWeakPtr
<MockTransaction
> {
33 MockTransaction(const MockDnsClientRuleList
& rules
,
34 const std::string
& hostname
,
36 const DnsTransactionFactory::CallbackType
& callback
)
37 : result_(MockDnsClientRule::FAIL_SYNC
),
42 // Find the relevant rule which matches |qtype| and prefix of |hostname|.
43 for (size_t i
= 0; i
< rules
.size(); ++i
) {
44 const std::string
& prefix
= rules
[i
].prefix
;
45 if ((rules
[i
].qtype
== qtype
) &&
46 (hostname
.size() >= prefix
.size()) &&
47 (hostname
.compare(0, prefix
.size(), prefix
) == 0)) {
48 result_
= rules
[i
].result
;
54 virtual const std::string
& GetHostname() const OVERRIDE
{
58 virtual uint16
GetType() const OVERRIDE
{
62 virtual int Start() OVERRIDE
{
63 EXPECT_FALSE(started_
);
65 if (MockDnsClientRule::FAIL_SYNC
== result_
)
66 return ERR_NAME_NOT_RESOLVED
;
67 // Using WeakPtr to cleanly cancel when transaction is destroyed.
68 MessageLoop::current()->PostTask(
70 base::Bind(&MockTransaction::Finish
, AsWeakPtr()));
71 return ERR_IO_PENDING
;
77 case MockDnsClientRule::EMPTY
:
78 case MockDnsClientRule::OK
: {
80 DNSDomainFromDot(hostname_
, &qname
);
81 DnsQuery
query(0, qname
, qtype_
);
84 char* buffer
= response
.io_buffer()->data();
85 int nbytes
= query
.io_buffer()->size();
86 memcpy(buffer
, query
.io_buffer()->data(), nbytes
);
87 dns_protocol::Header
* header
=
88 reinterpret_cast<dns_protocol::Header
*>(buffer
);
89 header
->flags
|= dns_protocol::kFlagResponse
;
91 if (MockDnsClientRule::OK
== result_
) {
92 const uint16 kPointerToQueryName
=
93 static_cast<uint16
>(0xc000 | sizeof(*header
));
95 const uint32 kTTL
= 86400; // One day.
97 // Size of RDATA which is a IPv4 or IPv6 address.
98 size_t rdata_size
= qtype_
== net::dns_protocol::kTypeA
?
99 net::kIPv4AddressSize
: net::kIPv6AddressSize
;
101 // 12 is the sum of sizes of the compressed name reference, TYPE,
102 // CLASS, TTL and RDLENGTH.
103 size_t answer_size
= 12 + rdata_size
;
105 // Write answer with loopback IP address.
106 header
->ancount
= base::HostToNet16(1);
107 BigEndianWriter
writer(buffer
+ nbytes
, answer_size
);
108 writer
.WriteU16(kPointerToQueryName
);
109 writer
.WriteU16(qtype_
);
110 writer
.WriteU16(net::dns_protocol::kClassIN
);
111 writer
.WriteU32(kTTL
);
112 writer
.WriteU16(rdata_size
);
113 if (qtype_
== net::dns_protocol::kTypeA
) {
114 char kIPv4Loopback
[] = { 0x7f, 0, 0, 1 };
115 writer
.WriteBytes(kIPv4Loopback
, sizeof(kIPv4Loopback
));
117 char kIPv6Loopback
[] = { 0, 0, 0, 0, 0, 0, 0, 0,
118 0, 0, 0, 0, 0, 0, 0, 1 };
119 writer
.WriteBytes(kIPv6Loopback
, sizeof(kIPv6Loopback
));
121 nbytes
+= answer_size
;
123 EXPECT_TRUE(response
.InitParse(nbytes
, query
));
124 callback_
.Run(this, OK
, &response
);
126 case MockDnsClientRule::FAIL_ASYNC
:
127 callback_
.Run(this, ERR_NAME_NOT_RESOLVED
, NULL
);
129 case MockDnsClientRule::TIMEOUT
:
130 callback_
.Run(this, ERR_DNS_TIMED_OUT
, NULL
);
138 MockDnsClientRule::Result result_
;
139 const std::string hostname_
;
141 DnsTransactionFactory::CallbackType callback_
;
146 // A DnsTransactionFactory which creates MockTransaction.
147 class MockTransactionFactory
: public DnsTransactionFactory
{
149 explicit MockTransactionFactory(const MockDnsClientRuleList
& rules
)
151 virtual ~MockTransactionFactory() {}
153 virtual scoped_ptr
<DnsTransaction
> CreateTransaction(
154 const std::string
& hostname
,
156 const DnsTransactionFactory::CallbackType
& callback
,
157 const BoundNetLog
&) OVERRIDE
{
158 return scoped_ptr
<DnsTransaction
>(
159 new MockTransaction(rules_
, hostname
, qtype
, callback
));
163 MockDnsClientRuleList rules_
;
166 class MockAddressSorter
: public AddressSorter
{
168 virtual ~MockAddressSorter() {}
169 virtual void Sort(const AddressList
& list
,
170 const CallbackType
& callback
) const OVERRIDE
{
172 callback
.Run(true, list
);
176 // MockDnsClient provides MockTransactionFactory.
177 class MockDnsClient
: public DnsClient
{
179 MockDnsClient(const DnsConfig
& config
,
180 const MockDnsClientRuleList
& rules
)
181 : config_(config
), factory_(rules
) {}
182 virtual ~MockDnsClient() {}
184 virtual void SetConfig(const DnsConfig
& config
) OVERRIDE
{
188 virtual const DnsConfig
* GetConfig() const OVERRIDE
{
189 return config_
.IsValid() ? &config_
: NULL
;
192 virtual DnsTransactionFactory
* GetTransactionFactory() OVERRIDE
{
193 return config_
.IsValid() ? &factory_
: NULL
;
196 virtual AddressSorter
* GetAddressSorter() OVERRIDE
{
197 return &address_sorter_
;
202 MockTransactionFactory factory_
;
203 MockAddressSorter address_sorter_
;
209 scoped_ptr
<DnsClient
> CreateMockDnsClient(const DnsConfig
& config
,
210 const MockDnsClientRuleList
& rules
) {
211 return scoped_ptr
<DnsClient
>(new MockDnsClient(config
, rules
));