1 # Protocol Buffers - Google's data interchange format
2 # Copyright 2008 Google Inc.
3 # http://code.google.com/p/protobuf/
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
17 """Unittest for reflection.py, which also indirectly tests the output of the
18 pure-Python protocol compiler.
21 __author__
= 'robinson@google.com (Will Robinson)'
26 # TODO(robinson): When we split this test in two, only some of these imports
27 # will be necessary in each test.
28 from google
.protobuf
import unittest_import_pb2
29 from google
.protobuf
import unittest_mset_pb2
30 from google
.protobuf
import unittest_pb2
31 from google
.protobuf
import descriptor_pb2
32 from google
.protobuf
import descriptor
33 from google
.protobuf
import message
34 from google
.protobuf
import reflection
35 from google
.protobuf
.internal
import more_extensions_pb2
36 from google
.protobuf
.internal
import more_messages_pb2
37 from google
.protobuf
.internal
import wire_format
38 from google
.protobuf
.internal
import test_util
39 from google
.protobuf
.internal
import decoder
42 class RefectionTest(unittest
.TestCase
):
44 def testSimpleHasBits(self
):
46 proto
= unittest_pb2
.TestAllTypes()
47 self
.assertTrue(not proto
.HasField('optional_int32'))
48 self
.assertEqual(0, proto
.optional_int32
)
49 # HasField() shouldn't be true if all we've done is
50 # read the default value.
51 self
.assertTrue(not proto
.HasField('optional_int32'))
52 proto
.optional_int32
= 1
53 # Setting a value however *should* set the "has" bit.
54 self
.assertTrue(proto
.HasField('optional_int32'))
55 proto
.ClearField('optional_int32')
56 # And clearing that value should unset the "has" bit.
57 self
.assertTrue(not proto
.HasField('optional_int32'))
59 def testHasBitsWithSinglyNestedScalar(self
):
60 # Helper used to test foreign messages and groups.
62 # composite_field_name should be the name of a non-repeated
63 # composite (i.e., foreign or group) field in TestAllTypes,
64 # and scalar_field_name should be the name of an integer-valued
65 # scalar field within that composite.
67 # I never thought I'd miss C++ macros and templates so much. :(
68 # This helper is semantically just:
70 # assert proto.composite_field.scalar_field == 0
71 # assert not proto.composite_field.HasField('scalar_field')
72 # assert not proto.HasField('composite_field')
74 # proto.composite_field.scalar_field = 10
75 # old_composite_field = proto.composite_field
77 # assert proto.composite_field.scalar_field == 10
78 # assert proto.composite_field.HasField('scalar_field')
79 # assert proto.HasField('composite_field')
81 # proto.ClearField('composite_field')
83 # assert not proto.composite_field.HasField('scalar_field')
84 # assert not proto.HasField('composite_field')
85 # assert proto.composite_field.scalar_field == 0
87 # # Now ensure that ClearField('composite_field') disconnected
88 # # the old field object from the object tree...
89 # assert old_composite_field is not proto.composite_field
90 # old_composite_field.scalar_field = 20
91 # assert not proto.composite_field.HasField('scalar_field')
92 # assert not proto.HasField('composite_field')
93 def TestCompositeHasBits(composite_field_name
, scalar_field_name
):
94 proto
= unittest_pb2
.TestAllTypes()
95 # First, check that we can get the scalar value, and see that it's the
96 # default (0), but that proto.HasField('omposite') and
97 # proto.composite.HasField('scalar') will still return False.
98 composite_field
= getattr(proto
, composite_field_name
)
99 original_scalar_value
= getattr(composite_field
, scalar_field_name
)
100 self
.assertEqual(0, original_scalar_value
)
101 # Assert that the composite object does not "have" the scalar.
102 self
.assertTrue(not composite_field
.HasField(scalar_field_name
))
103 # Assert that proto does not "have" the composite field.
104 self
.assertTrue(not proto
.HasField(composite_field_name
))
106 # Now set the scalar within the composite field. Ensure that the setting
107 # is reflected, and that proto.HasField('composite') and
108 # proto.composite.HasField('scalar') now both return True.
110 setattr(composite_field
, scalar_field_name
, new_val
)
111 self
.assertEqual(new_val
, getattr(composite_field
, scalar_field_name
))
112 # Hold on to a reference to the current composite_field object.
113 old_composite_field
= composite_field
114 # Assert that the has methods now return true.
115 self
.assertTrue(composite_field
.HasField(scalar_field_name
))
116 self
.assertTrue(proto
.HasField(composite_field_name
))
118 # Now call the clear method...
119 proto
.ClearField(composite_field_name
)
121 # ...and ensure that the "has" bits are all back to False...
122 composite_field
= getattr(proto
, composite_field_name
)
123 self
.assertTrue(not composite_field
.HasField(scalar_field_name
))
124 self
.assertTrue(not proto
.HasField(composite_field_name
))
125 # ...and ensure that the scalar field has returned to its default.
126 self
.assertEqual(0, getattr(composite_field
, scalar_field_name
))
128 # Finally, ensure that modifications to the old composite field object
129 # don't have any effect on the parent.
131 # (NOTE that when we clear the composite field in the parent, we actually
132 # don't recursively clear down the tree. Instead, we just disconnect the
133 # cleared composite from the tree.)
134 self
.assertTrue(old_composite_field
is not composite_field
)
135 setattr(old_composite_field
, scalar_field_name
, new_val
)
136 self
.assertTrue(not composite_field
.HasField(scalar_field_name
))
137 self
.assertTrue(not proto
.HasField(composite_field_name
))
138 self
.assertEqual(0, getattr(composite_field
, scalar_field_name
))
140 # Test simple, single-level nesting when we set a scalar.
141 TestCompositeHasBits('optionalgroup', 'a')
142 TestCompositeHasBits('optional_nested_message', 'bb')
143 TestCompositeHasBits('optional_foreign_message', 'c')
144 TestCompositeHasBits('optional_import_message', 'd')
146 def testReferencesToNestedMessage(self
):
147 proto
= unittest_pb2
.TestAllTypes()
148 nested
= proto
.optional_nested_message
150 # A previous version had a bug where this would raise an exception when
151 # hitting a now-dead weak reference.
154 def testDisconnectingNestedMessageBeforeSettingField(self
):
155 proto
= unittest_pb2
.TestAllTypes()
156 nested
= proto
.optional_nested_message
157 proto
.ClearField('optional_nested_message') # Should disconnect from parent
158 self
.assertTrue(nested
is not proto
.optional_nested_message
)
160 self
.assertTrue(not proto
.HasField('optional_nested_message'))
161 self
.assertEqual(0, proto
.optional_nested_message
.bb
)
163 def testHasBitsWhenModifyingRepeatedFields(self
):
164 # Test nesting when we add an element to a repeated field in a submessage.
165 proto
= unittest_pb2
.TestNestedMessageHasBits()
166 proto
.optional_nested_message
.nestedmessage_repeated_int32
.append(5)
168 [5], proto
.optional_nested_message
.nestedmessage_repeated_int32
)
169 self
.assertTrue(proto
.HasField('optional_nested_message'))
171 # Do the same test, but with a repeated composite field within the
173 proto
.ClearField('optional_nested_message')
174 self
.assertTrue(not proto
.HasField('optional_nested_message'))
175 proto
.optional_nested_message
.nestedmessage_repeated_foreignmessage
.add()
176 self
.assertTrue(proto
.HasField('optional_nested_message'))
178 def testHasBitsForManyLevelsOfNesting(self
):
179 # Test nesting many levels deep.
180 recursive_proto
= unittest_pb2
.TestMutualRecursionA()
181 self
.assertTrue(not recursive_proto
.HasField('bb'))
182 self
.assertEqual(0, recursive_proto
.bb
.a
.bb
.a
.bb
.optional_int32
)
183 self
.assertTrue(not recursive_proto
.HasField('bb'))
184 recursive_proto
.bb
.a
.bb
.a
.bb
.optional_int32
= 5
185 self
.assertEqual(5, recursive_proto
.bb
.a
.bb
.a
.bb
.optional_int32
)
186 self
.assertTrue(recursive_proto
.HasField('bb'))
187 self
.assertTrue(recursive_proto
.bb
.HasField('a'))
188 self
.assertTrue(recursive_proto
.bb
.a
.HasField('bb'))
189 self
.assertTrue(recursive_proto
.bb
.a
.bb
.HasField('a'))
190 self
.assertTrue(recursive_proto
.bb
.a
.bb
.a
.HasField('bb'))
191 self
.assertTrue(not recursive_proto
.bb
.a
.bb
.a
.bb
.HasField('a'))
192 self
.assertTrue(recursive_proto
.bb
.a
.bb
.a
.bb
.HasField('optional_int32'))
194 def testSingularListFields(self
):
195 proto
= unittest_pb2
.TestAllTypes()
196 proto
.optional_fixed32
= 1
197 proto
.optional_int32
= 5
198 proto
.optional_string
= 'foo'
200 [ (proto
.DESCRIPTOR
.fields_by_name
['optional_int32' ], 5),
201 (proto
.DESCRIPTOR
.fields_by_name
['optional_fixed32'], 1),
202 (proto
.DESCRIPTOR
.fields_by_name
['optional_string' ], 'foo') ],
205 def testRepeatedListFields(self
):
206 proto
= unittest_pb2
.TestAllTypes()
207 proto
.repeated_fixed32
.append(1)
208 proto
.repeated_int32
.append(5)
209 proto
.repeated_int32
.append(11)
210 proto
.repeated_string
.append('foo')
211 proto
.repeated_string
.append('bar')
212 proto
.repeated_string
.append('baz')
213 proto
.optional_int32
= 21
215 [ (proto
.DESCRIPTOR
.fields_by_name
['optional_int32' ], 21),
216 (proto
.DESCRIPTOR
.fields_by_name
['repeated_int32' ], [5, 11]),
217 (proto
.DESCRIPTOR
.fields_by_name
['repeated_fixed32'], [1]),
218 (proto
.DESCRIPTOR
.fields_by_name
['repeated_string' ],
219 ['foo', 'bar', 'baz']) ],
222 def testSingularListExtensions(self
):
223 proto
= unittest_pb2
.TestAllExtensions()
224 proto
.Extensions
[unittest_pb2
.optional_fixed32_extension
] = 1
225 proto
.Extensions
[unittest_pb2
.optional_int32_extension
] = 5
226 proto
.Extensions
[unittest_pb2
.optional_string_extension
] = 'foo'
228 [ (unittest_pb2
.optional_int32_extension
, 5),
229 (unittest_pb2
.optional_fixed32_extension
, 1),
230 (unittest_pb2
.optional_string_extension
, 'foo') ],
233 def testRepeatedListExtensions(self
):
234 proto
= unittest_pb2
.TestAllExtensions()
235 proto
.Extensions
[unittest_pb2
.repeated_fixed32_extension
].append(1)
236 proto
.Extensions
[unittest_pb2
.repeated_int32_extension
].append(5)
237 proto
.Extensions
[unittest_pb2
.repeated_int32_extension
].append(11)
238 proto
.Extensions
[unittest_pb2
.repeated_string_extension
].append('foo')
239 proto
.Extensions
[unittest_pb2
.repeated_string_extension
].append('bar')
240 proto
.Extensions
[unittest_pb2
.repeated_string_extension
].append('baz')
241 proto
.Extensions
[unittest_pb2
.optional_int32_extension
] = 21
243 [ (unittest_pb2
.optional_int32_extension
, 21),
244 (unittest_pb2
.repeated_int32_extension
, [5, 11]),
245 (unittest_pb2
.repeated_fixed32_extension
, [1]),
246 (unittest_pb2
.repeated_string_extension
, ['foo', 'bar', 'baz']) ],
249 def testListFieldsAndExtensions(self
):
250 proto
= unittest_pb2
.TestFieldOrderings()
251 test_util
.SetAllFieldsAndExtensions(proto
)
252 unittest_pb2
.my_extension_int
254 [ (proto
.DESCRIPTOR
.fields_by_name
['my_int' ], 1),
255 (unittest_pb2
.my_extension_int
, 23),
256 (proto
.DESCRIPTOR
.fields_by_name
['my_string'], 'foo'),
257 (unittest_pb2
.my_extension_string
, 'bar'),
258 (proto
.DESCRIPTOR
.fields_by_name
['my_float' ], 1.0) ],
261 def testDefaultValues(self
):
262 proto
= unittest_pb2
.TestAllTypes()
263 self
.assertEqual(0, proto
.optional_int32
)
264 self
.assertEqual(0, proto
.optional_int64
)
265 self
.assertEqual(0, proto
.optional_uint32
)
266 self
.assertEqual(0, proto
.optional_uint64
)
267 self
.assertEqual(0, proto
.optional_sint32
)
268 self
.assertEqual(0, proto
.optional_sint64
)
269 self
.assertEqual(0, proto
.optional_fixed32
)
270 self
.assertEqual(0, proto
.optional_fixed64
)
271 self
.assertEqual(0, proto
.optional_sfixed32
)
272 self
.assertEqual(0, proto
.optional_sfixed64
)
273 self
.assertEqual(0.0, proto
.optional_float
)
274 self
.assertEqual(0.0, proto
.optional_double
)
275 self
.assertEqual(False, proto
.optional_bool
)
276 self
.assertEqual('', proto
.optional_string
)
277 self
.assertEqual('', proto
.optional_bytes
)
279 self
.assertEqual(41, proto
.default_int32
)
280 self
.assertEqual(42, proto
.default_int64
)
281 self
.assertEqual(43, proto
.default_uint32
)
282 self
.assertEqual(44, proto
.default_uint64
)
283 self
.assertEqual(-45, proto
.default_sint32
)
284 self
.assertEqual(46, proto
.default_sint64
)
285 self
.assertEqual(47, proto
.default_fixed32
)
286 self
.assertEqual(48, proto
.default_fixed64
)
287 self
.assertEqual(49, proto
.default_sfixed32
)
288 self
.assertEqual(-50, proto
.default_sfixed64
)
289 self
.assertEqual(51.5, proto
.default_float
)
290 self
.assertEqual(52e3
, proto
.default_double
)
291 self
.assertEqual(True, proto
.default_bool
)
292 self
.assertEqual('hello', proto
.default_string
)
293 self
.assertEqual('world', proto
.default_bytes
)
294 self
.assertEqual(unittest_pb2
.TestAllTypes
.BAR
, proto
.default_nested_enum
)
295 self
.assertEqual(unittest_pb2
.FOREIGN_BAR
, proto
.default_foreign_enum
)
296 self
.assertEqual(unittest_import_pb2
.IMPORT_BAR
,
297 proto
.default_import_enum
)
299 def testHasFieldWithUnknownFieldName(self
):
300 proto
= unittest_pb2
.TestAllTypes()
301 self
.assertRaises(ValueError, proto
.HasField
, 'nonexistent_field')
303 def testClearFieldWithUnknownFieldName(self
):
304 proto
= unittest_pb2
.TestAllTypes()
305 self
.assertRaises(ValueError, proto
.ClearField
, 'nonexistent_field')
307 def testDisallowedAssignments(self
):
308 # It's illegal to assign values directly to repeated fields
309 # or to nonrepeated composite fields. Ensure that this fails.
310 proto
= unittest_pb2
.TestAllTypes()
312 self
.assertRaises(AttributeError, setattr, proto
, 'repeated_int32', 10)
313 # Lists shouldn't work, either.
314 self
.assertRaises(AttributeError, setattr, proto
, 'repeated_int32', [10])
316 self
.assertRaises(AttributeError, setattr, proto
,
317 'optional_nested_message', 23)
318 # proto.nonexistent_field = 23 should fail as well.
319 self
.assertRaises(AttributeError, setattr, proto
, 'nonexistent_field', 23)
321 # TODO(robinson): Add type-safety check for enums.
322 def testSingleScalarTypeSafety(self
):
323 proto
= unittest_pb2
.TestAllTypes()
324 self
.assertRaises(TypeError, setattr, proto
, 'optional_int32', 1.1)
325 self
.assertRaises(TypeError, setattr, proto
, 'optional_int32', 'foo')
326 self
.assertRaises(TypeError, setattr, proto
, 'optional_string', 10)
327 self
.assertRaises(TypeError, setattr, proto
, 'optional_bytes', 10)
329 def testSingleScalarBoundsChecking(self
):
330 def TestMinAndMaxIntegers(field_name
, expected_min
, expected_max
):
331 pb
= unittest_pb2
.TestAllTypes()
332 setattr(pb
, field_name
, expected_min
)
333 setattr(pb
, field_name
, expected_max
)
334 self
.assertRaises(ValueError, setattr, pb
, field_name
, expected_min
- 1)
335 self
.assertRaises(ValueError, setattr, pb
, field_name
, expected_max
+ 1)
337 TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1)
338 TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
339 TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
340 TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
341 TestMinAndMaxIntegers('optional_nested_enum', -(1 << 31), (1 << 31) - 1)
343 def testRepeatedScalarTypeSafety(self
):
344 proto
= unittest_pb2
.TestAllTypes()
345 self
.assertRaises(TypeError, proto
.repeated_int32
.append
, 1.1)
346 self
.assertRaises(TypeError, proto
.repeated_int32
.append
, 'foo')
347 self
.assertRaises(TypeError, proto
.repeated_string
, 10)
348 self
.assertRaises(TypeError, proto
.repeated_bytes
, 10)
350 proto
.repeated_int32
.append(10)
351 proto
.repeated_int32
[0] = 23
352 self
.assertRaises(IndexError, proto
.repeated_int32
.__setitem__
, 500, 23)
353 self
.assertRaises(TypeError, proto
.repeated_int32
.__setitem__
, 0, 'abc')
355 def testSingleScalarGettersAndSetters(self
):
356 proto
= unittest_pb2
.TestAllTypes()
357 self
.assertEqual(0, proto
.optional_int32
)
358 proto
.optional_int32
= 1
359 self
.assertEqual(1, proto
.optional_int32
)
360 # TODO(robinson): Test all other scalar field types.
362 def testSingleScalarClearField(self
):
363 proto
= unittest_pb2
.TestAllTypes()
364 # Should be allowed to clear something that's not there (a no-op).
365 proto
.ClearField('optional_int32')
366 proto
.optional_int32
= 1
367 self
.assertTrue(proto
.HasField('optional_int32'))
368 proto
.ClearField('optional_int32')
369 self
.assertEqual(0, proto
.optional_int32
)
370 self
.assertTrue(not proto
.HasField('optional_int32'))
371 # TODO(robinson): Test all other scalar field types.
374 proto
= unittest_pb2
.TestAllTypes()
375 self
.assertEqual(1, proto
.FOO
)
376 self
.assertEqual(1, unittest_pb2
.TestAllTypes
.FOO
)
377 self
.assertEqual(2, proto
.BAR
)
378 self
.assertEqual(2, unittest_pb2
.TestAllTypes
.BAR
)
379 self
.assertEqual(3, proto
.BAZ
)
380 self
.assertEqual(3, unittest_pb2
.TestAllTypes
.BAZ
)
382 def testRepeatedScalars(self
):
383 proto
= unittest_pb2
.TestAllTypes()
385 self
.assertTrue(not proto
.repeated_int32
)
386 self
.assertEqual(0, len(proto
.repeated_int32
))
387 proto
.repeated_int32
.append(5);
388 proto
.repeated_int32
.append(10);
389 self
.assertTrue(proto
.repeated_int32
)
390 self
.assertEqual(2, len(proto
.repeated_int32
))
392 self
.assertEqual([5, 10], proto
.repeated_int32
)
393 self
.assertEqual(5, proto
.repeated_int32
[0])
394 self
.assertEqual(10, proto
.repeated_int32
[-1])
395 # Test out-of-bounds indices.
396 self
.assertRaises(IndexError, proto
.repeated_int32
.__getitem__
, 1234)
397 self
.assertRaises(IndexError, proto
.repeated_int32
.__getitem__
, -1234)
398 # Test incorrect types passed to __getitem__.
399 self
.assertRaises(TypeError, proto
.repeated_int32
.__getitem__
, 'foo')
400 self
.assertRaises(TypeError, proto
.repeated_int32
.__getitem__
, None)
402 # Test that we can use the field as an iterator.
404 for i
in proto
.repeated_int32
:
406 self
.assertEqual([5, 10], result
)
409 proto
.ClearField('repeated_int32')
410 self
.assertTrue(not proto
.repeated_int32
)
411 self
.assertEqual(0, len(proto
.repeated_int32
))
413 def testRepeatedComposites(self
):
414 proto
= unittest_pb2
.TestAllTypes()
415 self
.assertTrue(not proto
.repeated_nested_message
)
416 self
.assertEqual(0, len(proto
.repeated_nested_message
))
417 m0
= proto
.repeated_nested_message
.add()
418 m1
= proto
.repeated_nested_message
.add()
419 self
.assertTrue(proto
.repeated_nested_message
)
420 self
.assertEqual(2, len(proto
.repeated_nested_message
))
421 self
.assertTrue(m0
is proto
.repeated_nested_message
[0])
422 self
.assertTrue(m1
is proto
.repeated_nested_message
[1])
423 self
.assertTrue(isinstance(m0
, unittest_pb2
.TestAllTypes
.NestedMessage
))
425 # Test out-of-bounds indices.
426 self
.assertRaises(IndexError, proto
.repeated_nested_message
.__getitem
__,
428 self
.assertRaises(IndexError, proto
.repeated_nested_message
.__getitem
__,
431 # Test incorrect types passed to __getitem__.
432 self
.assertRaises(TypeError, proto
.repeated_nested_message
.__getitem
__,
434 self
.assertRaises(TypeError, proto
.repeated_nested_message
.__getitem
__,
437 # Test that we can use the field as an iterator.
439 for i
in proto
.repeated_nested_message
:
441 self
.assertEqual(2, len(result
))
442 self
.assertTrue(m0
is result
[0])
443 self
.assertTrue(m1
is result
[1])
446 proto
.ClearField('repeated_nested_message')
447 self
.assertTrue(not proto
.repeated_nested_message
)
448 self
.assertEqual(0, len(proto
.repeated_nested_message
))
450 def testHandWrittenReflection(self
):
451 # TODO(robinson): We probably need a better way to specify
452 # protocol types by hand. But then again, this isn't something
453 # we expect many people to do. Hmm.
454 FieldDescriptor
= descriptor
.FieldDescriptor
455 foo_field_descriptor
= FieldDescriptor(
456 name
='foo_field', full_name
='MyProto.foo_field',
457 index
=0, number
=1, type=FieldDescriptor
.TYPE_INT64
,
458 cpp_type
=FieldDescriptor
.CPPTYPE_INT64
,
459 label
=FieldDescriptor
.LABEL_OPTIONAL
, default_value
=0,
460 containing_type
=None, message_type
=None, enum_type
=None,
461 is_extension
=False, extension_scope
=None,
462 options
=descriptor_pb2
.FieldOptions())
463 mydescriptor
= descriptor
.Descriptor(
464 name
='MyProto', full_name
='MyProto', filename
='ignored',
465 containing_type
=None, nested_types
=[], enum_types
=[],
466 fields
=[foo_field_descriptor
], extensions
=[],
467 options
=descriptor_pb2
.MessageOptions())
468 class MyProtoClass(message
.Message
):
469 DESCRIPTOR
= mydescriptor
470 __metaclass__
= reflection
.GeneratedProtocolMessageType
471 myproto_instance
= MyProtoClass()
472 self
.assertEqual(0, myproto_instance
.foo_field
)
473 self
.assertTrue(not myproto_instance
.HasField('foo_field'))
474 myproto_instance
.foo_field
= 23
475 self
.assertEqual(23, myproto_instance
.foo_field
)
476 self
.assertTrue(myproto_instance
.HasField('foo_field'))
478 def testTopLevelExtensionsForOptionalScalar(self
):
479 extendee_proto
= unittest_pb2
.TestAllExtensions()
480 extension
= unittest_pb2
.optional_int32_extension
481 self
.assertTrue(not extendee_proto
.HasExtension(extension
))
482 self
.assertEqual(0, extendee_proto
.Extensions
[extension
])
483 # As with normal scalar fields, just doing a read doesn't actually set the
485 self
.assertTrue(not extendee_proto
.HasExtension(extension
))
486 # Actually set the thing.
487 extendee_proto
.Extensions
[extension
] = 23
488 self
.assertEqual(23, extendee_proto
.Extensions
[extension
])
489 self
.assertTrue(extendee_proto
.HasExtension(extension
))
490 # Ensure that clearing works as well.
491 extendee_proto
.ClearExtension(extension
)
492 self
.assertEqual(0, extendee_proto
.Extensions
[extension
])
493 self
.assertTrue(not extendee_proto
.HasExtension(extension
))
495 def testTopLevelExtensionsForRepeatedScalar(self
):
496 extendee_proto
= unittest_pb2
.TestAllExtensions()
497 extension
= unittest_pb2
.repeated_string_extension
498 self
.assertEqual(0, len(extendee_proto
.Extensions
[extension
]))
499 extendee_proto
.Extensions
[extension
].append('foo')
500 self
.assertEqual(['foo'], extendee_proto
.Extensions
[extension
])
501 string_list
= extendee_proto
.Extensions
[extension
]
502 extendee_proto
.ClearExtension(extension
)
503 self
.assertEqual(0, len(extendee_proto
.Extensions
[extension
]))
504 self
.assertTrue(string_list
is not extendee_proto
.Extensions
[extension
])
505 # Shouldn't be allowed to do Extensions[extension] = 'a'
506 self
.assertRaises(TypeError, operator
.setitem
, extendee_proto
.Extensions
,
509 def testTopLevelExtensionsForOptionalMessage(self
):
510 extendee_proto
= unittest_pb2
.TestAllExtensions()
511 extension
= unittest_pb2
.optional_foreign_message_extension
512 self
.assertTrue(not extendee_proto
.HasExtension(extension
))
513 self
.assertEqual(0, extendee_proto
.Extensions
[extension
].c
)
514 # As with normal (non-extension) fields, merely reading from the
515 # thing shouldn't set the "has" bit.
516 self
.assertTrue(not extendee_proto
.HasExtension(extension
))
517 extendee_proto
.Extensions
[extension
].c
= 23
518 self
.assertEqual(23, extendee_proto
.Extensions
[extension
].c
)
519 self
.assertTrue(extendee_proto
.HasExtension(extension
))
520 # Save a reference here.
521 foreign_message
= extendee_proto
.Extensions
[extension
]
522 extendee_proto
.ClearExtension(extension
)
523 self
.assertTrue(foreign_message
is not extendee_proto
.Extensions
[extension
])
524 # Setting a field on foreign_message now shouldn't set
525 # any "has" bits on extendee_proto.
526 foreign_message
.c
= 42
527 self
.assertEqual(42, foreign_message
.c
)
528 self
.assertTrue(foreign_message
.HasField('c'))
529 self
.assertTrue(not extendee_proto
.HasExtension(extension
))
530 # Shouldn't be allowed to do Extensions[extension] = 'a'
531 self
.assertRaises(TypeError, operator
.setitem
, extendee_proto
.Extensions
,
534 def testTopLevelExtensionsForRepeatedMessage(self
):
535 extendee_proto
= unittest_pb2
.TestAllExtensions()
536 extension
= unittest_pb2
.repeatedgroup_extension
537 self
.assertEqual(0, len(extendee_proto
.Extensions
[extension
]))
538 group
= extendee_proto
.Extensions
[extension
].add()
540 self
.assertEqual(23, extendee_proto
.Extensions
[extension
][0].a
)
542 self
.assertEqual(42, extendee_proto
.Extensions
[extension
][0].a
)
543 group_list
= extendee_proto
.Extensions
[extension
]
544 extendee_proto
.ClearExtension(extension
)
545 self
.assertEqual(0, len(extendee_proto
.Extensions
[extension
]))
546 self
.assertTrue(group_list
is not extendee_proto
.Extensions
[extension
])
547 # Shouldn't be allowed to do Extensions[extension] = 'a'
548 self
.assertRaises(TypeError, operator
.setitem
, extendee_proto
.Extensions
,
551 def testNestedExtensions(self
):
552 extendee_proto
= unittest_pb2
.TestAllExtensions()
553 extension
= unittest_pb2
.TestRequired
.single
555 # We just test the non-repeated case.
556 self
.assertTrue(not extendee_proto
.HasExtension(extension
))
557 required
= extendee_proto
.Extensions
[extension
]
558 self
.assertEqual(0, required
.a
)
559 self
.assertTrue(not extendee_proto
.HasExtension(extension
))
561 self
.assertEqual(23, extendee_proto
.Extensions
[extension
].a
)
562 self
.assertTrue(extendee_proto
.HasExtension(extension
))
563 extendee_proto
.ClearExtension(extension
)
564 self
.assertTrue(required
is not extendee_proto
.Extensions
[extension
])
565 self
.assertTrue(not extendee_proto
.HasExtension(extension
))
567 # If message A directly contains message B, and
568 # a.HasField('b') is currently False, then mutating any
569 # extension in B should change a.HasField('b') to True
570 # (and so on up the object tree).
571 def testHasBitsForAncestorsOfExtendedMessage(self
):
572 # Optional scalar extension.
573 toplevel
= more_extensions_pb2
.TopLevelMessage()
574 self
.assertTrue(not toplevel
.HasField('submessage'))
575 self
.assertEqual(0, toplevel
.submessage
.Extensions
[
576 more_extensions_pb2
.optional_int_extension
])
577 self
.assertTrue(not toplevel
.HasField('submessage'))
578 toplevel
.submessage
.Extensions
[
579 more_extensions_pb2
.optional_int_extension
] = 23
580 self
.assertEqual(23, toplevel
.submessage
.Extensions
[
581 more_extensions_pb2
.optional_int_extension
])
582 self
.assertTrue(toplevel
.HasField('submessage'))
584 # Repeated scalar extension.
585 toplevel
= more_extensions_pb2
.TopLevelMessage()
586 self
.assertTrue(not toplevel
.HasField('submessage'))
587 self
.assertEqual([], toplevel
.submessage
.Extensions
[
588 more_extensions_pb2
.repeated_int_extension
])
589 self
.assertTrue(not toplevel
.HasField('submessage'))
590 toplevel
.submessage
.Extensions
[
591 more_extensions_pb2
.repeated_int_extension
].append(23)
592 self
.assertEqual([23], toplevel
.submessage
.Extensions
[
593 more_extensions_pb2
.repeated_int_extension
])
594 self
.assertTrue(toplevel
.HasField('submessage'))
596 # Optional message extension.
597 toplevel
= more_extensions_pb2
.TopLevelMessage()
598 self
.assertTrue(not toplevel
.HasField('submessage'))
599 self
.assertEqual(0, toplevel
.submessage
.Extensions
[
600 more_extensions_pb2
.optional_message_extension
].foreign_message_int
)
601 self
.assertTrue(not toplevel
.HasField('submessage'))
602 toplevel
.submessage
.Extensions
[
603 more_extensions_pb2
.optional_message_extension
].foreign_message_int
= 23
604 self
.assertEqual(23, toplevel
.submessage
.Extensions
[
605 more_extensions_pb2
.optional_message_extension
].foreign_message_int
)
606 self
.assertTrue(toplevel
.HasField('submessage'))
608 # Repeated message extension.
609 toplevel
= more_extensions_pb2
.TopLevelMessage()
610 self
.assertTrue(not toplevel
.HasField('submessage'))
611 self
.assertEqual(0, len(toplevel
.submessage
.Extensions
[
612 more_extensions_pb2
.repeated_message_extension
]))
613 self
.assertTrue(not toplevel
.HasField('submessage'))
614 foreign
= toplevel
.submessage
.Extensions
[
615 more_extensions_pb2
.repeated_message_extension
].add()
616 self
.assertTrue(foreign
is toplevel
.submessage
.Extensions
[
617 more_extensions_pb2
.repeated_message_extension
][0])
618 self
.assertTrue(toplevel
.HasField('submessage'))
620 def testDisconnectionAfterClearingEmptyMessage(self
):
621 toplevel
= more_extensions_pb2
.TopLevelMessage()
622 extendee_proto
= toplevel
.submessage
623 extension
= more_extensions_pb2
.optional_message_extension
624 extension_proto
= extendee_proto
.Extensions
[extension
]
625 extendee_proto
.ClearExtension(extension
)
626 extension_proto
.foreign_message_int
= 23
628 self
.assertTrue(not toplevel
.HasField('submessage'))
629 self
.assertTrue(extension_proto
is not extendee_proto
.Extensions
[extension
])
631 def testExtensionFailureModes(self
):
632 extendee_proto
= unittest_pb2
.TestAllExtensions()
634 # Try non-extension-handle arguments to HasExtension,
635 # ClearExtension(), and Extensions[]...
636 self
.assertRaises(KeyError, extendee_proto
.HasExtension
, 1234)
637 self
.assertRaises(KeyError, extendee_proto
.ClearExtension
, 1234)
638 self
.assertRaises(KeyError, extendee_proto
.Extensions
.__getitem
__, 1234)
639 self
.assertRaises(KeyError, extendee_proto
.Extensions
.__setitem
__, 1234, 5)
641 # Try something that *is* an extension handle, just not for
643 unknown_handle
= more_extensions_pb2
.optional_int_extension
644 self
.assertRaises(KeyError, extendee_proto
.HasExtension
,
646 self
.assertRaises(KeyError, extendee_proto
.ClearExtension
,
648 self
.assertRaises(KeyError, extendee_proto
.Extensions
.__getitem
__,
650 self
.assertRaises(KeyError, extendee_proto
.Extensions
.__setitem
__,
653 # Try call HasExtension() with a valid handle, but for a
654 # *repeated* field. (Just as with non-extension repeated
655 # fields, Has*() isn't supported for extension repeated fields).
656 self
.assertRaises(KeyError, extendee_proto
.HasExtension
,
657 unittest_pb2
.repeated_string_extension
)
659 def testCopyFrom(self
):
660 # TODO(robinson): Implement.
664 proto
= unittest_pb2
.TestAllTypes()
665 test_util
.SetAllFields(proto
)
668 self
.assertEquals(proto
.ByteSize(), 0)
669 empty_proto
= unittest_pb2
.TestAllTypes()
670 self
.assertEquals(proto
, empty_proto
)
672 # Test if extensions which were set are cleared.
673 proto
= unittest_pb2
.TestAllExtensions()
674 test_util
.SetAllExtensions(proto
)
677 self
.assertEquals(proto
.ByteSize(), 0)
678 empty_proto
= unittest_pb2
.TestAllExtensions()
679 self
.assertEquals(proto
, empty_proto
)
681 def testIsInitialized(self
):
682 # Trivial cases - all optional fields and extensions.
683 proto
= unittest_pb2
.TestAllTypes()
684 self
.assertTrue(proto
.IsInitialized())
685 proto
= unittest_pb2
.TestAllExtensions()
686 self
.assertTrue(proto
.IsInitialized())
688 # The case of uninitialized required fields.
689 proto
= unittest_pb2
.TestRequired()
690 self
.assertFalse(proto
.IsInitialized())
691 proto
.a
= proto
.b
= proto
.c
= 2
692 self
.assertTrue(proto
.IsInitialized())
694 # The case of uninitialized submessage.
695 proto
= unittest_pb2
.TestRequiredForeign()
696 self
.assertTrue(proto
.IsInitialized())
697 proto
.optional_message
.a
= 1
698 self
.assertFalse(proto
.IsInitialized())
699 proto
.optional_message
.b
= 0
700 proto
.optional_message
.c
= 0
701 self
.assertTrue(proto
.IsInitialized())
703 # Uninitialized repeated submessage.
704 message1
= proto
.repeated_message
.add()
705 self
.assertFalse(proto
.IsInitialized())
706 message1
.a
= message1
.b
= message1
.c
= 0
707 self
.assertTrue(proto
.IsInitialized())
709 # Uninitialized repeated group in an extension.
710 proto
= unittest_pb2
.TestAllExtensions()
711 extension
= unittest_pb2
.TestRequired
.multi
712 message1
= proto
.Extensions
[extension
].add()
713 message2
= proto
.Extensions
[extension
].add()
714 self
.assertFalse(proto
.IsInitialized())
718 self
.assertFalse(proto
.IsInitialized())
722 self
.assertTrue(proto
.IsInitialized())
724 # Uninitialized nonrepeated message in an extension.
725 proto
= unittest_pb2
.TestAllExtensions()
726 extension
= unittest_pb2
.TestRequired
.single
727 proto
.Extensions
[extension
].a
= 1
728 self
.assertFalse(proto
.IsInitialized())
729 proto
.Extensions
[extension
].b
= 2
730 proto
.Extensions
[extension
].c
= 3
731 self
.assertTrue(proto
.IsInitialized())
734 # Since we had so many tests for protocol buffer equality, we broke these out
735 # into separate TestCase classes.
738 class TestAllTypesEqualityTest(unittest
.TestCase
):
741 self
.first_proto
= unittest_pb2
.TestAllTypes()
742 self
.second_proto
= unittest_pb2
.TestAllTypes()
744 def testSelfEquality(self
):
745 self
.assertEqual(self
.first_proto
, self
.first_proto
)
747 def testEmptyProtosEqual(self
):
748 self
.assertEqual(self
.first_proto
, self
.second_proto
)
751 class FullProtosEqualityTest(unittest
.TestCase
):
753 """Equality tests using completely-full protos as a starting point."""
756 self
.first_proto
= unittest_pb2
.TestAllTypes()
757 self
.second_proto
= unittest_pb2
.TestAllTypes()
758 test_util
.SetAllFields(self
.first_proto
)
759 test_util
.SetAllFields(self
.second_proto
)
761 def testAllFieldsFilledEquality(self
):
762 self
.assertEqual(self
.first_proto
, self
.second_proto
)
764 def testNonRepeatedScalar(self
):
765 # Nonrepeated scalar field change should cause inequality.
766 self
.first_proto
.optional_int32
+= 1
767 self
.assertNotEqual(self
.first_proto
, self
.second_proto
)
768 # ...as should clearing a field.
769 self
.first_proto
.ClearField('optional_int32')
770 self
.assertNotEqual(self
.first_proto
, self
.second_proto
)
772 def testNonRepeatedComposite(self
):
773 # Change a nonrepeated composite field.
774 self
.first_proto
.optional_nested_message
.bb
+= 1
775 self
.assertNotEqual(self
.first_proto
, self
.second_proto
)
776 self
.first_proto
.optional_nested_message
.bb
-= 1
777 self
.assertEqual(self
.first_proto
, self
.second_proto
)
778 # Clear a field in the nested message.
779 self
.first_proto
.optional_nested_message
.ClearField('bb')
780 self
.assertNotEqual(self
.first_proto
, self
.second_proto
)
781 self
.first_proto
.optional_nested_message
.bb
= (
782 self
.second_proto
.optional_nested_message
.bb
)
783 self
.assertEqual(self
.first_proto
, self
.second_proto
)
784 # Remove the nested message entirely.
785 self
.first_proto
.ClearField('optional_nested_message')
786 self
.assertNotEqual(self
.first_proto
, self
.second_proto
)
788 def testRepeatedScalar(self
):
789 # Change a repeated scalar field.
790 self
.first_proto
.repeated_int32
.append(5)
791 self
.assertNotEqual(self
.first_proto
, self
.second_proto
)
792 self
.first_proto
.ClearField('repeated_int32')
793 self
.assertNotEqual(self
.first_proto
, self
.second_proto
)
795 def testRepeatedComposite(self
):
796 # Change value within a repeated composite field.
797 self
.first_proto
.repeated_nested_message
[0].bb
+= 1
798 self
.assertNotEqual(self
.first_proto
, self
.second_proto
)
799 self
.first_proto
.repeated_nested_message
[0].bb
-= 1
800 self
.assertEqual(self
.first_proto
, self
.second_proto
)
801 # Add a value to a repeated composite field.
802 self
.first_proto
.repeated_nested_message
.add()
803 self
.assertNotEqual(self
.first_proto
, self
.second_proto
)
804 self
.second_proto
.repeated_nested_message
.add()
805 self
.assertEqual(self
.first_proto
, self
.second_proto
)
807 def testNonRepeatedScalarHasBits(self
):
808 # Ensure that we test "has" bits as well as value for
809 # nonrepeated scalar field.
810 self
.first_proto
.ClearField('optional_int32')
811 self
.second_proto
.optional_int32
= 0
812 self
.assertNotEqual(self
.first_proto
, self
.second_proto
)
814 def testNonRepeatedCompositeHasBits(self
):
815 # Ensure that we test "has" bits as well as value for
816 # nonrepeated composite field.
817 self
.first_proto
.ClearField('optional_nested_message')
818 self
.second_proto
.optional_nested_message
.ClearField('bb')
819 self
.assertNotEqual(self
.first_proto
, self
.second_proto
)
820 # TODO(robinson): Replace next two lines with method
821 # to set the "has" bit without changing the value,
822 # if/when such a method exists.
823 self
.first_proto
.optional_nested_message
.bb
= 0
824 self
.first_proto
.optional_nested_message
.ClearField('bb')
825 self
.assertEqual(self
.first_proto
, self
.second_proto
)
828 class ExtensionEqualityTest(unittest
.TestCase
):
830 def testExtensionEquality(self
):
831 first_proto
= unittest_pb2
.TestAllExtensions()
832 second_proto
= unittest_pb2
.TestAllExtensions()
833 self
.assertEqual(first_proto
, second_proto
)
834 test_util
.SetAllExtensions(first_proto
)
835 self
.assertNotEqual(first_proto
, second_proto
)
836 test_util
.SetAllExtensions(second_proto
)
837 self
.assertEqual(first_proto
, second_proto
)
839 # Ensure that we check value equality.
840 first_proto
.Extensions
[unittest_pb2
.optional_int32_extension
] += 1
841 self
.assertNotEqual(first_proto
, second_proto
)
842 first_proto
.Extensions
[unittest_pb2
.optional_int32_extension
] -= 1
843 self
.assertEqual(first_proto
, second_proto
)
845 # Ensure that we also look at "has" bits.
846 first_proto
.ClearExtension(unittest_pb2
.optional_int32_extension
)
847 second_proto
.Extensions
[unittest_pb2
.optional_int32_extension
] = 0
848 self
.assertNotEqual(first_proto
, second_proto
)
849 first_proto
.Extensions
[unittest_pb2
.optional_int32_extension
] = 0
850 self
.assertEqual(first_proto
, second_proto
)
852 # Ensure that differences in cached values
853 # don't matter if "has" bits are both false.
854 first_proto
= unittest_pb2
.TestAllExtensions()
855 second_proto
= unittest_pb2
.TestAllExtensions()
857 0, first_proto
.Extensions
[unittest_pb2
.optional_int32_extension
])
858 self
.assertEqual(first_proto
, second_proto
)
861 class MutualRecursionEqualityTest(unittest
.TestCase
):
863 def testEqualityWithMutualRecursion(self
):
864 first_proto
= unittest_pb2
.TestMutualRecursionA()
865 second_proto
= unittest_pb2
.TestMutualRecursionA()
866 self
.assertEqual(first_proto
, second_proto
)
867 first_proto
.bb
.a
.bb
.optional_int32
= 23
868 self
.assertNotEqual(first_proto
, second_proto
)
869 second_proto
.bb
.a
.bb
.optional_int32
= 23
870 self
.assertEqual(first_proto
, second_proto
)
873 class ByteSizeTest(unittest
.TestCase
):
876 self
.proto
= unittest_pb2
.TestAllTypes()
877 self
.extended_proto
= more_extensions_pb2
.ExtendedMessage()
880 return self
.proto
.ByteSize()
882 def testEmptyMessage(self
):
883 self
.assertEqual(0, self
.proto
.ByteSize())
885 def testVarints(self
):
886 def Test(i
, expected_varint_size
):
888 self
.proto
.optional_int64
= i
889 # Add one to the varint size for the tag info
891 self
.assertEqual(expected_varint_size
+ 1, self
.Size())
894 for i
, num_bytes
in zip(range(7, 63, 7), range(1, 10000)):
895 Test((1 << i
) - 1, num_bytes
)
900 def testStrings(self
):
901 self
.proto
.optional_string
= ''
902 # Need one byte for tag info (tag #14), and one byte for length.
903 self
.assertEqual(2, self
.Size())
905 self
.proto
.optional_string
= 'abc'
906 # Need one byte for tag info (tag #14), and one byte for length.
907 self
.assertEqual(2 + len(self
.proto
.optional_string
), self
.Size())
909 self
.proto
.optional_string
= 'x' * 128
910 # Need one byte for tag info (tag #14), and TWO bytes for length.
911 self
.assertEqual(3 + len(self
.proto
.optional_string
), self
.Size())
913 def testOtherNumerics(self
):
914 self
.proto
.optional_fixed32
= 1234
915 # One byte for tag and 4 bytes for fixed32.
916 self
.assertEqual(5, self
.Size())
917 self
.proto
= unittest_pb2
.TestAllTypes()
919 self
.proto
.optional_fixed64
= 1234
920 # One byte for tag and 8 bytes for fixed64.
921 self
.assertEqual(9, self
.Size())
922 self
.proto
= unittest_pb2
.TestAllTypes()
924 self
.proto
.optional_float
= 1.234
925 # One byte for tag and 4 bytes for float.
926 self
.assertEqual(5, self
.Size())
927 self
.proto
= unittest_pb2
.TestAllTypes()
929 self
.proto
.optional_double
= 1.234
930 # One byte for tag and 8 bytes for float.
931 self
.assertEqual(9, self
.Size())
932 self
.proto
= unittest_pb2
.TestAllTypes()
934 self
.proto
.optional_sint32
= 64
935 # One byte for tag and 2 bytes for zig-zag-encoded 64.
936 self
.assertEqual(3, self
.Size())
937 self
.proto
= unittest_pb2
.TestAllTypes()
939 def testComposites(self
):
941 self
.proto
.optional_nested_message
.bb
= (1 << 14)
942 # Plus one byte for bb tag.
943 # Plus 1 byte for optional_nested_message serialized size.
944 # Plus two bytes for optional_nested_message tag.
945 self
.assertEqual(3 + 1 + 1 + 2, self
.Size())
947 def testGroups(self
):
949 self
.proto
.optionalgroup
.a
= (1 << 21)
950 # Plus two bytes for |a| tag.
951 # Plus 2 * two bytes for START_GROUP and END_GROUP tags.
952 self
.assertEqual(4 + 2 + 2*2, self
.Size())
954 def testRepeatedScalars(self
):
955 self
.proto
.repeated_int32
.append(10) # 1 byte.
956 self
.proto
.repeated_int32
.append(128) # 2 bytes.
957 # Also need 2 bytes for each entry for tag.
958 self
.assertEqual(1 + 2 + 2*2, self
.Size())
960 def testRepeatedComposites(self
):
961 # Empty message. 2 bytes tag plus 1 byte length.
962 foreign_message_0
= self
.proto
.repeated_nested_message
.add()
963 # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
964 foreign_message_1
= self
.proto
.repeated_nested_message
.add()
965 foreign_message_1
.bb
= 7
966 self
.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self
.Size())
968 def testRepeatedGroups(self
):
969 # 2-byte START_GROUP plus 2-byte END_GROUP.
970 group_0
= self
.proto
.repeatedgroup
.add()
971 # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a|
972 # plus 2-byte END_GROUP.
973 group_1
= self
.proto
.repeatedgroup
.add()
975 self
.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self
.Size())
977 def testExtensions(self
):
978 proto
= unittest_pb2
.TestAllExtensions()
979 self
.assertEqual(0, proto
.ByteSize())
980 extension
= unittest_pb2
.optional_int32_extension
# Field #1, 1 byte.
981 proto
.Extensions
[extension
] = 23
982 # 1 byte for tag, 1 byte for value.
983 self
.assertEqual(2, proto
.ByteSize())
985 def testCacheInvalidationForNonrepeatedScalar(self
):
986 # Test non-extension.
987 self
.proto
.optional_int32
= 1
988 self
.assertEqual(2, self
.proto
.ByteSize())
989 self
.proto
.optional_int32
= 128
990 self
.assertEqual(3, self
.proto
.ByteSize())
991 self
.proto
.ClearField('optional_int32')
992 self
.assertEqual(0, self
.proto
.ByteSize())
994 # Test within extension.
995 extension
= more_extensions_pb2
.optional_int_extension
996 self
.extended_proto
.Extensions
[extension
] = 1
997 self
.assertEqual(2, self
.extended_proto
.ByteSize())
998 self
.extended_proto
.Extensions
[extension
] = 128
999 self
.assertEqual(3, self
.extended_proto
.ByteSize())
1000 self
.extended_proto
.ClearExtension(extension
)
1001 self
.assertEqual(0, self
.extended_proto
.ByteSize())
1003 def testCacheInvalidationForRepeatedScalar(self
):
1004 # Test non-extension.
1005 self
.proto
.repeated_int32
.append(1)
1006 self
.assertEqual(3, self
.proto
.ByteSize())
1007 self
.proto
.repeated_int32
.append(1)
1008 self
.assertEqual(6, self
.proto
.ByteSize())
1009 self
.proto
.repeated_int32
[1] = 128
1010 self
.assertEqual(7, self
.proto
.ByteSize())
1011 self
.proto
.ClearField('repeated_int32')
1012 self
.assertEqual(0, self
.proto
.ByteSize())
1014 # Test within extension.
1015 extension
= more_extensions_pb2
.repeated_int_extension
1016 repeated
= self
.extended_proto
.Extensions
[extension
]
1018 self
.assertEqual(2, self
.extended_proto
.ByteSize())
1020 self
.assertEqual(4, self
.extended_proto
.ByteSize())
1022 self
.assertEqual(5, self
.extended_proto
.ByteSize())
1023 self
.extended_proto
.ClearExtension(extension
)
1024 self
.assertEqual(0, self
.extended_proto
.ByteSize())
1026 def testCacheInvalidationForNonrepeatedMessage(self
):
1027 # Test non-extension.
1028 self
.proto
.optional_foreign_message
.c
= 1
1029 self
.assertEqual(5, self
.proto
.ByteSize())
1030 self
.proto
.optional_foreign_message
.c
= 128
1031 self
.assertEqual(6, self
.proto
.ByteSize())
1032 self
.proto
.optional_foreign_message
.ClearField('c')
1033 self
.assertEqual(3, self
.proto
.ByteSize())
1034 self
.proto
.ClearField('optional_foreign_message')
1035 self
.assertEqual(0, self
.proto
.ByteSize())
1036 child
= self
.proto
.optional_foreign_message
1037 self
.proto
.ClearField('optional_foreign_message')
1039 self
.assertEqual(0, self
.proto
.ByteSize())
1041 # Test within extension.
1042 extension
= more_extensions_pb2
.optional_message_extension
1043 child
= self
.extended_proto
.Extensions
[extension
]
1044 self
.assertEqual(0, self
.extended_proto
.ByteSize())
1045 child
.foreign_message_int
= 1
1046 self
.assertEqual(4, self
.extended_proto
.ByteSize())
1047 child
.foreign_message_int
= 128
1048 self
.assertEqual(5, self
.extended_proto
.ByteSize())
1049 self
.extended_proto
.ClearExtension(extension
)
1050 self
.assertEqual(0, self
.extended_proto
.ByteSize())
1052 def testCacheInvalidationForRepeatedMessage(self
):
1053 # Test non-extension.
1054 child0
= self
.proto
.repeated_foreign_message
.add()
1055 self
.assertEqual(3, self
.proto
.ByteSize())
1056 self
.proto
.repeated_foreign_message
.add()
1057 self
.assertEqual(6, self
.proto
.ByteSize())
1059 self
.assertEqual(8, self
.proto
.ByteSize())
1060 self
.proto
.ClearField('repeated_foreign_message')
1061 self
.assertEqual(0, self
.proto
.ByteSize())
1063 # Test within extension.
1064 extension
= more_extensions_pb2
.repeated_message_extension
1065 child_list
= self
.extended_proto
.Extensions
[extension
]
1066 child0
= child_list
.add()
1067 self
.assertEqual(2, self
.extended_proto
.ByteSize())
1069 self
.assertEqual(4, self
.extended_proto
.ByteSize())
1070 child0
.foreign_message_int
= 1
1071 self
.assertEqual(6, self
.extended_proto
.ByteSize())
1072 child0
.ClearField('foreign_message_int')
1073 self
.assertEqual(4, self
.extended_proto
.ByteSize())
1074 self
.extended_proto
.ClearExtension(extension
)
1075 self
.assertEqual(0, self
.extended_proto
.ByteSize())
1078 # TODO(robinson): We need cross-language serialization consistency tests.
1079 # Issues to be sure to cover include:
1080 # * Handling of unrecognized tags ("uninterpreted_bytes").
1081 # * Handling of MessageSets.
1082 # * Consistent ordering of tags in the wire format,
1083 # including ordering between extensions and non-extension
1085 # * Consistent serialization of negative numbers, especially
1087 # * Handling of empty submessages (with and without "has"
1090 class SerializationTest(unittest
.TestCase
):
1092 def testSerializeEmtpyMessage(self
):
1093 first_proto
= unittest_pb2
.TestAllTypes()
1094 second_proto
= unittest_pb2
.TestAllTypes()
1095 serialized
= first_proto
.SerializeToString()
1096 self
.assertEqual(first_proto
.ByteSize(), len(serialized
))
1097 second_proto
.MergeFromString(serialized
)
1098 self
.assertEqual(first_proto
, second_proto
)
1100 def testSerializeAllFields(self
):
1101 first_proto
= unittest_pb2
.TestAllTypes()
1102 second_proto
= unittest_pb2
.TestAllTypes()
1103 test_util
.SetAllFields(first_proto
)
1104 serialized
= first_proto
.SerializeToString()
1105 self
.assertEqual(first_proto
.ByteSize(), len(serialized
))
1106 second_proto
.MergeFromString(serialized
)
1107 self
.assertEqual(first_proto
, second_proto
)
1109 def testSerializeAllExtensions(self
):
1110 first_proto
= unittest_pb2
.TestAllExtensions()
1111 second_proto
= unittest_pb2
.TestAllExtensions()
1112 test_util
.SetAllExtensions(first_proto
)
1113 serialized
= first_proto
.SerializeToString()
1114 second_proto
.MergeFromString(serialized
)
1115 self
.assertEqual(first_proto
, second_proto
)
1117 def testCanonicalSerializationOrder(self
):
1118 proto
= more_messages_pb2
.OutOfOrderFields()
1119 # These are also their tag numbers. Even though we're setting these in
1120 # reverse-tag order AND they're listed in reverse tag-order in the .proto
1121 # file, they should nonetheless be serialized in tag order.
1122 proto
.optional_sint32
= 5
1123 proto
.Extensions
[more_messages_pb2
.optional_uint64
] = 4
1124 proto
.optional_uint32
= 3
1125 proto
.Extensions
[more_messages_pb2
.optional_int64
] = 2
1126 proto
.optional_int32
= 1
1127 serialized
= proto
.SerializeToString()
1128 self
.assertEqual(proto
.ByteSize(), len(serialized
))
1129 d
= decoder
.Decoder(serialized
)
1130 ReadTag
= d
.ReadFieldNumberAndWireType
1131 self
.assertEqual((1, wire_format
.WIRETYPE_VARINT
), ReadTag())
1132 self
.assertEqual(1, d
.ReadInt32())
1133 self
.assertEqual((2, wire_format
.WIRETYPE_VARINT
), ReadTag())
1134 self
.assertEqual(2, d
.ReadInt64())
1135 self
.assertEqual((3, wire_format
.WIRETYPE_VARINT
), ReadTag())
1136 self
.assertEqual(3, d
.ReadUInt32())
1137 self
.assertEqual((4, wire_format
.WIRETYPE_VARINT
), ReadTag())
1138 self
.assertEqual(4, d
.ReadUInt64())
1139 self
.assertEqual((5, wire_format
.WIRETYPE_VARINT
), ReadTag())
1140 self
.assertEqual(5, d
.ReadSInt32())
1142 def testCanonicalSerializationOrderSameAsCpp(self
):
1143 # Copy of the same test we use for C++.
1144 proto
= unittest_pb2
.TestFieldOrderings()
1145 test_util
.SetAllFieldsAndExtensions(proto
)
1146 serialized
= proto
.SerializeToString()
1147 test_util
.ExpectAllFieldsAndExtensionsInOrder(serialized
)
1149 def testMergeFromStringWhenFieldsAlreadySet(self
):
1150 first_proto
= unittest_pb2
.TestAllTypes()
1151 first_proto
.repeated_string
.append('foobar')
1152 first_proto
.optional_int32
= 23
1153 first_proto
.optional_nested_message
.bb
= 42
1154 serialized
= first_proto
.SerializeToString()
1156 second_proto
= unittest_pb2
.TestAllTypes()
1157 second_proto
.repeated_string
.append('baz')
1158 second_proto
.optional_int32
= 100
1159 second_proto
.optional_nested_message
.bb
= 999
1161 second_proto
.MergeFromString(serialized
)
1162 # Ensure that we append to repeated fields.
1163 self
.assertEqual(['baz', 'foobar'], list(second_proto
.repeated_string
))
1164 # Ensure that we overwrite nonrepeatd scalars.
1165 self
.assertEqual(23, second_proto
.optional_int32
)
1166 # Ensure that we recursively call MergeFromString() on
1168 self
.assertEqual(42, second_proto
.optional_nested_message
.bb
)
1170 def testMessageSetWireFormat(self
):
1171 proto
= unittest_mset_pb2
.TestMessageSet()
1172 extension_message1
= unittest_mset_pb2
.TestMessageSetExtension1
1173 extension_message2
= unittest_mset_pb2
.TestMessageSetExtension2
1174 extension1
= extension_message1
.message_set_extension
1175 extension2
= extension_message2
.message_set_extension
1176 proto
.Extensions
[extension1
].i
= 123
1177 proto
.Extensions
[extension2
].str = 'foo'
1179 # Serialize using the MessageSet wire format (this is specified in the
1181 serialized
= proto
.SerializeToString()
1183 raw
= unittest_mset_pb2
.RawMessageSet()
1184 self
.assertEqual(False,
1185 raw
.DESCRIPTOR
.GetOptions().message_set_wire_format
)
1186 raw
.MergeFromString(serialized
)
1187 self
.assertEqual(2, len(raw
.item
))
1189 message1
= unittest_mset_pb2
.TestMessageSetExtension1()
1190 message1
.MergeFromString(raw
.item
[0].message
)
1191 self
.assertEqual(123, message1
.i
)
1193 message2
= unittest_mset_pb2
.TestMessageSetExtension2()
1194 message2
.MergeFromString(raw
.item
[1].message
)
1195 self
.assertEqual('foo', message2
.str)
1197 # Deserialize using the MessageSet wire format.
1198 proto2
= unittest_mset_pb2
.TestMessageSet()
1199 proto2
.MergeFromString(serialized
)
1200 self
.assertEqual(123, proto2
.Extensions
[extension1
].i
)
1201 self
.assertEqual('foo', proto2
.Extensions
[extension2
].str)
1204 self
.assertEqual(proto2
.ByteSize(), len(serialized
))
1205 self
.assertEqual(proto
.ByteSize(), len(serialized
))
1207 def testMessageSetWireFormatUnknownExtension(self
):
1208 # Create a message using the message set wire format with an unknown
1210 raw
= unittest_mset_pb2
.RawMessageSet()
1213 item
= raw
.item
.add()
1214 item
.type_id
= 1545008
1215 extension_message1
= unittest_mset_pb2
.TestMessageSetExtension1
1216 message1
= unittest_mset_pb2
.TestMessageSetExtension1()
1218 item
.message
= message1
.SerializeToString()
1220 # Add a second, unknown extension.
1221 item
= raw
.item
.add()
1222 item
.type_id
= 1545009
1223 extension_message1
= unittest_mset_pb2
.TestMessageSetExtension1
1224 message1
= unittest_mset_pb2
.TestMessageSetExtension1()
1226 item
.message
= message1
.SerializeToString()
1228 # Add another unknown extension.
1229 item
= raw
.item
.add()
1230 item
.type_id
= 1545010
1231 message1
= unittest_mset_pb2
.TestMessageSetExtension2()
1232 message1
.str = 'foo'
1233 item
.message
= message1
.SerializeToString()
1235 serialized
= raw
.SerializeToString()
1237 # Parse message using the message set wire format.
1238 proto
= unittest_mset_pb2
.TestMessageSet()
1239 proto
.MergeFromString(serialized
)
1241 # Check that the message parsed well.
1242 extension_message1
= unittest_mset_pb2
.TestMessageSetExtension1
1243 extension1
= extension_message1
.message_set_extension
1244 self
.assertEquals(12345, proto
.Extensions
[extension1
].i
)
1246 def testUnknownFields(self
):
1247 proto
= unittest_pb2
.TestAllTypes()
1248 test_util
.SetAllFields(proto
)
1250 serialized
= proto
.SerializeToString()
1252 # The empty message should be parsable with all of the fields
1254 proto2
= unittest_pb2
.TestEmptyMessage()
1256 # Parsing this message should succeed.
1257 proto2
.MergeFromString(serialized
)
1260 class OptionsTest(unittest
.TestCase
):
1262 def testMessageOptions(self
):
1263 proto
= unittest_mset_pb2
.TestMessageSet()
1264 self
.assertEqual(True,
1265 proto
.DESCRIPTOR
.GetOptions().message_set_wire_format
)
1266 proto
= unittest_pb2
.TestAllTypes()
1267 self
.assertEqual(False,
1268 proto
.DESCRIPTOR
.GetOptions().message_set_wire_format
)
1271 class UtilityTest(unittest
.TestCase
):
1273 def testImergeSorted(self
):
1274 ImergeSorted
= reflection
._ImergeSorted
1275 # Various types of emptiness.
1276 self
.assertEqual([], list(ImergeSorted()))
1277 self
.assertEqual([], list(ImergeSorted([])))
1278 self
.assertEqual([], list(ImergeSorted([], [])))
1280 # One nonempty list.
1281 self
.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3])))
1282 self
.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3], [])))
1283 self
.assertEqual([1, 2, 3], list(ImergeSorted([], [1, 2, 3])))
1285 # Merging some nonempty lists together.
1286 self
.assertEqual([1, 2, 3], list(ImergeSorted([1, 3], [2])))
1287 self
.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2])))
1288 self
.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2], [])))
1290 # Elements repeated across component iterators.
1291 self
.assertEqual([1, 2, 2, 3, 3],
1292 list(ImergeSorted([1, 2], [3], [2, 3])))
1294 # Elements repeated within an iterator.
1295 self
.assertEqual([1, 2, 2, 3, 3],
1296 list(ImergeSorted([1, 2, 2], [3], [3])))
1299 if __name__
== '__main__':