1 # -*- coding: utf-8 -*-
2 # Copyright 2015 Google Inc. All Rights Reserved.
4 # Licensed under the Apache License, Version 2.0 (the "License");
5 # you may not use this file except in compliance with the License.
6 # You may obtain a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15 """Unit tests for daisy chain wrapper class."""
17 from __future__
import absolute_import
22 import gslib
.cloud_api
23 from gslib
.daisy_chain_wrapper
import DaisyChainWrapper
24 from gslib
.storage_url
import StorageUrlFromString
25 import gslib
.tests
.testcase
as testcase
26 from gslib
.util
import TRANSFER_BUFFER_SIZE
29 _TEST_FILE
= 'test.txt'
32 class TestDaisyChainWrapper(testcase
.GsUtilUnitTestCase
):
33 """Unit tests for the DaisyChainWrapper class."""
35 _temp_test_file
= None
36 _dummy_url
= StorageUrlFromString('gs://bucket/object')
39 super(TestDaisyChainWrapper
, self
).setUp()
40 self
.test_data_file
= self
._GetTestFile
()
41 self
.test_data_file_len
= os
.path
.getsize(self
.test_data_file
)
43 def _GetTestFile(self
):
44 contents
= pkgutil
.get_data('gslib', 'tests/test_data/%s' % _TEST_FILE
)
45 if not self
._temp
_test
_file
:
46 # Write to a temp file because pkgutil doesn't expose a stream interface.
47 self
._temp
_test
_file
= self
.CreateTempFile(
48 file_name
=_TEST_FILE
, contents
=contents
)
49 return self
._temp
_test
_file
51 class MockDownloadCloudApi(gslib
.cloud_api
.CloudApi
):
52 """Mock CloudApi that implements GetObjectMedia for testing."""
54 def __init__(self
, write_values
):
55 """Initialize the mock that will be used by the download thread.
58 write_values: List of values that will be used for calls to write(),
59 in order, by the download thread. An Exception class may be part of
60 the list; if so, the Exception will be raised after previous
63 self
._write
_values
= write_values
66 def GetObjectMedia(self
, unused_bucket_name
, unused_object_name
,
67 download_stream
, start_byte
=0, end_byte
=None,
69 """Writes self._write_values to the download_stream."""
70 # Writes from start_byte up to, but not including end_byte (if not None).
71 # Does not slice values;
72 # self._write_values must line up with start/end_byte.
75 for write_value
in self
._write
_values
:
76 if bytes_read
< start_byte
:
77 bytes_read
+= len(write_value
)
79 if end_byte
and bytes_read
>= end_byte
:
81 if isinstance(write_value
, Exception):
83 download_stream
.write(write_value
)
84 bytes_read
+= len(write_value
)
86 def _WriteFromWrapperToFile(self
, daisy_chain_wrapper
, file_path
):
87 """Writes all contents from the DaisyChainWrapper to the named file."""
88 with
open(file_path
, 'wb') as upload_stream
:
90 data
= daisy_chain_wrapper
.read(TRANSFER_BUFFER_SIZE
)
93 upload_stream
.write(data
)
95 def testDownloadSingleChunk(self
):
96 """Tests a single call to GetObjectMedia."""
98 with
open(self
.test_data_file
, 'rb') as stream
:
100 data
= stream
.read(TRANSFER_BUFFER_SIZE
)
103 write_values
.append(data
)
104 upload_file
= self
.CreateTempFile()
105 # Test for a single call even if the chunk size is larger than the data.
106 for chunk_size
in (self
.test_data_file_len
, self
.test_data_file_len
+ 1):
107 mock_api
= self
.MockDownloadCloudApi(write_values
)
108 daisy_chain_wrapper
= DaisyChainWrapper(
109 self
._dummy
_url
, self
.test_data_file_len
, mock_api
,
110 download_chunk_size
=chunk_size
)
111 self
._WriteFromWrapperToFile
(daisy_chain_wrapper
, upload_file
)
112 # Since the chunk size is >= the file size, only a single GetObjectMedia
113 # call should be made.
114 self
.assertEquals(mock_api
.get_calls
, 1)
115 with
open(upload_file
, 'rb') as upload_stream
:
116 with
open(self
.test_data_file
, 'rb') as download_stream
:
117 self
.assertEqual(upload_stream
.read(), download_stream
.read())
119 def testDownloadMultiChunk(self
):
120 """Tests multiple calls to GetObjectMedia."""
121 upload_file
= self
.CreateTempFile()
123 with
open(self
.test_data_file
, 'rb') as stream
:
125 data
= stream
.read(TRANSFER_BUFFER_SIZE
)
128 write_values
.append(data
)
129 mock_api
= self
.MockDownloadCloudApi(write_values
)
130 daisy_chain_wrapper
= DaisyChainWrapper(
131 self
._dummy
_url
, self
.test_data_file_len
, mock_api
,
132 download_chunk_size
=TRANSFER_BUFFER_SIZE
)
133 self
._WriteFromWrapperToFile
(daisy_chain_wrapper
, upload_file
)
134 num_expected_calls
= self
.test_data_file_len
/ TRANSFER_BUFFER_SIZE
135 if self
.test_data_file_len
% TRANSFER_BUFFER_SIZE
:
136 num_expected_calls
+= 1
137 # Since the chunk size is < the file size, multiple calls to GetObjectMedia
139 self
.assertEqual(mock_api
.get_calls
, num_expected_calls
)
140 with
open(upload_file
, 'rb') as upload_stream
:
141 with
open(self
.test_data_file
, 'rb') as download_stream
:
142 self
.assertEqual(upload_stream
.read(), download_stream
.read())
144 def testDownloadWithZeroWrites(self
):
145 """Tests 0-byte writes to the download stream from GetObjectMedia."""
147 with
open(self
.test_data_file
, 'rb') as stream
:
149 write_values
.append(b
'')
150 data
= stream
.read(TRANSFER_BUFFER_SIZE
)
151 write_values
.append(b
'')
154 write_values
.append(data
)
155 upload_file
= self
.CreateTempFile()
156 mock_api
= self
.MockDownloadCloudApi(write_values
)
157 daisy_chain_wrapper
= DaisyChainWrapper(
158 self
._dummy
_url
, self
.test_data_file_len
, mock_api
,
159 download_chunk_size
=self
.test_data_file_len
)
160 self
._WriteFromWrapperToFile
(daisy_chain_wrapper
, upload_file
)
161 self
.assertEquals(mock_api
.get_calls
, 1)
162 with
open(upload_file
, 'rb') as upload_stream
:
163 with
open(self
.test_data_file
, 'rb') as download_stream
:
164 self
.assertEqual(upload_stream
.read(), download_stream
.read())
166 def testDownloadWithPartialWrite(self
):
167 """Tests unaligned writes to the download stream from GetObjectMedia."""
168 with
open(self
.test_data_file
, 'rb') as stream
:
169 chunk
= stream
.read(TRANSFER_BUFFER_SIZE
)
171 chunk_minus_one_byte
= chunk
[1:TRANSFER_BUFFER_SIZE
]
172 half_chunk
= chunk
[0:TRANSFER_BUFFER_SIZE
/2]
174 write_values_dict
= {
175 'First byte first chunk unaligned':
176 (one_byte
, chunk_minus_one_byte
, chunk
, chunk
),
177 'Last byte first chunk unaligned':
178 (chunk_minus_one_byte
, chunk
, chunk
),
179 'First byte second chunk unaligned':
180 (chunk
, one_byte
, chunk_minus_one_byte
, chunk
),
181 'Last byte second chunk unaligned':
182 (chunk
, chunk_minus_one_byte
, one_byte
, chunk
),
183 'First byte final chunk unaligned':
184 (chunk
, chunk
, one_byte
, chunk_minus_one_byte
),
185 'Last byte final chunk unaligned':
186 (chunk
, chunk
, chunk_minus_one_byte
, one_byte
),
188 (half_chunk
, half_chunk
, half_chunk
),
190 (one_byte
, half_chunk
, one_byte
, half_chunk
, chunk
,
191 chunk_minus_one_byte
, chunk
, one_byte
, half_chunk
, one_byte
)
193 upload_file
= self
.CreateTempFile()
194 for case_name
, write_values
in write_values_dict
.iteritems():
195 expected_contents
= b
''
196 for write_value
in write_values
:
197 expected_contents
+= write_value
198 mock_api
= self
.MockDownloadCloudApi(write_values
)
199 daisy_chain_wrapper
= DaisyChainWrapper(
200 self
._dummy
_url
, len(expected_contents
), mock_api
,
201 download_chunk_size
=self
.test_data_file_len
)
202 self
._WriteFromWrapperToFile
(daisy_chain_wrapper
, upload_file
)
203 with
open(upload_file
, 'rb') as upload_stream
:
204 self
.assertEqual(upload_stream
.read(), expected_contents
,
205 'Uploaded file contents for case %s did not match'
208 def testSeekAndReturn(self
):
209 """Tests seeking to the end of the wrapper (simulates getting size)."""
211 with
open(self
.test_data_file
, 'rb') as stream
:
213 data
= stream
.read(TRANSFER_BUFFER_SIZE
)
216 write_values
.append(data
)
217 upload_file
= self
.CreateTempFile()
218 mock_api
= self
.MockDownloadCloudApi(write_values
)
219 daisy_chain_wrapper
= DaisyChainWrapper(
220 self
._dummy
_url
, self
.test_data_file_len
, mock_api
,
221 download_chunk_size
=self
.test_data_file_len
)
222 with
open(upload_file
, 'wb') as upload_stream
:
224 daisy_chain_wrapper
.seek(0, whence
=os
.SEEK_END
)
225 daisy_chain_wrapper
.seek(current_position
)
227 data
= daisy_chain_wrapper
.read(TRANSFER_BUFFER_SIZE
)
228 current_position
+= len(data
)
229 daisy_chain_wrapper
.seek(0, whence
=os
.SEEK_END
)
230 daisy_chain_wrapper
.seek(current_position
)
233 upload_stream
.write(data
)
234 self
.assertEquals(mock_api
.get_calls
, 1)
235 with
open(upload_file
, 'rb') as upload_stream
:
236 with
open(self
.test_data_file
, 'rb') as download_stream
:
237 self
.assertEqual(upload_stream
.read(), download_stream
.read())
239 def testRestartDownloadThread(self
):
240 """Tests seek to non-stored position; this restarts the download thread."""
242 with
open(self
.test_data_file
, 'rb') as stream
:
244 data
= stream
.read(TRANSFER_BUFFER_SIZE
)
247 write_values
.append(data
)
248 upload_file
= self
.CreateTempFile()
249 mock_api
= self
.MockDownloadCloudApi(write_values
)
250 daisy_chain_wrapper
= DaisyChainWrapper(
251 self
._dummy
_url
, self
.test_data_file_len
, mock_api
,
252 download_chunk_size
=self
.test_data_file_len
)
253 daisy_chain_wrapper
.read(TRANSFER_BUFFER_SIZE
)
254 daisy_chain_wrapper
.read(TRANSFER_BUFFER_SIZE
)
255 daisy_chain_wrapper
.seek(0)
256 self
._WriteFromWrapperToFile
(daisy_chain_wrapper
, upload_file
)
257 self
.assertEquals(mock_api
.get_calls
, 2)
258 with
open(upload_file
, 'rb') as upload_stream
:
259 with
open(self
.test_data_file
, 'rb') as download_stream
:
260 self
.assertEqual(upload_stream
.read(), download_stream
.read())
262 def testDownloadThreadException(self
):
263 """Tests that an exception is propagated via the upload thread."""
265 class DownloadException(Exception):
268 write_values
= [b
'a', b
'b',
269 DownloadException('Download thread forces failure')]
270 upload_file
= self
.CreateTempFile()
271 mock_api
= self
.MockDownloadCloudApi(write_values
)
272 daisy_chain_wrapper
= DaisyChainWrapper(
273 self
._dummy
_url
, self
.test_data_file_len
, mock_api
,
274 download_chunk_size
=self
.test_data_file_len
)
276 self
._WriteFromWrapperToFile
(daisy_chain_wrapper
, upload_file
)
277 self
.fail('Expected exception')
278 except DownloadException
, e
:
279 self
.assertIn('Download thread forces failure', str(e
))
281 def testInvalidSeek(self
):
282 """Tests that seeking fails for unsupported seek arguments."""
283 daisy_chain_wrapper
= DaisyChainWrapper(
284 self
._dummy
_url
, self
.test_data_file_len
, self
.MockDownloadCloudApi([]))
286 # SEEK_CUR is invalid.
287 daisy_chain_wrapper
.seek(0, whence
=os
.SEEK_CUR
)
288 self
.fail('Expected exception')
290 self
.assertIn('does not support seek mode', str(e
))
293 # Seeking from the end with an offset is invalid.
294 daisy_chain_wrapper
.seek(1, whence
=os
.SEEK_END
)
295 self
.fail('Expected exception')
297 self
.assertIn('Invalid seek during daisy chain', str(e
))