1 // Copyright (c) 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
10 #include "base/at_exit.h"
11 #include "base/bind.h"
12 #include "base/callback.h"
13 #include "base/command_line.h"
14 #include "base/files/file_path.h"
15 #include "base/lazy_instance.h"
16 #include "base/logging.h"
17 #include "base/memory/scoped_ptr.h"
18 #include "base/message_loop/message_loop.h"
19 #include "base/run_loop.h"
20 #include "base/strings/string_number_conversions.h"
21 #include "base/strings/string_split.h"
22 #include "base/strings/string_util.h"
23 #include "base/strings/stringprintf.h"
24 #include "base/synchronization/waitable_event.h"
25 #include "base/threading/thread.h"
26 #include "base/threading/thread_local.h"
27 #include "chrome/test/chromedriver/logging.h"
28 #include "chrome/test/chromedriver/net/port_server.h"
29 #include "chrome/test/chromedriver/server/http_handler.h"
30 #include "chrome/test/chromedriver/version.h"
31 #include "net/base/ip_endpoint.h"
32 #include "net/base/net_errors.h"
33 #include "net/server/http_server.h"
34 #include "net/server/http_server_request_info.h"
35 #include "net/server/http_server_response_info.h"
36 #include "net/socket/tcp_server_socket.h"
40 const char kLocalHostAddress
[] = "127.0.0.1";
41 const int kBufferSize
= 100 * 1024 * 1024; // 100 MB
43 typedef base::Callback
<
44 void(const net::HttpServerRequestInfo
&, const HttpResponseSenderFunc
&)>
45 HttpRequestHandlerFunc
;
47 class HttpServer
: public net::HttpServer::Delegate
{
49 explicit HttpServer(const HttpRequestHandlerFunc
& handle_request_func
)
50 : handle_request_func_(handle_request_func
),
51 weak_factory_(this) {}
53 virtual ~HttpServer() {}
55 bool Start(uint16 port
, bool allow_remote
) {
56 std::string binding_ip
= kLocalHostAddress
;
58 binding_ip
= "0.0.0.0";
59 scoped_ptr
<net::ServerSocket
> server_socket(
60 new net::TCPServerSocket(NULL
, net::NetLog::Source()));
61 server_socket
->ListenWithAddressAndPort(binding_ip
, port
, 1);
62 server_
.reset(new net::HttpServer(server_socket
.Pass(), this));
63 net::IPEndPoint address
;
64 return server_
->GetLocalAddress(&address
) == net::OK
;
67 // Overridden from net::HttpServer::Delegate:
68 void OnConnect(int connection_id
) override
{
69 server_
->SetSendBufferSize(connection_id
, kBufferSize
);
70 server_
->SetReceiveBufferSize(connection_id
, kBufferSize
);
72 void OnHttpRequest(int connection_id
,
73 const net::HttpServerRequestInfo
& info
) override
{
74 handle_request_func_
.Run(
76 base::Bind(&HttpServer::OnResponse
,
77 weak_factory_
.GetWeakPtr(),
80 void OnWebSocketRequest(int connection_id
,
81 const net::HttpServerRequestInfo
& info
) override
{}
82 void OnWebSocketMessage(int connection_id
, const std::string
& data
) override
{
84 void OnClose(int connection_id
) override
{}
87 void OnResponse(int connection_id
,
88 scoped_ptr
<net::HttpServerResponseInfo
> response
) {
89 // Don't support keep-alive, since there's no way to detect if the
90 // client is HTTP/1.0. In such cases, the client may hang waiting for
91 // the connection to close (e.g., python 2.7 urllib).
92 response
->AddHeader("Connection", "close");
93 server_
->SendResponse(connection_id
, *response
);
94 // Don't need to call server_->Close(), since SendResponse() will handle
98 HttpRequestHandlerFunc handle_request_func_
;
99 scoped_ptr
<net::HttpServer
> server_
;
100 base::WeakPtrFactory
<HttpServer
> weak_factory_
; // Should be last.
103 void SendResponseOnCmdThread(
104 const scoped_refptr
<base::SingleThreadTaskRunner
>& io_task_runner
,
105 const HttpResponseSenderFunc
& send_response_on_io_func
,
106 scoped_ptr
<net::HttpServerResponseInfo
> response
) {
107 io_task_runner
->PostTask(
108 FROM_HERE
, base::Bind(send_response_on_io_func
, base::Passed(&response
)));
111 void HandleRequestOnCmdThread(
112 HttpHandler
* handler
,
113 const std::vector
<std::string
>& whitelisted_ips
,
114 const net::HttpServerRequestInfo
& request
,
115 const HttpResponseSenderFunc
& send_response_func
) {
116 if (!whitelisted_ips
.empty()) {
117 std::string peer_address
= request
.peer
.ToStringWithoutPort();
118 if (peer_address
!= kLocalHostAddress
&&
119 std::find(whitelisted_ips
.begin(), whitelisted_ips
.end(),
120 peer_address
) == whitelisted_ips
.end()) {
121 LOG(WARNING
) << "unauthorized access from " << request
.peer
.ToString();
122 scoped_ptr
<net::HttpServerResponseInfo
> response(
123 new net::HttpServerResponseInfo(net::HTTP_UNAUTHORIZED
));
124 response
->SetBody("Unauthorized access", "text/plain");
125 send_response_func
.Run(response
.Pass());
130 handler
->Handle(request
, send_response_func
);
133 void HandleRequestOnIOThread(
134 const scoped_refptr
<base::SingleThreadTaskRunner
>& cmd_task_runner
,
135 const HttpRequestHandlerFunc
& handle_request_on_cmd_func
,
136 const net::HttpServerRequestInfo
& request
,
137 const HttpResponseSenderFunc
& send_response_func
) {
138 cmd_task_runner
->PostTask(
140 base::Bind(handle_request_on_cmd_func
,
142 base::Bind(&SendResponseOnCmdThread
,
143 base::MessageLoopProxy::current(),
144 send_response_func
)));
147 base::LazyInstance
<base::ThreadLocalPointer
<HttpServer
> >
148 lazy_tls_server
= LAZY_INSTANCE_INITIALIZER
;
150 void StopServerOnIOThread() {
151 // Note, |server| may be NULL.
152 HttpServer
* server
= lazy_tls_server
.Pointer()->Get();
153 lazy_tls_server
.Pointer()->Set(NULL
);
157 void StartServerOnIOThread(uint16 port
,
159 const HttpRequestHandlerFunc
& handle_request_func
) {
160 scoped_ptr
<HttpServer
> temp_server(new HttpServer(handle_request_func
));
161 if (!temp_server
->Start(port
, allow_remote
)) {
162 printf("Port not available. Exiting...\n");
165 lazy_tls_server
.Pointer()->Set(temp_server
.release());
168 void RunServer(uint16 port
,
170 const std::vector
<std::string
>& whitelisted_ips
,
171 const std::string
& url_base
,
173 scoped_ptr
<PortServer
> port_server
) {
174 base::Thread
io_thread("ChromeDriver IO");
175 CHECK(io_thread
.StartWithOptions(
176 base::Thread::Options(base::MessageLoop::TYPE_IO
, 0)));
178 base::MessageLoop cmd_loop
;
179 base::RunLoop cmd_run_loop
;
180 HttpHandler
handler(cmd_run_loop
.QuitClosure(),
181 io_thread
.message_loop_proxy(),
185 HttpRequestHandlerFunc handle_request_func
=
186 base::Bind(&HandleRequestOnCmdThread
, &handler
, whitelisted_ips
);
188 io_thread
.message_loop()
189 ->PostTask(FROM_HERE
,
190 base::Bind(&StartServerOnIOThread
,
193 base::Bind(&HandleRequestOnIOThread
,
194 cmd_loop
.message_loop_proxy(),
195 handle_request_func
)));
196 // Run the command loop. This loop is quit after the response for a shutdown
197 // request is posted to the IO loop. After the command loop quits, a task
198 // is posted to the IO loop to stop the server. Lastly, the IO thread is
199 // destroyed, which waits until all pending tasks have been completed.
200 // This assumes the response is sent synchronously as part of the IO task.
202 io_thread
.message_loop()
203 ->PostTask(FROM_HERE
, base::Bind(&StopServerOnIOThread
));
208 int main(int argc
, char *argv
[]) {
209 base::CommandLine::Init(argc
, argv
);
211 base::AtExitManager at_exit
;
212 base::CommandLine
* cmd_line
= base::CommandLine::ForCurrentProcess();
214 #if defined(OS_LINUX)
215 // Select the locale from the environment by passing an empty string instead
216 // of the default "C" locale. This is particularly needed for the keycode
217 // conversion code to work.
218 setlocale(LC_ALL
, "");
221 // Parse command line flags.
224 bool allow_remote
= false;
225 std::vector
<std::string
> whitelisted_ips
;
226 std::string url_base
;
227 scoped_ptr
<PortServer
> port_server
;
228 if (cmd_line
->HasSwitch("h") || cmd_line
->HasSwitch("help")) {
230 const char* const kOptionAndDescriptions
[] = {
231 "port=PORT", "port to listen on",
232 "adb-port=PORT", "adb server port",
233 "log-path=FILE", "write server log to file instead of stderr, "
234 "increases log level to INFO",
235 "verbose", "log verbosely",
236 "version", "print the version number and exit",
237 "silent", "log nothing",
238 "url-base", "base URL path prefix for commands, e.g. wd/url",
239 "port-server", "address of server to contact for reserving a port",
240 "whitelisted-ips", "comma-separated whitelist of remote IPv4 addresses "
241 "which are allowed to connect to ChromeDriver",
243 for (size_t i
= 0; i
< arraysize(kOptionAndDescriptions
) - 1; i
+= 2) {
244 options
+= base::StringPrintf(
246 kOptionAndDescriptions
[i
], kOptionAndDescriptions
[i
+ 1]);
248 printf("Usage: %s [OPTIONS]\n\nOptions\n%s", argv
[0], options
.c_str());
251 if (cmd_line
->HasSwitch("v") || cmd_line
->HasSwitch("version")) {
252 printf("ChromeDriver %s\n", kChromeDriverVersion
);
255 if (cmd_line
->HasSwitch("port")) {
257 if (!base::StringToInt(cmd_line
->GetSwitchValueASCII("port"),
259 cmd_line_port
< 0 || cmd_line_port
> 65535) {
260 printf("Invalid port. Exiting...\n");
263 port
= static_cast<uint16
>(cmd_line_port
);
265 if (cmd_line
->HasSwitch("adb-port")) {
266 if (!base::StringToInt(cmd_line
->GetSwitchValueASCII("adb-port"),
268 printf("Invalid adb-port. Exiting...\n");
272 if (cmd_line
->HasSwitch("port-server")) {
273 #if defined(OS_LINUX)
274 std::string address
= cmd_line
->GetSwitchValueASCII("port-server");
275 if (address
.empty() || address
[0] != '@') {
276 printf("Invalid port-server. Exiting...\n");
280 // First character of path is \0 to use Linux's abstract namespace.
282 path
+= address
.substr(1);
283 port_server
.reset(new PortServer(path
));
285 printf("Warning: port-server not implemented for this platform.\n");
288 if (cmd_line
->HasSwitch("url-base"))
289 url_base
= cmd_line
->GetSwitchValueASCII("url-base");
290 if (url_base
.empty() || url_base
[0] != '/')
291 url_base
= "/" + url_base
;
292 if (url_base
[url_base
.length() - 1] != '/')
293 url_base
= url_base
+ "/";
294 if (cmd_line
->HasSwitch("whitelisted-ips")) {
296 std::string whitelist
= cmd_line
->GetSwitchValueASCII("whitelisted-ips");
297 base::SplitString(whitelist
, ',', &whitelisted_ips
);
299 if (!cmd_line
->HasSwitch("silent")) {
300 printf("Starting ChromeDriver %s on port %u\n", kChromeDriverVersion
, port
);
302 printf("Only local connections are allowed.\n");
303 } else if (!whitelisted_ips
.empty()) {
304 printf("Remote connections are allowed by a whitelist (%s).\n",
305 cmd_line
->GetSwitchValueASCII("whitelisted-ips").c_str());
307 printf("All remote connections are allowed. Use a whitelist instead!\n");
312 if (!InitLogging()) {
313 printf("Unable to initialize logging. Exiting...\n");
316 RunServer(port
, allow_remote
, whitelisted_ips
,
317 url_base
, adb_port
, port_server
.Pass());