From 01ff53fbe692062fd1da43d60fde9e1e55c3a4ff Mon Sep 17 00:00:00 2001 From: xi Date: Sat, 22 Apr 2006 20:40:43 +0000 Subject: [PATCH] Add support for pickling/unpickling python objects. git-svn-id: http://svn.pyyaml.org/pyyaml/trunk@147 18f92427-320e-0410-9341-c67f048884a3 --- lib/yaml/__init__.py | 9 ++ lib/yaml/constructor.py | 92 +++++++++++++++++ lib/yaml/representer.py | 172 +++++++++++++++++++++++++++++--- tests/data/construct-python-object.code | 23 +++++ tests/data/construct-python-object.data | 21 ++++ tests/test_constructor.py | 127 +++++++++++++++++++++++ tests/test_representer.py | 9 +- 7 files changed, 436 insertions(+), 17 deletions(-) create mode 100644 tests/data/construct-python-object.code create mode 100644 tests/data/construct-python-object.data diff --git a/lib/yaml/__init__.py b/lib/yaml/__init__.py index 0de89db..22df18b 100644 --- a/lib/yaml/__init__.py +++ b/lib/yaml/__init__.py @@ -231,6 +231,15 @@ def add_representer(data_type, representer, Dumper=Dumper): """ Dumper.add_representer(data_type, representer) +def add_multi_representer(data_type, multi_representer, Dumper=Dumper): + """ + Add a representer for the given type. + Multi-representer is a function accepting a Dumper instance + and an instance of the given data type or subtype + and producing the corresponding representation node. + """ + Dumper.add_multi_representer(data_type, multi_representer) + class YAMLObjectMetaclass(type): """ The metaclass for YAMLObject. diff --git a/lib/yaml/constructor.py b/lib/yaml/constructor.py index 6eaf043..57ad53d 100644 --- a/lib/yaml/constructor.py +++ b/lib/yaml/constructor.py @@ -492,6 +492,86 @@ class Constructor(SafeConstructor): node.start_mark) return self.find_python_module(suffix, node.start_mark) + class classobj: pass + + def make_python_instance(self, suffix, node, + args=None, kwds=None, newobj=False): + if not args: + args = [] + if not kwds: + kwds = {} + cls = self.find_python_name(suffix, node.start_mark) + if newobj and isinstance(cls, type(self.classobj)) \ + and not args and not kwds: + instance = self.classobj() + instance.__class__ = cls + return instance + elif newobj and isinstance(cls, type): + return cls.__new__(cls, *args, **kwds) + else: + return cls(*args, **kwds) + + def set_python_instance_state(self, instance, state): + if hasattr(instance, '__setstate__'): + instance.__setstate__(state) + else: + slotstate = {} + if isinstance(state, tuple) and len(state) == 2: + state, slotstate = state + if hasattr(instance, '__dict__'): + instance.__dict__.update(state) + elif state: + slotstate.update(state) + for key, value in slotstate.items(): + setattr(object, key, value) + + def construct_python_object(self, suffix, node): + # Format: + # !!python/object:module.name { ... state ... } + instance = self.make_python_instance(suffix, node, newobj=True) + state = self.construct_mapping(node) + self.set_python_instance_state(instance, state) + return instance + + def construct_python_object_apply(self, suffix, node, newobj=False): + # Format: + # !!python/object/apply # (or !!python/object/new) + # args: [ ... arguments ... ] + # kwds: { ... keywords ... } + # state: ... state ... + # listitems: [ ... listitems ... ] + # dictitems: { ... dictitems ... } + # or short format: + # !!python/object/apply [ ... arguments ... ] + # The difference between !!python/object/apply and !!python/object/new + # is how an object is created, check make_python_instance for details. + if isinstance(node, SequenceNode): + args = self.construct_sequence(node) + kwds = {} + state = {} + listitems = [] + dictitems = {} + else: + value = self.construct_mapping(node) + args = value.get('args', []) + kwds = value.get('kwds', {}) + state = value.get('state', {}) + listitems = value.get('listitems', []) + dictitems = value.get('dictitems', {}) + instance = self.make_python_instance(suffix, node, args, kwds, newobj) + if state: + self.set_python_instance_state(instance, state) + if listitems: + instance.extend(listitems) + if dictitems: + for key in dictitems: + instance[key] = dictitems[key] + return instance + + def construct_python_object_new(self, suffix, node): + return self.construct_python_object_apply(suffix, node, newobj=True) + + Constructor.add_constructor( u'tag:yaml.org,2002:python/none', Constructor.construct_yaml_null) @@ -544,3 +624,15 @@ Constructor.add_multi_constructor( u'tag:yaml.org,2002:python/module:', Constructor.construct_python_module) +Constructor.add_multi_constructor( + u'tag:yaml.org,2002:python/object:', + Constructor.construct_python_object) + +Constructor.add_multi_constructor( + u'tag:yaml.org,2002:python/object/apply:', + Constructor.construct_python_object_apply) + +Constructor.add_multi_constructor( + u'tag:yaml.org,2002:python/object/new:', + Constructor.construct_python_object_new) + diff --git a/lib/yaml/representer.py b/lib/yaml/representer.py index 749182d..236487e 100644 --- a/lib/yaml/representer.py +++ b/lib/yaml/representer.py @@ -16,7 +16,7 @@ try: except NameError: from sets import Set as set -import sys +import sys, copy_reg class RepresenterError(YAMLError): pass @@ -24,12 +24,13 @@ class RepresenterError(YAMLError): class BaseRepresenter: yaml_representers = {} + yaml_multi_representers = {} def __init__(self): self.represented_objects = {} def represent(self, data): - node = self.represent_object(data) + node = self.represent_data(data) self.serialize(node) self.represented_objects = {} @@ -49,7 +50,7 @@ class BaseRepresenter: bases.extend(self.get_classobj_bases(base)) return bases - def represent_object(self, data): + def represent_data(self, data): if self.ignore_aliases(data): alias_key = None else: @@ -64,15 +65,20 @@ class BaseRepresenter: data_types = type(data).__mro__ if type(data) is self.instance_type: data_types = self.get_classobj_bases(data.__class__)+list(data_types) - for data_type in data_types: - if data_type in self.yaml_representers: - node = self.yaml_representers[data_type](self, data) - break + if data_types[0] in self.yaml_representers: + node = self.yaml_representers[data_types[0]](self, data) else: - if None in self.yaml_representers: - node = self.yaml_representers[None](self, data) + for data_type in data_types: + if data_type in self.yaml_multi_representers: + node = self.yaml_multi_representers[data_type](self, data) + break else: - node = ScalarNode(None, unicode(data)) + if None in self.yaml_multi_representers: + node = self.yaml_multi_representers[None](self, data) + elif None in self.yaml_representers: + node = self.yaml_representers[None](self, data) + else: + node = ScalarNode(None, unicode(data)) if alias_key is not None: self.represented_objects[alias_key] = node return node @@ -83,27 +89,52 @@ class BaseRepresenter: cls.yaml_representers[data_type] = representer add_representer = classmethod(add_representer) + def add_multi_representer(cls, data_type, representer): + if not 'yaml_multi_representers' in cls.__dict__: + cls.yaml_multi_representers = cls.yaml_multi_representers.copy() + cls.yaml_multi_representers[data_type] = representer + add_multi_representer = classmethod(add_multi_representer) + def represent_scalar(self, tag, value, style=None): return ScalarNode(tag, value, style=style) def represent_sequence(self, tag, sequence, flow_style=None): + best_style = True value = [] for item in sequence: - value.append(self.represent_object(item)) + node_item = self.represent_data(item) + if not (isinstance(node_item, ScalarNode) and not node_item.style): + best_style = False + value.append(self.represent_data(item)) + if flow_style is None: + flow_style = best_style return SequenceNode(tag, value, flow_style=flow_style) def represent_mapping(self, tag, mapping, flow_style=None): + best_style = True if hasattr(mapping, 'keys'): value = {} for item_key in mapping.keys(): item_value = mapping[item_key] - value[self.represent_object(item_key)] = \ - self.represent_object(item_value) + node_key = self.represent_data(item_key) + node_value = self.represent_data(item_value) + if not (isinstance(node_key, ScalarNode) and not node_key.style): + best_style = False + if not (isinstance(node_value, ScalarNode) and not node_value.style): + best_style = False + value[node_key] = node_value else: value = [] for item_key, item_value in mapping: - value.append((self.represent_object(item_key), - self.represent_object(item_value))) + node_key = self.represent_data(item_key) + node_value = self.represent_data(item_value) + if not (isinstance(node_key, ScalarNode) and not node_key.style): + best_style = False + if not (isinstance(node_value, ScalarNode) and not node_value.style): + best_style = False + value.append((node_key, node_value)) + if flow_style is None: + flow_style = best_style return MappingNode(tag, value, flow_style=flow_style) def ignore_aliases(self, data): @@ -258,7 +289,7 @@ SafeRepresenter.add_representer(None, SafeRepresenter.represent_undefined) class Representer(SafeRepresenter): - + def represent_str(self, data): tag = None style = None @@ -312,6 +343,109 @@ class Representer(SafeRepresenter): return self.represent_scalar( u'tag:yaml.org,2002:python/module:'+data.__name__, u'') + def represent_instance(self, data): + # For instances of classic classes, we use __getinitargs__ and + # __getstate__ to serialize the data. + + # If data.__getinitargs__ exists, the object must be reconstructed by + # calling cls(**args), where args is a tuple returned by + # __getinitargs__. Otherwise, the cls.__init__ method should never be + # called and the class instance is created by instantiating a trivial + # class and assigning to the instance's __class__ variable. + + # If data.__getstate__ exists, it returns the state of the object. + # Otherwise, the state of the object is data.__dict__. + + # We produce either a !!python/object or !!python/object/new node. + # If data.__getinitargs__ does not exist and state is a dictionary, we + # produce a !!python/object node . Otherwise we produce a + # !!python/object/new node. + + cls = data.__class__ + class_name = u'%s.%s' % (cls.__module__, cls.__name__) + args = None + state = None + if hasattr(data, '__getinitargs__'): + args = list(data.__getinitargs__()) + if hasattr(data, '__getstate__'): + state = data.__getstate__() + else: + state = data.__dict__ + if args is None and isinstance(state, dict): + return self.represent_mapping( + u'tag:yaml.org,2002:python/object:'+class_name, state) + if isinstance(state, dict) and not state: + return self.represent_sequence( + u'tag:yaml.org,2002:python/object/new:'+class_name, args) + value = {} + if args: + value['args'] = args + value['state'] = state + return self.represent_mapping( + u'tag:yaml.org,2002:python/object/new:'+class_name, value) + + def represent_object(self, data): + # We use __reduce__ API to save the data. data.__reduce__ returns + # a tuple of length 2-5: + # (function, args, state, listitems, dictitems) + + # For reconstructing, we calls function(*args), then set its state, + # listitems, and dictitems if they are not None. + + # A special case is when function.__name__ == '__newobj__'. In this + # case we create the object with args[0].__new__(*args). + + # Another special case is when __reduce__ returns a string - we don't + # support it. + + # We produce a !!python/object, !!python/object/new or + # !!python/object/apply node. + + cls = type(data) + if cls in copy_reg.dispatch_table: + reduce = copy_reg.dispatch_table[cls] + elif hasattr(data, '__reduce_ex__'): + reduce = data.__reduce_ex__(2) + elif hasattr(data, '__reduce__'): + reduce = data.__reduce__() + else: + raise RepresenterError("cannot represent object: %r" % data) + reduce = (list(reduce)+[None]*5)[:5] + function, args, state, listitems, dictitems = reduce + args = list(args) + if state is None: + state = {} + if listitems is not None: + listitems = list(listitems) + if dictitems is not None: + dictitems = dict(dictitems) + if function.__name__ == '__newobj__': + function = args[0] + args = args[1:] + tag = u'tag:yaml.org,2002:python/object/new:' + newobj = True + else: + tag = u'tag:yaml.org,2002:python/object/apply:' + newobj = False + function_name = u'%s.%s' % (function.__module__, function.__name__) + if not args and not listitems and not dictitems \ + and isinstance(state, dict) and newobj: + return self.represent_mapping( + u'tag:yaml.org,2002:python/object:'+function_name, state) + if not listitems and not dictitems \ + and isinstance(state, dict) and not state: + return self.represent_sequence(tag+function_name, args) + value = {} + if args: + value['args'] = args + if state or not isinstance(state, dict): + value['state'] = state + if listitems: + value['listitems'] = listitems + if dictitems: + value['dictitems'] = dictitems + return self.represent_mapping(tag+function_name, value) + Representer.add_representer(str, Representer.represent_str) @@ -342,3 +476,9 @@ Representer.add_representer(Representer.builtin_function_type, Representer.add_representer(Representer.module_type, Representer.represent_module) +Representer.add_multi_representer(Representer.instance_type, + Representer.represent_instance) + +Representer.add_multi_representer(object, + Representer.represent_object) + diff --git a/tests/data/construct-python-object.code b/tests/data/construct-python-object.code new file mode 100644 index 0000000..7f1edf1 --- /dev/null +++ b/tests/data/construct-python-object.code @@ -0,0 +1,23 @@ +[ +AnObject(1, 'two', [3,3,3]), +AnInstance(1, 'two', [3,3,3]), + +AnObject(1, 'two', [3,3,3]), +AnInstance(1, 'two', [3,3,3]), + +AState(1, 'two', [3,3,3]), +ACustomState(1, 'two', [3,3,3]), + +InitArgs(1, 'two', [3,3,3]), +InitArgsWithState(1, 'two', [3,3,3]), + +NewArgs(1, 'two', [3,3,3]), +NewArgsWithState(1, 'two', [3,3,3]), + +Reduce(1, 'two', [3,3,3]), +ReduceWithState(1, 'two', [3,3,3]), + +MyInt(3), +MyList(3), +MyDict(3), +] diff --git a/tests/data/construct-python-object.data b/tests/data/construct-python-object.data new file mode 100644 index 0000000..bce8b2e --- /dev/null +++ b/tests/data/construct-python-object.data @@ -0,0 +1,21 @@ +- !!python/object:test_constructor.AnObject { foo: 1, bar: two, baz: [3,3,3] } +- !!python/object:test_constructor.AnInstance { foo: 1, bar: two, baz: [3,3,3] } + +- !!python/object/new:test_constructor.AnObject { args: [1, two], kwds: {baz: [3,3,3]} } +- !!python/object/apply:test_constructor.AnInstance { args: [1, two], kwds: {baz: [3,3,3]} } + +- !!python/object:test_constructor.AState { _foo: 1, _bar: two, _baz: [3,3,3] } +- !!python/object/new:test_constructor.ACustomState { state: !!python/tuple [1, two, [3,3,3]] } + +- !!python/object/new:test_constructor.InitArgs [1, two, [3,3,3]] +- !!python/object/new:test_constructor.InitArgsWithState { args: [1, two], state: [3,3,3] } + +- !!python/object/new:test_constructor.NewArgs [1, two, [3,3,3]] +- !!python/object/new:test_constructor.NewArgsWithState { args: [1, two], state: [3,3,3] } + +- !!python/object/apply:test_constructor.Reduce [1, two, [3,3,3]] +- !!python/object/apply:test_constructor.ReduceWithState { args: [1, two], state: [3,3,3] } + +- !!python/object/new:test_constructor.MyInt [3] +- !!python/object/new:test_constructor.MyList { listitems: [~, ~, ~] } +- !!python/object/new:test_constructor.MyDict { dictitems: {0, 1, 2} } diff --git a/tests/test_constructor.py b/tests/test_constructor.py index f6e5b8e..23fac0c 100644 --- a/tests/test_constructor.py +++ b/tests/test_constructor.py @@ -113,6 +113,133 @@ class YAMLObject2(YAMLObject): else: return False +class AnObject(object): + + def __new__(cls, foo=None, bar=None, baz=None): + self = object.__new__(cls) + self.foo = foo + self.bar = bar + self.baz = baz + return self + + def __cmp__(self, other): + return cmp((type(self), self.foo, self.bar, self.baz), + (type(other), other.foo, other.bar, other.baz)) + + def __eq__(self, other): + return type(self) is type(other) and \ + (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz) + +class AnInstance: + + def __init__(self, foo=None, bar=None, baz=None): + self.foo = foo + self.bar = bar + self.baz = baz + + def __cmp__(self, other): + return cmp((type(self), self.foo, self.bar, self.baz), + (type(other), other.foo, other.bar, other.baz)) + + def __eq__(self, other): + return type(self) is type(other) and \ + (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz) + +class AState(AnInstance): + + def __getstate__(self): + return { + '_foo': self.foo, + '_bar': self.bar, + '_baz': self.baz, + } + + def __setstate__(self, state): + self.foo = state['_foo'] + self.bar = state['_bar'] + self.baz = state['_baz'] + +class ACustomState(AnInstance): + + def __getstate__(self): + return (self.foo, self.bar, self.baz) + + def __setstate__(self, state): + self.foo, self.bar, self.baz = state + +class InitArgs(AnInstance): + + def __getinitargs__(self): + return (self.foo, self.bar, self.baz) + + def __getstate__(self): + return {} + +class InitArgsWithState(AnInstance): + + def __getinitargs__(self): + return (self.foo, self.bar) + + def __getstate__(self): + return self.baz + + def __setstate__(self, state): + self.baz = state + +class NewArgs(AnObject): + + def __getnewargs__(self): + return (self.foo, self.bar, self.baz) + + def __getstate__(self): + return {} + +class NewArgsWithState(AnObject): + + def __getnewargs__(self): + return (self.foo, self.bar) + + def __getstate__(self): + return self.baz + + def __setstate__(self, state): + self.baz = state + +class Reduce(AnObject): + + def __reduce__(self): + return self.__class__, (self.foo, self.bar, self.baz) + +class ReduceWithState(AnObject): + + def __reduce__(self): + return self.__class__, (self.foo, self.bar), self.baz + + def __setstate__(self, state): + self.baz = state + +class MyInt(int): + + def __eq__(self, other): + return type(self) is type(other) and int(self) == int(other) + +class MyList(list): + + def __init__(self, n=1): + self.extend([None]*n) + + def __eq__(self, other): + return type(self) is type(other) and list(self) == list(other) + +class MyDict(dict): + + def __init__(self, n=1): + for k in range(n): + self[k] = None + + def __eq__(self, other): + return type(self) is type(other) and dict(self) == dict(other) + class TestConstructorTypes(test_appliance.TestAppliance): def _testTypes(self, test_name, data_filename, code_filename): diff --git a/tests/test_representer.py b/tests/test_representer.py index 343d7f5..6835dfd 100644 --- a/tests/test_representer.py +++ b/tests/test_representer.py @@ -24,7 +24,14 @@ class TestRepresenterTypes(test_appliance.TestAppliance): data2 = data2.items() data2.sort() data2 = repr(data2) - if data1 != data2: + if data1 != data2: + raise + elif isinstance(data1, list): + self.failUnlessEqual(type(data1), type(data2)) + self.failUnlessEqual(len(data1), len(data2)) + for item1, item2 in zip(data1, data2): + self.failUnlessEqual(item1, item2) + else: raise except: print -- 2.11.4.GIT