Roll src/third_party/WebKit bf18a82:a9cee16 (svn 185297:185304)
[chromium-blink-merge.git] / chrome / test / chromedriver / server / chromedriver_server.cc
blobd7f09d222e0e18cb1a540a1113a1d1d16ade36f6
1 // Copyright (c) 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include <stdio.h>
6 #include <locale>
7 #include <string>
8 #include <vector>
10 #include "base/at_exit.h"
11 #include "base/bind.h"
12 #include "base/callback.h"
13 #include "base/command_line.h"
14 #include "base/files/file_path.h"
15 #include "base/lazy_instance.h"
16 #include "base/logging.h"
17 #include "base/memory/scoped_ptr.h"
18 #include "base/message_loop/message_loop.h"
19 #include "base/run_loop.h"
20 #include "base/strings/string_number_conversions.h"
21 #include "base/strings/string_split.h"
22 #include "base/strings/string_util.h"
23 #include "base/strings/stringprintf.h"
24 #include "base/synchronization/waitable_event.h"
25 #include "base/threading/thread.h"
26 #include "base/threading/thread_local.h"
27 #include "chrome/test/chromedriver/logging.h"
28 #include "chrome/test/chromedriver/net/port_server.h"
29 #include "chrome/test/chromedriver/server/http_handler.h"
30 #include "chrome/test/chromedriver/version.h"
31 #include "net/base/ip_endpoint.h"
32 #include "net/base/net_errors.h"
33 #include "net/server/http_server.h"
34 #include "net/server/http_server_request_info.h"
35 #include "net/server/http_server_response_info.h"
36 #include "net/socket/tcp_server_socket.h"
38 namespace {
40 const char kLocalHostAddress[] = "127.0.0.1";
41 const int kBufferSize = 100 * 1024 * 1024; // 100 MB
43 typedef base::Callback<
44 void(const net::HttpServerRequestInfo&, const HttpResponseSenderFunc&)>
45 HttpRequestHandlerFunc;
47 class HttpServer : public net::HttpServer::Delegate {
48 public:
49 explicit HttpServer(const HttpRequestHandlerFunc& handle_request_func)
50 : handle_request_func_(handle_request_func),
51 weak_factory_(this) {}
53 virtual ~HttpServer() {}
55 bool Start(int port, bool allow_remote) {
56 std::string binding_ip = kLocalHostAddress;
57 if (allow_remote)
58 binding_ip = "0.0.0.0";
59 scoped_ptr<net::ServerSocket> server_socket(
60 new net::TCPServerSocket(NULL, net::NetLog::Source()));
61 server_socket->ListenWithAddressAndPort(binding_ip, port, 1);
62 server_.reset(new net::HttpServer(server_socket.Pass(), this));
63 net::IPEndPoint address;
64 return server_->GetLocalAddress(&address) == net::OK;
67 // Overridden from net::HttpServer::Delegate:
68 void OnConnect(int connection_id) override {
69 server_->SetSendBufferSize(connection_id, kBufferSize);
70 server_->SetReceiveBufferSize(connection_id, kBufferSize);
72 void OnHttpRequest(int connection_id,
73 const net::HttpServerRequestInfo& info) override {
74 handle_request_func_.Run(
75 info,
76 base::Bind(&HttpServer::OnResponse,
77 weak_factory_.GetWeakPtr(),
78 connection_id));
80 void OnWebSocketRequest(int connection_id,
81 const net::HttpServerRequestInfo& info) override {}
82 void OnWebSocketMessage(int connection_id, const std::string& data) override {
84 void OnClose(int connection_id) override {}
86 private:
87 void OnResponse(int connection_id,
88 scoped_ptr<net::HttpServerResponseInfo> response) {
89 // Don't support keep-alive, since there's no way to detect if the
90 // client is HTTP/1.0. In such cases, the client may hang waiting for
91 // the connection to close (e.g., python 2.7 urllib).
92 response->AddHeader("Connection", "close");
93 server_->SendResponse(connection_id, *response);
94 // Don't need to call server_->Close(), since SendResponse() will handle
95 // this for us.
98 HttpRequestHandlerFunc handle_request_func_;
99 scoped_ptr<net::HttpServer> server_;
100 base::WeakPtrFactory<HttpServer> weak_factory_; // Should be last.
103 void SendResponseOnCmdThread(
104 const scoped_refptr<base::SingleThreadTaskRunner>& io_task_runner,
105 const HttpResponseSenderFunc& send_response_on_io_func,
106 scoped_ptr<net::HttpServerResponseInfo> response) {
107 io_task_runner->PostTask(
108 FROM_HERE, base::Bind(send_response_on_io_func, base::Passed(&response)));
111 void HandleRequestOnCmdThread(
112 HttpHandler* handler,
113 const std::vector<std::string>& whitelisted_ips,
114 const net::HttpServerRequestInfo& request,
115 const HttpResponseSenderFunc& send_response_func) {
116 if (!whitelisted_ips.empty()) {
117 std::string peer_address = request.peer.ToStringWithoutPort();
118 if (peer_address != kLocalHostAddress &&
119 std::find(whitelisted_ips.begin(), whitelisted_ips.end(),
120 peer_address) == whitelisted_ips.end()) {
121 LOG(WARNING) << "unauthorized access from " << request.peer.ToString();
122 scoped_ptr<net::HttpServerResponseInfo> response(
123 new net::HttpServerResponseInfo(net::HTTP_UNAUTHORIZED));
124 response->SetBody("Unauthorized access", "text/plain");
125 send_response_func.Run(response.Pass());
126 return;
130 handler->Handle(request, send_response_func);
133 void HandleRequestOnIOThread(
134 const scoped_refptr<base::SingleThreadTaskRunner>& cmd_task_runner,
135 const HttpRequestHandlerFunc& handle_request_on_cmd_func,
136 const net::HttpServerRequestInfo& request,
137 const HttpResponseSenderFunc& send_response_func) {
138 cmd_task_runner->PostTask(
139 FROM_HERE,
140 base::Bind(handle_request_on_cmd_func,
141 request,
142 base::Bind(&SendResponseOnCmdThread,
143 base::MessageLoopProxy::current(),
144 send_response_func)));
147 base::LazyInstance<base::ThreadLocalPointer<HttpServer> >
148 lazy_tls_server = LAZY_INSTANCE_INITIALIZER;
150 void StopServerOnIOThread() {
151 // Note, |server| may be NULL.
152 HttpServer* server = lazy_tls_server.Pointer()->Get();
153 lazy_tls_server.Pointer()->Set(NULL);
154 delete server;
157 void StartServerOnIOThread(int port,
158 bool allow_remote,
159 const HttpRequestHandlerFunc& handle_request_func) {
160 scoped_ptr<HttpServer> temp_server(new HttpServer(handle_request_func));
161 if (!temp_server->Start(port, allow_remote)) {
162 printf("Port not available. Exiting...\n");
163 exit(1);
165 lazy_tls_server.Pointer()->Set(temp_server.release());
168 void RunServer(int port,
169 bool allow_remote,
170 const std::vector<std::string>& whitelisted_ips,
171 const std::string& url_base,
172 int adb_port,
173 scoped_ptr<PortServer> port_server) {
174 base::Thread io_thread("ChromeDriver IO");
175 CHECK(io_thread.StartWithOptions(
176 base::Thread::Options(base::MessageLoop::TYPE_IO, 0)));
178 base::MessageLoop cmd_loop;
179 base::RunLoop cmd_run_loop;
180 HttpHandler handler(cmd_run_loop.QuitClosure(),
181 io_thread.message_loop_proxy(),
182 url_base,
183 adb_port,
184 port_server.Pass());
185 HttpRequestHandlerFunc handle_request_func =
186 base::Bind(&HandleRequestOnCmdThread, &handler, whitelisted_ips);
188 io_thread.message_loop()
189 ->PostTask(FROM_HERE,
190 base::Bind(&StartServerOnIOThread,
191 port,
192 allow_remote,
193 base::Bind(&HandleRequestOnIOThread,
194 cmd_loop.message_loop_proxy(),
195 handle_request_func)));
196 // Run the command loop. This loop is quit after the response for a shutdown
197 // request is posted to the IO loop. After the command loop quits, a task
198 // is posted to the IO loop to stop the server. Lastly, the IO thread is
199 // destroyed, which waits until all pending tasks have been completed.
200 // This assumes the response is sent synchronously as part of the IO task.
201 cmd_run_loop.Run();
202 io_thread.message_loop()
203 ->PostTask(FROM_HERE, base::Bind(&StopServerOnIOThread));
206 } // namespace
208 int main(int argc, char *argv[]) {
209 CommandLine::Init(argc, argv);
211 base::AtExitManager at_exit;
212 CommandLine* cmd_line = CommandLine::ForCurrentProcess();
214 #if defined(OS_LINUX)
215 // Select the locale from the environment by passing an empty string instead
216 // of the default "C" locale. This is particularly needed for the keycode
217 // conversion code to work.
218 setlocale(LC_ALL, "");
219 #endif
221 // Parse command line flags.
222 int port = 9515;
223 int adb_port = 5037;
224 bool allow_remote = false;
225 std::vector<std::string> whitelisted_ips;
226 std::string url_base;
227 scoped_ptr<PortServer> port_server;
228 if (cmd_line->HasSwitch("h") || cmd_line->HasSwitch("help")) {
229 std::string options;
230 const char* const kOptionAndDescriptions[] = {
231 "port=PORT", "port to listen on",
232 "adb-port=PORT", "adb server port",
233 "log-path=FILE", "write server log to file instead of stderr, "
234 "increases log level to INFO",
235 "verbose", "log verbosely",
236 "version", "print the version number and exit",
237 "silent", "log nothing",
238 "url-base", "base URL path prefix for commands, e.g. wd/url",
239 "port-server", "address of server to contact for reserving a port",
240 "whitelisted-ips", "comma-separated whitelist of remote IPv4 addresses "
241 "which are allowed to connect to ChromeDriver",
243 for (size_t i = 0; i < arraysize(kOptionAndDescriptions) - 1; i += 2) {
244 options += base::StringPrintf(
245 " --%-30s%s\n",
246 kOptionAndDescriptions[i], kOptionAndDescriptions[i + 1]);
248 printf("Usage: %s [OPTIONS]\n\nOptions\n%s", argv[0], options.c_str());
249 return 0;
251 if (cmd_line->HasSwitch("v") || cmd_line->HasSwitch("version")) {
252 printf("ChromeDriver %s\n", kChromeDriverVersion);
253 return 0;
255 if (cmd_line->HasSwitch("port")) {
256 if (!base::StringToInt(cmd_line->GetSwitchValueASCII("port"), &port)) {
257 printf("Invalid port. Exiting...\n");
258 return 1;
261 if (cmd_line->HasSwitch("adb-port")) {
262 if (!base::StringToInt(cmd_line->GetSwitchValueASCII("adb-port"),
263 &adb_port)) {
264 printf("Invalid adb-port. Exiting...\n");
265 return 1;
268 if (cmd_line->HasSwitch("port-server")) {
269 #if defined(OS_LINUX)
270 std::string address = cmd_line->GetSwitchValueASCII("port-server");
271 if (address.empty() || address[0] != '@') {
272 printf("Invalid port-server. Exiting...\n");
273 return 1;
275 std::string path;
276 // First character of path is \0 to use Linux's abstract namespace.
277 path.push_back(0);
278 path += address.substr(1);
279 port_server.reset(new PortServer(path));
280 #else
281 printf("Warning: port-server not implemented for this platform.\n");
282 #endif
284 if (cmd_line->HasSwitch("url-base"))
285 url_base = cmd_line->GetSwitchValueASCII("url-base");
286 if (url_base.empty() || url_base[0] != '/')
287 url_base = "/" + url_base;
288 if (url_base[url_base.length() - 1] != '/')
289 url_base = url_base + "/";
290 if (cmd_line->HasSwitch("whitelisted-ips")) {
291 allow_remote = true;
292 std::string whitelist = cmd_line->GetSwitchValueASCII("whitelisted-ips");
293 base::SplitString(whitelist, ',', &whitelisted_ips);
295 if (!cmd_line->HasSwitch("silent")) {
296 printf("Starting ChromeDriver %s on port %d\n", kChromeDriverVersion, port);
297 if (!allow_remote) {
298 printf("Only local connections are allowed.\n");
299 } else if (!whitelisted_ips.empty()) {
300 printf("Remote connections are allowed by a whitelist (%s).\n",
301 cmd_line->GetSwitchValueASCII("whitelisted-ips").c_str());
302 } else {
303 printf("All remote connections are allowed. Use a whitelist instead!\n");
305 fflush(stdout);
308 if (!InitLogging()) {
309 printf("Unable to initialize logging. Exiting...\n");
310 return 1;
312 RunServer(port, allow_remote, whitelisted_ips,
313 url_base, adb_port, port_server.Pass());
314 return 0;