Supervised user import: Listen for profile creation/deletion
[chromium-blink-merge.git] / device / test / usb_test_gadget_impl.cc
blob2eea05808acc1bc825d332df275b47e286469707
1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
5 #include "device/test/usb_test_gadget.h"
7 #include <string>
8 #include <vector>
10 #include "base/command_line.h"
11 #include "base/compiler_specific.h"
12 #include "base/files/file.h"
13 #include "base/files/file_path.h"
14 #include "base/logging.h"
15 #include "base/macros.h"
16 #include "base/memory/ref_counted.h"
17 #include "base/memory/scoped_ptr.h"
18 #include "base/path_service.h"
19 #include "base/process/process_handle.h"
20 #include "base/run_loop.h"
21 #include "base/scoped_observer.h"
22 #include "base/strings/stringprintf.h"
23 #include "base/strings/utf_string_conversions.h"
24 #include "base/thread_task_runner_handle.h"
25 #include "base/time/time.h"
26 #include "device/usb/usb_device.h"
27 #include "device/usb/usb_device_handle.h"
28 #include "device/usb/usb_service.h"
29 #include "net/proxy/proxy_service.h"
30 #include "net/url_request/url_fetcher.h"
31 #include "net/url_request/url_fetcher_delegate.h"
32 #include "net/url_request/url_request_context.h"
33 #include "net/url_request/url_request_context_builder.h"
34 #include "net/url_request/url_request_context_getter.h"
35 #include "url/gurl.h"
37 namespace device {
39 class UsbTestGadgetImpl : public UsbTestGadget {
40 public:
41 UsbTestGadgetImpl(
42 scoped_refptr<net::URLRequestContextGetter> request_context_getter,
43 UsbService* usb_service,
44 scoped_refptr<UsbDevice> device);
45 ~UsbTestGadgetImpl() override;
47 bool Unclaim() override;
48 bool Disconnect() override;
49 bool Reconnect() override;
50 bool SetType(Type type) override;
51 UsbDevice* GetDevice() const override;
53 private:
54 std::string device_address_;
55 scoped_refptr<UsbDevice> device_;
56 scoped_refptr<net::URLRequestContextGetter> request_context_getter_;
57 UsbService* usb_service_;
59 DISALLOW_COPY_AND_ASSIGN(UsbTestGadgetImpl);
62 namespace {
64 static const char kCommandLineSwitch[] = "enable-gadget-tests";
65 static const int kReenumeratePeriod = 100; // 0.1 seconds
67 struct UsbTestGadgetConfiguration {
68 UsbTestGadget::Type type;
69 const char* http_resource;
70 uint16 product_id;
73 static const struct UsbTestGadgetConfiguration kConfigurations[] = {
74 {UsbTestGadget::DEFAULT, "/unconfigure", 0x58F0},
75 {UsbTestGadget::KEYBOARD, "/keyboard/configure", 0x58F1},
76 {UsbTestGadget::MOUSE, "/mouse/configure", 0x58F2},
77 {UsbTestGadget::HID_ECHO, "/hid_echo/configure", 0x58F3},
78 {UsbTestGadget::ECHO, "/echo/configure", 0x58F4},
81 bool ReadFile(const base::FilePath& file_path, std::string* content) {
82 base::File file(file_path, base::File::FLAG_OPEN | base::File::FLAG_READ);
83 if (!file.IsValid()) {
84 LOG(ERROR) << "Cannot open " << file_path.MaybeAsASCII() << ": "
85 << base::File::ErrorToString(file.error_details());
86 return false;
89 STLClearObject(content);
90 int rv;
91 do {
92 char buf[4096];
93 rv = file.ReadAtCurrentPos(buf, sizeof buf);
94 if (rv == -1) {
95 LOG(ERROR) << "Cannot read " << file_path.MaybeAsASCII() << ": "
96 << base::File::ErrorToString(file.error_details());
97 return false;
99 content->append(buf, rv);
100 } while (rv > 0);
102 return true;
105 bool ReadLocalVersion(std::string* version) {
106 base::FilePath file_path;
107 CHECK(PathService::Get(base::DIR_EXE, &file_path));
108 file_path = file_path.AppendASCII("usb_gadget.zip.md5");
110 return ReadFile(file_path, version);
113 bool ReadLocalPackage(std::string* package) {
114 base::FilePath file_path;
115 CHECK(PathService::Get(base::DIR_EXE, &file_path));
116 file_path = file_path.AppendASCII("usb_gadget.zip");
118 return ReadFile(file_path, package);
121 scoped_ptr<net::URLFetcher> CreateURLFetcher(
122 scoped_refptr<net::URLRequestContextGetter> request_context_getter,
123 const GURL& url,
124 net::URLFetcher::RequestType request_type,
125 net::URLFetcherDelegate* delegate) {
126 scoped_ptr<net::URLFetcher> url_fetcher(
127 net::URLFetcher::Create(url, request_type, delegate));
129 url_fetcher->SetRequestContext(request_context_getter.get());
131 return url_fetcher;
134 class URLRequestContextGetter : public net::URLRequestContextGetter {
135 public:
136 URLRequestContextGetter(
137 scoped_refptr<base::SingleThreadTaskRunner> network_task_runner)
138 : network_task_runner_(network_task_runner) {}
140 private:
141 ~URLRequestContextGetter() override {}
143 // net::URLRequestContextGetter implementation
144 net::URLRequestContext* GetURLRequestContext() override {
145 context_builder_.set_proxy_service(net::ProxyService::CreateDirect());
146 return context_builder_.Build();
149 scoped_refptr<base::SingleThreadTaskRunner> GetNetworkTaskRunner()
150 const override {
151 return network_task_runner_;
154 net::URLRequestContextBuilder context_builder_;
155 scoped_refptr<base::SingleThreadTaskRunner> network_task_runner_;
158 class URLFetcherDelegate : public net::URLFetcherDelegate {
159 public:
160 URLFetcherDelegate() {}
161 ~URLFetcherDelegate() override {}
163 void WaitForCompletion() { run_loop_.Run(); }
165 void OnURLFetchComplete(const net::URLFetcher* source) override {
166 run_loop_.Quit();
169 private:
170 base::RunLoop run_loop_;
172 DISALLOW_COPY_AND_ASSIGN(URLFetcherDelegate);
175 int SimplePOSTRequest(
176 scoped_refptr<net::URLRequestContextGetter> request_context_getter,
177 const GURL& url,
178 const std::string& form_data) {
179 URLFetcherDelegate delegate;
180 scoped_ptr<net::URLFetcher> url_fetcher = CreateURLFetcher(
181 request_context_getter, url, net::URLFetcher::POST, &delegate);
183 url_fetcher->SetUploadData("application/x-www-form-urlencoded", form_data);
184 url_fetcher->Start();
185 delegate.WaitForCompletion();
187 return url_fetcher->GetResponseCode();
190 class UsbGadgetFactory : public UsbService::Observer,
191 public net::URLFetcherDelegate {
192 public:
193 UsbGadgetFactory(scoped_refptr<base::SingleThreadTaskRunner> io_task_runner)
194 : observer_(this), weak_factory_(this) {
195 usb_service_ = UsbService::GetInstance(io_task_runner);
196 request_context_getter_ = new URLRequestContextGetter(io_task_runner);
198 static uint32 next_session_id;
199 base::ProcessId process_id = base::GetCurrentProcId();
200 session_id_ = base::StringPrintf("%d-%d", process_id, next_session_id++);
202 observer_.Add(usb_service_);
205 ~UsbGadgetFactory() override {}
207 scoped_ptr<UsbTestGadget> WaitForDevice() {
208 EnumerateDevices();
209 run_loop_.Run();
210 return make_scoped_ptr(
211 new UsbTestGadgetImpl(request_context_getter_, usb_service_, device_));
214 private:
215 void EnumerateDevices() {
216 if (!device_) {
217 usb_service_->GetDevices(base::Bind(
218 &UsbGadgetFactory::OnDevicesEnumerated, weak_factory_.GetWeakPtr()));
222 void OnDevicesEnumerated(
223 const std::vector<scoped_refptr<UsbDevice>>& devices) {
224 for (const scoped_refptr<UsbDevice>& device : devices) {
225 OnDeviceAdded(device);
228 if (!device_) {
229 // TODO(reillyg): This timer could be replaced by a way to use long-
230 // polling to wait for claimed devices to become unclaimed.
231 base::MessageLoop::current()->PostDelayedTask(
232 FROM_HERE, base::Bind(&UsbGadgetFactory::EnumerateDevices,
233 weak_factory_.GetWeakPtr()),
234 base::TimeDelta::FromMilliseconds(kReenumeratePeriod));
238 void OnDeviceAdded(scoped_refptr<UsbDevice> device) override {
239 if (device_.get()) {
240 // Already trying to claim a device.
241 return;
244 if (device->vendor_id() != 0x18D1 || device->product_id() != 0x58F0 ||
245 device->serial_number().empty()) {
246 return;
249 std::string serial_number = base::UTF16ToUTF8(device->serial_number());
250 if (serial_number == serial_number_) {
251 // We were waiting for the device to reappear after upgrade.
252 device_ = device;
253 run_loop_.Quit();
254 return;
257 device_ = device;
258 serial_number_ = serial_number;
259 Claim();
262 void Claim() {
263 VLOG(1) << "Trying to claim " << serial_number_ << ".";
265 GURL url("http://" + serial_number_ + "/claim");
266 std::string form_data = base::StringPrintf(
267 "session_id=%s", net::EscapeUrlEncodedData(session_id_, true).c_str());
268 url_fetcher_ = CreateURLFetcher(request_context_getter_, url,
269 net::URLFetcher::POST, this);
270 url_fetcher_->SetUploadData("application/x-www-form-urlencoded", form_data);
271 url_fetcher_->Start();
274 void GetVersion() {
275 GURL url("http://" + serial_number_ + "/version");
276 url_fetcher_ = CreateURLFetcher(request_context_getter_, url,
277 net::URLFetcher::GET, this);
278 url_fetcher_->Start();
281 bool Update(const std::string& version) {
282 LOG(INFO) << "Updating " << serial_number_ << " to " << version << "...";
284 GURL url("http://" + serial_number_ + "/update");
285 url_fetcher_ = CreateURLFetcher(request_context_getter_, url,
286 net::URLFetcher::POST, this);
287 std::string mime_header = base::StringPrintf(
288 "--foo\r\n"
289 "Content-Disposition: form-data; name=\"file\"; "
290 "filename=\"usb_gadget-%s.zip\"\r\n"
291 "Content-Type: application/octet-stream\r\n"
292 "\r\n",
293 version.c_str());
294 std::string mime_footer("\r\n--foo--\r\n");
296 std::string package;
297 if (!ReadLocalPackage(&package)) {
298 return false;
301 url_fetcher_->SetUploadData("multipart/form-data; boundary=foo",
302 mime_header + package + mime_footer);
303 url_fetcher_->Start();
304 device_ = nullptr;
305 return true;
308 void OnURLFetchComplete(const net::URLFetcher* source) override {
309 DCHECK(!serial_number_.empty());
311 int response_code = source->GetResponseCode();
312 if (!claimed_) {
313 // Just completed a /claim request.
314 if (response_code == 200) {
315 claimed_ = true;
316 GetVersion();
317 } else {
318 if (response_code != 403) {
319 LOG(WARNING) << "Unexpected HTTP " << response_code
320 << " from /claim.";
322 Reset();
324 } else if (version_.empty()) {
325 // Just completed a /version request.
326 if (response_code != 200) {
327 LOG(WARNING) << "Unexpected HTTP " << response_code
328 << " from /version.";
329 Reset();
330 return;
333 if (!source->GetResponseAsString(&version_)) {
334 LOG(WARNING) << "Failed to read body from /version.";
335 Reset();
336 return;
339 std::string local_version;
340 if (!ReadLocalVersion(&local_version)) {
341 Reset();
342 return;
345 if (version_ == local_version) {
346 run_loop_.Quit();
347 } else {
348 if (!Update(local_version)) {
349 Reset();
352 } else {
353 // Just completed an /update request.
354 if (response_code != 200) {
355 LOG(WARNING) << "Unexpected HTTP " << response_code << " from /update.";
356 Reset();
357 return;
360 // Must wait for the device to reconnect.
364 void Reset() {
365 device_ = nullptr;
366 serial_number_.clear();
367 claimed_ = false;
368 version_.clear();
370 // Wait a bit and then try again to find an available device.
371 base::MessageLoop::current()->PostDelayedTask(
372 FROM_HERE, base::Bind(&UsbGadgetFactory::EnumerateDevices,
373 weak_factory_.GetWeakPtr()),
374 base::TimeDelta::FromMilliseconds(kReenumeratePeriod));
377 UsbService* usb_service_ = nullptr;
378 scoped_refptr<net::URLRequestContextGetter> request_context_getter_;
379 std::string session_id_;
380 scoped_ptr<net::URLFetcher> url_fetcher_;
381 scoped_refptr<UsbDevice> device_;
382 std::string serial_number_;
383 bool claimed_ = false;
384 std::string version_;
385 base::RunLoop run_loop_;
386 ScopedObserver<UsbService, UsbService::Observer> observer_;
387 base::WeakPtrFactory<UsbGadgetFactory> weak_factory_;
390 class DeviceAddListener : public UsbService::Observer {
391 public:
392 DeviceAddListener(UsbService* usb_service,
393 const std::string& serial_number,
394 int product_id)
395 : usb_service_(usb_service),
396 serial_number_(serial_number),
397 product_id_(product_id),
398 observer_(this),
399 weak_factory_(this) {
400 observer_.Add(usb_service_);
402 virtual ~DeviceAddListener() {}
404 scoped_refptr<UsbDevice> WaitForAdd() {
405 usb_service_->GetDevices(base::Bind(&DeviceAddListener::OnDevicesEnumerated,
406 weak_factory_.GetWeakPtr()));
407 run_loop_.Run();
408 return device_;
411 private:
412 void OnDevicesEnumerated(
413 const std::vector<scoped_refptr<UsbDevice>>& devices) {
414 for (const scoped_refptr<UsbDevice>& device : devices) {
415 OnDeviceAdded(device);
419 void OnDeviceAdded(scoped_refptr<UsbDevice> device) override {
420 if (device->vendor_id() == 0x18D1 && !device->serial_number().empty()) {
421 const uint16 product_id = device->product_id();
422 if (product_id_ == -1) {
423 bool found = false;
424 for (size_t i = 0; i < arraysize(kConfigurations); ++i) {
425 if (product_id == kConfigurations[i].product_id) {
426 found = true;
427 break;
430 if (!found) {
431 return;
433 } else {
434 if (product_id_ != product_id) {
435 return;
439 if (serial_number_ != base::UTF16ToUTF8(device->serial_number())) {
440 return;
443 device_ = device;
444 run_loop_.Quit();
448 UsbService* usb_service_;
449 const std::string serial_number_;
450 const int product_id_;
451 base::RunLoop run_loop_;
452 scoped_refptr<UsbDevice> device_;
453 ScopedObserver<UsbService, UsbService::Observer> observer_;
454 base::WeakPtrFactory<DeviceAddListener> weak_factory_;
456 DISALLOW_COPY_AND_ASSIGN(DeviceAddListener);
459 class DeviceRemoveListener : public UsbService::Observer {
460 public:
461 DeviceRemoveListener(UsbService* usb_service, scoped_refptr<UsbDevice> device)
462 : usb_service_(usb_service),
463 device_(device),
464 observer_(this),
465 weak_factory_(this) {
466 observer_.Add(usb_service_);
468 virtual ~DeviceRemoveListener() {}
470 void WaitForRemove() {
471 usb_service_->GetDevices(
472 base::Bind(&DeviceRemoveListener::OnDevicesEnumerated,
473 weak_factory_.GetWeakPtr()));
474 run_loop_.Run();
477 private:
478 void OnDevicesEnumerated(
479 const std::vector<scoped_refptr<UsbDevice>>& devices) {
480 bool found = false;
481 for (const scoped_refptr<UsbDevice>& device : devices) {
482 if (device_ == device) {
483 found = true;
486 if (!found) {
487 run_loop_.Quit();
491 void OnDeviceRemoved(scoped_refptr<UsbDevice> device) override {
492 if (device_ == device) {
493 run_loop_.Quit();
497 UsbService* usb_service_;
498 base::RunLoop run_loop_;
499 scoped_refptr<UsbDevice> device_;
500 ScopedObserver<UsbService, UsbService::Observer> observer_;
501 base::WeakPtrFactory<DeviceRemoveListener> weak_factory_;
503 DISALLOW_COPY_AND_ASSIGN(DeviceRemoveListener);
506 } // namespace
508 bool UsbTestGadget::IsTestEnabled() {
509 base::CommandLine* command_line = base::CommandLine::ForCurrentProcess();
510 return command_line->HasSwitch(kCommandLineSwitch);
513 scoped_ptr<UsbTestGadget> UsbTestGadget::Claim(
514 scoped_refptr<base::SingleThreadTaskRunner> io_task_runner) {
515 UsbGadgetFactory gadget_factory(io_task_runner);
516 return gadget_factory.WaitForDevice().Pass();
519 UsbTestGadgetImpl::UsbTestGadgetImpl(
520 scoped_refptr<net::URLRequestContextGetter> request_context_getter_,
521 UsbService* usb_service,
522 scoped_refptr<UsbDevice> device)
523 : device_address_(base::UTF16ToUTF8(device->serial_number())),
524 device_(device),
525 request_context_getter_(request_context_getter_),
526 usb_service_(usb_service) {
529 UsbTestGadgetImpl::~UsbTestGadgetImpl() {
530 if (!device_address_.empty()) {
531 Unclaim();
535 UsbDevice* UsbTestGadgetImpl::GetDevice() const {
536 return device_.get();
539 bool UsbTestGadgetImpl::Unclaim() {
540 VLOG(1) << "Releasing the device at " << device_address_ << ".";
542 GURL url("http://" + device_address_ + "/unclaim");
543 int response_code = SimplePOSTRequest(request_context_getter_, url, "");
545 if (response_code != 200) {
546 LOG(ERROR) << "Unexpected HTTP " << response_code << " from /unclaim.";
547 return false;
550 device_address_.clear();
551 return true;
554 bool UsbTestGadgetImpl::SetType(Type type) {
555 const struct UsbTestGadgetConfiguration* config = NULL;
556 for (size_t i = 0; i < arraysize(kConfigurations); ++i) {
557 if (kConfigurations[i].type == type) {
558 config = &kConfigurations[i];
561 CHECK(config);
563 GURL url("http://" + device_address_ + config->http_resource);
564 int response_code = SimplePOSTRequest(request_context_getter_, url, "");
566 if (response_code != 200) {
567 LOG(ERROR) << "Unexpected HTTP " << response_code
568 << " from " << config->http_resource << ".";
569 return false;
572 // Release the old reference to the device and try to open a new one.
573 DeviceAddListener add_listener(usb_service_, device_address_,
574 config->product_id);
575 device_ = add_listener.WaitForAdd();
576 DCHECK(device_.get());
577 return true;
580 bool UsbTestGadgetImpl::Disconnect() {
581 GURL url("http://" + device_address_ + "/disconnect");
582 int response_code = SimplePOSTRequest(request_context_getter_, url, "");
584 if (response_code != 200) {
585 LOG(ERROR) << "Unexpected HTTP " << response_code << " from " << url << ".";
586 return false;
589 // Release the old reference to the device and wait until it can't be found.
590 DeviceRemoveListener remove_listener(usb_service_, device_);
591 remove_listener.WaitForRemove();
592 device_ = nullptr;
593 return true;
596 bool UsbTestGadgetImpl::Reconnect() {
597 GURL url("http://" + device_address_ + "/reconnect");
598 int response_code = SimplePOSTRequest(request_context_getter_, url, "");
600 if (response_code != 200) {
601 LOG(ERROR) << "Unexpected HTTP " << response_code << " from " << url << ".";
602 return false;
605 DeviceAddListener add_listener(usb_service_, device_address_, -1);
606 device_ = add_listener.WaitForAdd();
607 DCHECK(device_.get());
608 return true;
611 } // namespace device