more types restricted, 3.3.3
[QuestHelper.git] / S3.py
blob640821a8e691710cd3083abd4363524b982e7c96
1 #!/usr/bin/env python
3 # This software code is made available "AS IS" without warranties of any
4 # kind. You may copy, display, modify and redistribute the software
5 # code either by itself or as incorporated into your code; provided that
6 # you do not remove any proprietary notices. Your use of this software
7 # code is at your own risk and you waive any claim against Amazon
8 # Digital Services, Inc. or its affiliates with respect to your use of
9 # this software code. (c) 2006-2007 Amazon Digital Services, Inc. or its
10 # affiliates.
12 import base64
13 import hmac
14 import http.client
15 import re
16 import hashlib
17 import sys
18 import time
19 import urllib.request, urllib.parse, urllib.error
20 import urllib.parse
21 import xml.sax
23 DEFAULT_HOST = 's3.amazonaws.com'
24 PORTS_BY_SECURITY = { True: 443, False: 80 }
25 METADATA_PREFIX = 'x-amz-meta-'
26 AMAZON_HEADER_PREFIX = 'x-amz-'
28 # generates the aws canonical string for the given parameters
29 def canonical_string(method, bucket="", key="", query_args={}, headers={}, expires=None):
30 interesting_headers = {}
31 for header_key in headers:
32 lk = header_key.lower()
33 if lk in ['content-md5', 'content-type', 'date'] or lk.startswith(AMAZON_HEADER_PREFIX):
34 interesting_headers[lk] = headers[header_key].strip()
36 # these keys get empty strings if they don't exist
37 if 'content-type' not in interesting_headers:
38 interesting_headers['content-type'] = ''
39 if 'content-md5' not in interesting_headers:
40 interesting_headers['content-md5'] = ''
42 # just in case someone used this. it's not necessary in this lib.
43 if 'x-amz-date' in interesting_headers:
44 interesting_headers['date'] = ''
46 # if you're using expires for query string auth, then it trumps date
47 # (and x-amz-date)
48 if expires:
49 interesting_headers['date'] = str(expires)
51 sorted_header_keys = list(interesting_headers.keys())
52 sorted_header_keys.sort()
54 buf = "%s\n" % method
55 for header_key in sorted_header_keys:
56 if header_key.startswith(AMAZON_HEADER_PREFIX):
57 buf += "%s:%s\n" % (header_key, interesting_headers[header_key])
58 else:
59 buf += "%s\n" % interesting_headers[header_key]
61 # append the bucket if it exists
62 if bucket != "":
63 buf += "/%s" % bucket
65 # add the key. even if it doesn't exist, add the slash
66 buf += "/%s" % urllib.parse.quote_plus(key)
68 # handle special query string arguments
70 if "acl" in query_args:
71 buf += "?acl"
72 elif "torrent" in query_args:
73 buf += "?torrent"
74 elif "logging" in query_args:
75 buf += "?logging"
76 elif "location" in query_args:
77 buf += "?location"
79 return buf
81 # computes the base64'ed hmac-sha hash of the canonical string and the secret
82 # access key, optionally urlencoding the result
83 def encode(aws_secret_access_key, dat, urlencode=False):
84 b64_hmac = str(base64.encodestring(hmac.new(bytes(aws_secret_access_key, "ascii"), bytes(dat, "ascii"), hashlib.sha1).digest()).strip(), "ascii")
85 if urlencode:
86 return urllib.parse.quote_plus(b64_hmac)
87 else:
88 return b64_hmac
90 def merge_meta(headers, metadata):
91 final_headers = headers.copy()
92 for k in list(metadata.keys()):
93 final_headers[METADATA_PREFIX + k] = metadata[k]
95 return final_headers
97 # builds the query arg string
98 def query_args_hash_to_string(query_args):
99 query_string = ""
100 pairs = []
101 for k, v in list(query_args.items()):
102 piece = k
103 if v != None:
104 piece += "=%s" % urllib.parse.quote_plus(str(v))
105 pairs.append(piece)
107 return '&'.join(pairs)
110 class CallingFormat:
111 PATH = 1
112 SUBDOMAIN = 2
113 VANITY = 3
115 def build_url_base(protocol, server, port, bucket, calling_format):
116 url_base = '%s://' % protocol
118 if bucket == '':
119 url_base += server
120 elif calling_format == CallingFormat.SUBDOMAIN:
121 url_base += "%s.%s" % (bucket, server)
122 elif calling_format == CallingFormat.VANITY:
123 url_base += bucket
124 else:
125 url_base += server
127 url_base += ":%s" % port
129 if (bucket != '') and (calling_format == CallingFormat.PATH):
130 url_base += "/%s" % bucket
132 return url_base
134 build_url_base = staticmethod(build_url_base)
138 class Location:
139 DEFAULT = None
140 EU = 'EU'
144 class AWSAuthConnection:
145 def __init__(self, aws_access_key_id, aws_secret_access_key, is_secure=True,
146 server=DEFAULT_HOST, port=None, calling_format=CallingFormat.SUBDOMAIN):
148 if not port:
149 port = PORTS_BY_SECURITY[is_secure]
151 self.aws_access_key_id = aws_access_key_id
152 self.aws_secret_access_key = aws_secret_access_key
153 self.is_secure = is_secure
154 self.server = server
155 self.port = port
156 self.calling_format = calling_format
158 def create_bucket(self, bucket, headers={}):
159 return Response(self._make_request('PUT', bucket, '', {}, headers))
161 def create_located_bucket(self, bucket, location=Location.DEFAULT, headers={}):
162 if location == Location.DEFAULT:
163 body = ""
164 else:
165 body = "<CreateBucketConstraint><LocationConstraint>" + \
166 location + \
167 "</LocationConstraint></CreateBucketConstraint>"
168 return Response(self._make_request('PUT', bucket, '', {}, headers, body))
170 def check_bucket_exists(self, bucket):
171 return self._make_request('HEAD', bucket, '', {}, {})
173 def list_bucket(self, bucket, options={}, headers={}):
174 return ListBucketResponse(self._make_request('GET', bucket, '', options, headers))
176 def delete_bucket(self, bucket, headers={}):
177 return Response(self._make_request('DELETE', bucket, '', {}, headers))
179 def put(self, bucket, key, object, headers={}):
180 if not isinstance(object, S3Object):
181 object = S3Object(object)
183 return Response(
184 self._make_request(
185 'PUT',
186 bucket,
187 key,
189 headers,
190 object.data,
191 object.metadata))
193 def get(self, bucket, key, headers={}):
194 return GetResponse(
195 self._make_request('GET', bucket, key, {}, headers))
197 def delete(self, bucket, key, headers={}):
198 return Response(
199 self._make_request('DELETE', bucket, key, {}, headers))
201 def get_bucket_logging(self, bucket, headers={}):
202 return GetResponse(self._make_request('GET', bucket, '', { 'logging': None }, headers))
204 def put_bucket_logging(self, bucket, logging_xml_doc, headers={}):
205 return Response(self._make_request('PUT', bucket, '', { 'logging': None }, headers, logging_xml_doc))
207 def get_bucket_acl(self, bucket, headers={}):
208 return self.get_acl(bucket, '', headers)
210 def get_acl(self, bucket, key, headers={}):
211 return GetResponse(
212 self._make_request('GET', bucket, key, { 'acl': None }, headers))
214 def put_bucket_acl(self, bucket, acl_xml_document, headers={}):
215 return self.put_acl(bucket, '', acl_xml_document, headers)
217 def put_acl(self, bucket, key, acl_xml_document, headers={}):
218 return Response(
219 self._make_request(
220 'PUT',
221 bucket,
222 key,
223 { 'acl': None },
224 headers,
225 acl_xml_document))
227 def list_all_my_buckets(self, headers={}):
228 return ListAllMyBucketsResponse(self._make_request('GET', '', '', {}, headers))
230 def get_bucket_location(self, bucket):
231 return LocationResponse(self._make_request('GET', bucket, '', {'location' : None}))
233 # end public methods
235 def _make_request(self, method, bucket='', key='', query_args={}, headers={}, data='', metadata={}):
237 server = ''
238 if bucket == '':
239 server = self.server
240 elif self.calling_format == CallingFormat.SUBDOMAIN:
241 server = "%s.%s" % (bucket, self.server)
242 elif self.calling_format == CallingFormat.VANITY:
243 server = bucket
244 else:
245 server = self.server
247 path = ''
249 if (bucket != '') and (self.calling_format == CallingFormat.PATH):
250 path += "/%s" % bucket
252 # add the slash after the bucket regardless
253 # the key will be appended if it is non-empty
254 path += "/%s" % urllib.parse.quote_plus(key)
257 # build the path_argument string
258 # add the ? in all cases since
259 # signature and credentials follow path args
260 if len(query_args):
261 path += "?" + query_args_hash_to_string(query_args)
263 is_secure = self.is_secure
264 host = "%s:%d" % (server, self.port)
265 while True:
266 if (is_secure):
267 connection = http.client.HTTPSConnection(host)
268 else:
269 connection = http.client.HTTPConnection(host)
271 final_headers = merge_meta(headers, metadata);
272 # add auth header
273 self._add_aws_auth_header(final_headers, method, bucket, key, query_args)
275 connection.request(method, path, data, final_headers)
276 resp = connection.getresponse()
277 if resp.status < 300 or resp.status >= 400:
278 return resp
279 # handle redirect
280 location = resp.getheader('location')
281 if not location:
282 return resp
283 # (close connection)
284 resp.read()
285 scheme, host, path, params, query, fragment \
286 = urllib.parse.urlparse(location)
287 if scheme == "http": is_secure = True
288 elif scheme == "https": is_secure = False
289 else: raise invalidURL("Not http/https: " + location)
290 if query: path += "?" + query
291 # retry with redirect
293 def _add_aws_auth_header(self, headers, method, bucket, key, query_args):
294 if 'Date' not in headers:
295 headers['Date'] = time.strftime("%a, %d %b %Y %X GMT", time.gmtime())
297 c_string = canonical_string(method, bucket, key, query_args, headers)
298 headers['Authorization'] = \
299 "AWS %s:%s" % (self.aws_access_key_id, encode(self.aws_secret_access_key, c_string))
302 class QueryStringAuthGenerator:
303 # by default, expire in 1 minute
304 DEFAULT_EXPIRES_IN = 60
306 def __init__(self, aws_access_key_id, aws_secret_access_key, is_secure=True,
307 server=DEFAULT_HOST, port=None, calling_format=CallingFormat.SUBDOMAIN):
309 if not port:
310 port = PORTS_BY_SECURITY[is_secure]
312 self.aws_access_key_id = aws_access_key_id
313 self.aws_secret_access_key = aws_secret_access_key
314 if (is_secure):
315 self.protocol = 'https'
316 else:
317 self.protocol = 'http'
319 self.is_secure = is_secure
320 self.server = server
321 self.port = port
322 self.calling_format = calling_format
323 self.__expires_in = QueryStringAuthGenerator.DEFAULT_EXPIRES_IN
324 self.__expires = None
326 # for backwards compatibility with older versions
327 self.server_name = "%s:%s" % (self.server, self.port)
329 def set_expires_in(self, expires_in):
330 self.__expires_in = expires_in
331 self.__expires = None
333 def set_expires(self, expires):
334 self.__expires = expires
335 self.__expires_in = None
337 def create_bucket(self, bucket, headers={}):
338 return self.generate_url('PUT', bucket, '', {}, headers)
340 def list_bucket(self, bucket, options={}, headers={}):
341 return self.generate_url('GET', bucket, '', options, headers)
343 def delete_bucket(self, bucket, headers={}):
344 return self.generate_url('DELETE', bucket, '', {}, headers)
346 def put(self, bucket, key, object, headers={}):
347 if not isinstance(object, S3Object):
348 object = S3Object(object)
350 return self.generate_url(
351 'PUT',
352 bucket,
353 key,
355 merge_meta(headers, object.metadata))
357 def get(self, bucket, key, headers={}):
358 return self.generate_url('GET', bucket, key, {}, headers)
360 def delete(self, bucket, key, headers={}):
361 return self.generate_url('DELETE', bucket, key, {}, headers)
363 def get_bucket_logging(self, bucket, headers={}):
364 return self.generate_url('GET', bucket, '', { 'logging': None }, headers)
366 def put_bucket_logging(self, bucket, logging_xml_doc, headers={}):
367 return self.generate_url('PUT', bucket, '', { 'logging': None }, headers)
369 def get_bucket_acl(self, bucket, headers={}):
370 return self.get_acl(bucket, '', headers)
372 def get_acl(self, bucket, key='', headers={}):
373 return self.generate_url('GET', bucket, key, { 'acl': None }, headers)
375 def put_bucket_acl(self, bucket, acl_xml_document, headers={}):
376 return self.put_acl(bucket, '', acl_xml_document, headers)
378 # don't really care what the doc is here.
379 def put_acl(self, bucket, key, acl_xml_document, headers={}):
380 return self.generate_url('PUT', bucket, key, { 'acl': None }, headers)
382 def list_all_my_buckets(self, headers={}):
383 return self.generate_url('GET', '', '', {}, headers)
385 def make_bare_url(self, bucket, key=''):
386 full_url = self.generate_url(self, bucket, key)
387 return full_url[:full_url.index('?')]
389 def generate_url(self, method, bucket='', key='', query_args={}, headers={}):
390 expires = 0
391 if self.__expires_in != None:
392 expires = int(time.time() + self.__expires_in)
393 elif self.__expires != None:
394 expires = int(self.__expires)
395 else:
396 raise RuntimeError("Invalid expires state")
398 canonical_str = canonical_string(method, bucket, key, query_args, headers, expires)
399 encoded_canonical = encode(self.aws_secret_access_key, canonical_str)
401 url = CallingFormat.build_url_base(self.protocol, self.server, self.port, bucket, self.calling_format)
403 url += "/%s" % urllib.parse.quote_plus(key)
405 query_args['Signature'] = encoded_canonical
406 query_args['Expires'] = expires
407 query_args['AWSAccessKeyId'] = self.aws_access_key_id
409 url += "?%s" % query_args_hash_to_string(query_args)
411 return url
414 class S3Object:
415 def __init__(self, data, metadata={}):
416 self.data = data
417 self.metadata = metadata
419 class Owner:
420 def __init__(self, id='', display_name=''):
421 self.id = id
422 self.display_name = display_name
424 class ListEntry:
425 def __init__(self, key='', last_modified=None, etag='', size=0, storage_class='', owner=None):
426 self.key = key
427 self.last_modified = last_modified
428 self.etag = etag
429 self.size = size
430 self.storage_class = storage_class
431 self.owner = owner
433 class CommonPrefixEntry:
434 def __init(self, prefix=''):
435 self.prefix = prefix
437 class Bucket:
438 def __init__(self, name='', creation_date=''):
439 self.name = name
440 self.creation_date = creation_date
442 class Response:
443 def __init__(self, http_response):
444 self.http_response = http_response
445 # you have to do this read, even if you don't expect a body.
446 # otherwise, the next request fails.
447 self.body = http_response.read()
448 if http_response.status >= 300 and self.body:
449 self.message = self.body
450 else:
451 self.message = "%03d %s" % (http_response.status, http_response.reason)
455 class ListBucketResponse(Response):
456 def __init__(self, http_response):
457 Response.__init__(self, http_response)
458 if http_response.status < 300:
459 handler = ListBucketHandler()
460 xml.sax.parseString(self.body, handler)
461 self.entries = handler.entries
462 self.common_prefixes = handler.common_prefixes
463 self.name = handler.name
464 self.marker = handler.marker
465 self.prefix = handler.prefix
466 self.is_truncated = handler.is_truncated
467 self.delimiter = handler.delimiter
468 self.max_keys = handler.max_keys
469 self.next_marker = handler.next_marker
470 else:
471 self.entries = []
473 class ListAllMyBucketsResponse(Response):
474 def __init__(self, http_response):
475 Response.__init__(self, http_response)
476 if http_response.status < 300:
477 handler = ListAllMyBucketsHandler()
478 xml.sax.parseString(self.body, handler)
479 self.entries = handler.entries
480 else:
481 self.entries = []
483 class GetResponse(Response):
484 def __init__(self, http_response):
485 Response.__init__(self, http_response)
486 response_headers = http_response.msg # older pythons don't have getheaders
487 metadata = self.get_aws_metadata(response_headers)
488 self.object = S3Object(self.body, metadata)
490 def get_aws_metadata(self, headers):
491 metadata = {}
492 for hkey in list(headers.keys()):
493 if hkey.lower().startswith(METADATA_PREFIX):
494 metadata[hkey[len(METADATA_PREFIX):]] = headers[hkey]
495 del headers[hkey]
497 return metadata
499 class LocationResponse(Response):
500 def __init__(self, http_response):
501 Response.__init__(self, http_response)
502 if http_response.status < 300:
503 handler = LocationHandler()
504 xml.sax.parseString(self.body, handler)
505 self.location = handler.location
507 class ListBucketHandler(xml.sax.ContentHandler):
508 def __init__(self):
509 self.entries = []
510 self.curr_entry = None
511 self.curr_text = ''
512 self.common_prefixes = []
513 self.curr_common_prefix = None
514 self.name = ''
515 self.marker = ''
516 self.prefix = ''
517 self.is_truncated = False
518 self.delimiter = ''
519 self.max_keys = 0
520 self.next_marker = ''
521 self.is_echoed_prefix_set = False
523 def startElement(self, name, attrs):
524 if name == 'Contents':
525 self.curr_entry = ListEntry()
526 elif name == 'Owner':
527 self.curr_entry.owner = Owner()
528 elif name == 'CommonPrefixes':
529 self.curr_common_prefix = CommonPrefixEntry()
532 def endElement(self, name):
533 if name == 'Contents':
534 self.entries.append(self.curr_entry)
535 elif name == 'CommonPrefixes':
536 self.common_prefixes.append(self.curr_common_prefix)
537 elif name == 'Key':
538 self.curr_entry.key = self.curr_text
539 elif name == 'LastModified':
540 self.curr_entry.last_modified = self.curr_text
541 elif name == 'ETag':
542 self.curr_entry.etag = self.curr_text
543 elif name == 'Size':
544 self.curr_entry.size = int(self.curr_text)
545 elif name == 'ID':
546 self.curr_entry.owner.id = self.curr_text
547 elif name == 'DisplayName':
548 self.curr_entry.owner.display_name = self.curr_text
549 elif name == 'StorageClass':
550 self.curr_entry.storage_class = self.curr_text
551 elif name == 'Name':
552 self.name = self.curr_text
553 elif name == 'Prefix' and self.is_echoed_prefix_set:
554 self.curr_common_prefix.prefix = self.curr_text
555 elif name == 'Prefix':
556 self.prefix = self.curr_text
557 self.is_echoed_prefix_set = True
558 elif name == 'Marker':
559 self.marker = self.curr_text
560 elif name == 'IsTruncated':
561 self.is_truncated = self.curr_text == 'true'
562 elif name == 'Delimiter':
563 self.delimiter = self.curr_text
564 elif name == 'MaxKeys':
565 self.max_keys = int(self.curr_text)
566 elif name == 'NextMarker':
567 self.next_marker = self.curr_text
569 self.curr_text = ''
571 def characters(self, content):
572 self.curr_text += content
575 class ListAllMyBucketsHandler(xml.sax.ContentHandler):
576 def __init__(self):
577 self.entries = []
578 self.curr_entry = None
579 self.curr_text = ''
581 def startElement(self, name, attrs):
582 if name == 'Bucket':
583 self.curr_entry = Bucket()
585 def endElement(self, name):
586 if name == 'Name':
587 self.curr_entry.name = self.curr_text
588 elif name == 'CreationDate':
589 self.curr_entry.creation_date = self.curr_text
590 elif name == 'Bucket':
591 self.entries.append(self.curr_entry)
593 def characters(self, content):
594 self.curr_text = content
597 class LocationHandler(xml.sax.ContentHandler):
598 def __init__(self):
599 self.location = None
600 self.state = 'init'
602 def startElement(self, name, attrs):
603 if self.state == 'init':
604 if name == 'LocationConstraint':
605 self.state = 'tag_location'
606 self.location = ''
607 else: self.state = 'bad'
608 else: self.state = 'bad'
610 def endElement(self, name):
611 if self.state == 'tag_location' and name == 'LocationConstraint':
612 self.state = 'done'
613 else: self.state = 'bad'
615 def characters(self, content):
616 if self.state == 'tag_location':
617 self.location += content