Don't clobber the download metadata file when downloading a file for which we don...
[chromium-blink-merge.git] / device / test / usb_test_gadget_impl.cc
blob0d88d57661d916707be5bc362f856d7c8dbff602
1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "device/test/usb_test_gadget.h"
7 #include <string>
8 #include <vector>
10 #include "base/command_line.h"
11 #include "base/compiler_specific.h"
12 #include "base/files/file.h"
13 #include "base/files/file_path.h"
14 #include "base/logging.h"
15 #include "base/macros.h"
16 #include "base/memory/ref_counted.h"
17 #include "base/memory/scoped_ptr.h"
18 #include "base/path_service.h"
19 #include "base/process/process_handle.h"
20 #include "base/run_loop.h"
21 #include "base/scoped_observer.h"
22 #include "base/strings/stringprintf.h"
23 #include "base/strings/utf_string_conversions.h"
24 #include "base/thread_task_runner_handle.h"
25 #include "base/time/time.h"
26 #include "device/core/device_client.h"
27 #include "device/usb/usb_device.h"
28 #include "device/usb/usb_device_handle.h"
29 #include "device/usb/usb_service.h"
30 #include "net/proxy/proxy_service.h"
31 #include "net/url_request/url_fetcher.h"
32 #include "net/url_request/url_fetcher_delegate.h"
33 #include "net/url_request/url_request_context.h"
34 #include "net/url_request/url_request_context_builder.h"
35 #include "net/url_request/url_request_context_getter.h"
36 #include "url/gurl.h"
38 namespace device {
40 class UsbTestGadgetImpl : public UsbTestGadget {
41 public:
42 UsbTestGadgetImpl(
43 scoped_refptr<net::URLRequestContextGetter> request_context_getter,
44 UsbService* usb_service,
45 scoped_refptr<UsbDevice> device);
46 ~UsbTestGadgetImpl() override;
48 bool Unclaim() override;
49 bool Disconnect() override;
50 bool Reconnect() override;
51 bool SetType(Type type) override;
52 UsbDevice* GetDevice() const override;
54 private:
55 std::string device_address_;
56 scoped_refptr<UsbDevice> device_;
57 scoped_refptr<net::URLRequestContextGetter> request_context_getter_;
58 UsbService* usb_service_;
60 DISALLOW_COPY_AND_ASSIGN(UsbTestGadgetImpl);
63 namespace {
65 static const char kCommandLineSwitch[] = "enable-gadget-tests";
66 static const int kReenumeratePeriod = 100; // 0.1 seconds
68 struct UsbTestGadgetConfiguration {
69 UsbTestGadget::Type type;
70 const char* http_resource;
71 uint16 product_id;
74 static const struct UsbTestGadgetConfiguration kConfigurations[] = {
75 {UsbTestGadget::DEFAULT, "/unconfigure", 0x58F0},
76 {UsbTestGadget::KEYBOARD, "/keyboard/configure", 0x58F1},
77 {UsbTestGadget::MOUSE, "/mouse/configure", 0x58F2},
78 {UsbTestGadget::HID_ECHO, "/hid_echo/configure", 0x58F3},
79 {UsbTestGadget::ECHO, "/echo/configure", 0x58F4},
82 bool ReadFile(const base::FilePath& file_path, std::string* content) {
83 base::File file(file_path, base::File::FLAG_OPEN | base::File::FLAG_READ);
84 if (!file.IsValid()) {
85 LOG(ERROR) << "Cannot open " << file_path.MaybeAsASCII() << ": "
86 << base::File::ErrorToString(file.error_details());
87 return false;
90 STLClearObject(content);
91 int rv;
92 do {
93 char buf[4096];
94 rv = file.ReadAtCurrentPos(buf, sizeof buf);
95 if (rv == -1) {
96 LOG(ERROR) << "Cannot read " << file_path.MaybeAsASCII() << ": "
97 << base::File::ErrorToString(file.error_details());
98 return false;
100 content->append(buf, rv);
101 } while (rv > 0);
103 return true;
106 bool ReadLocalVersion(std::string* version) {
107 base::FilePath file_path;
108 CHECK(PathService::Get(base::DIR_EXE, &file_path));
109 file_path = file_path.AppendASCII("usb_gadget.zip.md5");
111 return ReadFile(file_path, version);
114 bool ReadLocalPackage(std::string* package) {
115 base::FilePath file_path;
116 CHECK(PathService::Get(base::DIR_EXE, &file_path));
117 file_path = file_path.AppendASCII("usb_gadget.zip");
119 return ReadFile(file_path, package);
122 scoped_ptr<net::URLFetcher> CreateURLFetcher(
123 scoped_refptr<net::URLRequestContextGetter> request_context_getter,
124 const GURL& url,
125 net::URLFetcher::RequestType request_type,
126 net::URLFetcherDelegate* delegate) {
127 scoped_ptr<net::URLFetcher> url_fetcher =
128 net::URLFetcher::Create(url, request_type, delegate);
130 url_fetcher->SetRequestContext(request_context_getter.get());
132 return url_fetcher;
135 class URLRequestContextGetter : public net::URLRequestContextGetter {
136 public:
137 URLRequestContextGetter(
138 scoped_refptr<base::SingleThreadTaskRunner> network_task_runner)
139 : network_task_runner_(network_task_runner) {}
141 private:
142 ~URLRequestContextGetter() override {}
144 // net::URLRequestContextGetter implementation
145 net::URLRequestContext* GetURLRequestContext() override {
146 if (!context_) {
147 net::URLRequestContextBuilder context_builder;
148 context_builder.set_proxy_service(
149 make_scoped_ptr(net::ProxyService::CreateDirect()));
150 context_ = context_builder.Build().Pass();
152 return context_.get();
155 scoped_refptr<base::SingleThreadTaskRunner> GetNetworkTaskRunner()
156 const override {
157 return network_task_runner_;
160 scoped_ptr<net::URLRequestContext> context_;
161 scoped_refptr<base::SingleThreadTaskRunner> network_task_runner_;
164 class URLFetcherDelegate : public net::URLFetcherDelegate {
165 public:
166 URLFetcherDelegate() {}
167 ~URLFetcherDelegate() override {}
169 void WaitForCompletion() { run_loop_.Run(); }
171 void OnURLFetchComplete(const net::URLFetcher* source) override {
172 run_loop_.Quit();
175 private:
176 base::RunLoop run_loop_;
178 DISALLOW_COPY_AND_ASSIGN(URLFetcherDelegate);
181 int SimplePOSTRequest(
182 scoped_refptr<net::URLRequestContextGetter> request_context_getter,
183 const GURL& url,
184 const std::string& form_data) {
185 URLFetcherDelegate delegate;
186 scoped_ptr<net::URLFetcher> url_fetcher = CreateURLFetcher(
187 request_context_getter, url, net::URLFetcher::POST, &delegate);
189 url_fetcher->SetUploadData("application/x-www-form-urlencoded", form_data);
190 url_fetcher->Start();
191 delegate.WaitForCompletion();
193 return url_fetcher->GetResponseCode();
196 class UsbGadgetFactory : public UsbService::Observer,
197 public net::URLFetcherDelegate {
198 public:
199 UsbGadgetFactory(scoped_refptr<base::SingleThreadTaskRunner> io_task_runner)
200 : observer_(this), weak_factory_(this) {
201 usb_service_ = DeviceClient::Get()->GetUsbService();
202 request_context_getter_ = new URLRequestContextGetter(io_task_runner);
204 static uint32 next_session_id;
205 base::ProcessId process_id = base::GetCurrentProcId();
206 session_id_ = base::StringPrintf("%d-%d", process_id, next_session_id++);
208 observer_.Add(usb_service_);
211 ~UsbGadgetFactory() override {}
213 scoped_ptr<UsbTestGadget> WaitForDevice() {
214 EnumerateDevices();
215 run_loop_.Run();
216 return make_scoped_ptr(
217 new UsbTestGadgetImpl(request_context_getter_, usb_service_, device_));
220 private:
221 void EnumerateDevices() {
222 if (!device_) {
223 usb_service_->GetDevices(base::Bind(
224 &UsbGadgetFactory::OnDevicesEnumerated, weak_factory_.GetWeakPtr()));
228 void OnDevicesEnumerated(
229 const std::vector<scoped_refptr<UsbDevice>>& devices) {
230 for (const scoped_refptr<UsbDevice>& device : devices) {
231 OnDeviceAdded(device);
234 if (!device_) {
235 // TODO(reillyg): This timer could be replaced by a way to use long-
236 // polling to wait for claimed devices to become unclaimed.
237 base::MessageLoop::current()->PostDelayedTask(
238 FROM_HERE, base::Bind(&UsbGadgetFactory::EnumerateDevices,
239 weak_factory_.GetWeakPtr()),
240 base::TimeDelta::FromMilliseconds(kReenumeratePeriod));
244 void OnDeviceAdded(scoped_refptr<UsbDevice> device) override {
245 if (device_.get()) {
246 // Already trying to claim a device.
247 return;
250 if (device->vendor_id() != 0x18D1 || device->product_id() != 0x58F0 ||
251 device->serial_number().empty()) {
252 return;
255 std::string serial_number = base::UTF16ToUTF8(device->serial_number());
256 if (serial_number == serial_number_) {
257 // We were waiting for the device to reappear after upgrade.
258 device_ = device;
259 run_loop_.Quit();
260 return;
263 device_ = device;
264 serial_number_ = serial_number;
265 Claim();
268 void Claim() {
269 VLOG(1) << "Trying to claim " << serial_number_ << ".";
271 GURL url("http://" + serial_number_ + "/claim");
272 std::string form_data = base::StringPrintf(
273 "session_id=%s", net::EscapeUrlEncodedData(session_id_, true).c_str());
274 url_fetcher_ = CreateURLFetcher(request_context_getter_, url,
275 net::URLFetcher::POST, this);
276 url_fetcher_->SetUploadData("application/x-www-form-urlencoded", form_data);
277 url_fetcher_->Start();
280 void GetVersion() {
281 GURL url("http://" + serial_number_ + "/version");
282 url_fetcher_ = CreateURLFetcher(request_context_getter_, url,
283 net::URLFetcher::GET, this);
284 url_fetcher_->Start();
287 bool Update(const std::string& version) {
288 LOG(INFO) << "Updating " << serial_number_ << " to " << version << "...";
290 GURL url("http://" + serial_number_ + "/update");
291 url_fetcher_ = CreateURLFetcher(request_context_getter_, url,
292 net::URLFetcher::POST, this);
293 std::string mime_header = base::StringPrintf(
294 "--foo\r\n"
295 "Content-Disposition: form-data; name=\"file\"; "
296 "filename=\"usb_gadget-%s.zip\"\r\n"
297 "Content-Type: application/octet-stream\r\n"
298 "\r\n",
299 version.c_str());
300 std::string mime_footer("\r\n--foo--\r\n");
302 std::string package;
303 if (!ReadLocalPackage(&package)) {
304 return false;
307 url_fetcher_->SetUploadData("multipart/form-data; boundary=foo",
308 mime_header + package + mime_footer);
309 url_fetcher_->Start();
310 device_ = nullptr;
311 return true;
314 void OnURLFetchComplete(const net::URLFetcher* source) override {
315 DCHECK(!serial_number_.empty());
317 int response_code = source->GetResponseCode();
318 if (!claimed_) {
319 // Just completed a /claim request.
320 if (response_code == 200) {
321 claimed_ = true;
322 GetVersion();
323 } else {
324 if (response_code != 403) {
325 LOG(WARNING) << "Unexpected HTTP " << response_code
326 << " from /claim.";
328 Reset();
330 } else if (version_.empty()) {
331 // Just completed a /version request.
332 if (response_code != 200) {
333 LOG(WARNING) << "Unexpected HTTP " << response_code
334 << " from /version.";
335 Reset();
336 return;
339 if (!source->GetResponseAsString(&version_)) {
340 LOG(WARNING) << "Failed to read body from /version.";
341 Reset();
342 return;
345 std::string local_version;
346 if (!ReadLocalVersion(&local_version)) {
347 Reset();
348 return;
351 if (version_ == local_version) {
352 run_loop_.Quit();
353 } else {
354 if (!Update(local_version)) {
355 Reset();
358 } else {
359 // Just completed an /update request.
360 if (response_code != 200) {
361 LOG(WARNING) << "Unexpected HTTP " << response_code << " from /update.";
362 Reset();
363 return;
366 // Must wait for the device to reconnect.
370 void Reset() {
371 device_ = nullptr;
372 serial_number_.clear();
373 claimed_ = false;
374 version_.clear();
376 // Wait a bit and then try again to find an available device.
377 base::MessageLoop::current()->PostDelayedTask(
378 FROM_HERE, base::Bind(&UsbGadgetFactory::EnumerateDevices,
379 weak_factory_.GetWeakPtr()),
380 base::TimeDelta::FromMilliseconds(kReenumeratePeriod));
383 UsbService* usb_service_ = nullptr;
384 scoped_refptr<net::URLRequestContextGetter> request_context_getter_;
385 std::string session_id_;
386 scoped_ptr<net::URLFetcher> url_fetcher_;
387 scoped_refptr<UsbDevice> device_;
388 std::string serial_number_;
389 bool claimed_ = false;
390 std::string version_;
391 base::RunLoop run_loop_;
392 ScopedObserver<UsbService, UsbService::Observer> observer_;
393 base::WeakPtrFactory<UsbGadgetFactory> weak_factory_;
396 class DeviceAddListener : public UsbService::Observer {
397 public:
398 DeviceAddListener(UsbService* usb_service,
399 const std::string& serial_number,
400 int product_id)
401 : usb_service_(usb_service),
402 serial_number_(serial_number),
403 product_id_(product_id),
404 observer_(this),
405 weak_factory_(this) {
406 observer_.Add(usb_service_);
408 ~DeviceAddListener() override {}
410 scoped_refptr<UsbDevice> WaitForAdd() {
411 usb_service_->GetDevices(base::Bind(&DeviceAddListener::OnDevicesEnumerated,
412 weak_factory_.GetWeakPtr()));
413 run_loop_.Run();
414 return device_;
417 private:
418 void OnDevicesEnumerated(
419 const std::vector<scoped_refptr<UsbDevice>>& devices) {
420 for (const scoped_refptr<UsbDevice>& device : devices) {
421 OnDeviceAdded(device);
425 void OnDeviceAdded(scoped_refptr<UsbDevice> device) override {
426 if (device->vendor_id() == 0x18D1 && !device->serial_number().empty()) {
427 const uint16 product_id = device->product_id();
428 if (product_id_ == -1) {
429 bool found = false;
430 for (size_t i = 0; i < arraysize(kConfigurations); ++i) {
431 if (product_id == kConfigurations[i].product_id) {
432 found = true;
433 break;
436 if (!found) {
437 return;
439 } else {
440 if (product_id_ != product_id) {
441 return;
445 if (serial_number_ != base::UTF16ToUTF8(device->serial_number())) {
446 return;
449 device_ = device;
450 run_loop_.Quit();
454 UsbService* usb_service_;
455 const std::string serial_number_;
456 const int product_id_;
457 base::RunLoop run_loop_;
458 scoped_refptr<UsbDevice> device_;
459 ScopedObserver<UsbService, UsbService::Observer> observer_;
460 base::WeakPtrFactory<DeviceAddListener> weak_factory_;
462 DISALLOW_COPY_AND_ASSIGN(DeviceAddListener);
465 class DeviceRemoveListener : public UsbService::Observer {
466 public:
467 DeviceRemoveListener(UsbService* usb_service, scoped_refptr<UsbDevice> device)
468 : usb_service_(usb_service),
469 device_(device),
470 observer_(this),
471 weak_factory_(this) {
472 observer_.Add(usb_service_);
474 ~DeviceRemoveListener() override {}
476 void WaitForRemove() {
477 usb_service_->GetDevices(
478 base::Bind(&DeviceRemoveListener::OnDevicesEnumerated,
479 weak_factory_.GetWeakPtr()));
480 run_loop_.Run();
483 private:
484 void OnDevicesEnumerated(
485 const std::vector<scoped_refptr<UsbDevice>>& devices) {
486 bool found = false;
487 for (const scoped_refptr<UsbDevice>& device : devices) {
488 if (device_ == device) {
489 found = true;
492 if (!found) {
493 run_loop_.Quit();
497 void OnDeviceRemoved(scoped_refptr<UsbDevice> device) override {
498 if (device_ == device) {
499 run_loop_.Quit();
503 UsbService* usb_service_;
504 base::RunLoop run_loop_;
505 scoped_refptr<UsbDevice> device_;
506 ScopedObserver<UsbService, UsbService::Observer> observer_;
507 base::WeakPtrFactory<DeviceRemoveListener> weak_factory_;
509 DISALLOW_COPY_AND_ASSIGN(DeviceRemoveListener);
512 } // namespace
514 bool UsbTestGadget::IsTestEnabled() {
515 base::CommandLine* command_line = base::CommandLine::ForCurrentProcess();
516 return command_line->HasSwitch(kCommandLineSwitch);
519 scoped_ptr<UsbTestGadget> UsbTestGadget::Claim(
520 scoped_refptr<base::SingleThreadTaskRunner> io_task_runner) {
521 UsbGadgetFactory gadget_factory(io_task_runner);
522 return gadget_factory.WaitForDevice().Pass();
525 UsbTestGadgetImpl::UsbTestGadgetImpl(
526 scoped_refptr<net::URLRequestContextGetter> request_context_getter_,
527 UsbService* usb_service,
528 scoped_refptr<UsbDevice> device)
529 : device_address_(base::UTF16ToUTF8(device->serial_number())),
530 device_(device),
531 request_context_getter_(request_context_getter_),
532 usb_service_(usb_service) {
535 UsbTestGadgetImpl::~UsbTestGadgetImpl() {
536 if (!device_address_.empty()) {
537 Unclaim();
541 UsbDevice* UsbTestGadgetImpl::GetDevice() const {
542 return device_.get();
545 bool UsbTestGadgetImpl::Unclaim() {
546 VLOG(1) << "Releasing the device at " << device_address_ << ".";
548 GURL url("http://" + device_address_ + "/unclaim");
549 int response_code = SimplePOSTRequest(request_context_getter_, url, "");
551 if (response_code != 200) {
552 LOG(ERROR) << "Unexpected HTTP " << response_code << " from /unclaim.";
553 return false;
556 device_address_.clear();
557 return true;
560 bool UsbTestGadgetImpl::SetType(Type type) {
561 const struct UsbTestGadgetConfiguration* config = NULL;
562 for (size_t i = 0; i < arraysize(kConfigurations); ++i) {
563 if (kConfigurations[i].type == type) {
564 config = &kConfigurations[i];
567 CHECK(config);
569 GURL url("http://" + device_address_ + config->http_resource);
570 int response_code = SimplePOSTRequest(request_context_getter_, url, "");
572 if (response_code != 200) {
573 LOG(ERROR) << "Unexpected HTTP " << response_code
574 << " from " << config->http_resource << ".";
575 return false;
578 // Release the old reference to the device and try to open a new one.
579 DeviceAddListener add_listener(usb_service_, device_address_,
580 config->product_id);
581 device_ = add_listener.WaitForAdd();
582 DCHECK(device_.get());
583 return true;
586 bool UsbTestGadgetImpl::Disconnect() {
587 GURL url("http://" + device_address_ + "/disconnect");
588 int response_code = SimplePOSTRequest(request_context_getter_, url, "");
590 if (response_code != 200) {
591 LOG(ERROR) << "Unexpected HTTP " << response_code << " from " << url << ".";
592 return false;
595 // Release the old reference to the device and wait until it can't be found.
596 DeviceRemoveListener remove_listener(usb_service_, device_);
597 remove_listener.WaitForRemove();
598 device_ = nullptr;
599 return true;
602 bool UsbTestGadgetImpl::Reconnect() {
603 GURL url("http://" + device_address_ + "/reconnect");
604 int response_code = SimplePOSTRequest(request_context_getter_, url, "");
606 if (response_code != 200) {
607 LOG(ERROR) << "Unexpected HTTP " << response_code << " from " << url << ".";
608 return false;
611 DeviceAddListener add_listener(usb_service_, device_address_, -1);
612 device_ = add_listener.WaitForAdd();
613 DCHECK(device_.get());
614 return true;
617 } // namespace device