Update version number and release date.
[python/dscho.git] / Lib / idlelib / rpc.py
blob15946a660ffef65d75d3617abc4342f3ab849cd7
1 """RPC Implemention, originally written for the Python Idle IDE
3 For security reasons, GvR requested that Idle's Python execution server process
4 connect to the Idle process, which listens for the connection. Since Idle has
5 has only one client per server, this was not a limitation.
7 +---------------------------------+ +-------------+
8 | SocketServer.BaseRequestHandler | | SocketIO |
9 +---------------------------------+ +-------------+
10 ^ | register() |
11 | | unregister()|
12 | +-------------+
13 | ^ ^
14 | | |
15 | + -------------------+ |
16 | | |
17 +-------------------------+ +-----------------+
18 | RPCHandler | | RPCClient |
19 | [attribute of RPCServer]| | |
20 +-------------------------+ +-----------------+
22 The RPCServer handler class is expected to provide register/unregister methods.
23 RPCHandler inherits the mix-in class SocketIO, which provides these methods.
25 See the Idle run.main() docstring for further information on how this was
26 accomplished in Idle.
28 """
30 import sys
31 import socket
32 import select
33 import SocketServer
34 import struct
35 import cPickle as pickle
36 import threading
37 import traceback
38 import copy_reg
39 import types
40 import marshal
42 def unpickle_code(ms):
43 co = marshal.loads(ms)
44 assert isinstance(co, types.CodeType)
45 return co
47 def pickle_code(co):
48 assert isinstance(co, types.CodeType)
49 ms = marshal.dumps(co)
50 return unpickle_code, (ms,)
52 # XXX KBK 24Aug02 function pickling capability not used in Idle
53 # def unpickle_function(ms):
54 # return ms
56 # def pickle_function(fn):
57 # assert isinstance(fn, type.FunctionType)
58 # return `fn`
60 copy_reg.pickle(types.CodeType, pickle_code, unpickle_code)
61 # copy_reg.pickle(types.FunctionType, pickle_function, unpickle_function)
63 BUFSIZE = 8*1024
65 class RPCServer(SocketServer.TCPServer):
67 def __init__(self, addr, handlerclass=None):
68 if handlerclass is None:
69 handlerclass = RPCHandler
70 SocketServer.TCPServer.__init__(self, addr, handlerclass)
72 def server_bind(self):
73 "Override TCPServer method, no bind() phase for connecting entity"
74 pass
76 def server_activate(self):
77 """Override TCPServer method, connect() instead of listen()
79 Due to the reversed connection, self.server_address is actually the
80 address of the Idle Client to which we are connecting.
82 """
83 self.socket.connect(self.server_address)
85 def get_request(self):
86 "Override TCPServer method, return already connected socket"
87 return self.socket, self.server_address
89 def handle_error(self, request, client_address):
90 """Override TCPServer method
92 Error message goes to __stderr__. No error message if exiting
93 normally or socket raised EOF. Other exceptions not handled in
94 server code will cause os._exit.
96 """
97 try:
98 raise
99 except SystemExit:
100 raise
101 except EOFError:
102 pass
103 except:
104 erf = sys.__stderr__
105 print>>erf, '\n' + '-'*40
106 print>>erf, 'Unhandled server exception!'
107 print>>erf, 'Thread: %s' % threading.currentThread().getName()
108 print>>erf, 'Client Address: ', client_address
109 print>>erf, 'Request: ', repr(request)
110 traceback.print_exc(file=erf)
111 print>>erf, '\n*** Unrecoverable, server exiting!'
112 print>>erf, '-'*40
113 import os
114 os._exit(0)
117 objecttable = {}
119 class SocketIO:
121 nextseq = 0
123 def __init__(self, sock, objtable=None, debugging=None):
124 self.mainthread = threading.currentThread()
125 if debugging is not None:
126 self.debugging = debugging
127 self.sock = sock
128 if objtable is None:
129 objtable = objecttable
130 self.objtable = objtable
131 self.cvar = threading.Condition()
132 self.responses = {}
133 self.cvars = {}
134 self.interrupted = False
136 def close(self):
137 sock = self.sock
138 self.sock = None
139 if sock is not None:
140 sock.close()
142 def debug(self, *args):
143 if not self.debugging:
144 return
145 s = self.location + " " + str(threading.currentThread().getName())
146 for a in args:
147 s = s + " " + str(a)
148 print>>sys.__stderr__, s
150 def register(self, oid, object):
151 self.objtable[oid] = object
153 def unregister(self, oid):
154 try:
155 del self.objtable[oid]
156 except KeyError:
157 pass
159 def localcall(self, request):
160 self.debug("localcall:", request)
161 try:
162 how, (oid, methodname, args, kwargs) = request
163 except TypeError:
164 return ("ERROR", "Bad request format")
165 assert how == "call"
166 if not self.objtable.has_key(oid):
167 return ("ERROR", "Unknown object id: %s" % `oid`)
168 obj = self.objtable[oid]
169 if methodname == "__methods__":
170 methods = {}
171 _getmethods(obj, methods)
172 return ("OK", methods)
173 if methodname == "__attributes__":
174 attributes = {}
175 _getattributes(obj, attributes)
176 return ("OK", attributes)
177 if not hasattr(obj, methodname):
178 return ("ERROR", "Unsupported method name: %s" % `methodname`)
179 method = getattr(obj, methodname)
180 try:
181 ret = method(*args, **kwargs)
182 if isinstance(ret, RemoteObject):
183 ret = remoteref(ret)
184 return ("OK", ret)
185 except SystemExit:
186 raise
187 except socket.error:
188 pass
189 except:
190 self.debug("localcall:EXCEPTION")
191 traceback.print_exc(file=sys.__stderr__)
192 return ("EXCEPTION", None)
194 def remotecall(self, oid, methodname, args, kwargs):
195 self.debug("remotecall:asynccall: ", oid, methodname)
196 # XXX KBK 06Feb03 self.interrupted logic may not be necessary if
197 # subprocess is threaded.
198 if self.interrupted:
199 self.interrupted = False
200 raise KeyboardInterrupt
201 seq = self.asynccall(oid, methodname, args, kwargs)
202 return self.asyncreturn(seq)
204 def asynccall(self, oid, methodname, args, kwargs):
205 request = ("call", (oid, methodname, args, kwargs))
206 seq = self.newseq()
207 self.debug(("asynccall:%d:" % seq), oid, methodname, args, kwargs)
208 self.putmessage((seq, request))
209 return seq
211 def asyncreturn(self, seq):
212 self.debug("asyncreturn:%d:call getresponse(): " % seq)
213 response = self.getresponse(seq, wait=None)
214 self.debug(("asyncreturn:%d:response: " % seq), response)
215 return self.decoderesponse(response)
217 def decoderesponse(self, response):
218 how, what = response
219 if how == "OK":
220 return what
221 if how == "EXCEPTION":
222 self.debug("decoderesponse: EXCEPTION")
223 return None
224 if how == "ERROR":
225 self.debug("decoderesponse: Internal ERROR:", what)
226 raise RuntimeError, what
227 raise SystemError, (how, what)
229 def mainloop(self):
230 """Listen on socket until I/O not ready or EOF
232 Main thread pollresponse() will loop looking for seq number None, which
233 never comes, and exit on EOFError.
236 try:
237 self.getresponse(myseq=None, wait=None)
238 except EOFError:
239 pass
241 def getresponse(self, myseq, wait):
242 response = self._getresponse(myseq, wait)
243 if response is not None:
244 how, what = response
245 if how == "OK":
246 response = how, self._proxify(what)
247 return response
249 def _proxify(self, obj):
250 if isinstance(obj, RemoteProxy):
251 return RPCProxy(self, obj.oid)
252 if isinstance(obj, types.ListType):
253 return map(self._proxify, obj)
254 # XXX Check for other types -- not currently needed
255 return obj
257 def _getresponse(self, myseq, wait):
258 self.debug("_getresponse:myseq:", myseq)
259 if threading.currentThread() is self.mainthread:
260 # Main thread: does all reading of requests or responses
261 # Loop here, blocking each time until socket is ready.
262 while 1:
263 response = self.pollresponse(myseq, wait)
264 if response is not None:
265 return response
266 else:
267 # Auxiliary thread: wait for notification from main thread
268 self.cvar.acquire()
269 self.cvars[myseq] = self.cvar
270 while not self.responses.has_key(myseq):
271 self.cvar.wait()
272 response = self.responses[myseq]
273 del self.responses[myseq]
274 del self.cvars[myseq]
275 self.cvar.release()
276 return response
278 def newseq(self):
279 self.nextseq = seq = self.nextseq + 2
280 return seq
282 def putmessage(self, message):
283 self.debug("putmessage:%d:" % message[0])
284 try:
285 s = pickle.dumps(message)
286 except:
287 print >>sys.__stderr__, "Cannot pickle:", `message`
288 raise
289 s = struct.pack("<i", len(s)) + s
290 while len(s) > 0:
291 try:
292 n = self.sock.send(s)
293 except AttributeError:
294 # socket was closed
295 raise IOError
296 else:
297 s = s[n:]
299 def ioready(self, wait=0.0):
300 r, w, x = select.select([self.sock.fileno()], [], [], wait)
301 return len(r)
303 buffer = ""
304 bufneed = 4
305 bufstate = 0 # meaning: 0 => reading count; 1 => reading data
307 def pollpacket(self, wait=0.0):
308 self._stage0()
309 if len(self.buffer) < self.bufneed:
310 if not self.ioready(wait):
311 return None
312 try:
313 s = self.sock.recv(BUFSIZE)
314 except socket.error:
315 raise EOFError
316 if len(s) == 0:
317 raise EOFError
318 self.buffer += s
319 self._stage0()
320 return self._stage1()
322 def _stage0(self):
323 if self.bufstate == 0 and len(self.buffer) >= 4:
324 s = self.buffer[:4]
325 self.buffer = self.buffer[4:]
326 self.bufneed = struct.unpack("<i", s)[0]
327 self.bufstate = 1
329 def _stage1(self):
330 if self.bufstate == 1 and len(self.buffer) >= self.bufneed:
331 packet = self.buffer[:self.bufneed]
332 self.buffer = self.buffer[self.bufneed:]
333 self.bufneed = 4
334 self.bufstate = 0
335 return packet
337 def pollmessage(self, wait=0.0):
338 packet = self.pollpacket(wait)
339 if packet is None:
340 return None
341 try:
342 message = pickle.loads(packet)
343 except:
344 print >>sys.__stderr__, "-----------------------"
345 print >>sys.__stderr__, "cannot unpickle packet:", `packet`
346 traceback.print_stack(file=sys.__stderr__)
347 print >>sys.__stderr__, "-----------------------"
348 raise
349 return message
351 def pollresponse(self, myseq, wait=0.0):
352 """Handle messages received on the socket.
354 Some messages received may be asynchronous 'call' commands, and
355 some may be responses intended for other threads.
357 Loop until message with myseq sequence number is received. Save others
358 in self.responses and notify the owning thread, except that 'call'
359 commands are handed off to localcall() and the response sent back
360 across the link with the appropriate sequence number.
363 while 1:
364 message = self.pollmessage(wait)
365 if message is None: # socket not ready
366 return None
367 #wait = 0.0 # poll on subsequent passes instead of blocking
368 seq, resq = message
369 self.debug("pollresponse:%d:myseq:%s" % (seq, myseq))
370 if resq[0] == "call":
371 self.debug("pollresponse:%d:localcall:call:" % seq)
372 response = self.localcall(resq)
373 self.debug("pollresponse:%d:localcall:response:%s"
374 % (seq, response))
375 self.putmessage((seq, response))
376 continue
377 elif seq == myseq:
378 return resq
379 else:
380 self.cvar.acquire()
381 cv = self.cvars.get(seq)
382 # response involving unknown sequence number is discarded,
383 # probably intended for prior incarnation
384 if cv is not None:
385 self.responses[seq] = resq
386 cv.notify()
387 self.cvar.release()
388 continue
390 #----------------- end class SocketIO --------------------
392 class RemoteObject:
393 # Token mix-in class
394 pass
396 def remoteref(obj):
397 oid = id(obj)
398 objecttable[oid] = obj
399 return RemoteProxy(oid)
401 class RemoteProxy:
403 def __init__(self, oid):
404 self.oid = oid
406 class RPCHandler(SocketServer.BaseRequestHandler, SocketIO):
408 debugging = False
409 location = "#S" # Server
411 def __init__(self, sock, addr, svr):
412 svr.current_handler = self ## cgt xxx
413 SocketIO.__init__(self, sock)
414 SocketServer.BaseRequestHandler.__init__(self, sock, addr, svr)
416 def handle(self):
417 "handle() method required by SocketServer"
418 self.mainloop()
420 def get_remote_proxy(self, oid):
421 return RPCProxy(self, oid)
423 class RPCClient(SocketIO):
425 debugging = False
426 location = "#C" # Client
428 nextseq = 1 # Requests coming from the client are odd numbered
430 def __init__(self, address, family=socket.AF_INET, type=socket.SOCK_STREAM):
431 self.listening_sock = socket.socket(family, type)
432 self.listening_sock.setsockopt(socket.SOL_SOCKET,
433 socket.SO_REUSEADDR, 1)
434 self.listening_sock.bind(address)
435 self.listening_sock.listen(1)
437 def accept(self):
438 working_sock, address = self.listening_sock.accept()
439 if self.debugging:
440 print>>sys.__stderr__, "****** Connection request from ", address
441 if address[0] == '127.0.0.1':
442 SocketIO.__init__(self, working_sock)
443 else:
444 print>>sys.__stderr__, "** Invalid host: ", address
445 raise socket.error
447 def get_remote_proxy(self, oid):
448 return RPCProxy(self, oid)
450 class RPCProxy:
452 __methods = None
453 __attributes = None
455 def __init__(self, sockio, oid):
456 self.sockio = sockio
457 self.oid = oid
459 def __getattr__(self, name):
460 if self.__methods is None:
461 self.__getmethods()
462 if self.__methods.get(name):
463 return MethodProxy(self.sockio, self.oid, name)
464 if self.__attributes is None:
465 self.__getattributes()
466 if not self.__attributes.has_key(name):
467 raise AttributeError, name
468 __getattr__.DebuggerStepThrough=1
470 def __getattributes(self):
471 self.__attributes = self.sockio.remotecall(self.oid,
472 "__attributes__", (), {})
474 def __getmethods(self):
475 self.__methods = self.sockio.remotecall(self.oid,
476 "__methods__", (), {})
478 def _getmethods(obj, methods):
479 # Helper to get a list of methods from an object
480 # Adds names to dictionary argument 'methods'
481 for name in dir(obj):
482 attr = getattr(obj, name)
483 if callable(attr):
484 methods[name] = 1
485 if type(obj) == types.InstanceType:
486 _getmethods(obj.__class__, methods)
487 if type(obj) == types.ClassType:
488 for super in obj.__bases__:
489 _getmethods(super, methods)
491 def _getattributes(obj, attributes):
492 for name in dir(obj):
493 attr = getattr(obj, name)
494 if not callable(attr):
495 attributes[name] = 1
497 class MethodProxy:
499 def __init__(self, sockio, oid, name):
500 self.sockio = sockio
501 self.oid = oid
502 self.name = name
504 def __call__(self, *args, **kwargs):
505 value = self.sockio.remotecall(self.oid, self.name, args, kwargs)
506 return value
509 # Self Test
512 def testServer(addr):
513 # XXX 25 Jul 02 KBK needs update to use rpc.py register/unregister methods
514 class RemotePerson:
515 def __init__(self,name):
516 self.name = name
517 def greet(self, name):
518 print "(someone called greet)"
519 print "Hello %s, I am %s." % (name, self.name)
520 print
521 def getName(self):
522 print "(someone called getName)"
523 print
524 return self.name
525 def greet_this_guy(self, name):
526 print "(someone called greet_this_guy)"
527 print "About to greet %s ..." % name
528 remote_guy = self.server.current_handler.get_remote_proxy(name)
529 remote_guy.greet("Thomas Edison")
530 print "Done."
531 print
533 person = RemotePerson("Thomas Edison")
534 svr = RPCServer(addr)
535 svr.register('thomas', person)
536 person.server = svr # only required if callbacks are used
538 # svr.serve_forever()
539 svr.handle_request() # process once only
541 def testClient(addr):
542 "demonstrates RPC Client"
543 # XXX 25 Jul 02 KBK needs update to use rpc.py register/unregister methods
544 import time
545 clt=RPCClient(addr)
546 thomas = clt.get_remote_proxy("thomas")
547 print "The remote person's name is ..."
548 print thomas.getName()
549 # print clt.remotecall("thomas", "getName", (), {})
550 print
551 time.sleep(1)
552 print "Getting remote thomas to say hi..."
553 thomas.greet("Alexander Bell")
554 #clt.remotecall("thomas","greet",("Alexander Bell",), {})
555 print "Done."
556 print
557 time.sleep(2)
558 # demonstrates remote server calling local instance
559 class LocalPerson:
560 def __init__(self,name):
561 self.name = name
562 def greet(self, name):
563 print "You've greeted me!"
564 def getName(self):
565 return self.name
566 person = LocalPerson("Alexander Bell")
567 clt.register("alexander",person)
568 thomas.greet_this_guy("alexander")
569 # clt.remotecall("thomas","greet_this_guy",("alexander",), {})
571 def test():
572 addr=("localhost",8833)
573 if len(sys.argv) == 2:
574 if sys.argv[1]=='-server':
575 testServer(addr)
576 return
577 testClient(addr)
579 if __name__ == '__main__':
580 test()