Fix a typo in a plain scalar scanner.
[pyyaml/python3.git] / tests / test_constructor.py
blobcd6695feb7e2efc0c137bbf667eb84488c87e6c1
2 import test_appliance
3 try:
4 import datetime
5 except ImportError:
6 pass
7 try:
8 set
9 except NameError:
10 from sets import Set as set
12 from yaml import *
14 import yaml.tokens
16 class MyLoader(Loader):
17 pass
18 class MyDumper(Dumper):
19 pass
21 class MyTestClass1:
23 def __init__(self, x, y=0, z=0):
24 self.x = x
25 self.y = y
26 self.z = z
28 def __eq__(self, other):
29 if isinstance(other, MyTestClass1):
30 return self.__class__, self.__dict__ == other.__class__, other.__dict__
31 else:
32 return False
34 def construct1(constructor, node):
35 mapping = constructor.construct_mapping(node)
36 return MyTestClass1(**mapping)
37 def represent1(representer, native):
38 return representer.represent_mapping("!tag1", native.__dict__)
40 add_constructor("!tag1", construct1, Loader=MyLoader)
41 add_representer(MyTestClass1, represent1, Dumper=MyDumper)
43 class MyTestClass2(MyTestClass1, YAMLObject):
45 yaml_loader = MyLoader
46 yaml_dumper = MyDumper
47 yaml_tag = "!tag2"
49 def from_yaml(cls, constructor, node):
50 x = constructor.construct_yaml_int(node)
51 return cls(x=x)
52 from_yaml = classmethod(from_yaml)
54 def to_yaml(cls, representer, native):
55 return representer.represent_scalar(cls.yaml_tag, str(native.x))
56 to_yaml = classmethod(to_yaml)
58 class MyTestClass3(MyTestClass2):
60 yaml_tag = "!tag3"
62 def from_yaml(cls, constructor, node):
63 mapping = constructor.construct_mapping(node)
64 if '=' in mapping:
65 x = mapping['=']
66 del mapping['=']
67 mapping['x'] = x
68 return cls(**mapping)
69 from_yaml = classmethod(from_yaml)
71 def to_yaml(cls, representer, native):
72 return representer.represent_mapping(cls.yaml_tag, native.__dict__)
73 to_yaml = classmethod(to_yaml)
75 class YAMLObject1(YAMLObject):
77 yaml_loader = MyLoader
78 yaml_dumper = MyDumper
79 yaml_tag = '!foo'
81 def __init__(self, my_parameter=None, my_another_parameter=None):
82 self.my_parameter = my_parameter
83 self.my_another_parameter = my_another_parameter
85 def __eq__(self, other):
86 if isinstance(other, YAMLObject1):
87 return self.__class__, self.__dict__ == other.__class__, other.__dict__
88 else:
89 return False
91 class YAMLObject2(YAMLObject):
93 yaml_loader = MyLoader
94 yaml_dumper = MyDumper
95 yaml_tag = '!bar'
97 def __init__(self, foo=1, bar=2, baz=3):
98 self.foo = foo
99 self.bar = bar
100 self.baz = baz
102 def __getstate__(self):
103 return {1: self.foo, 2: self.bar, 3: self.baz}
105 def __setstate__(self, state):
106 self.foo = state[1]
107 self.bar = state[2]
108 self.baz = state[3]
110 def __eq__(self, other):
111 if isinstance(other, YAMLObject2):
112 return self.__class__, self.__dict__ == other.__class__, other.__dict__
113 else:
114 return False
116 class AnObject(object):
118 def __new__(cls, foo=None, bar=None, baz=None):
119 self = object.__new__(cls)
120 self.foo = foo
121 self.bar = bar
122 self.baz = baz
123 return self
125 def __cmp__(self, other):
126 return cmp((type(self), self.foo, self.bar, self.baz),
127 (type(other), other.foo, other.bar, other.baz))
129 def __eq__(self, other):
130 return type(self) is type(other) and \
131 (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
133 class AnInstance:
135 def __init__(self, foo=None, bar=None, baz=None):
136 self.foo = foo
137 self.bar = bar
138 self.baz = baz
140 def __cmp__(self, other):
141 return cmp((type(self), self.foo, self.bar, self.baz),
142 (type(other), other.foo, other.bar, other.baz))
144 def __eq__(self, other):
145 return type(self) is type(other) and \
146 (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
148 class AState(AnInstance):
150 def __getstate__(self):
151 return {
152 '_foo': self.foo,
153 '_bar': self.bar,
154 '_baz': self.baz,
157 def __setstate__(self, state):
158 self.foo = state['_foo']
159 self.bar = state['_bar']
160 self.baz = state['_baz']
162 class ACustomState(AnInstance):
164 def __getstate__(self):
165 return (self.foo, self.bar, self.baz)
167 def __setstate__(self, state):
168 self.foo, self.bar, self.baz = state
170 class InitArgs(AnInstance):
172 def __getinitargs__(self):
173 return (self.foo, self.bar, self.baz)
175 def __getstate__(self):
176 return {}
178 class InitArgsWithState(AnInstance):
180 def __getinitargs__(self):
181 return (self.foo, self.bar)
183 def __getstate__(self):
184 return self.baz
186 def __setstate__(self, state):
187 self.baz = state
189 class NewArgs(AnObject):
191 def __getnewargs__(self):
192 return (self.foo, self.bar, self.baz)
194 def __getstate__(self):
195 return {}
197 class NewArgsWithState(AnObject):
199 def __getnewargs__(self):
200 return (self.foo, self.bar)
202 def __getstate__(self):
203 return self.baz
205 def __setstate__(self, state):
206 self.baz = state
208 class Reduce(AnObject):
210 def __reduce__(self):
211 return self.__class__, (self.foo, self.bar, self.baz)
213 class ReduceWithState(AnObject):
215 def __reduce__(self):
216 return self.__class__, (self.foo, self.bar), self.baz
218 def __setstate__(self, state):
219 self.baz = state
221 class MyInt(int):
223 def __eq__(self, other):
224 return type(self) is type(other) and int(self) == int(other)
226 class MyList(list):
228 def __init__(self, n=1):
229 self.extend([None]*n)
231 def __eq__(self, other):
232 return type(self) is type(other) and list(self) == list(other)
234 class MyDict(dict):
236 def __init__(self, n=1):
237 for k in range(n):
238 self[k] = None
240 def __eq__(self, other):
241 return type(self) is type(other) and dict(self) == dict(other)
243 class TestConstructorTypes(test_appliance.TestAppliance):
245 def _testTypes(self, test_name, data_filename, code_filename):
246 data1 = None
247 data2 = None
248 try:
249 data1 = list(load_all(file(data_filename, 'rb'), Loader=MyLoader))
250 if len(data1) == 1:
251 data1 = data1[0]
252 data2 = eval(file(code_filename, 'rb').read())
253 self.failUnlessEqual(type(data1), type(data2))
254 try:
255 self.failUnlessEqual(data1, data2)
256 except AssertionError:
257 if isinstance(data1, dict):
258 data1 = [(repr(key), value) for key, value in data1.items()]
259 data1.sort()
260 data1 = repr(data1)
261 data2 = [(repr(key), value) for key, value in data2.items()]
262 data2.sort()
263 data2 = repr(data2)
264 if data1 != data2:
265 raise
266 elif isinstance(data1, list):
267 self.failUnlessEqual(type(data1), type(data2))
268 self.failUnlessEqual(len(data1), len(data2))
269 for item1, item2 in zip(data1, data2):
270 if (item1 != item1 or (item1 == 0.0 and item1 == 1.0)) and \
271 (item2 != item2 or (item2 == 0.0 and item2 == 1.0)):
272 continue
273 self.failUnlessEqual(item1, item2)
274 else:
275 raise
276 except:
277 print
278 print "DATA:"
279 print file(data_filename, 'rb').read()
280 print "CODE:"
281 print file(code_filename, 'rb').read()
282 print "NATIVES1:", data1
283 print "NATIVES2:", data2
284 raise
286 TestConstructorTypes.add_tests('testTypes', '.data', '.code')