1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/socket/websocket_endpoint_lock_manager.h"
10 #include "base/location.h"
11 #include "base/logging.h"
12 #include "base/single_thread_task_runner.h"
13 #include "base/thread_task_runner_handle.h"
14 #include "net/base/net_errors.h"
15 #include "net/log/net_log.h"
21 // This delay prevents DoS attacks.
22 // TODO(ricea): Replace this with randomised truncated exponential backoff.
23 // See crbug.com/377613.
24 const int kUnlockDelayInMs
= 10;
28 WebSocketEndpointLockManager::Waiter::~Waiter() {
35 WebSocketEndpointLockManager
* WebSocketEndpointLockManager::GetInstance() {
36 return base::Singleton
<WebSocketEndpointLockManager
>::get();
39 int WebSocketEndpointLockManager::LockEndpoint(const IPEndPoint
& endpoint
,
41 LockInfoMap::value_type
insert_value(endpoint
, LockInfo());
42 std::pair
<LockInfoMap::iterator
, bool> rv
=
43 lock_info_map_
.insert(insert_value
);
44 LockInfo
& lock_info_in_map
= rv
.first
->second
;
46 DVLOG(3) << "Locking endpoint " << endpoint
.ToString();
47 lock_info_in_map
.queue
.reset(new LockInfo::WaiterQueue
);
50 DVLOG(3) << "Waiting for endpoint " << endpoint
.ToString();
51 lock_info_in_map
.queue
->Append(waiter
);
52 return ERR_IO_PENDING
;
55 void WebSocketEndpointLockManager::RememberSocket(StreamSocket
* socket
,
56 const IPEndPoint
& endpoint
) {
57 LockInfoMap::iterator lock_info_it
= lock_info_map_
.find(endpoint
);
58 CHECK(lock_info_it
!= lock_info_map_
.end());
60 socket_lock_info_map_
.insert(SocketLockInfoMap::value_type(
61 socket
, lock_info_it
)).second
;
63 DCHECK(!lock_info_it
->second
.socket
);
64 lock_info_it
->second
.socket
= socket
;
65 DVLOG(3) << "Remembered (StreamSocket*)" << socket
<< " for "
66 << endpoint
.ToString() << " (" << socket_lock_info_map_
.size()
67 << " socket(s) remembered)";
70 void WebSocketEndpointLockManager::UnlockSocket(StreamSocket
* socket
) {
71 SocketLockInfoMap::iterator socket_it
= socket_lock_info_map_
.find(socket
);
72 if (socket_it
== socket_lock_info_map_
.end())
75 LockInfoMap::iterator lock_info_it
= socket_it
->second
;
77 DVLOG(3) << "Unlocking (StreamSocket*)" << socket
<< " for "
78 << lock_info_it
->first
.ToString() << " ("
79 << socket_lock_info_map_
.size() << " socket(s) left)";
80 socket_lock_info_map_
.erase(socket_it
);
81 DCHECK_EQ(socket
, lock_info_it
->second
.socket
);
82 lock_info_it
->second
.socket
= NULL
;
83 UnlockEndpointAfterDelay(lock_info_it
->first
);
86 void WebSocketEndpointLockManager::UnlockEndpoint(const IPEndPoint
& endpoint
) {
87 LockInfoMap::iterator lock_info_it
= lock_info_map_
.find(endpoint
);
88 if (lock_info_it
== lock_info_map_
.end())
90 if (lock_info_it
->second
.socket
)
91 EraseSocket(lock_info_it
);
92 UnlockEndpointAfterDelay(endpoint
);
95 bool WebSocketEndpointLockManager::IsEmpty() const {
96 return lock_info_map_
.empty() && socket_lock_info_map_
.empty();
99 base::TimeDelta
WebSocketEndpointLockManager::SetUnlockDelayForTesting(
100 base::TimeDelta new_delay
) {
101 base::TimeDelta old_delay
= unlock_delay_
;
102 unlock_delay_
= new_delay
;
106 WebSocketEndpointLockManager::LockInfo::LockInfo() : socket(NULL
) {}
107 WebSocketEndpointLockManager::LockInfo::~LockInfo() {
111 WebSocketEndpointLockManager::LockInfo::LockInfo(const LockInfo
& rhs
)
112 : socket(rhs
.socket
) {
116 WebSocketEndpointLockManager::WebSocketEndpointLockManager()
117 : unlock_delay_(base::TimeDelta::FromMilliseconds(kUnlockDelayInMs
)),
118 pending_unlock_count_(0),
119 weak_factory_(this) {
122 WebSocketEndpointLockManager::~WebSocketEndpointLockManager() {
123 DCHECK_EQ(lock_info_map_
.size(), pending_unlock_count_
);
124 DCHECK(socket_lock_info_map_
.empty());
127 void WebSocketEndpointLockManager::UnlockEndpointAfterDelay(
128 const IPEndPoint
& endpoint
) {
129 DVLOG(3) << "Delaying " << unlock_delay_
.InMilliseconds()
130 << "ms before unlocking endpoint " << endpoint
.ToString();
131 ++pending_unlock_count_
;
132 base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
134 base::Bind(&WebSocketEndpointLockManager::DelayedUnlockEndpoint
,
135 weak_factory_
.GetWeakPtr(), endpoint
),
139 void WebSocketEndpointLockManager::DelayedUnlockEndpoint(
140 const IPEndPoint
& endpoint
) {
141 LockInfoMap::iterator lock_info_it
= lock_info_map_
.find(endpoint
);
142 DCHECK_GT(pending_unlock_count_
, 0U);
143 --pending_unlock_count_
;
144 if (lock_info_it
== lock_info_map_
.end())
146 DCHECK(!lock_info_it
->second
.socket
);
147 LockInfo::WaiterQueue
* queue
= lock_info_it
->second
.queue
.get();
149 if (queue
->empty()) {
150 DVLOG(3) << "Unlocking endpoint " << lock_info_it
->first
.ToString();
151 lock_info_map_
.erase(lock_info_it
);
155 DVLOG(3) << "Unlocking endpoint " << lock_info_it
->first
.ToString()
156 << " and activating next waiter";
157 Waiter
* next_job
= queue
->head()->value();
158 next_job
->RemoveFromList();
159 next_job
->GotEndpointLock();
162 void WebSocketEndpointLockManager::EraseSocket(
163 LockInfoMap::iterator lock_info_it
) {
164 DVLOG(3) << "Removing (StreamSocket*)" << lock_info_it
->second
.socket
165 << " for " << lock_info_it
->first
.ToString() << " ("
166 << socket_lock_info_map_
.size() << " socket(s) left)";
167 size_t erased
= socket_lock_info_map_
.erase(lock_info_it
->second
.socket
);
168 DCHECK_EQ(1U, erased
);
169 lock_info_it
->second
.socket
= NULL
;