1 # Protocol Buffers - Google's data interchange format
2 # Copyright 2008 Google Inc.
3 # http://code.google.com/p/protobuf/
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
17 """Test for google.protobuf.internal.input_stream."""
19 __author__
= 'robinson@google.com (Will Robinson)'
22 from google
.protobuf
import message
23 from google
.protobuf
.internal
import wire_format
24 from google
.protobuf
.internal
import input_stream
27 class InputStreamTest(unittest
.TestCase
):
29 def testEndOfStream(self
):
30 stream
= input_stream
.InputStream('abcd')
31 self
.assertFalse(stream
.EndOfStream())
32 self
.assertEqual('abcd', stream
.ReadString(10))
33 self
.assertTrue(stream
.EndOfStream())
35 def testPosition(self
):
36 stream
= input_stream
.InputStream('abcd')
37 self
.assertEqual(0, stream
.Position())
38 self
.assertEqual(0, stream
.Position()) # No side-effects.
40 self
.assertEqual(1, stream
.Position())
42 self
.assertEqual(2, stream
.Position())
44 self
.assertEqual(4, stream
.Position()) # Can't go past end of stream.
46 def testGetSubBuffer(self
):
47 stream
= input_stream
.InputStream('abcd')
48 # Try leaving out the size.
49 self
.assertEqual('abcd', str(stream
.GetSubBuffer()))
51 # GetSubBuffer() always starts at current size.
52 self
.assertEqual('bcd', str(stream
.GetSubBuffer()))
54 self
.assertEqual('', str(stream
.GetSubBuffer(0)))
55 # Negative sizes should raise an error.
56 self
.assertRaises(message
.DecodeError
, stream
.GetSubBuffer
, -1)
57 # Positive sizes should work as expected.
58 self
.assertEqual('b', str(stream
.GetSubBuffer(1)))
59 self
.assertEqual('bc', str(stream
.GetSubBuffer(2)))
60 # Sizes longer than remaining bytes in the buffer should
61 # return the whole remaining buffer.
62 self
.assertEqual('bcd', str(stream
.GetSubBuffer(1000)))
64 def testSkipBytes(self
):
65 stream
= input_stream
.InputStream('')
66 # Skipping bytes when at the end of stream
67 # should have no effect.
71 self
.assertTrue(stream
.EndOfStream())
72 self
.assertEqual(0, stream
.Position())
74 # Try skipping within a stream.
75 stream
= input_stream
.InputStream('abcd')
76 self
.assertEqual(0, stream
.Position())
78 self
.assertEqual(1, stream
.Position())
79 stream
.SkipBytes(10) # Can't skip past the end.
80 self
.assertEqual(4, stream
.Position())
82 # Ensure that a negative skip raises an exception.
83 stream
= input_stream
.InputStream('abcd')
85 self
.assertRaises(message
.DecodeError
, stream
.SkipBytes
, -1)
87 def testReadString(self
):
89 # Also test going past the total stream length.
90 for i
in range(len(s
) + 10):
91 stream
= input_stream
.InputStream(s
)
92 self
.assertEqual(s
[:i
], stream
.ReadString(i
))
93 self
.assertEqual(min(i
, len(s
)), stream
.Position())
94 stream
= input_stream
.InputStream(s
)
95 self
.assertRaises(message
.DecodeError
, stream
.ReadString
, -1)
97 def EnsureFailureOnEmptyStream(self
, input_stream_method
):
98 """Helper for integer-parsing tests below.
99 Ensures that the given InputStream method raises a DecodeError
100 if called on a stream with no bytes remaining.
102 stream
= input_stream
.InputStream('')
103 self
.assertRaises(message
.DecodeError
, input_stream_method
, stream
)
105 def testReadLittleEndian32(self
):
106 self
.EnsureFailureOnEmptyStream(input_stream
.InputStream
.ReadLittleEndian32
)
109 s
+= '\x00\x00\x00\x00'
111 s
+= '\x01\x00\x00\x00'
112 # Read a bunch of different bytes.
113 s
+= '\x01\x02\x03\x04'
114 # Read max unsigned 32-bit int.
115 s
+= '\xff\xff\xff\xff'
116 # Try a read with fewer than 4 bytes left in the stream.
118 stream
= input_stream
.InputStream(s
)
119 self
.assertEqual(0, stream
.ReadLittleEndian32())
120 self
.assertEqual(4, stream
.Position())
121 self
.assertEqual(1, stream
.ReadLittleEndian32())
122 self
.assertEqual(8, stream
.Position())
123 self
.assertEqual(0x04030201, stream
.ReadLittleEndian32())
124 self
.assertEqual(12, stream
.Position())
125 self
.assertEqual(wire_format
.UINT32_MAX
, stream
.ReadLittleEndian32())
126 self
.assertEqual(16, stream
.Position())
127 self
.assertRaises(message
.DecodeError
, stream
.ReadLittleEndian32
)
129 def testReadLittleEndian64(self
):
130 self
.EnsureFailureOnEmptyStream(input_stream
.InputStream
.ReadLittleEndian64
)
133 s
+= '\x00\x00\x00\x00\x00\x00\x00\x00'
135 s
+= '\x01\x00\x00\x00\x00\x00\x00\x00'
136 # Read a bunch of different bytes.
137 s
+= '\x01\x02\x03\x04\x05\x06\x07\x08'
138 # Read max unsigned 64-bit int.
139 s
+= '\xff\xff\xff\xff\xff\xff\xff\xff'
140 # Try a read with fewer than 8 bytes left in the stream.
142 stream
= input_stream
.InputStream(s
)
143 self
.assertEqual(0, stream
.ReadLittleEndian64())
144 self
.assertEqual(8, stream
.Position())
145 self
.assertEqual(1, stream
.ReadLittleEndian64())
146 self
.assertEqual(16, stream
.Position())
147 self
.assertEqual(0x0807060504030201, stream
.ReadLittleEndian64())
148 self
.assertEqual(24, stream
.Position())
149 self
.assertEqual(wire_format
.UINT64_MAX
, stream
.ReadLittleEndian64())
150 self
.assertEqual(32, stream
.Position())
151 self
.assertRaises(message
.DecodeError
, stream
.ReadLittleEndian64
)
153 def ReadVarintSuccessTestHelper(self
, varints_and_ints
, read_method
):
154 """Helper for tests below that test successful reads of various varints.
157 varints_and_ints: Iterable of (str, integer) pairs, where the string
158 gives the wire encoding and the integer gives the value we expect
159 to be returned by the read_method upon encountering this string.
160 read_method: Unbound InputStream method that is capable of reading
161 the encoded strings provided in the first elements of varints_and_ints.
163 s
= ''.join(s
for s
, i
in varints_and_ints
)
164 stream
= input_stream
.InputStream(s
)
166 self
.assertEqual(expected_pos
, stream
.Position())
167 for s
, expected_int
in varints_and_ints
:
168 self
.assertEqual(expected_int
, read_method(stream
))
169 expected_pos
+= len(s
)
170 self
.assertEqual(expected_pos
, stream
.Position())
172 def testReadVarint32Success(self
):
178 ('\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01', -1),
179 ('\xff\xff\xff\xff\x07', wire_format
.INT32_MAX
),
180 ('\x80\x80\x80\x80\xf8\xff\xff\xff\xff\x01', wire_format
.INT32_MIN
),
182 self
.ReadVarintSuccessTestHelper(varints_and_ints
,
183 input_stream
.InputStream
.ReadVarint32
)
185 def testReadVarint32Failure(self
):
186 self
.EnsureFailureOnEmptyStream(input_stream
.InputStream
.ReadVarint32
)
188 # Try and fail to read INT32_MAX + 1.
189 s
= '\x80\x80\x80\x80\x08'
190 stream
= input_stream
.InputStream(s
)
191 self
.assertRaises(message
.DecodeError
, stream
.ReadVarint32
)
193 # Try and fail to read INT32_MIN - 1.
194 s
= '\xfe\xff\xff\xff\xf7\xff\xff\xff\xff\x01'
195 stream
= input_stream
.InputStream(s
)
196 self
.assertRaises(message
.DecodeError
, stream
.ReadVarint32
)
198 # Try and fail to read something that looks like
199 # a varint with more than 10 bytes.
200 s
= '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00'
201 stream
= input_stream
.InputStream(s
)
202 self
.assertRaises(message
.DecodeError
, stream
.ReadVarint32
)
204 def testReadVarUInt32Success(self
):
210 ('\xff\xff\xff\xff\x0f', wire_format
.UINT32_MAX
),
212 self
.ReadVarintSuccessTestHelper(varints_and_ints
,
213 input_stream
.InputStream
.ReadVarUInt32
)
215 def testReadVarUInt32Failure(self
):
216 self
.EnsureFailureOnEmptyStream(input_stream
.InputStream
.ReadVarUInt32
)
217 # Try and fail to read UINT32_MAX + 1
218 s
= '\x80\x80\x80\x80\x10'
219 stream
= input_stream
.InputStream(s
)
220 self
.assertRaises(message
.DecodeError
, stream
.ReadVarUInt32
)
222 # Try and fail to read something that looks like
223 # a varint with more than 10 bytes.
224 s
= '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00'
225 stream
= input_stream
.InputStream(s
)
226 self
.assertRaises(message
.DecodeError
, stream
.ReadVarUInt32
)
228 def testReadVarint64Success(self
):
232 ('\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01', -1),
235 ('\xff\xff\xff\xff\xff\xff\xff\xff\x7f', wire_format
.INT64_MAX
),
236 ('\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01', wire_format
.INT64_MIN
),
238 self
.ReadVarintSuccessTestHelper(varints_and_ints
,
239 input_stream
.InputStream
.ReadVarint64
)
241 def testReadVarint64Failure(self
):
242 self
.EnsureFailureOnEmptyStream(input_stream
.InputStream
.ReadVarint64
)
243 # Try and fail to read something with the mythical 64th bit set.
244 s
= '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02'
245 stream
= input_stream
.InputStream(s
)
246 self
.assertRaises(message
.DecodeError
, stream
.ReadVarint64
)
248 # Try and fail to read something that looks like
249 # a varint with more than 10 bytes.
250 s
= '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00'
251 stream
= input_stream
.InputStream(s
)
252 self
.assertRaises(message
.DecodeError
, stream
.ReadVarint64
)
254 def testReadVarUInt64Success(self
):
260 ('\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01', 1 << 63),
262 self
.ReadVarintSuccessTestHelper(varints_and_ints
,
263 input_stream
.InputStream
.ReadVarUInt64
)
265 def testReadVarUInt64Failure(self
):
266 self
.EnsureFailureOnEmptyStream(input_stream
.InputStream
.ReadVarUInt64
)
267 # Try and fail to read something with the mythical 64th bit set.
268 s
= '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02'
269 stream
= input_stream
.InputStream(s
)
270 self
.assertRaises(message
.DecodeError
, stream
.ReadVarUInt64
)
272 # Try and fail to read something that looks like
273 # a varint with more than 10 bytes.
274 s
= '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00'
275 stream
= input_stream
.InputStream(s
)
276 self
.assertRaises(message
.DecodeError
, stream
.ReadVarUInt64
)
278 if __name__
== '__main__':