1 // Copyright (c) 2012 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/dns/dns_socket_pool.h"
7 #include "base/logging.h"
8 #include "base/rand_util.h"
9 #include "base/stl_util.h"
10 #include "net/base/address_list.h"
11 #include "net/base/ip_endpoint.h"
12 #include "net/base/net_errors.h"
13 #include "net/base/rand_callback.h"
14 #include "net/socket/client_socket_factory.h"
15 #include "net/socket/stream_socket.h"
16 #include "net/udp/datagram_client_socket.h"
22 // When we initialize the SocketPool, we allocate kInitialPoolSize sockets.
23 // When we allocate a socket, we ensure we have at least kAllocateMinSize
24 // sockets to choose from. Freed sockets are not retained.
26 // On Windows, we can't request specific (random) ports, since that will
27 // trigger firewall prompts, so request default ones, but keep a pile of
28 // them. Everywhere else, request fresh, random ports each time.
30 const DatagramSocket::BindType kBindType
= DatagramSocket::DEFAULT_BIND
;
31 const unsigned kInitialPoolSize
= 256;
32 const unsigned kAllocateMinSize
= 256;
34 const DatagramSocket::BindType kBindType
= DatagramSocket::RANDOM_BIND
;
35 const unsigned kInitialPoolSize
= 0;
36 const unsigned kAllocateMinSize
= 1;
41 DnsSocketPool::DnsSocketPool(ClientSocketFactory
* socket_factory
)
42 : socket_factory_(socket_factory
),
48 void DnsSocketPool::InitializeInternal(
49 const std::vector
<IPEndPoint
>* nameservers
,
52 DCHECK(!initialized_
);
55 nameservers_
= nameservers
;
59 scoped_ptr
<StreamSocket
> DnsSocketPool::CreateTCPSocket(
60 unsigned server_index
,
61 const NetLog::Source
& source
) {
62 DCHECK_LT(server_index
, nameservers_
->size());
64 return scoped_ptr
<StreamSocket
>(
65 socket_factory_
->CreateTransportClientSocket(
66 AddressList((*nameservers_
)[server_index
]), net_log_
, source
));
69 scoped_ptr
<DatagramClientSocket
> DnsSocketPool::CreateConnectedSocket(
70 unsigned server_index
) {
71 DCHECK_LT(server_index
, nameservers_
->size());
73 scoped_ptr
<DatagramClientSocket
> socket
;
75 NetLog::Source no_source
;
76 socket
= socket_factory_
->CreateDatagramClientSocket(
77 kBindType
, base::Bind(&base::RandInt
), net_log_
, no_source
);
80 int rv
= socket
->Connect((*nameservers_
)[server_index
]);
82 VLOG(1) << "Failed to connect socket: " << rv
;
86 LOG(WARNING
) << "Failed to create socket.";
92 class NullDnsSocketPool
: public DnsSocketPool
{
94 NullDnsSocketPool(ClientSocketFactory
* factory
)
95 : DnsSocketPool(factory
) {
98 virtual void Initialize(
99 const std::vector
<IPEndPoint
>* nameservers
,
100 NetLog
* net_log
) OVERRIDE
{
101 InitializeInternal(nameservers
, net_log
);
104 virtual scoped_ptr
<DatagramClientSocket
> AllocateSocket(
105 unsigned server_index
) OVERRIDE
{
106 return CreateConnectedSocket(server_index
);
109 virtual void FreeSocket(
110 unsigned server_index
,
111 scoped_ptr
<DatagramClientSocket
> socket
) OVERRIDE
{
115 DISALLOW_COPY_AND_ASSIGN(NullDnsSocketPool
);
119 scoped_ptr
<DnsSocketPool
> DnsSocketPool::CreateNull(
120 ClientSocketFactory
* factory
) {
121 return scoped_ptr
<DnsSocketPool
>(new NullDnsSocketPool(factory
));
124 class DefaultDnsSocketPool
: public DnsSocketPool
{
126 DefaultDnsSocketPool(ClientSocketFactory
* factory
)
127 : DnsSocketPool(factory
) {
130 virtual ~DefaultDnsSocketPool();
132 virtual void Initialize(
133 const std::vector
<IPEndPoint
>* nameservers
,
134 NetLog
* net_log
) OVERRIDE
;
136 virtual scoped_ptr
<DatagramClientSocket
> AllocateSocket(
137 unsigned server_index
) OVERRIDE
;
139 virtual void FreeSocket(
140 unsigned server_index
,
141 scoped_ptr
<DatagramClientSocket
> socket
) OVERRIDE
;
144 void FillPool(unsigned server_index
, unsigned size
);
146 typedef std::vector
<DatagramClientSocket
*> SocketVector
;
148 std::vector
<SocketVector
> pools_
;
150 DISALLOW_COPY_AND_ASSIGN(DefaultDnsSocketPool
);
154 scoped_ptr
<DnsSocketPool
> DnsSocketPool::CreateDefault(
155 ClientSocketFactory
* factory
) {
156 return scoped_ptr
<DnsSocketPool
>(new DefaultDnsSocketPool(factory
));
159 void DefaultDnsSocketPool::Initialize(
160 const std::vector
<IPEndPoint
>* nameservers
,
162 InitializeInternal(nameservers
, net_log
);
164 DCHECK(pools_
.empty());
165 const unsigned num_servers
= nameservers
->size();
166 pools_
.resize(num_servers
);
167 for (unsigned server_index
= 0; server_index
< num_servers
; ++server_index
)
168 FillPool(server_index
, kInitialPoolSize
);
171 DefaultDnsSocketPool::~DefaultDnsSocketPool() {
172 unsigned num_servers
= pools_
.size();
173 for (unsigned server_index
= 0; server_index
< num_servers
; ++server_index
) {
174 SocketVector
& pool
= pools_
[server_index
];
175 STLDeleteElements(&pool
);
179 scoped_ptr
<DatagramClientSocket
> DefaultDnsSocketPool::AllocateSocket(
180 unsigned server_index
) {
181 DCHECK_LT(server_index
, pools_
.size());
182 SocketVector
& pool
= pools_
[server_index
];
184 FillPool(server_index
, kAllocateMinSize
);
185 if (pool
.size() == 0) {
186 LOG(WARNING
) << "No DNS sockets available in pool " << server_index
<< "!";
187 return scoped_ptr
<DatagramClientSocket
>();
190 if (pool
.size() < kAllocateMinSize
) {
191 LOG(WARNING
) << "Low DNS port entropy: wanted " << kAllocateMinSize
192 << " sockets to choose from, but only have " << pool
.size()
193 << " in pool " << server_index
<< ".";
196 unsigned socket_index
= base::RandInt(0, pool
.size() - 1);
197 DatagramClientSocket
* socket
= pool
[socket_index
];
198 pool
[socket_index
] = pool
.back();
201 return scoped_ptr
<DatagramClientSocket
>(socket
);
204 void DefaultDnsSocketPool::FreeSocket(
205 unsigned server_index
,
206 scoped_ptr
<DatagramClientSocket
> socket
) {
207 DCHECK_LT(server_index
, pools_
.size());
210 void DefaultDnsSocketPool::FillPool(unsigned server_index
, unsigned size
) {
211 SocketVector
& pool
= pools_
[server_index
];
213 for (unsigned pool_index
= pool
.size(); pool_index
< size
; ++pool_index
) {
214 DatagramClientSocket
* socket
=
215 CreateConnectedSocket(server_index
).release();
218 pool
.push_back(socket
);