Add style information to events generated by Parser.
[pyyaml/python3.git] / tests / test_emitter.py
blob2da6696f7b70222a429f32e31a7f6f59fcd6f3f3
2 import test_appliance, sys, StringIO
4 from yaml import *
5 import yaml
7 class TestEmitterOnCanonical(test_appliance.TestAppliance):
9 def _testEmitterOnCanonical(self, test_name, canonical_filename):
10 events = list(iter(Parser(Scanner(Reader(file(canonical_filename, 'rb'))))))
11 #writer = sys.stdout
12 writer = StringIO.StringIO()
13 emitter = Emitter(writer)
14 #print "-"*30
15 #print "ORIGINAL DATA:"
16 #print file(canonical_filename, 'rb').read()
17 for event in events:
18 emitter.emit(event)
19 data = writer.getvalue()
20 new_events = list(parse(data))
21 self.failUnlessEqual(len(events), len(new_events))
22 for event, new_event in zip(events, new_events):
23 self.failUnlessEqual(event.__class__, new_event.__class__)
25 TestEmitterOnCanonical.add_tests('testEmitterOnCanonical', '.canonical')
27 class EventsConstructor(Constructor):
29 def construct_event(self, node):
30 if isinstance(node, ScalarNode):
31 mapping = {}
32 else:
33 mapping = self.construct_mapping(node)
34 class_name = str(node.tag[1:])+'Event'
35 if class_name in ['AliasEvent', 'ScalarEvent', 'SequenceStartEvent', 'MappingStartEvent']:
36 mapping.setdefault('anchor', None)
37 if class_name in ['ScalarEvent', 'SequenceStartEvent', 'MappingStartEvent']:
38 mapping.setdefault('tag', None)
39 if class_name == 'ScalarEvent':
40 mapping.setdefault('value', '')
41 value = getattr(yaml, class_name)(**mapping)
42 return value
44 EventsConstructor.add_constructor(None, EventsConstructor.construct_event)
46 class TestEmitter(test_appliance.TestAppliance):
48 def _testEmitter(self, test_name, events_filename):
49 events = load_document(file(events_filename, 'rb'), Constructor=EventsConstructor)
50 self._dump(events_filename, events)
51 writer = StringIO.StringIO()
52 emitter = Emitter(writer)
53 for event in events:
54 emitter.emit(event)
55 data = writer.getvalue()
56 new_events = list(parse(data))
57 self.failUnlessEqual(len(events), len(new_events))
58 for event, new_event in zip(events, new_events):
59 self.failUnlessEqual(event.__class__, new_event.__class__)
61 def _dump(self, events_filename, events):
62 writer = sys.stdout
63 emitter = Emitter(writer)
64 print "="*30
65 print "EVENTS:"
66 print file(events_filename, 'rb').read()
67 print '-'*30
68 print "OUTPUT:"
69 for event in events:
70 emitter.emit(event)
72 TestEmitter.add_tests('testEmitter', '.events')