Use the types module instead of constructing type objects by hand. Fix #41. Thanks...
[pyyaml/python3.git] / tests / test_constructor.py
blob04f54ef808f1894c229e3dde2979cc77c1564ddd
2 import test_appliance
4 import datetime
5 try:
6 set
7 except NameError:
8 from sets import Set as set
10 from yaml import *
12 import yaml.tokens
14 class MyLoader(Loader):
15 pass
16 class MyDumper(Dumper):
17 pass
19 class MyTestClass1:
21 def __init__(self, x, y=0, z=0):
22 self.x = x
23 self.y = y
24 self.z = z
26 def __eq__(self, other):
27 if isinstance(other, MyTestClass1):
28 return self.__class__, self.__dict__ == other.__class__, other.__dict__
29 else:
30 return False
32 def construct1(constructor, node):
33 mapping = constructor.construct_mapping(node)
34 return MyTestClass1(**mapping)
35 def represent1(representer, native):
36 return representer.represent_mapping("!tag1", native.__dict__)
38 add_constructor("!tag1", construct1, Loader=MyLoader)
39 add_representer(MyTestClass1, represent1, Dumper=MyDumper)
41 class MyTestClass2(MyTestClass1, YAMLObject):
43 yaml_loader = MyLoader
44 yaml_dumper = MyDumper
45 yaml_tag = "!tag2"
47 def from_yaml(cls, constructor, node):
48 x = constructor.construct_yaml_int(node)
49 return cls(x=x)
50 from_yaml = classmethod(from_yaml)
52 def to_yaml(cls, representer, native):
53 return representer.represent_scalar(cls.yaml_tag, str(native.x))
54 to_yaml = classmethod(to_yaml)
56 class MyTestClass3(MyTestClass2):
58 yaml_tag = "!tag3"
60 def from_yaml(cls, constructor, node):
61 mapping = constructor.construct_mapping(node)
62 if '=' in mapping:
63 x = mapping['=']
64 del mapping['=']
65 mapping['x'] = x
66 return cls(**mapping)
67 from_yaml = classmethod(from_yaml)
69 def to_yaml(cls, representer, native):
70 return representer.represent_mapping(cls.yaml_tag, native.__dict__)
71 to_yaml = classmethod(to_yaml)
73 class YAMLObject1(YAMLObject):
75 yaml_loader = MyLoader
76 yaml_dumper = MyDumper
77 yaml_tag = '!foo'
79 def __init__(self, my_parameter=None, my_another_parameter=None):
80 self.my_parameter = my_parameter
81 self.my_another_parameter = my_another_parameter
83 def __eq__(self, other):
84 if isinstance(other, YAMLObject1):
85 return self.__class__, self.__dict__ == other.__class__, other.__dict__
86 else:
87 return False
89 class YAMLObject2(YAMLObject):
91 yaml_loader = MyLoader
92 yaml_dumper = MyDumper
93 yaml_tag = '!bar'
95 def __init__(self, foo=1, bar=2, baz=3):
96 self.foo = foo
97 self.bar = bar
98 self.baz = baz
100 def __getstate__(self):
101 return {1: self.foo, 2: self.bar, 3: self.baz}
103 def __setstate__(self, state):
104 self.foo = state[1]
105 self.bar = state[2]
106 self.baz = state[3]
108 def __eq__(self, other):
109 if isinstance(other, YAMLObject2):
110 return self.__class__, self.__dict__ == other.__class__, other.__dict__
111 else:
112 return False
114 class AnObject(object):
116 def __new__(cls, foo=None, bar=None, baz=None):
117 self = object.__new__(cls)
118 self.foo = foo
119 self.bar = bar
120 self.baz = baz
121 return self
123 def __cmp__(self, other):
124 return cmp((type(self), self.foo, self.bar, self.baz),
125 (type(other), other.foo, other.bar, other.baz))
127 def __eq__(self, other):
128 return type(self) is type(other) and \
129 (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
131 class AnInstance:
133 def __init__(self, foo=None, bar=None, baz=None):
134 self.foo = foo
135 self.bar = bar
136 self.baz = baz
138 def __cmp__(self, other):
139 return cmp((type(self), self.foo, self.bar, self.baz),
140 (type(other), other.foo, other.bar, other.baz))
142 def __eq__(self, other):
143 return type(self) is type(other) and \
144 (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
146 class AState(AnInstance):
148 def __getstate__(self):
149 return {
150 '_foo': self.foo,
151 '_bar': self.bar,
152 '_baz': self.baz,
155 def __setstate__(self, state):
156 self.foo = state['_foo']
157 self.bar = state['_bar']
158 self.baz = state['_baz']
160 class ACustomState(AnInstance):
162 def __getstate__(self):
163 return (self.foo, self.bar, self.baz)
165 def __setstate__(self, state):
166 self.foo, self.bar, self.baz = state
168 class InitArgs(AnInstance):
170 def __getinitargs__(self):
171 return (self.foo, self.bar, self.baz)
173 def __getstate__(self):
174 return {}
176 class InitArgsWithState(AnInstance):
178 def __getinitargs__(self):
179 return (self.foo, self.bar)
181 def __getstate__(self):
182 return self.baz
184 def __setstate__(self, state):
185 self.baz = state
187 class NewArgs(AnObject):
189 def __getnewargs__(self):
190 return (self.foo, self.bar, self.baz)
192 def __getstate__(self):
193 return {}
195 class NewArgsWithState(AnObject):
197 def __getnewargs__(self):
198 return (self.foo, self.bar)
200 def __getstate__(self):
201 return self.baz
203 def __setstate__(self, state):
204 self.baz = state
206 class Reduce(AnObject):
208 def __reduce__(self):
209 return self.__class__, (self.foo, self.bar, self.baz)
211 class ReduceWithState(AnObject):
213 def __reduce__(self):
214 return self.__class__, (self.foo, self.bar), self.baz
216 def __setstate__(self, state):
217 self.baz = state
219 class MyInt(int):
221 def __eq__(self, other):
222 return type(self) is type(other) and int(self) == int(other)
224 class MyList(list):
226 def __init__(self, n=1):
227 self.extend([None]*n)
229 def __eq__(self, other):
230 return type(self) is type(other) and list(self) == list(other)
232 class MyDict(dict):
234 def __init__(self, n=1):
235 for k in range(n):
236 self[k] = None
238 def __eq__(self, other):
239 return type(self) is type(other) and dict(self) == dict(other)
241 class FixedOffset(datetime.tzinfo):
243 def __init__(self, offset, name):
244 self.__offset = datetime.timedelta(minutes=offset)
245 self.__name = name
247 def utcoffset(self, dt):
248 return self.__offset
250 def tzname(self, dt):
251 return self.__name
253 def dst(self, dt):
254 return datetime.timedelta(0)
257 def execute(code):
258 exec code
259 return value
261 class TestConstructorTypes(test_appliance.TestAppliance):
263 def _testTypes(self, test_name, data_filename, code_filename):
264 data1 = None
265 data2 = None
266 try:
267 data1 = list(load_all(file(data_filename, 'rb'), Loader=MyLoader))
268 if len(data1) == 1:
269 data1 = data1[0]
270 data2 = eval(file(code_filename, 'rb').read())
271 self.failUnlessEqual(type(data1), type(data2))
272 try:
273 self.failUnlessEqual(data1, data2)
274 except (AssertionError, TypeError):
275 if isinstance(data1, dict):
276 data1 = [(repr(key), value) for key, value in data1.items()]
277 data1.sort()
278 data1 = repr(data1)
279 data2 = [(repr(key), value) for key, value in data2.items()]
280 data2.sort()
281 data2 = repr(data2)
282 if data1 != data2:
283 raise
284 elif isinstance(data1, list):
285 self.failUnlessEqual(type(data1), type(data2))
286 self.failUnlessEqual(len(data1), len(data2))
287 for item1, item2 in zip(data1, data2):
288 if (item1 != item1 or (item1 == 0.0 and item1 == 1.0)) and \
289 (item2 != item2 or (item2 == 0.0 and item2 == 1.0)):
290 continue
291 if isinstance(item1, datetime.datetime) \
292 and isinstance(item2, datetime.datetime):
293 self.failUnlessEqual(item1.microsecond,
294 item2.microsecond)
295 if isinstance(item1, datetime.datetime):
296 item1 = item1.utctimetuple()
297 if isinstance(item2, datetime.datetime):
298 item2 = item2.utctimetuple()
299 self.failUnlessEqual(item1, item2)
300 else:
301 raise
302 except:
303 print
304 print "DATA:"
305 print file(data_filename, 'rb').read()
306 print "CODE:"
307 print file(code_filename, 'rb').read()
308 print "NATIVES1:", data1
309 print "NATIVES2:", data2
310 raise
312 TestConstructorTypes.add_tests('testTypes', '.data', '.code')