SupervisedUser SafeSites: Switch to the new SafeSearch API
[chromium-blink-merge.git] / net / dns / mock_host_resolver.cc
blob42ec6b9a5325e74c90d2e8cf1fa4a9b18905f546
1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/dns/mock_host_resolver.h"
7 #include <string>
8 #include <vector>
10 #include "base/bind.h"
11 #include "base/location.h"
12 #include "base/memory/ref_counted.h"
13 #include "base/single_thread_task_runner.h"
14 #include "base/stl_util.h"
15 #include "base/strings/pattern.h"
16 #include "base/strings/string_split.h"
17 #include "base/strings/string_util.h"
18 #include "base/thread_task_runner_handle.h"
19 #include "base/threading/platform_thread.h"
20 #include "net/base/ip_endpoint.h"
21 #include "net/base/net_errors.h"
22 #include "net/base/net_util.h"
23 #include "net/base/test_completion_callback.h"
24 #include "net/dns/host_cache.h"
26 #if defined(OS_WIN)
27 #include "net/base/winsock_init.h"
28 #endif
30 namespace net {
32 namespace {
34 // Cache size for the MockCachingHostResolver.
35 const unsigned kMaxCacheEntries = 100;
36 // TTL for the successful resolutions. Failures are not cached.
37 const unsigned kCacheEntryTTLSeconds = 60;
39 } // namespace
41 int ParseAddressList(const std::string& host_list,
42 const std::string& canonical_name,
43 AddressList* addrlist) {
44 *addrlist = AddressList();
45 addrlist->set_canonical_name(canonical_name);
46 for (const base::StringPiece& address : base::SplitStringPiece(
47 host_list, ",", base::TRIM_WHITESPACE, base::SPLIT_WANT_ALL)) {
48 IPAddressNumber ip_number;
49 if (!ParseIPLiteralToNumber(address, &ip_number)) {
50 LOG(WARNING) << "Not a supported IP literal: " << address.as_string();
51 return ERR_UNEXPECTED;
53 addrlist->push_back(IPEndPoint(ip_number, 0));
55 return OK;
58 struct MockHostResolverBase::Request {
59 Request(const RequestInfo& req_info,
60 AddressList* addr,
61 const CompletionCallback& cb)
62 : info(req_info), addresses(addr), callback(cb) {}
63 RequestInfo info;
64 AddressList* addresses;
65 CompletionCallback callback;
68 MockHostResolverBase::~MockHostResolverBase() {
69 STLDeleteValues(&requests_);
72 int MockHostResolverBase::Resolve(const RequestInfo& info,
73 RequestPriority priority,
74 AddressList* addresses,
75 const CompletionCallback& callback,
76 RequestHandle* handle,
77 const BoundNetLog& net_log) {
78 DCHECK(CalledOnValidThread());
79 last_request_priority_ = priority;
80 num_resolve_++;
81 size_t id = next_request_id_++;
82 int rv = ResolveFromIPLiteralOrCache(info, addresses);
83 if (rv != ERR_DNS_CACHE_MISS) {
84 return rv;
86 if (synchronous_mode_) {
87 return ResolveProc(id, info, addresses);
89 // Store the request for asynchronous resolution
90 Request* req = new Request(info, addresses, callback);
91 requests_[id] = req;
92 if (handle)
93 *handle = reinterpret_cast<RequestHandle>(id);
95 if (!ondemand_mode_) {
96 base::ThreadTaskRunnerHandle::Get()->PostTask(
97 FROM_HERE,
98 base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), id));
101 return ERR_IO_PENDING;
104 int MockHostResolverBase::ResolveFromCache(const RequestInfo& info,
105 AddressList* addresses,
106 const BoundNetLog& net_log) {
107 num_resolve_from_cache_++;
108 DCHECK(CalledOnValidThread());
109 next_request_id_++;
110 int rv = ResolveFromIPLiteralOrCache(info, addresses);
111 return rv;
114 void MockHostResolverBase::CancelRequest(RequestHandle handle) {
115 DCHECK(CalledOnValidThread());
116 size_t id = reinterpret_cast<size_t>(handle);
117 RequestMap::iterator it = requests_.find(id);
118 if (it != requests_.end()) {
119 scoped_ptr<Request> req(it->second);
120 requests_.erase(it);
121 } else {
122 NOTREACHED() << "CancelRequest must NOT be called after request is "
123 "complete or canceled.";
127 HostCache* MockHostResolverBase::GetHostCache() {
128 return cache_.get();
131 void MockHostResolverBase::ResolveAllPending() {
132 DCHECK(CalledOnValidThread());
133 DCHECK(ondemand_mode_);
134 for (RequestMap::iterator i = requests_.begin(); i != requests_.end(); ++i) {
135 base::ThreadTaskRunnerHandle::Get()->PostTask(
136 FROM_HERE,
137 base::Bind(&MockHostResolverBase::ResolveNow, AsWeakPtr(), i->first));
141 // start id from 1 to distinguish from NULL RequestHandle
142 MockHostResolverBase::MockHostResolverBase(bool use_caching)
143 : last_request_priority_(DEFAULT_PRIORITY),
144 synchronous_mode_(false),
145 ondemand_mode_(false),
146 next_request_id_(1),
147 num_resolve_(0),
148 num_resolve_from_cache_(0) {
149 rules_ = CreateCatchAllHostResolverProc();
151 if (use_caching) {
152 cache_.reset(new HostCache(kMaxCacheEntries));
156 int MockHostResolverBase::ResolveFromIPLiteralOrCache(const RequestInfo& info,
157 AddressList* addresses) {
158 IPAddressNumber ip;
159 if (ParseIPLiteralToNumber(info.hostname(), &ip)) {
160 // This matches the behavior HostResolverImpl.
161 if (info.address_family() != ADDRESS_FAMILY_UNSPECIFIED &&
162 info.address_family() != GetAddressFamily(ip)) {
163 return ERR_NAME_NOT_RESOLVED;
166 *addresses = AddressList::CreateFromIPAddress(ip, info.port());
167 if (info.host_resolver_flags() & HOST_RESOLVER_CANONNAME)
168 addresses->SetDefaultCanonicalName();
169 return OK;
171 int rv = ERR_DNS_CACHE_MISS;
172 if (cache_.get() && info.allow_cached_response()) {
173 HostCache::Key key(info.hostname(),
174 info.address_family(),
175 info.host_resolver_flags());
176 const HostCache::Entry* entry = cache_->Lookup(key, base::TimeTicks::Now());
177 if (entry) {
178 rv = entry->error;
179 if (rv == OK)
180 *addresses = AddressList::CopyWithPort(entry->addrlist, info.port());
183 return rv;
186 int MockHostResolverBase::ResolveProc(size_t id,
187 const RequestInfo& info,
188 AddressList* addresses) {
189 AddressList addr;
190 int rv = rules_->Resolve(info.hostname(),
191 info.address_family(),
192 info.host_resolver_flags(),
193 &addr,
194 NULL);
195 if (cache_.get()) {
196 HostCache::Key key(info.hostname(),
197 info.address_family(),
198 info.host_resolver_flags());
199 // Storing a failure with TTL 0 so that it overwrites previous value.
200 base::TimeDelta ttl;
201 if (rv == OK)
202 ttl = base::TimeDelta::FromSeconds(kCacheEntryTTLSeconds);
203 cache_->Set(key, HostCache::Entry(rv, addr), base::TimeTicks::Now(), ttl);
205 if (rv == OK)
206 *addresses = AddressList::CopyWithPort(addr, info.port());
207 return rv;
210 void MockHostResolverBase::ResolveNow(size_t id) {
211 RequestMap::iterator it = requests_.find(id);
212 if (it == requests_.end())
213 return; // was canceled
215 scoped_ptr<Request> req(it->second);
216 requests_.erase(it);
217 int rv = ResolveProc(id, req->info, req->addresses);
218 if (!req->callback.is_null())
219 req->callback.Run(rv);
222 //-----------------------------------------------------------------------------
224 struct RuleBasedHostResolverProc::Rule {
225 enum ResolverType {
226 kResolverTypeFail,
227 kResolverTypeSystem,
228 kResolverTypeIPLiteral,
231 ResolverType resolver_type;
232 std::string host_pattern;
233 AddressFamily address_family;
234 HostResolverFlags host_resolver_flags;
235 std::string replacement;
236 std::string canonical_name;
237 int latency_ms; // In milliseconds.
239 Rule(ResolverType resolver_type,
240 const std::string& host_pattern,
241 AddressFamily address_family,
242 HostResolverFlags host_resolver_flags,
243 const std::string& replacement,
244 const std::string& canonical_name,
245 int latency_ms)
246 : resolver_type(resolver_type),
247 host_pattern(host_pattern),
248 address_family(address_family),
249 host_resolver_flags(host_resolver_flags),
250 replacement(replacement),
251 canonical_name(canonical_name),
252 latency_ms(latency_ms) {}
255 RuleBasedHostResolverProc::RuleBasedHostResolverProc(HostResolverProc* previous)
256 : HostResolverProc(previous) {
259 void RuleBasedHostResolverProc::AddRule(const std::string& host_pattern,
260 const std::string& replacement) {
261 AddRuleForAddressFamily(host_pattern, ADDRESS_FAMILY_UNSPECIFIED,
262 replacement);
265 void RuleBasedHostResolverProc::AddRuleForAddressFamily(
266 const std::string& host_pattern,
267 AddressFamily address_family,
268 const std::string& replacement) {
269 DCHECK(!replacement.empty());
270 HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
271 HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
272 Rule rule(Rule::kResolverTypeSystem,
273 host_pattern,
274 address_family,
275 flags,
276 replacement,
277 std::string(),
279 AddRuleInternal(rule);
282 void RuleBasedHostResolverProc::AddIPLiteralRule(
283 const std::string& host_pattern,
284 const std::string& ip_literal,
285 const std::string& canonical_name) {
286 // Literals are always resolved to themselves by HostResolverImpl,
287 // consequently we do not support remapping them.
288 IPAddressNumber ip_number;
289 DCHECK(!ParseIPLiteralToNumber(host_pattern, &ip_number));
290 HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
291 HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
292 if (!canonical_name.empty())
293 flags |= HOST_RESOLVER_CANONNAME;
294 Rule rule(Rule::kResolverTypeIPLiteral, host_pattern,
295 ADDRESS_FAMILY_UNSPECIFIED, flags, ip_literal, canonical_name,
297 AddRuleInternal(rule);
300 void RuleBasedHostResolverProc::AddRuleWithLatency(
301 const std::string& host_pattern,
302 const std::string& replacement,
303 int latency_ms) {
304 DCHECK(!replacement.empty());
305 HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
306 HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
307 Rule rule(Rule::kResolverTypeSystem,
308 host_pattern,
309 ADDRESS_FAMILY_UNSPECIFIED,
310 flags,
311 replacement,
312 std::string(),
313 latency_ms);
314 AddRuleInternal(rule);
317 void RuleBasedHostResolverProc::AllowDirectLookup(
318 const std::string& host_pattern) {
319 HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
320 HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
321 Rule rule(Rule::kResolverTypeSystem,
322 host_pattern,
323 ADDRESS_FAMILY_UNSPECIFIED,
324 flags,
325 std::string(),
326 std::string(),
328 AddRuleInternal(rule);
331 void RuleBasedHostResolverProc::AddSimulatedFailure(
332 const std::string& host_pattern) {
333 HostResolverFlags flags = HOST_RESOLVER_LOOPBACK_ONLY |
334 HOST_RESOLVER_DEFAULT_FAMILY_SET_DUE_TO_NO_IPV6;
335 Rule rule(Rule::kResolverTypeFail,
336 host_pattern,
337 ADDRESS_FAMILY_UNSPECIFIED,
338 flags,
339 std::string(),
340 std::string(),
342 AddRuleInternal(rule);
345 void RuleBasedHostResolverProc::ClearRules() {
346 base::AutoLock lock(rule_lock_);
347 rules_.clear();
350 int RuleBasedHostResolverProc::Resolve(const std::string& host,
351 AddressFamily address_family,
352 HostResolverFlags host_resolver_flags,
353 AddressList* addrlist,
354 int* os_error) {
355 base::AutoLock lock(rule_lock_);
356 RuleList::iterator r;
357 for (r = rules_.begin(); r != rules_.end(); ++r) {
358 bool matches_address_family =
359 r->address_family == ADDRESS_FAMILY_UNSPECIFIED ||
360 r->address_family == address_family;
361 // Ignore HOST_RESOLVER_SYSTEM_ONLY, since it should have no impact on
362 // whether a rule matches.
363 HostResolverFlags flags = host_resolver_flags & ~HOST_RESOLVER_SYSTEM_ONLY;
364 // Flags match if all of the bitflags in host_resolver_flags are enabled
365 // in the rule's host_resolver_flags. However, the rule may have additional
366 // flags specified, in which case the flags should still be considered a
367 // match.
368 bool matches_flags = (r->host_resolver_flags & flags) == flags;
369 if (matches_flags && matches_address_family &&
370 base::MatchPattern(host, r->host_pattern)) {
371 if (r->latency_ms != 0) {
372 base::PlatformThread::Sleep(
373 base::TimeDelta::FromMilliseconds(r->latency_ms));
376 // Remap to a new host.
377 const std::string& effective_host =
378 r->replacement.empty() ? host : r->replacement;
380 // Apply the resolving function to the remapped hostname.
381 switch (r->resolver_type) {
382 case Rule::kResolverTypeFail:
383 return ERR_NAME_NOT_RESOLVED;
384 case Rule::kResolverTypeSystem:
385 #if defined(OS_WIN)
386 EnsureWinsockInit();
387 #endif
388 return SystemHostResolverCall(effective_host,
389 address_family,
390 host_resolver_flags,
391 addrlist, os_error);
392 case Rule::kResolverTypeIPLiteral:
393 return ParseAddressList(effective_host,
394 r->canonical_name,
395 addrlist);
396 default:
397 NOTREACHED();
398 return ERR_UNEXPECTED;
402 return ResolveUsingPrevious(host, address_family,
403 host_resolver_flags, addrlist, os_error);
406 RuleBasedHostResolverProc::~RuleBasedHostResolverProc() {
409 void RuleBasedHostResolverProc::AddRuleInternal(const Rule& rule) {
410 base::AutoLock lock(rule_lock_);
411 rules_.push_back(rule);
414 RuleBasedHostResolverProc* CreateCatchAllHostResolverProc() {
415 RuleBasedHostResolverProc* catchall = new RuleBasedHostResolverProc(NULL);
416 catchall->AddIPLiteralRule("*", "127.0.0.1", "localhost");
418 // Next add a rules-based layer the use controls.
419 return new RuleBasedHostResolverProc(catchall);
422 //-----------------------------------------------------------------------------
424 int HangingHostResolver::Resolve(const RequestInfo& info,
425 RequestPriority priority,
426 AddressList* addresses,
427 const CompletionCallback& callback,
428 RequestHandle* out_req,
429 const BoundNetLog& net_log) {
430 return ERR_IO_PENDING;
433 int HangingHostResolver::ResolveFromCache(const RequestInfo& info,
434 AddressList* addresses,
435 const BoundNetLog& net_log) {
436 return ERR_DNS_CACHE_MISS;
439 //-----------------------------------------------------------------------------
441 ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc() {}
443 ScopedDefaultHostResolverProc::ScopedDefaultHostResolverProc(
444 HostResolverProc* proc) {
445 Init(proc);
448 ScopedDefaultHostResolverProc::~ScopedDefaultHostResolverProc() {
449 HostResolverProc* old_proc =
450 HostResolverProc::SetDefault(previous_proc_.get());
451 // The lifetimes of multiple instances must be nested.
452 CHECK_EQ(old_proc, current_proc_.get());
455 void ScopedDefaultHostResolverProc::Init(HostResolverProc* proc) {
456 current_proc_ = proc;
457 previous_proc_ = HostResolverProc::SetDefault(current_proc_.get());
458 current_proc_->SetLastProc(previous_proc_.get());
461 } // namespace net