Roll src/third_party/WebKit c63b89c:29324ab (svn 202546:202547)
[chromium-blink-merge.git] / device / test / usb_test_gadget_impl.cc
blobdf85756207276754aa685a6b461f1459d21b4a31
1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "device/test/usb_test_gadget.h"
7 #include <string>
8 #include <vector>
10 #include "base/command_line.h"
11 #include "base/compiler_specific.h"
12 #include "base/files/file.h"
13 #include "base/files/file_path.h"
14 #include "base/logging.h"
15 #include "base/macros.h"
16 #include "base/memory/ref_counted.h"
17 #include "base/memory/scoped_ptr.h"
18 #include "base/path_service.h"
19 #include "base/process/process_handle.h"
20 #include "base/run_loop.h"
21 #include "base/scoped_observer.h"
22 #include "base/strings/stringprintf.h"
23 #include "base/strings/utf_string_conversions.h"
24 #include "base/thread_task_runner_handle.h"
25 #include "base/time/time.h"
26 #include "device/core/device_client.h"
27 #include "device/usb/usb_device.h"
28 #include "device/usb/usb_device_handle.h"
29 #include "device/usb/usb_service.h"
30 #include "net/proxy/proxy_service.h"
31 #include "net/url_request/url_fetcher.h"
32 #include "net/url_request/url_fetcher_delegate.h"
33 #include "net/url_request/url_request_context.h"
34 #include "net/url_request/url_request_context_builder.h"
35 #include "net/url_request/url_request_context_getter.h"
36 #include "url/gurl.h"
38 namespace device {
40 class UsbTestGadgetImpl : public UsbTestGadget {
41 public:
42 UsbTestGadgetImpl(
43 scoped_refptr<net::URLRequestContextGetter> request_context_getter,
44 UsbService* usb_service,
45 scoped_refptr<UsbDevice> device);
46 ~UsbTestGadgetImpl() override;
48 bool Unclaim() override;
49 bool Disconnect() override;
50 bool Reconnect() override;
51 bool SetType(Type type) override;
52 UsbDevice* GetDevice() const override;
54 private:
55 std::string device_address_;
56 scoped_refptr<UsbDevice> device_;
57 scoped_refptr<net::URLRequestContextGetter> request_context_getter_;
58 UsbService* usb_service_;
60 DISALLOW_COPY_AND_ASSIGN(UsbTestGadgetImpl);
63 namespace {
65 static const char kCommandLineSwitch[] = "enable-gadget-tests";
66 static const int kReenumeratePeriod = 100; // 0.1 seconds
68 struct UsbTestGadgetConfiguration {
69 UsbTestGadget::Type type;
70 const char* http_resource;
71 uint16 product_id;
74 static const struct UsbTestGadgetConfiguration kConfigurations[] = {
75 {UsbTestGadget::DEFAULT, "/unconfigure", 0x58F0},
76 {UsbTestGadget::KEYBOARD, "/keyboard/configure", 0x58F1},
77 {UsbTestGadget::MOUSE, "/mouse/configure", 0x58F2},
78 {UsbTestGadget::HID_ECHO, "/hid_echo/configure", 0x58F3},
79 {UsbTestGadget::ECHO, "/echo/configure", 0x58F4},
82 bool ReadFile(const base::FilePath& file_path, std::string* content) {
83 base::File file(file_path, base::File::FLAG_OPEN | base::File::FLAG_READ);
84 if (!file.IsValid()) {
85 LOG(ERROR) << "Cannot open " << file_path.MaybeAsASCII() << ": "
86 << base::File::ErrorToString(file.error_details());
87 return false;
90 STLClearObject(content);
91 int rv;
92 do {
93 char buf[4096];
94 rv = file.ReadAtCurrentPos(buf, sizeof buf);
95 if (rv == -1) {
96 LOG(ERROR) << "Cannot read " << file_path.MaybeAsASCII() << ": "
97 << base::File::ErrorToString(file.error_details());
98 return false;
100 content->append(buf, rv);
101 } while (rv > 0);
103 return true;
106 bool ReadLocalVersion(std::string* version) {
107 base::FilePath file_path;
108 CHECK(PathService::Get(base::DIR_EXE, &file_path));
109 file_path = file_path.AppendASCII("usb_gadget.zip.md5");
111 return ReadFile(file_path, version);
114 bool ReadLocalPackage(std::string* package) {
115 base::FilePath file_path;
116 CHECK(PathService::Get(base::DIR_EXE, &file_path));
117 file_path = file_path.AppendASCII("usb_gadget.zip");
119 return ReadFile(file_path, package);
122 scoped_ptr<net::URLFetcher> CreateURLFetcher(
123 scoped_refptr<net::URLRequestContextGetter> request_context_getter,
124 const GURL& url,
125 net::URLFetcher::RequestType request_type,
126 net::URLFetcherDelegate* delegate) {
127 scoped_ptr<net::URLFetcher> url_fetcher =
128 net::URLFetcher::Create(url, request_type, delegate);
130 url_fetcher->SetRequestContext(request_context_getter.get());
132 return url_fetcher;
135 class URLRequestContextGetter : public net::URLRequestContextGetter {
136 public:
137 URLRequestContextGetter(
138 scoped_refptr<base::SingleThreadTaskRunner> network_task_runner)
139 : network_task_runner_(network_task_runner) {}
141 private:
142 ~URLRequestContextGetter() override {}
144 // net::URLRequestContextGetter implementation
145 net::URLRequestContext* GetURLRequestContext() override {
146 if (!context_) {
147 net::URLRequestContextBuilder context_builder;
148 context_builder.set_proxy_service(net::ProxyService::CreateDirect());
149 context_ = context_builder.Build().Pass();
151 return context_.get();
154 scoped_refptr<base::SingleThreadTaskRunner> GetNetworkTaskRunner()
155 const override {
156 return network_task_runner_;
159 scoped_ptr<net::URLRequestContext> context_;
160 scoped_refptr<base::SingleThreadTaskRunner> network_task_runner_;
163 class URLFetcherDelegate : public net::URLFetcherDelegate {
164 public:
165 URLFetcherDelegate() {}
166 ~URLFetcherDelegate() override {}
168 void WaitForCompletion() { run_loop_.Run(); }
170 void OnURLFetchComplete(const net::URLFetcher* source) override {
171 run_loop_.Quit();
174 private:
175 base::RunLoop run_loop_;
177 DISALLOW_COPY_AND_ASSIGN(URLFetcherDelegate);
180 int SimplePOSTRequest(
181 scoped_refptr<net::URLRequestContextGetter> request_context_getter,
182 const GURL& url,
183 const std::string& form_data) {
184 URLFetcherDelegate delegate;
185 scoped_ptr<net::URLFetcher> url_fetcher = CreateURLFetcher(
186 request_context_getter, url, net::URLFetcher::POST, &delegate);
188 url_fetcher->SetUploadData("application/x-www-form-urlencoded", form_data);
189 url_fetcher->Start();
190 delegate.WaitForCompletion();
192 return url_fetcher->GetResponseCode();
195 class UsbGadgetFactory : public UsbService::Observer,
196 public net::URLFetcherDelegate {
197 public:
198 UsbGadgetFactory(scoped_refptr<base::SingleThreadTaskRunner> io_task_runner)
199 : observer_(this), weak_factory_(this) {
200 usb_service_ = DeviceClient::Get()->GetUsbService();
201 request_context_getter_ = new URLRequestContextGetter(io_task_runner);
203 static uint32 next_session_id;
204 base::ProcessId process_id = base::GetCurrentProcId();
205 session_id_ = base::StringPrintf("%d-%d", process_id, next_session_id++);
207 observer_.Add(usb_service_);
210 ~UsbGadgetFactory() override {}
212 scoped_ptr<UsbTestGadget> WaitForDevice() {
213 EnumerateDevices();
214 run_loop_.Run();
215 return make_scoped_ptr(
216 new UsbTestGadgetImpl(request_context_getter_, usb_service_, device_));
219 private:
220 void EnumerateDevices() {
221 if (!device_) {
222 usb_service_->GetDevices(base::Bind(
223 &UsbGadgetFactory::OnDevicesEnumerated, weak_factory_.GetWeakPtr()));
227 void OnDevicesEnumerated(
228 const std::vector<scoped_refptr<UsbDevice>>& devices) {
229 for (const scoped_refptr<UsbDevice>& device : devices) {
230 OnDeviceAdded(device);
233 if (!device_) {
234 // TODO(reillyg): This timer could be replaced by a way to use long-
235 // polling to wait for claimed devices to become unclaimed.
236 base::MessageLoop::current()->PostDelayedTask(
237 FROM_HERE, base::Bind(&UsbGadgetFactory::EnumerateDevices,
238 weak_factory_.GetWeakPtr()),
239 base::TimeDelta::FromMilliseconds(kReenumeratePeriod));
243 void OnDeviceAdded(scoped_refptr<UsbDevice> device) override {
244 if (device_.get()) {
245 // Already trying to claim a device.
246 return;
249 if (device->vendor_id() != 0x18D1 || device->product_id() != 0x58F0 ||
250 device->serial_number().empty()) {
251 return;
254 std::string serial_number = base::UTF16ToUTF8(device->serial_number());
255 if (serial_number == serial_number_) {
256 // We were waiting for the device to reappear after upgrade.
257 device_ = device;
258 run_loop_.Quit();
259 return;
262 device_ = device;
263 serial_number_ = serial_number;
264 Claim();
267 void Claim() {
268 VLOG(1) << "Trying to claim " << serial_number_ << ".";
270 GURL url("http://" + serial_number_ + "/claim");
271 std::string form_data = base::StringPrintf(
272 "session_id=%s", net::EscapeUrlEncodedData(session_id_, true).c_str());
273 url_fetcher_ = CreateURLFetcher(request_context_getter_, url,
274 net::URLFetcher::POST, this);
275 url_fetcher_->SetUploadData("application/x-www-form-urlencoded", form_data);
276 url_fetcher_->Start();
279 void GetVersion() {
280 GURL url("http://" + serial_number_ + "/version");
281 url_fetcher_ = CreateURLFetcher(request_context_getter_, url,
282 net::URLFetcher::GET, this);
283 url_fetcher_->Start();
286 bool Update(const std::string& version) {
287 LOG(INFO) << "Updating " << serial_number_ << " to " << version << "...";
289 GURL url("http://" + serial_number_ + "/update");
290 url_fetcher_ = CreateURLFetcher(request_context_getter_, url,
291 net::URLFetcher::POST, this);
292 std::string mime_header = base::StringPrintf(
293 "--foo\r\n"
294 "Content-Disposition: form-data; name=\"file\"; "
295 "filename=\"usb_gadget-%s.zip\"\r\n"
296 "Content-Type: application/octet-stream\r\n"
297 "\r\n",
298 version.c_str());
299 std::string mime_footer("\r\n--foo--\r\n");
301 std::string package;
302 if (!ReadLocalPackage(&package)) {
303 return false;
306 url_fetcher_->SetUploadData("multipart/form-data; boundary=foo",
307 mime_header + package + mime_footer);
308 url_fetcher_->Start();
309 device_ = nullptr;
310 return true;
313 void OnURLFetchComplete(const net::URLFetcher* source) override {
314 DCHECK(!serial_number_.empty());
316 int response_code = source->GetResponseCode();
317 if (!claimed_) {
318 // Just completed a /claim request.
319 if (response_code == 200) {
320 claimed_ = true;
321 GetVersion();
322 } else {
323 if (response_code != 403) {
324 LOG(WARNING) << "Unexpected HTTP " << response_code
325 << " from /claim.";
327 Reset();
329 } else if (version_.empty()) {
330 // Just completed a /version request.
331 if (response_code != 200) {
332 LOG(WARNING) << "Unexpected HTTP " << response_code
333 << " from /version.";
334 Reset();
335 return;
338 if (!source->GetResponseAsString(&version_)) {
339 LOG(WARNING) << "Failed to read body from /version.";
340 Reset();
341 return;
344 std::string local_version;
345 if (!ReadLocalVersion(&local_version)) {
346 Reset();
347 return;
350 if (version_ == local_version) {
351 run_loop_.Quit();
352 } else {
353 if (!Update(local_version)) {
354 Reset();
357 } else {
358 // Just completed an /update request.
359 if (response_code != 200) {
360 LOG(WARNING) << "Unexpected HTTP " << response_code << " from /update.";
361 Reset();
362 return;
365 // Must wait for the device to reconnect.
369 void Reset() {
370 device_ = nullptr;
371 serial_number_.clear();
372 claimed_ = false;
373 version_.clear();
375 // Wait a bit and then try again to find an available device.
376 base::MessageLoop::current()->PostDelayedTask(
377 FROM_HERE, base::Bind(&UsbGadgetFactory::EnumerateDevices,
378 weak_factory_.GetWeakPtr()),
379 base::TimeDelta::FromMilliseconds(kReenumeratePeriod));
382 UsbService* usb_service_ = nullptr;
383 scoped_refptr<net::URLRequestContextGetter> request_context_getter_;
384 std::string session_id_;
385 scoped_ptr<net::URLFetcher> url_fetcher_;
386 scoped_refptr<UsbDevice> device_;
387 std::string serial_number_;
388 bool claimed_ = false;
389 std::string version_;
390 base::RunLoop run_loop_;
391 ScopedObserver<UsbService, UsbService::Observer> observer_;
392 base::WeakPtrFactory<UsbGadgetFactory> weak_factory_;
395 class DeviceAddListener : public UsbService::Observer {
396 public:
397 DeviceAddListener(UsbService* usb_service,
398 const std::string& serial_number,
399 int product_id)
400 : usb_service_(usb_service),
401 serial_number_(serial_number),
402 product_id_(product_id),
403 observer_(this),
404 weak_factory_(this) {
405 observer_.Add(usb_service_);
407 ~DeviceAddListener() override {}
409 scoped_refptr<UsbDevice> WaitForAdd() {
410 usb_service_->GetDevices(base::Bind(&DeviceAddListener::OnDevicesEnumerated,
411 weak_factory_.GetWeakPtr()));
412 run_loop_.Run();
413 return device_;
416 private:
417 void OnDevicesEnumerated(
418 const std::vector<scoped_refptr<UsbDevice>>& devices) {
419 for (const scoped_refptr<UsbDevice>& device : devices) {
420 OnDeviceAdded(device);
424 void OnDeviceAdded(scoped_refptr<UsbDevice> device) override {
425 if (device->vendor_id() == 0x18D1 && !device->serial_number().empty()) {
426 const uint16 product_id = device->product_id();
427 if (product_id_ == -1) {
428 bool found = false;
429 for (size_t i = 0; i < arraysize(kConfigurations); ++i) {
430 if (product_id == kConfigurations[i].product_id) {
431 found = true;
432 break;
435 if (!found) {
436 return;
438 } else {
439 if (product_id_ != product_id) {
440 return;
444 if (serial_number_ != base::UTF16ToUTF8(device->serial_number())) {
445 return;
448 device_ = device;
449 run_loop_.Quit();
453 UsbService* usb_service_;
454 const std::string serial_number_;
455 const int product_id_;
456 base::RunLoop run_loop_;
457 scoped_refptr<UsbDevice> device_;
458 ScopedObserver<UsbService, UsbService::Observer> observer_;
459 base::WeakPtrFactory<DeviceAddListener> weak_factory_;
461 DISALLOW_COPY_AND_ASSIGN(DeviceAddListener);
464 class DeviceRemoveListener : public UsbService::Observer {
465 public:
466 DeviceRemoveListener(UsbService* usb_service, scoped_refptr<UsbDevice> device)
467 : usb_service_(usb_service),
468 device_(device),
469 observer_(this),
470 weak_factory_(this) {
471 observer_.Add(usb_service_);
473 ~DeviceRemoveListener() override {}
475 void WaitForRemove() {
476 usb_service_->GetDevices(
477 base::Bind(&DeviceRemoveListener::OnDevicesEnumerated,
478 weak_factory_.GetWeakPtr()));
479 run_loop_.Run();
482 private:
483 void OnDevicesEnumerated(
484 const std::vector<scoped_refptr<UsbDevice>>& devices) {
485 bool found = false;
486 for (const scoped_refptr<UsbDevice>& device : devices) {
487 if (device_ == device) {
488 found = true;
491 if (!found) {
492 run_loop_.Quit();
496 void OnDeviceRemoved(scoped_refptr<UsbDevice> device) override {
497 if (device_ == device) {
498 run_loop_.Quit();
502 UsbService* usb_service_;
503 base::RunLoop run_loop_;
504 scoped_refptr<UsbDevice> device_;
505 ScopedObserver<UsbService, UsbService::Observer> observer_;
506 base::WeakPtrFactory<DeviceRemoveListener> weak_factory_;
508 DISALLOW_COPY_AND_ASSIGN(DeviceRemoveListener);
511 } // namespace
513 bool UsbTestGadget::IsTestEnabled() {
514 base::CommandLine* command_line = base::CommandLine::ForCurrentProcess();
515 return command_line->HasSwitch(kCommandLineSwitch);
518 scoped_ptr<UsbTestGadget> UsbTestGadget::Claim(
519 scoped_refptr<base::SingleThreadTaskRunner> io_task_runner) {
520 UsbGadgetFactory gadget_factory(io_task_runner);
521 return gadget_factory.WaitForDevice().Pass();
524 UsbTestGadgetImpl::UsbTestGadgetImpl(
525 scoped_refptr<net::URLRequestContextGetter> request_context_getter_,
526 UsbService* usb_service,
527 scoped_refptr<UsbDevice> device)
528 : device_address_(base::UTF16ToUTF8(device->serial_number())),
529 device_(device),
530 request_context_getter_(request_context_getter_),
531 usb_service_(usb_service) {
534 UsbTestGadgetImpl::~UsbTestGadgetImpl() {
535 if (!device_address_.empty()) {
536 Unclaim();
540 UsbDevice* UsbTestGadgetImpl::GetDevice() const {
541 return device_.get();
544 bool UsbTestGadgetImpl::Unclaim() {
545 VLOG(1) << "Releasing the device at " << device_address_ << ".";
547 GURL url("http://" + device_address_ + "/unclaim");
548 int response_code = SimplePOSTRequest(request_context_getter_, url, "");
550 if (response_code != 200) {
551 LOG(ERROR) << "Unexpected HTTP " << response_code << " from /unclaim.";
552 return false;
555 device_address_.clear();
556 return true;
559 bool UsbTestGadgetImpl::SetType(Type type) {
560 const struct UsbTestGadgetConfiguration* config = NULL;
561 for (size_t i = 0; i < arraysize(kConfigurations); ++i) {
562 if (kConfigurations[i].type == type) {
563 config = &kConfigurations[i];
566 CHECK(config);
568 GURL url("http://" + device_address_ + config->http_resource);
569 int response_code = SimplePOSTRequest(request_context_getter_, url, "");
571 if (response_code != 200) {
572 LOG(ERROR) << "Unexpected HTTP " << response_code
573 << " from " << config->http_resource << ".";
574 return false;
577 // Release the old reference to the device and try to open a new one.
578 DeviceAddListener add_listener(usb_service_, device_address_,
579 config->product_id);
580 device_ = add_listener.WaitForAdd();
581 DCHECK(device_.get());
582 return true;
585 bool UsbTestGadgetImpl::Disconnect() {
586 GURL url("http://" + device_address_ + "/disconnect");
587 int response_code = SimplePOSTRequest(request_context_getter_, url, "");
589 if (response_code != 200) {
590 LOG(ERROR) << "Unexpected HTTP " << response_code << " from " << url << ".";
591 return false;
594 // Release the old reference to the device and wait until it can't be found.
595 DeviceRemoveListener remove_listener(usb_service_, device_);
596 remove_listener.WaitForRemove();
597 device_ = nullptr;
598 return true;
601 bool UsbTestGadgetImpl::Reconnect() {
602 GURL url("http://" + device_address_ + "/reconnect");
603 int response_code = SimplePOSTRequest(request_context_getter_, url, "");
605 if (response_code != 200) {
606 LOG(ERROR) << "Unexpected HTTP " << response_code << " from " << url << ".";
607 return false;
610 DeviceAddListener add_listener(usb_service_, device_address_, -1);
611 device_ = add_listener.WaitForAdd();
612 DCHECK(device_.get());
613 return true;
616 } // namespace device