1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "net/socket/websocket_endpoint_lock_manager.h"
10 #include "base/logging.h"
11 #include "base/message_loop/message_loop.h"
12 #include "net/base/net_errors.h"
13 #include "net/log/net_log.h"
19 // This delay prevents DoS attacks.
20 // TODO(ricea): Replace this with randomised truncated exponential backoff.
21 // See crbug.com/377613.
22 const int kUnlockDelayInMs
= 10;
26 WebSocketEndpointLockManager::Waiter::~Waiter() {
33 WebSocketEndpointLockManager
* WebSocketEndpointLockManager::GetInstance() {
34 return Singleton
<WebSocketEndpointLockManager
>::get();
37 int WebSocketEndpointLockManager::LockEndpoint(const IPEndPoint
& endpoint
,
39 LockInfoMap::value_type
insert_value(endpoint
, LockInfo());
40 std::pair
<LockInfoMap::iterator
, bool> rv
=
41 lock_info_map_
.insert(insert_value
);
42 LockInfo
& lock_info_in_map
= rv
.first
->second
;
44 DVLOG(3) << "Locking endpoint " << endpoint
.ToString();
45 lock_info_in_map
.queue
.reset(new LockInfo::WaiterQueue
);
48 DVLOG(3) << "Waiting for endpoint " << endpoint
.ToString();
49 lock_info_in_map
.queue
->Append(waiter
);
50 return ERR_IO_PENDING
;
53 void WebSocketEndpointLockManager::RememberSocket(StreamSocket
* socket
,
54 const IPEndPoint
& endpoint
) {
55 LockInfoMap::iterator lock_info_it
= lock_info_map_
.find(endpoint
);
56 CHECK(lock_info_it
!= lock_info_map_
.end());
58 socket_lock_info_map_
.insert(SocketLockInfoMap::value_type(
59 socket
, lock_info_it
)).second
;
61 DCHECK(!lock_info_it
->second
.socket
);
62 lock_info_it
->second
.socket
= socket
;
63 DVLOG(3) << "Remembered (StreamSocket*)" << socket
<< " for "
64 << endpoint
.ToString() << " (" << socket_lock_info_map_
.size()
65 << " socket(s) remembered)";
68 void WebSocketEndpointLockManager::UnlockSocket(StreamSocket
* socket
) {
69 SocketLockInfoMap::iterator socket_it
= socket_lock_info_map_
.find(socket
);
70 if (socket_it
== socket_lock_info_map_
.end())
73 LockInfoMap::iterator lock_info_it
= socket_it
->second
;
75 DVLOG(3) << "Unlocking (StreamSocket*)" << socket
<< " for "
76 << lock_info_it
->first
.ToString() << " ("
77 << socket_lock_info_map_
.size() << " socket(s) left)";
78 socket_lock_info_map_
.erase(socket_it
);
79 DCHECK_EQ(socket
, lock_info_it
->second
.socket
);
80 lock_info_it
->second
.socket
= NULL
;
81 UnlockEndpointAfterDelay(lock_info_it
->first
);
84 void WebSocketEndpointLockManager::UnlockEndpoint(const IPEndPoint
& endpoint
) {
85 LockInfoMap::iterator lock_info_it
= lock_info_map_
.find(endpoint
);
86 if (lock_info_it
== lock_info_map_
.end())
88 if (lock_info_it
->second
.socket
)
89 EraseSocket(lock_info_it
);
90 UnlockEndpointAfterDelay(endpoint
);
93 bool WebSocketEndpointLockManager::IsEmpty() const {
94 return lock_info_map_
.empty() && socket_lock_info_map_
.empty();
97 base::TimeDelta
WebSocketEndpointLockManager::SetUnlockDelayForTesting(
98 base::TimeDelta new_delay
) {
99 base::TimeDelta old_delay
= unlock_delay_
;
100 unlock_delay_
= new_delay
;
104 WebSocketEndpointLockManager::LockInfo::LockInfo() : socket(NULL
) {}
105 WebSocketEndpointLockManager::LockInfo::~LockInfo() {
109 WebSocketEndpointLockManager::LockInfo::LockInfo(const LockInfo
& rhs
)
110 : socket(rhs
.socket
) {
114 WebSocketEndpointLockManager::WebSocketEndpointLockManager()
115 : unlock_delay_(base::TimeDelta::FromMilliseconds(kUnlockDelayInMs
)),
116 pending_unlock_count_(0),
117 weak_factory_(this) {
120 WebSocketEndpointLockManager::~WebSocketEndpointLockManager() {
121 DCHECK_EQ(lock_info_map_
.size(), pending_unlock_count_
);
122 DCHECK(socket_lock_info_map_
.empty());
125 void WebSocketEndpointLockManager::UnlockEndpointAfterDelay(
126 const IPEndPoint
& endpoint
) {
127 DVLOG(3) << "Delaying " << unlock_delay_
.InMilliseconds()
128 << "ms before unlocking endpoint " << endpoint
.ToString();
129 ++pending_unlock_count_
;
130 base::MessageLoop::current()->PostDelayedTask(
132 base::Bind(&WebSocketEndpointLockManager::DelayedUnlockEndpoint
,
133 weak_factory_
.GetWeakPtr(), endpoint
),
137 void WebSocketEndpointLockManager::DelayedUnlockEndpoint(
138 const IPEndPoint
& endpoint
) {
139 LockInfoMap::iterator lock_info_it
= lock_info_map_
.find(endpoint
);
140 DCHECK_GT(pending_unlock_count_
, 0U);
141 --pending_unlock_count_
;
142 if (lock_info_it
== lock_info_map_
.end())
144 DCHECK(!lock_info_it
->second
.socket
);
145 LockInfo::WaiterQueue
* queue
= lock_info_it
->second
.queue
.get();
147 if (queue
->empty()) {
148 DVLOG(3) << "Unlocking endpoint " << lock_info_it
->first
.ToString();
149 lock_info_map_
.erase(lock_info_it
);
153 DVLOG(3) << "Unlocking endpoint " << lock_info_it
->first
.ToString()
154 << " and activating next waiter";
155 Waiter
* next_job
= queue
->head()->value();
156 next_job
->RemoveFromList();
157 next_job
->GotEndpointLock();
160 void WebSocketEndpointLockManager::EraseSocket(
161 LockInfoMap::iterator lock_info_it
) {
162 DVLOG(3) << "Removing (StreamSocket*)" << lock_info_it
->second
.socket
163 << " for " << lock_info_it
->first
.ToString() << " ("
164 << socket_lock_info_map_
.size() << " socket(s) left)";
165 size_t erased
= socket_lock_info_map_
.erase(lock_info_it
->second
.socket
);
166 DCHECK_EQ(1U, erased
);
167 lock_info_it
->second
.socket
= NULL
;