Probably broke Win7 Tests (dbg)(6). http://build.chromium.org/p/chromium.win/builders...
[chromium-blink-merge.git] / net / dns / dns_test_util.cc
blob01acb346fa9d8f940a90543f49378e20e8e295d5
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/dns/dns_test_util.h"
7 #include <string>
9 #include "base/big_endian.h"
10 #include "base/bind.h"
11 #include "base/memory/weak_ptr.h"
12 #include "base/message_loop/message_loop.h"
13 #include "base/sys_byteorder.h"
14 #include "net/base/dns_util.h"
15 #include "net/base/io_buffer.h"
16 #include "net/base/net_errors.h"
17 #include "net/dns/address_sorter.h"
18 #include "net/dns/dns_query.h"
19 #include "net/dns/dns_response.h"
20 #include "net/dns/dns_transaction.h"
21 #include "testing/gtest/include/gtest/gtest.h"
23 namespace net {
24 namespace {
26 class MockAddressSorter : public AddressSorter {
27 public:
28 virtual ~MockAddressSorter() {}
29 virtual void Sort(const AddressList& list,
30 const CallbackType& callback) const OVERRIDE {
31 // Do nothing.
32 callback.Run(true, list);
36 // A DnsTransaction which uses MockDnsClientRuleList to determine the response.
37 class MockTransaction : public DnsTransaction,
38 public base::SupportsWeakPtr<MockTransaction> {
39 public:
40 MockTransaction(const MockDnsClientRuleList& rules,
41 const std::string& hostname,
42 uint16 qtype,
43 const DnsTransactionFactory::CallbackType& callback)
44 : result_(MockDnsClientRule::FAIL),
45 hostname_(hostname),
46 qtype_(qtype),
47 callback_(callback),
48 started_(false),
49 delayed_(false) {
50 // Find the relevant rule which matches |qtype| and prefix of |hostname|.
51 for (size_t i = 0; i < rules.size(); ++i) {
52 const std::string& prefix = rules[i].prefix;
53 if ((rules[i].qtype == qtype) &&
54 (hostname.size() >= prefix.size()) &&
55 (hostname.compare(0, prefix.size(), prefix) == 0)) {
56 result_ = rules[i].result;
57 delayed_ = rules[i].delay;
58 break;
63 virtual const std::string& GetHostname() const OVERRIDE {
64 return hostname_;
67 virtual uint16 GetType() const OVERRIDE {
68 return qtype_;
71 virtual void Start() OVERRIDE {
72 EXPECT_FALSE(started_);
73 started_ = true;
74 if (delayed_)
75 return;
76 // Using WeakPtr to cleanly cancel when transaction is destroyed.
77 base::MessageLoop::current()->PostTask(
78 FROM_HERE, base::Bind(&MockTransaction::Finish, AsWeakPtr()));
81 void FinishDelayedTransaction() {
82 EXPECT_TRUE(delayed_);
83 delayed_ = false;
84 Finish();
87 bool delayed() const { return delayed_; }
89 private:
90 void Finish() {
91 switch (result_) {
92 case MockDnsClientRule::EMPTY:
93 case MockDnsClientRule::OK: {
94 std::string qname;
95 DNSDomainFromDot(hostname_, &qname);
96 DnsQuery query(0, qname, qtype_);
98 DnsResponse response;
99 char* buffer = response.io_buffer()->data();
100 int nbytes = query.io_buffer()->size();
101 memcpy(buffer, query.io_buffer()->data(), nbytes);
102 dns_protocol::Header* header =
103 reinterpret_cast<dns_protocol::Header*>(buffer);
104 header->flags |= dns_protocol::kFlagResponse;
106 if (MockDnsClientRule::OK == result_) {
107 const uint16 kPointerToQueryName =
108 static_cast<uint16>(0xc000 | sizeof(*header));
110 const uint32 kTTL = 86400; // One day.
112 // Size of RDATA which is a IPv4 or IPv6 address.
113 size_t rdata_size = qtype_ == net::dns_protocol::kTypeA ?
114 net::kIPv4AddressSize : net::kIPv6AddressSize;
116 // 12 is the sum of sizes of the compressed name reference, TYPE,
117 // CLASS, TTL and RDLENGTH.
118 size_t answer_size = 12 + rdata_size;
120 // Write answer with loopback IP address.
121 header->ancount = base::HostToNet16(1);
122 base::BigEndianWriter writer(buffer + nbytes, answer_size);
123 writer.WriteU16(kPointerToQueryName);
124 writer.WriteU16(qtype_);
125 writer.WriteU16(net::dns_protocol::kClassIN);
126 writer.WriteU32(kTTL);
127 writer.WriteU16(rdata_size);
128 if (qtype_ == net::dns_protocol::kTypeA) {
129 char kIPv4Loopback[] = { 0x7f, 0, 0, 1 };
130 writer.WriteBytes(kIPv4Loopback, sizeof(kIPv4Loopback));
131 } else {
132 char kIPv6Loopback[] = { 0, 0, 0, 0, 0, 0, 0, 0,
133 0, 0, 0, 0, 0, 0, 0, 1 };
134 writer.WriteBytes(kIPv6Loopback, sizeof(kIPv6Loopback));
136 nbytes += answer_size;
138 EXPECT_TRUE(response.InitParse(nbytes, query));
139 callback_.Run(this, OK, &response);
140 } break;
141 case MockDnsClientRule::FAIL:
142 callback_.Run(this, ERR_NAME_NOT_RESOLVED, NULL);
143 break;
144 case MockDnsClientRule::TIMEOUT:
145 callback_.Run(this, ERR_DNS_TIMED_OUT, NULL);
146 break;
147 default:
148 NOTREACHED();
149 break;
153 MockDnsClientRule::Result result_;
154 const std::string hostname_;
155 const uint16 qtype_;
156 DnsTransactionFactory::CallbackType callback_;
157 bool started_;
158 bool delayed_;
161 } // namespace
163 // A DnsTransactionFactory which creates MockTransaction.
164 class MockTransactionFactory : public DnsTransactionFactory {
165 public:
166 explicit MockTransactionFactory(const MockDnsClientRuleList& rules)
167 : rules_(rules) {}
169 virtual ~MockTransactionFactory() {}
171 virtual scoped_ptr<DnsTransaction> CreateTransaction(
172 const std::string& hostname,
173 uint16 qtype,
174 const DnsTransactionFactory::CallbackType& callback,
175 const BoundNetLog&) OVERRIDE {
176 MockTransaction* transaction =
177 new MockTransaction(rules_, hostname, qtype, callback);
178 if (transaction->delayed())
179 delayed_transactions_.push_back(transaction->AsWeakPtr());
180 return scoped_ptr<DnsTransaction>(transaction);
183 void CompleteDelayedTransactions() {
184 DelayedTransactionList old_delayed_transactions;
185 old_delayed_transactions.swap(delayed_transactions_);
186 for (DelayedTransactionList::iterator it = old_delayed_transactions.begin();
187 it != old_delayed_transactions.end(); ++it) {
188 if (it->get())
189 (*it)->FinishDelayedTransaction();
193 private:
194 typedef std::vector<base::WeakPtr<MockTransaction> > DelayedTransactionList;
196 MockDnsClientRuleList rules_;
197 DelayedTransactionList delayed_transactions_;
200 MockDnsClient::MockDnsClient(const DnsConfig& config,
201 const MockDnsClientRuleList& rules)
202 : config_(config),
203 factory_(new MockTransactionFactory(rules)),
204 address_sorter_(new MockAddressSorter()) {
207 MockDnsClient::~MockDnsClient() {}
209 void MockDnsClient::SetConfig(const DnsConfig& config) {
210 config_ = config;
213 const DnsConfig* MockDnsClient::GetConfig() const {
214 return config_.IsValid() ? &config_ : NULL;
217 DnsTransactionFactory* MockDnsClient::GetTransactionFactory() {
218 return config_.IsValid() ? factory_.get() : NULL;
221 AddressSorter* MockDnsClient::GetAddressSorter() {
222 return address_sorter_.get();
225 void MockDnsClient::CompleteDelayedTransactions() {
226 factory_->CompleteDelayedTransactions();
229 } // namespace net