Initial import of v2.0.0beta
[protobuf.git] / python / google / protobuf / internal / reflection_test.py
blob5947f97a88d0083c4a1cbd1557a25bd98d298007
1 # Protocol Buffers - Google's data interchange format
2 # Copyright 2008 Google Inc.
3 # http://code.google.com/p/protobuf/
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
17 """Unittest for reflection.py, which also indirectly tests the output of the
18 pure-Python protocol compiler.
19 """
21 __author__ = 'robinson@google.com (Will Robinson)'
23 import operator
25 import unittest
26 # TODO(robinson): When we split this test in two, only some of these imports
27 # will be necessary in each test.
28 from google.protobuf import unittest_import_pb2
29 from google.protobuf import unittest_mset_pb2
30 from google.protobuf import unittest_pb2
31 from google.protobuf import descriptor_pb2
32 from google.protobuf import descriptor
33 from google.protobuf import message
34 from google.protobuf import reflection
35 from google.protobuf.internal import more_extensions_pb2
36 from google.protobuf.internal import more_messages_pb2
37 from google.protobuf.internal import wire_format
38 from google.protobuf.internal import test_util
39 from google.protobuf.internal import decoder
42 class RefectionTest(unittest.TestCase):
44 def testSimpleHasBits(self):
45 # Test a scalar.
46 proto = unittest_pb2.TestAllTypes()
47 self.assertTrue(not proto.HasField('optional_int32'))
48 self.assertEqual(0, proto.optional_int32)
49 # HasField() shouldn't be true if all we've done is
50 # read the default value.
51 self.assertTrue(not proto.HasField('optional_int32'))
52 proto.optional_int32 = 1
53 # Setting a value however *should* set the "has" bit.
54 self.assertTrue(proto.HasField('optional_int32'))
55 proto.ClearField('optional_int32')
56 # And clearing that value should unset the "has" bit.
57 self.assertTrue(not proto.HasField('optional_int32'))
59 def testHasBitsWithSinglyNestedScalar(self):
60 # Helper used to test foreign messages and groups.
62 # composite_field_name should be the name of a non-repeated
63 # composite (i.e., foreign or group) field in TestAllTypes,
64 # and scalar_field_name should be the name of an integer-valued
65 # scalar field within that composite.
67 # I never thought I'd miss C++ macros and templates so much. :(
68 # This helper is semantically just:
70 # assert proto.composite_field.scalar_field == 0
71 # assert not proto.composite_field.HasField('scalar_field')
72 # assert not proto.HasField('composite_field')
74 # proto.composite_field.scalar_field = 10
75 # old_composite_field = proto.composite_field
77 # assert proto.composite_field.scalar_field == 10
78 # assert proto.composite_field.HasField('scalar_field')
79 # assert proto.HasField('composite_field')
81 # proto.ClearField('composite_field')
83 # assert not proto.composite_field.HasField('scalar_field')
84 # assert not proto.HasField('composite_field')
85 # assert proto.composite_field.scalar_field == 0
87 # # Now ensure that ClearField('composite_field') disconnected
88 # # the old field object from the object tree...
89 # assert old_composite_field is not proto.composite_field
90 # old_composite_field.scalar_field = 20
91 # assert not proto.composite_field.HasField('scalar_field')
92 # assert not proto.HasField('composite_field')
93 def TestCompositeHasBits(composite_field_name, scalar_field_name):
94 proto = unittest_pb2.TestAllTypes()
95 # First, check that we can get the scalar value, and see that it's the
96 # default (0), but that proto.HasField('omposite') and
97 # proto.composite.HasField('scalar') will still return False.
98 composite_field = getattr(proto, composite_field_name)
99 original_scalar_value = getattr(composite_field, scalar_field_name)
100 self.assertEqual(0, original_scalar_value)
101 # Assert that the composite object does not "have" the scalar.
102 self.assertTrue(not composite_field.HasField(scalar_field_name))
103 # Assert that proto does not "have" the composite field.
104 self.assertTrue(not proto.HasField(composite_field_name))
106 # Now set the scalar within the composite field. Ensure that the setting
107 # is reflected, and that proto.HasField('composite') and
108 # proto.composite.HasField('scalar') now both return True.
109 new_val = 20
110 setattr(composite_field, scalar_field_name, new_val)
111 self.assertEqual(new_val, getattr(composite_field, scalar_field_name))
112 # Hold on to a reference to the current composite_field object.
113 old_composite_field = composite_field
114 # Assert that the has methods now return true.
115 self.assertTrue(composite_field.HasField(scalar_field_name))
116 self.assertTrue(proto.HasField(composite_field_name))
118 # Now call the clear method...
119 proto.ClearField(composite_field_name)
121 # ...and ensure that the "has" bits are all back to False...
122 composite_field = getattr(proto, composite_field_name)
123 self.assertTrue(not composite_field.HasField(scalar_field_name))
124 self.assertTrue(not proto.HasField(composite_field_name))
125 # ...and ensure that the scalar field has returned to its default.
126 self.assertEqual(0, getattr(composite_field, scalar_field_name))
128 # Finally, ensure that modifications to the old composite field object
129 # don't have any effect on the parent.
131 # (NOTE that when we clear the composite field in the parent, we actually
132 # don't recursively clear down the tree. Instead, we just disconnect the
133 # cleared composite from the tree.)
134 self.assertTrue(old_composite_field is not composite_field)
135 setattr(old_composite_field, scalar_field_name, new_val)
136 self.assertTrue(not composite_field.HasField(scalar_field_name))
137 self.assertTrue(not proto.HasField(composite_field_name))
138 self.assertEqual(0, getattr(composite_field, scalar_field_name))
140 # Test simple, single-level nesting when we set a scalar.
141 TestCompositeHasBits('optionalgroup', 'a')
142 TestCompositeHasBits('optional_nested_message', 'bb')
143 TestCompositeHasBits('optional_foreign_message', 'c')
144 TestCompositeHasBits('optional_import_message', 'd')
146 def testReferencesToNestedMessage(self):
147 proto = unittest_pb2.TestAllTypes()
148 nested = proto.optional_nested_message
149 del proto
150 # A previous version had a bug where this would raise an exception when
151 # hitting a now-dead weak reference.
152 nested.bb = 23
154 def testDisconnectingNestedMessageBeforeSettingField(self):
155 proto = unittest_pb2.TestAllTypes()
156 nested = proto.optional_nested_message
157 proto.ClearField('optional_nested_message') # Should disconnect from parent
158 self.assertTrue(nested is not proto.optional_nested_message)
159 nested.bb = 23
160 self.assertTrue(not proto.HasField('optional_nested_message'))
161 self.assertEqual(0, proto.optional_nested_message.bb)
163 def testHasBitsWhenModifyingRepeatedFields(self):
164 # Test nesting when we add an element to a repeated field in a submessage.
165 proto = unittest_pb2.TestNestedMessageHasBits()
166 proto.optional_nested_message.nestedmessage_repeated_int32.append(5)
167 self.assertEqual(
168 [5], proto.optional_nested_message.nestedmessage_repeated_int32)
169 self.assertTrue(proto.HasField('optional_nested_message'))
171 # Do the same test, but with a repeated composite field within the
172 # submessage.
173 proto.ClearField('optional_nested_message')
174 self.assertTrue(not proto.HasField('optional_nested_message'))
175 proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add()
176 self.assertTrue(proto.HasField('optional_nested_message'))
178 def testHasBitsForManyLevelsOfNesting(self):
179 # Test nesting many levels deep.
180 recursive_proto = unittest_pb2.TestMutualRecursionA()
181 self.assertTrue(not recursive_proto.HasField('bb'))
182 self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32)
183 self.assertTrue(not recursive_proto.HasField('bb'))
184 recursive_proto.bb.a.bb.a.bb.optional_int32 = 5
185 self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32)
186 self.assertTrue(recursive_proto.HasField('bb'))
187 self.assertTrue(recursive_proto.bb.HasField('a'))
188 self.assertTrue(recursive_proto.bb.a.HasField('bb'))
189 self.assertTrue(recursive_proto.bb.a.bb.HasField('a'))
190 self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb'))
191 self.assertTrue(not recursive_proto.bb.a.bb.a.bb.HasField('a'))
192 self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32'))
194 def testSingularListFields(self):
195 proto = unittest_pb2.TestAllTypes()
196 proto.optional_fixed32 = 1
197 proto.optional_int32 = 5
198 proto.optional_string = 'foo'
199 self.assertEqual(
200 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5),
201 (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
202 (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ],
203 proto.ListFields())
205 def testRepeatedListFields(self):
206 proto = unittest_pb2.TestAllTypes()
207 proto.repeated_fixed32.append(1)
208 proto.repeated_int32.append(5)
209 proto.repeated_int32.append(11)
210 proto.repeated_string.append('foo')
211 proto.repeated_string.append('bar')
212 proto.repeated_string.append('baz')
213 proto.optional_int32 = 21
214 self.assertEqual(
215 [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21),
216 (proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]),
217 (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]),
218 (proto.DESCRIPTOR.fields_by_name['repeated_string' ],
219 ['foo', 'bar', 'baz']) ],
220 proto.ListFields())
222 def testSingularListExtensions(self):
223 proto = unittest_pb2.TestAllExtensions()
224 proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1
225 proto.Extensions[unittest_pb2.optional_int32_extension ] = 5
226 proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo'
227 self.assertEqual(
228 [ (unittest_pb2.optional_int32_extension , 5),
229 (unittest_pb2.optional_fixed32_extension, 1),
230 (unittest_pb2.optional_string_extension , 'foo') ],
231 proto.ListFields())
233 def testRepeatedListExtensions(self):
234 proto = unittest_pb2.TestAllExtensions()
235 proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1)
236 proto.Extensions[unittest_pb2.repeated_int32_extension ].append(5)
237 proto.Extensions[unittest_pb2.repeated_int32_extension ].append(11)
238 proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo')
239 proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar')
240 proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz')
241 proto.Extensions[unittest_pb2.optional_int32_extension ] = 21
242 self.assertEqual(
243 [ (unittest_pb2.optional_int32_extension , 21),
244 (unittest_pb2.repeated_int32_extension , [5, 11]),
245 (unittest_pb2.repeated_fixed32_extension, [1]),
246 (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ],
247 proto.ListFields())
249 def testListFieldsAndExtensions(self):
250 proto = unittest_pb2.TestFieldOrderings()
251 test_util.SetAllFieldsAndExtensions(proto)
252 unittest_pb2.my_extension_int
253 self.assertEqual(
254 [ (proto.DESCRIPTOR.fields_by_name['my_int' ], 1),
255 (unittest_pb2.my_extension_int , 23),
256 (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'),
257 (unittest_pb2.my_extension_string , 'bar'),
258 (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ],
259 proto.ListFields())
261 def testDefaultValues(self):
262 proto = unittest_pb2.TestAllTypes()
263 self.assertEqual(0, proto.optional_int32)
264 self.assertEqual(0, proto.optional_int64)
265 self.assertEqual(0, proto.optional_uint32)
266 self.assertEqual(0, proto.optional_uint64)
267 self.assertEqual(0, proto.optional_sint32)
268 self.assertEqual(0, proto.optional_sint64)
269 self.assertEqual(0, proto.optional_fixed32)
270 self.assertEqual(0, proto.optional_fixed64)
271 self.assertEqual(0, proto.optional_sfixed32)
272 self.assertEqual(0, proto.optional_sfixed64)
273 self.assertEqual(0.0, proto.optional_float)
274 self.assertEqual(0.0, proto.optional_double)
275 self.assertEqual(False, proto.optional_bool)
276 self.assertEqual('', proto.optional_string)
277 self.assertEqual('', proto.optional_bytes)
279 self.assertEqual(41, proto.default_int32)
280 self.assertEqual(42, proto.default_int64)
281 self.assertEqual(43, proto.default_uint32)
282 self.assertEqual(44, proto.default_uint64)
283 self.assertEqual(-45, proto.default_sint32)
284 self.assertEqual(46, proto.default_sint64)
285 self.assertEqual(47, proto.default_fixed32)
286 self.assertEqual(48, proto.default_fixed64)
287 self.assertEqual(49, proto.default_sfixed32)
288 self.assertEqual(-50, proto.default_sfixed64)
289 self.assertEqual(51.5, proto.default_float)
290 self.assertEqual(52e3, proto.default_double)
291 self.assertEqual(True, proto.default_bool)
292 self.assertEqual('hello', proto.default_string)
293 self.assertEqual('world', proto.default_bytes)
294 self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum)
295 self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum)
296 self.assertEqual(unittest_import_pb2.IMPORT_BAR,
297 proto.default_import_enum)
299 def testHasFieldWithUnknownFieldName(self):
300 proto = unittest_pb2.TestAllTypes()
301 self.assertRaises(ValueError, proto.HasField, 'nonexistent_field')
303 def testClearFieldWithUnknownFieldName(self):
304 proto = unittest_pb2.TestAllTypes()
305 self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field')
307 def testDisallowedAssignments(self):
308 # It's illegal to assign values directly to repeated fields
309 # or to nonrepeated composite fields. Ensure that this fails.
310 proto = unittest_pb2.TestAllTypes()
311 # Repeated fields.
312 self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10)
313 # Lists shouldn't work, either.
314 self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10])
315 # Composite fields.
316 self.assertRaises(AttributeError, setattr, proto,
317 'optional_nested_message', 23)
318 # proto.nonexistent_field = 23 should fail as well.
319 self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23)
321 # TODO(robinson): Add type-safety check for enums.
322 def testSingleScalarTypeSafety(self):
323 proto = unittest_pb2.TestAllTypes()
324 self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1)
325 self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo')
326 self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
327 self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
329 def testSingleScalarBoundsChecking(self):
330 def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
331 pb = unittest_pb2.TestAllTypes()
332 setattr(pb, field_name, expected_min)
333 setattr(pb, field_name, expected_max)
334 self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1)
335 self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1)
337 TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1)
338 TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
339 TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
340 TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
341 TestMinAndMaxIntegers('optional_nested_enum', -(1 << 31), (1 << 31) - 1)
343 def testRepeatedScalarTypeSafety(self):
344 proto = unittest_pb2.TestAllTypes()
345 self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
346 self.assertRaises(TypeError, proto.repeated_int32.append, 'foo')
347 self.assertRaises(TypeError, proto.repeated_string, 10)
348 self.assertRaises(TypeError, proto.repeated_bytes, 10)
350 proto.repeated_int32.append(10)
351 proto.repeated_int32[0] = 23
352 self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23)
353 self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc')
355 def testSingleScalarGettersAndSetters(self):
356 proto = unittest_pb2.TestAllTypes()
357 self.assertEqual(0, proto.optional_int32)
358 proto.optional_int32 = 1
359 self.assertEqual(1, proto.optional_int32)
360 # TODO(robinson): Test all other scalar field types.
362 def testSingleScalarClearField(self):
363 proto = unittest_pb2.TestAllTypes()
364 # Should be allowed to clear something that's not there (a no-op).
365 proto.ClearField('optional_int32')
366 proto.optional_int32 = 1
367 self.assertTrue(proto.HasField('optional_int32'))
368 proto.ClearField('optional_int32')
369 self.assertEqual(0, proto.optional_int32)
370 self.assertTrue(not proto.HasField('optional_int32'))
371 # TODO(robinson): Test all other scalar field types.
373 def testEnums(self):
374 proto = unittest_pb2.TestAllTypes()
375 self.assertEqual(1, proto.FOO)
376 self.assertEqual(1, unittest_pb2.TestAllTypes.FOO)
377 self.assertEqual(2, proto.BAR)
378 self.assertEqual(2, unittest_pb2.TestAllTypes.BAR)
379 self.assertEqual(3, proto.BAZ)
380 self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
382 def testRepeatedScalars(self):
383 proto = unittest_pb2.TestAllTypes()
385 self.assertTrue(not proto.repeated_int32)
386 self.assertEqual(0, len(proto.repeated_int32))
387 proto.repeated_int32.append(5);
388 proto.repeated_int32.append(10);
389 self.assertTrue(proto.repeated_int32)
390 self.assertEqual(2, len(proto.repeated_int32))
392 self.assertEqual([5, 10], proto.repeated_int32)
393 self.assertEqual(5, proto.repeated_int32[0])
394 self.assertEqual(10, proto.repeated_int32[-1])
395 # Test out-of-bounds indices.
396 self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234)
397 self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234)
398 # Test incorrect types passed to __getitem__.
399 self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo')
400 self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None)
402 # Test that we can use the field as an iterator.
403 result = []
404 for i in proto.repeated_int32:
405 result.append(i)
406 self.assertEqual([5, 10], result)
408 # Test clearing.
409 proto.ClearField('repeated_int32')
410 self.assertTrue(not proto.repeated_int32)
411 self.assertEqual(0, len(proto.repeated_int32))
413 def testRepeatedComposites(self):
414 proto = unittest_pb2.TestAllTypes()
415 self.assertTrue(not proto.repeated_nested_message)
416 self.assertEqual(0, len(proto.repeated_nested_message))
417 m0 = proto.repeated_nested_message.add()
418 m1 = proto.repeated_nested_message.add()
419 self.assertTrue(proto.repeated_nested_message)
420 self.assertEqual(2, len(proto.repeated_nested_message))
421 self.assertTrue(m0 is proto.repeated_nested_message[0])
422 self.assertTrue(m1 is proto.repeated_nested_message[1])
423 self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage))
425 # Test out-of-bounds indices.
426 self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
427 1234)
428 self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
429 -1234)
431 # Test incorrect types passed to __getitem__.
432 self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
433 'foo')
434 self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
435 None)
437 # Test that we can use the field as an iterator.
438 result = []
439 for i in proto.repeated_nested_message:
440 result.append(i)
441 self.assertEqual(2, len(result))
442 self.assertTrue(m0 is result[0])
443 self.assertTrue(m1 is result[1])
445 # Test clearing.
446 proto.ClearField('repeated_nested_message')
447 self.assertTrue(not proto.repeated_nested_message)
448 self.assertEqual(0, len(proto.repeated_nested_message))
450 def testHandWrittenReflection(self):
451 # TODO(robinson): We probably need a better way to specify
452 # protocol types by hand. But then again, this isn't something
453 # we expect many people to do. Hmm.
454 FieldDescriptor = descriptor.FieldDescriptor
455 foo_field_descriptor = FieldDescriptor(
456 name='foo_field', full_name='MyProto.foo_field',
457 index=0, number=1, type=FieldDescriptor.TYPE_INT64,
458 cpp_type=FieldDescriptor.CPPTYPE_INT64,
459 label=FieldDescriptor.LABEL_OPTIONAL, default_value=0,
460 containing_type=None, message_type=None, enum_type=None,
461 is_extension=False, extension_scope=None,
462 options=descriptor_pb2.FieldOptions())
463 mydescriptor = descriptor.Descriptor(
464 name='MyProto', full_name='MyProto', filename='ignored',
465 containing_type=None, nested_types=[], enum_types=[],
466 fields=[foo_field_descriptor], extensions=[],
467 options=descriptor_pb2.MessageOptions())
468 class MyProtoClass(message.Message):
469 DESCRIPTOR = mydescriptor
470 __metaclass__ = reflection.GeneratedProtocolMessageType
471 myproto_instance = MyProtoClass()
472 self.assertEqual(0, myproto_instance.foo_field)
473 self.assertTrue(not myproto_instance.HasField('foo_field'))
474 myproto_instance.foo_field = 23
475 self.assertEqual(23, myproto_instance.foo_field)
476 self.assertTrue(myproto_instance.HasField('foo_field'))
478 def testTopLevelExtensionsForOptionalScalar(self):
479 extendee_proto = unittest_pb2.TestAllExtensions()
480 extension = unittest_pb2.optional_int32_extension
481 self.assertTrue(not extendee_proto.HasExtension(extension))
482 self.assertEqual(0, extendee_proto.Extensions[extension])
483 # As with normal scalar fields, just doing a read doesn't actually set the
484 # "has" bit.
485 self.assertTrue(not extendee_proto.HasExtension(extension))
486 # Actually set the thing.
487 extendee_proto.Extensions[extension] = 23
488 self.assertEqual(23, extendee_proto.Extensions[extension])
489 self.assertTrue(extendee_proto.HasExtension(extension))
490 # Ensure that clearing works as well.
491 extendee_proto.ClearExtension(extension)
492 self.assertEqual(0, extendee_proto.Extensions[extension])
493 self.assertTrue(not extendee_proto.HasExtension(extension))
495 def testTopLevelExtensionsForRepeatedScalar(self):
496 extendee_proto = unittest_pb2.TestAllExtensions()
497 extension = unittest_pb2.repeated_string_extension
498 self.assertEqual(0, len(extendee_proto.Extensions[extension]))
499 extendee_proto.Extensions[extension].append('foo')
500 self.assertEqual(['foo'], extendee_proto.Extensions[extension])
501 string_list = extendee_proto.Extensions[extension]
502 extendee_proto.ClearExtension(extension)
503 self.assertEqual(0, len(extendee_proto.Extensions[extension]))
504 self.assertTrue(string_list is not extendee_proto.Extensions[extension])
505 # Shouldn't be allowed to do Extensions[extension] = 'a'
506 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
507 extension, 'a')
509 def testTopLevelExtensionsForOptionalMessage(self):
510 extendee_proto = unittest_pb2.TestAllExtensions()
511 extension = unittest_pb2.optional_foreign_message_extension
512 self.assertTrue(not extendee_proto.HasExtension(extension))
513 self.assertEqual(0, extendee_proto.Extensions[extension].c)
514 # As with normal (non-extension) fields, merely reading from the
515 # thing shouldn't set the "has" bit.
516 self.assertTrue(not extendee_proto.HasExtension(extension))
517 extendee_proto.Extensions[extension].c = 23
518 self.assertEqual(23, extendee_proto.Extensions[extension].c)
519 self.assertTrue(extendee_proto.HasExtension(extension))
520 # Save a reference here.
521 foreign_message = extendee_proto.Extensions[extension]
522 extendee_proto.ClearExtension(extension)
523 self.assertTrue(foreign_message is not extendee_proto.Extensions[extension])
524 # Setting a field on foreign_message now shouldn't set
525 # any "has" bits on extendee_proto.
526 foreign_message.c = 42
527 self.assertEqual(42, foreign_message.c)
528 self.assertTrue(foreign_message.HasField('c'))
529 self.assertTrue(not extendee_proto.HasExtension(extension))
530 # Shouldn't be allowed to do Extensions[extension] = 'a'
531 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
532 extension, 'a')
534 def testTopLevelExtensionsForRepeatedMessage(self):
535 extendee_proto = unittest_pb2.TestAllExtensions()
536 extension = unittest_pb2.repeatedgroup_extension
537 self.assertEqual(0, len(extendee_proto.Extensions[extension]))
538 group = extendee_proto.Extensions[extension].add()
539 group.a = 23
540 self.assertEqual(23, extendee_proto.Extensions[extension][0].a)
541 group.a = 42
542 self.assertEqual(42, extendee_proto.Extensions[extension][0].a)
543 group_list = extendee_proto.Extensions[extension]
544 extendee_proto.ClearExtension(extension)
545 self.assertEqual(0, len(extendee_proto.Extensions[extension]))
546 self.assertTrue(group_list is not extendee_proto.Extensions[extension])
547 # Shouldn't be allowed to do Extensions[extension] = 'a'
548 self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
549 extension, 'a')
551 def testNestedExtensions(self):
552 extendee_proto = unittest_pb2.TestAllExtensions()
553 extension = unittest_pb2.TestRequired.single
555 # We just test the non-repeated case.
556 self.assertTrue(not extendee_proto.HasExtension(extension))
557 required = extendee_proto.Extensions[extension]
558 self.assertEqual(0, required.a)
559 self.assertTrue(not extendee_proto.HasExtension(extension))
560 required.a = 23
561 self.assertEqual(23, extendee_proto.Extensions[extension].a)
562 self.assertTrue(extendee_proto.HasExtension(extension))
563 extendee_proto.ClearExtension(extension)
564 self.assertTrue(required is not extendee_proto.Extensions[extension])
565 self.assertTrue(not extendee_proto.HasExtension(extension))
567 # If message A directly contains message B, and
568 # a.HasField('b') is currently False, then mutating any
569 # extension in B should change a.HasField('b') to True
570 # (and so on up the object tree).
571 def testHasBitsForAncestorsOfExtendedMessage(self):
572 # Optional scalar extension.
573 toplevel = more_extensions_pb2.TopLevelMessage()
574 self.assertTrue(not toplevel.HasField('submessage'))
575 self.assertEqual(0, toplevel.submessage.Extensions[
576 more_extensions_pb2.optional_int_extension])
577 self.assertTrue(not toplevel.HasField('submessage'))
578 toplevel.submessage.Extensions[
579 more_extensions_pb2.optional_int_extension] = 23
580 self.assertEqual(23, toplevel.submessage.Extensions[
581 more_extensions_pb2.optional_int_extension])
582 self.assertTrue(toplevel.HasField('submessage'))
584 # Repeated scalar extension.
585 toplevel = more_extensions_pb2.TopLevelMessage()
586 self.assertTrue(not toplevel.HasField('submessage'))
587 self.assertEqual([], toplevel.submessage.Extensions[
588 more_extensions_pb2.repeated_int_extension])
589 self.assertTrue(not toplevel.HasField('submessage'))
590 toplevel.submessage.Extensions[
591 more_extensions_pb2.repeated_int_extension].append(23)
592 self.assertEqual([23], toplevel.submessage.Extensions[
593 more_extensions_pb2.repeated_int_extension])
594 self.assertTrue(toplevel.HasField('submessage'))
596 # Optional message extension.
597 toplevel = more_extensions_pb2.TopLevelMessage()
598 self.assertTrue(not toplevel.HasField('submessage'))
599 self.assertEqual(0, toplevel.submessage.Extensions[
600 more_extensions_pb2.optional_message_extension].foreign_message_int)
601 self.assertTrue(not toplevel.HasField('submessage'))
602 toplevel.submessage.Extensions[
603 more_extensions_pb2.optional_message_extension].foreign_message_int = 23
604 self.assertEqual(23, toplevel.submessage.Extensions[
605 more_extensions_pb2.optional_message_extension].foreign_message_int)
606 self.assertTrue(toplevel.HasField('submessage'))
608 # Repeated message extension.
609 toplevel = more_extensions_pb2.TopLevelMessage()
610 self.assertTrue(not toplevel.HasField('submessage'))
611 self.assertEqual(0, len(toplevel.submessage.Extensions[
612 more_extensions_pb2.repeated_message_extension]))
613 self.assertTrue(not toplevel.HasField('submessage'))
614 foreign = toplevel.submessage.Extensions[
615 more_extensions_pb2.repeated_message_extension].add()
616 self.assertTrue(foreign is toplevel.submessage.Extensions[
617 more_extensions_pb2.repeated_message_extension][0])
618 self.assertTrue(toplevel.HasField('submessage'))
620 def testDisconnectionAfterClearingEmptyMessage(self):
621 toplevel = more_extensions_pb2.TopLevelMessage()
622 extendee_proto = toplevel.submessage
623 extension = more_extensions_pb2.optional_message_extension
624 extension_proto = extendee_proto.Extensions[extension]
625 extendee_proto.ClearExtension(extension)
626 extension_proto.foreign_message_int = 23
628 self.assertTrue(not toplevel.HasField('submessage'))
629 self.assertTrue(extension_proto is not extendee_proto.Extensions[extension])
631 def testExtensionFailureModes(self):
632 extendee_proto = unittest_pb2.TestAllExtensions()
634 # Try non-extension-handle arguments to HasExtension,
635 # ClearExtension(), and Extensions[]...
636 self.assertRaises(KeyError, extendee_proto.HasExtension, 1234)
637 self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234)
638 self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234)
639 self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5)
641 # Try something that *is* an extension handle, just not for
642 # this message...
643 unknown_handle = more_extensions_pb2.optional_int_extension
644 self.assertRaises(KeyError, extendee_proto.HasExtension,
645 unknown_handle)
646 self.assertRaises(KeyError, extendee_proto.ClearExtension,
647 unknown_handle)
648 self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__,
649 unknown_handle)
650 self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__,
651 unknown_handle, 5)
653 # Try call HasExtension() with a valid handle, but for a
654 # *repeated* field. (Just as with non-extension repeated
655 # fields, Has*() isn't supported for extension repeated fields).
656 self.assertRaises(KeyError, extendee_proto.HasExtension,
657 unittest_pb2.repeated_string_extension)
659 def testCopyFrom(self):
660 # TODO(robinson): Implement.
661 pass
663 def testClear(self):
664 proto = unittest_pb2.TestAllTypes()
665 test_util.SetAllFields(proto)
666 # Clear the message.
667 proto.Clear()
668 self.assertEquals(proto.ByteSize(), 0)
669 empty_proto = unittest_pb2.TestAllTypes()
670 self.assertEquals(proto, empty_proto)
672 # Test if extensions which were set are cleared.
673 proto = unittest_pb2.TestAllExtensions()
674 test_util.SetAllExtensions(proto)
675 # Clear the message.
676 proto.Clear()
677 self.assertEquals(proto.ByteSize(), 0)
678 empty_proto = unittest_pb2.TestAllExtensions()
679 self.assertEquals(proto, empty_proto)
681 def testIsInitialized(self):
682 # Trivial cases - all optional fields and extensions.
683 proto = unittest_pb2.TestAllTypes()
684 self.assertTrue(proto.IsInitialized())
685 proto = unittest_pb2.TestAllExtensions()
686 self.assertTrue(proto.IsInitialized())
688 # The case of uninitialized required fields.
689 proto = unittest_pb2.TestRequired()
690 self.assertFalse(proto.IsInitialized())
691 proto.a = proto.b = proto.c = 2
692 self.assertTrue(proto.IsInitialized())
694 # The case of uninitialized submessage.
695 proto = unittest_pb2.TestRequiredForeign()
696 self.assertTrue(proto.IsInitialized())
697 proto.optional_message.a = 1
698 self.assertFalse(proto.IsInitialized())
699 proto.optional_message.b = 0
700 proto.optional_message.c = 0
701 self.assertTrue(proto.IsInitialized())
703 # Uninitialized repeated submessage.
704 message1 = proto.repeated_message.add()
705 self.assertFalse(proto.IsInitialized())
706 message1.a = message1.b = message1.c = 0
707 self.assertTrue(proto.IsInitialized())
709 # Uninitialized repeated group in an extension.
710 proto = unittest_pb2.TestAllExtensions()
711 extension = unittest_pb2.TestRequired.multi
712 message1 = proto.Extensions[extension].add()
713 message2 = proto.Extensions[extension].add()
714 self.assertFalse(proto.IsInitialized())
715 message1.a = 1
716 message1.b = 1
717 message1.c = 1
718 self.assertFalse(proto.IsInitialized())
719 message2.a = 2
720 message2.b = 2
721 message2.c = 2
722 self.assertTrue(proto.IsInitialized())
724 # Uninitialized nonrepeated message in an extension.
725 proto = unittest_pb2.TestAllExtensions()
726 extension = unittest_pb2.TestRequired.single
727 proto.Extensions[extension].a = 1
728 self.assertFalse(proto.IsInitialized())
729 proto.Extensions[extension].b = 2
730 proto.Extensions[extension].c = 3
731 self.assertTrue(proto.IsInitialized())
734 # Since we had so many tests for protocol buffer equality, we broke these out
735 # into separate TestCase classes.
738 class TestAllTypesEqualityTest(unittest.TestCase):
740 def setUp(self):
741 self.first_proto = unittest_pb2.TestAllTypes()
742 self.second_proto = unittest_pb2.TestAllTypes()
744 def testSelfEquality(self):
745 self.assertEqual(self.first_proto, self.first_proto)
747 def testEmptyProtosEqual(self):
748 self.assertEqual(self.first_proto, self.second_proto)
751 class FullProtosEqualityTest(unittest.TestCase):
753 """Equality tests using completely-full protos as a starting point."""
755 def setUp(self):
756 self.first_proto = unittest_pb2.TestAllTypes()
757 self.second_proto = unittest_pb2.TestAllTypes()
758 test_util.SetAllFields(self.first_proto)
759 test_util.SetAllFields(self.second_proto)
761 def testAllFieldsFilledEquality(self):
762 self.assertEqual(self.first_proto, self.second_proto)
764 def testNonRepeatedScalar(self):
765 # Nonrepeated scalar field change should cause inequality.
766 self.first_proto.optional_int32 += 1
767 self.assertNotEqual(self.first_proto, self.second_proto)
768 # ...as should clearing a field.
769 self.first_proto.ClearField('optional_int32')
770 self.assertNotEqual(self.first_proto, self.second_proto)
772 def testNonRepeatedComposite(self):
773 # Change a nonrepeated composite field.
774 self.first_proto.optional_nested_message.bb += 1
775 self.assertNotEqual(self.first_proto, self.second_proto)
776 self.first_proto.optional_nested_message.bb -= 1
777 self.assertEqual(self.first_proto, self.second_proto)
778 # Clear a field in the nested message.
779 self.first_proto.optional_nested_message.ClearField('bb')
780 self.assertNotEqual(self.first_proto, self.second_proto)
781 self.first_proto.optional_nested_message.bb = (
782 self.second_proto.optional_nested_message.bb)
783 self.assertEqual(self.first_proto, self.second_proto)
784 # Remove the nested message entirely.
785 self.first_proto.ClearField('optional_nested_message')
786 self.assertNotEqual(self.first_proto, self.second_proto)
788 def testRepeatedScalar(self):
789 # Change a repeated scalar field.
790 self.first_proto.repeated_int32.append(5)
791 self.assertNotEqual(self.first_proto, self.second_proto)
792 self.first_proto.ClearField('repeated_int32')
793 self.assertNotEqual(self.first_proto, self.second_proto)
795 def testRepeatedComposite(self):
796 # Change value within a repeated composite field.
797 self.first_proto.repeated_nested_message[0].bb += 1
798 self.assertNotEqual(self.first_proto, self.second_proto)
799 self.first_proto.repeated_nested_message[0].bb -= 1
800 self.assertEqual(self.first_proto, self.second_proto)
801 # Add a value to a repeated composite field.
802 self.first_proto.repeated_nested_message.add()
803 self.assertNotEqual(self.first_proto, self.second_proto)
804 self.second_proto.repeated_nested_message.add()
805 self.assertEqual(self.first_proto, self.second_proto)
807 def testNonRepeatedScalarHasBits(self):
808 # Ensure that we test "has" bits as well as value for
809 # nonrepeated scalar field.
810 self.first_proto.ClearField('optional_int32')
811 self.second_proto.optional_int32 = 0
812 self.assertNotEqual(self.first_proto, self.second_proto)
814 def testNonRepeatedCompositeHasBits(self):
815 # Ensure that we test "has" bits as well as value for
816 # nonrepeated composite field.
817 self.first_proto.ClearField('optional_nested_message')
818 self.second_proto.optional_nested_message.ClearField('bb')
819 self.assertNotEqual(self.first_proto, self.second_proto)
820 # TODO(robinson): Replace next two lines with method
821 # to set the "has" bit without changing the value,
822 # if/when such a method exists.
823 self.first_proto.optional_nested_message.bb = 0
824 self.first_proto.optional_nested_message.ClearField('bb')
825 self.assertEqual(self.first_proto, self.second_proto)
828 class ExtensionEqualityTest(unittest.TestCase):
830 def testExtensionEquality(self):
831 first_proto = unittest_pb2.TestAllExtensions()
832 second_proto = unittest_pb2.TestAllExtensions()
833 self.assertEqual(first_proto, second_proto)
834 test_util.SetAllExtensions(first_proto)
835 self.assertNotEqual(first_proto, second_proto)
836 test_util.SetAllExtensions(second_proto)
837 self.assertEqual(first_proto, second_proto)
839 # Ensure that we check value equality.
840 first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1
841 self.assertNotEqual(first_proto, second_proto)
842 first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1
843 self.assertEqual(first_proto, second_proto)
845 # Ensure that we also look at "has" bits.
846 first_proto.ClearExtension(unittest_pb2.optional_int32_extension)
847 second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
848 self.assertNotEqual(first_proto, second_proto)
849 first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
850 self.assertEqual(first_proto, second_proto)
852 # Ensure that differences in cached values
853 # don't matter if "has" bits are both false.
854 first_proto = unittest_pb2.TestAllExtensions()
855 second_proto = unittest_pb2.TestAllExtensions()
856 self.assertEqual(
857 0, first_proto.Extensions[unittest_pb2.optional_int32_extension])
858 self.assertEqual(first_proto, second_proto)
861 class MutualRecursionEqualityTest(unittest.TestCase):
863 def testEqualityWithMutualRecursion(self):
864 first_proto = unittest_pb2.TestMutualRecursionA()
865 second_proto = unittest_pb2.TestMutualRecursionA()
866 self.assertEqual(first_proto, second_proto)
867 first_proto.bb.a.bb.optional_int32 = 23
868 self.assertNotEqual(first_proto, second_proto)
869 second_proto.bb.a.bb.optional_int32 = 23
870 self.assertEqual(first_proto, second_proto)
873 class ByteSizeTest(unittest.TestCase):
875 def setUp(self):
876 self.proto = unittest_pb2.TestAllTypes()
877 self.extended_proto = more_extensions_pb2.ExtendedMessage()
879 def Size(self):
880 return self.proto.ByteSize()
882 def testEmptyMessage(self):
883 self.assertEqual(0, self.proto.ByteSize())
885 def testVarints(self):
886 def Test(i, expected_varint_size):
887 self.proto.Clear()
888 self.proto.optional_int64 = i
889 # Add one to the varint size for the tag info
890 # for tag 1.
891 self.assertEqual(expected_varint_size + 1, self.Size())
892 Test(0, 1)
893 Test(1, 1)
894 for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)):
895 Test((1 << i) - 1, num_bytes)
896 Test(-1, 10)
897 Test(-2, 10)
898 Test(-(1 << 63), 10)
900 def testStrings(self):
901 self.proto.optional_string = ''
902 # Need one byte for tag info (tag #14), and one byte for length.
903 self.assertEqual(2, self.Size())
905 self.proto.optional_string = 'abc'
906 # Need one byte for tag info (tag #14), and one byte for length.
907 self.assertEqual(2 + len(self.proto.optional_string), self.Size())
909 self.proto.optional_string = 'x' * 128
910 # Need one byte for tag info (tag #14), and TWO bytes for length.
911 self.assertEqual(3 + len(self.proto.optional_string), self.Size())
913 def testOtherNumerics(self):
914 self.proto.optional_fixed32 = 1234
915 # One byte for tag and 4 bytes for fixed32.
916 self.assertEqual(5, self.Size())
917 self.proto = unittest_pb2.TestAllTypes()
919 self.proto.optional_fixed64 = 1234
920 # One byte for tag and 8 bytes for fixed64.
921 self.assertEqual(9, self.Size())
922 self.proto = unittest_pb2.TestAllTypes()
924 self.proto.optional_float = 1.234
925 # One byte for tag and 4 bytes for float.
926 self.assertEqual(5, self.Size())
927 self.proto = unittest_pb2.TestAllTypes()
929 self.proto.optional_double = 1.234
930 # One byte for tag and 8 bytes for float.
931 self.assertEqual(9, self.Size())
932 self.proto = unittest_pb2.TestAllTypes()
934 self.proto.optional_sint32 = 64
935 # One byte for tag and 2 bytes for zig-zag-encoded 64.
936 self.assertEqual(3, self.Size())
937 self.proto = unittest_pb2.TestAllTypes()
939 def testComposites(self):
940 # 3 bytes.
941 self.proto.optional_nested_message.bb = (1 << 14)
942 # Plus one byte for bb tag.
943 # Plus 1 byte for optional_nested_message serialized size.
944 # Plus two bytes for optional_nested_message tag.
945 self.assertEqual(3 + 1 + 1 + 2, self.Size())
947 def testGroups(self):
948 # 4 bytes.
949 self.proto.optionalgroup.a = (1 << 21)
950 # Plus two bytes for |a| tag.
951 # Plus 2 * two bytes for START_GROUP and END_GROUP tags.
952 self.assertEqual(4 + 2 + 2*2, self.Size())
954 def testRepeatedScalars(self):
955 self.proto.repeated_int32.append(10) # 1 byte.
956 self.proto.repeated_int32.append(128) # 2 bytes.
957 # Also need 2 bytes for each entry for tag.
958 self.assertEqual(1 + 2 + 2*2, self.Size())
960 def testRepeatedComposites(self):
961 # Empty message. 2 bytes tag plus 1 byte length.
962 foreign_message_0 = self.proto.repeated_nested_message.add()
963 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
964 foreign_message_1 = self.proto.repeated_nested_message.add()
965 foreign_message_1.bb = 7
966 self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
968 def testRepeatedGroups(self):
969 # 2-byte START_GROUP plus 2-byte END_GROUP.
970 group_0 = self.proto.repeatedgroup.add()
971 # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a|
972 # plus 2-byte END_GROUP.
973 group_1 = self.proto.repeatedgroup.add()
974 group_1.a = 7
975 self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size())
977 def testExtensions(self):
978 proto = unittest_pb2.TestAllExtensions()
979 self.assertEqual(0, proto.ByteSize())
980 extension = unittest_pb2.optional_int32_extension # Field #1, 1 byte.
981 proto.Extensions[extension] = 23
982 # 1 byte for tag, 1 byte for value.
983 self.assertEqual(2, proto.ByteSize())
985 def testCacheInvalidationForNonrepeatedScalar(self):
986 # Test non-extension.
987 self.proto.optional_int32 = 1
988 self.assertEqual(2, self.proto.ByteSize())
989 self.proto.optional_int32 = 128
990 self.assertEqual(3, self.proto.ByteSize())
991 self.proto.ClearField('optional_int32')
992 self.assertEqual(0, self.proto.ByteSize())
994 # Test within extension.
995 extension = more_extensions_pb2.optional_int_extension
996 self.extended_proto.Extensions[extension] = 1
997 self.assertEqual(2, self.extended_proto.ByteSize())
998 self.extended_proto.Extensions[extension] = 128
999 self.assertEqual(3, self.extended_proto.ByteSize())
1000 self.extended_proto.ClearExtension(extension)
1001 self.assertEqual(0, self.extended_proto.ByteSize())
1003 def testCacheInvalidationForRepeatedScalar(self):
1004 # Test non-extension.
1005 self.proto.repeated_int32.append(1)
1006 self.assertEqual(3, self.proto.ByteSize())
1007 self.proto.repeated_int32.append(1)
1008 self.assertEqual(6, self.proto.ByteSize())
1009 self.proto.repeated_int32[1] = 128
1010 self.assertEqual(7, self.proto.ByteSize())
1011 self.proto.ClearField('repeated_int32')
1012 self.assertEqual(0, self.proto.ByteSize())
1014 # Test within extension.
1015 extension = more_extensions_pb2.repeated_int_extension
1016 repeated = self.extended_proto.Extensions[extension]
1017 repeated.append(1)
1018 self.assertEqual(2, self.extended_proto.ByteSize())
1019 repeated.append(1)
1020 self.assertEqual(4, self.extended_proto.ByteSize())
1021 repeated[1] = 128
1022 self.assertEqual(5, self.extended_proto.ByteSize())
1023 self.extended_proto.ClearExtension(extension)
1024 self.assertEqual(0, self.extended_proto.ByteSize())
1026 def testCacheInvalidationForNonrepeatedMessage(self):
1027 # Test non-extension.
1028 self.proto.optional_foreign_message.c = 1
1029 self.assertEqual(5, self.proto.ByteSize())
1030 self.proto.optional_foreign_message.c = 128
1031 self.assertEqual(6, self.proto.ByteSize())
1032 self.proto.optional_foreign_message.ClearField('c')
1033 self.assertEqual(3, self.proto.ByteSize())
1034 self.proto.ClearField('optional_foreign_message')
1035 self.assertEqual(0, self.proto.ByteSize())
1036 child = self.proto.optional_foreign_message
1037 self.proto.ClearField('optional_foreign_message')
1038 child.c = 128
1039 self.assertEqual(0, self.proto.ByteSize())
1041 # Test within extension.
1042 extension = more_extensions_pb2.optional_message_extension
1043 child = self.extended_proto.Extensions[extension]
1044 self.assertEqual(0, self.extended_proto.ByteSize())
1045 child.foreign_message_int = 1
1046 self.assertEqual(4, self.extended_proto.ByteSize())
1047 child.foreign_message_int = 128
1048 self.assertEqual(5, self.extended_proto.ByteSize())
1049 self.extended_proto.ClearExtension(extension)
1050 self.assertEqual(0, self.extended_proto.ByteSize())
1052 def testCacheInvalidationForRepeatedMessage(self):
1053 # Test non-extension.
1054 child0 = self.proto.repeated_foreign_message.add()
1055 self.assertEqual(3, self.proto.ByteSize())
1056 self.proto.repeated_foreign_message.add()
1057 self.assertEqual(6, self.proto.ByteSize())
1058 child0.c = 1
1059 self.assertEqual(8, self.proto.ByteSize())
1060 self.proto.ClearField('repeated_foreign_message')
1061 self.assertEqual(0, self.proto.ByteSize())
1063 # Test within extension.
1064 extension = more_extensions_pb2.repeated_message_extension
1065 child_list = self.extended_proto.Extensions[extension]
1066 child0 = child_list.add()
1067 self.assertEqual(2, self.extended_proto.ByteSize())
1068 child_list.add()
1069 self.assertEqual(4, self.extended_proto.ByteSize())
1070 child0.foreign_message_int = 1
1071 self.assertEqual(6, self.extended_proto.ByteSize())
1072 child0.ClearField('foreign_message_int')
1073 self.assertEqual(4, self.extended_proto.ByteSize())
1074 self.extended_proto.ClearExtension(extension)
1075 self.assertEqual(0, self.extended_proto.ByteSize())
1078 # TODO(robinson): We need cross-language serialization consistency tests.
1079 # Issues to be sure to cover include:
1080 # * Handling of unrecognized tags ("uninterpreted_bytes").
1081 # * Handling of MessageSets.
1082 # * Consistent ordering of tags in the wire format,
1083 # including ordering between extensions and non-extension
1084 # fields.
1085 # * Consistent serialization of negative numbers, especially
1086 # negative int32s.
1087 # * Handling of empty submessages (with and without "has"
1088 # bits set).
1090 class SerializationTest(unittest.TestCase):
1092 def testSerializeEmtpyMessage(self):
1093 first_proto = unittest_pb2.TestAllTypes()
1094 second_proto = unittest_pb2.TestAllTypes()
1095 serialized = first_proto.SerializeToString()
1096 self.assertEqual(first_proto.ByteSize(), len(serialized))
1097 second_proto.MergeFromString(serialized)
1098 self.assertEqual(first_proto, second_proto)
1100 def testSerializeAllFields(self):
1101 first_proto = unittest_pb2.TestAllTypes()
1102 second_proto = unittest_pb2.TestAllTypes()
1103 test_util.SetAllFields(first_proto)
1104 serialized = first_proto.SerializeToString()
1105 self.assertEqual(first_proto.ByteSize(), len(serialized))
1106 second_proto.MergeFromString(serialized)
1107 self.assertEqual(first_proto, second_proto)
1109 def testSerializeAllExtensions(self):
1110 first_proto = unittest_pb2.TestAllExtensions()
1111 second_proto = unittest_pb2.TestAllExtensions()
1112 test_util.SetAllExtensions(first_proto)
1113 serialized = first_proto.SerializeToString()
1114 second_proto.MergeFromString(serialized)
1115 self.assertEqual(first_proto, second_proto)
1117 def testCanonicalSerializationOrder(self):
1118 proto = more_messages_pb2.OutOfOrderFields()
1119 # These are also their tag numbers. Even though we're setting these in
1120 # reverse-tag order AND they're listed in reverse tag-order in the .proto
1121 # file, they should nonetheless be serialized in tag order.
1122 proto.optional_sint32 = 5
1123 proto.Extensions[more_messages_pb2.optional_uint64] = 4
1124 proto.optional_uint32 = 3
1125 proto.Extensions[more_messages_pb2.optional_int64] = 2
1126 proto.optional_int32 = 1
1127 serialized = proto.SerializeToString()
1128 self.assertEqual(proto.ByteSize(), len(serialized))
1129 d = decoder.Decoder(serialized)
1130 ReadTag = d.ReadFieldNumberAndWireType
1131 self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag())
1132 self.assertEqual(1, d.ReadInt32())
1133 self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag())
1134 self.assertEqual(2, d.ReadInt64())
1135 self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag())
1136 self.assertEqual(3, d.ReadUInt32())
1137 self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag())
1138 self.assertEqual(4, d.ReadUInt64())
1139 self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag())
1140 self.assertEqual(5, d.ReadSInt32())
1142 def testCanonicalSerializationOrderSameAsCpp(self):
1143 # Copy of the same test we use for C++.
1144 proto = unittest_pb2.TestFieldOrderings()
1145 test_util.SetAllFieldsAndExtensions(proto)
1146 serialized = proto.SerializeToString()
1147 test_util.ExpectAllFieldsAndExtensionsInOrder(serialized)
1149 def testMergeFromStringWhenFieldsAlreadySet(self):
1150 first_proto = unittest_pb2.TestAllTypes()
1151 first_proto.repeated_string.append('foobar')
1152 first_proto.optional_int32 = 23
1153 first_proto.optional_nested_message.bb = 42
1154 serialized = first_proto.SerializeToString()
1156 second_proto = unittest_pb2.TestAllTypes()
1157 second_proto.repeated_string.append('baz')
1158 second_proto.optional_int32 = 100
1159 second_proto.optional_nested_message.bb = 999
1161 second_proto.MergeFromString(serialized)
1162 # Ensure that we append to repeated fields.
1163 self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string))
1164 # Ensure that we overwrite nonrepeatd scalars.
1165 self.assertEqual(23, second_proto.optional_int32)
1166 # Ensure that we recursively call MergeFromString() on
1167 # submessages.
1168 self.assertEqual(42, second_proto.optional_nested_message.bb)
1170 def testMessageSetWireFormat(self):
1171 proto = unittest_mset_pb2.TestMessageSet()
1172 extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
1173 extension_message2 = unittest_mset_pb2.TestMessageSetExtension2
1174 extension1 = extension_message1.message_set_extension
1175 extension2 = extension_message2.message_set_extension
1176 proto.Extensions[extension1].i = 123
1177 proto.Extensions[extension2].str = 'foo'
1179 # Serialize using the MessageSet wire format (this is specified in the
1180 # .proto file).
1181 serialized = proto.SerializeToString()
1183 raw = unittest_mset_pb2.RawMessageSet()
1184 self.assertEqual(False,
1185 raw.DESCRIPTOR.GetOptions().message_set_wire_format)
1186 raw.MergeFromString(serialized)
1187 self.assertEqual(2, len(raw.item))
1189 message1 = unittest_mset_pb2.TestMessageSetExtension1()
1190 message1.MergeFromString(raw.item[0].message)
1191 self.assertEqual(123, message1.i)
1193 message2 = unittest_mset_pb2.TestMessageSetExtension2()
1194 message2.MergeFromString(raw.item[1].message)
1195 self.assertEqual('foo', message2.str)
1197 # Deserialize using the MessageSet wire format.
1198 proto2 = unittest_mset_pb2.TestMessageSet()
1199 proto2.MergeFromString(serialized)
1200 self.assertEqual(123, proto2.Extensions[extension1].i)
1201 self.assertEqual('foo', proto2.Extensions[extension2].str)
1203 # Check byte size.
1204 self.assertEqual(proto2.ByteSize(), len(serialized))
1205 self.assertEqual(proto.ByteSize(), len(serialized))
1207 def testMessageSetWireFormatUnknownExtension(self):
1208 # Create a message using the message set wire format with an unknown
1209 # message.
1210 raw = unittest_mset_pb2.RawMessageSet()
1212 # Add an item.
1213 item = raw.item.add()
1214 item.type_id = 1545008
1215 extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
1216 message1 = unittest_mset_pb2.TestMessageSetExtension1()
1217 message1.i = 12345
1218 item.message = message1.SerializeToString()
1220 # Add a second, unknown extension.
1221 item = raw.item.add()
1222 item.type_id = 1545009
1223 extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
1224 message1 = unittest_mset_pb2.TestMessageSetExtension1()
1225 message1.i = 12346
1226 item.message = message1.SerializeToString()
1228 # Add another unknown extension.
1229 item = raw.item.add()
1230 item.type_id = 1545010
1231 message1 = unittest_mset_pb2.TestMessageSetExtension2()
1232 message1.str = 'foo'
1233 item.message = message1.SerializeToString()
1235 serialized = raw.SerializeToString()
1237 # Parse message using the message set wire format.
1238 proto = unittest_mset_pb2.TestMessageSet()
1239 proto.MergeFromString(serialized)
1241 # Check that the message parsed well.
1242 extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
1243 extension1 = extension_message1.message_set_extension
1244 self.assertEquals(12345, proto.Extensions[extension1].i)
1246 def testUnknownFields(self):
1247 proto = unittest_pb2.TestAllTypes()
1248 test_util.SetAllFields(proto)
1250 serialized = proto.SerializeToString()
1252 # The empty message should be parsable with all of the fields
1253 # unknown.
1254 proto2 = unittest_pb2.TestEmptyMessage()
1256 # Parsing this message should succeed.
1257 proto2.MergeFromString(serialized)
1260 class OptionsTest(unittest.TestCase):
1262 def testMessageOptions(self):
1263 proto = unittest_mset_pb2.TestMessageSet()
1264 self.assertEqual(True,
1265 proto.DESCRIPTOR.GetOptions().message_set_wire_format)
1266 proto = unittest_pb2.TestAllTypes()
1267 self.assertEqual(False,
1268 proto.DESCRIPTOR.GetOptions().message_set_wire_format)
1271 class UtilityTest(unittest.TestCase):
1273 def testImergeSorted(self):
1274 ImergeSorted = reflection._ImergeSorted
1275 # Various types of emptiness.
1276 self.assertEqual([], list(ImergeSorted()))
1277 self.assertEqual([], list(ImergeSorted([])))
1278 self.assertEqual([], list(ImergeSorted([], [])))
1280 # One nonempty list.
1281 self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3])))
1282 self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3], [])))
1283 self.assertEqual([1, 2, 3], list(ImergeSorted([], [1, 2, 3])))
1285 # Merging some nonempty lists together.
1286 self.assertEqual([1, 2, 3], list(ImergeSorted([1, 3], [2])))
1287 self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2])))
1288 self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2], [])))
1290 # Elements repeated across component iterators.
1291 self.assertEqual([1, 2, 2, 3, 3],
1292 list(ImergeSorted([1, 2], [3], [2, 3])))
1294 # Elements repeated within an iterator.
1295 self.assertEqual([1, 2, 2, 3, 3],
1296 list(ImergeSorted([1, 2, 2], [3], [3])))
1299 if __name__ == '__main__':
1300 unittest.main()