import socket
import threading
import struct
import json
import traceback
import zlib
# python implementation fully
# supports version 4 and partially
# compatible with version 5
versionMarkers = [0xDEBC0004, 0xDEBC0005]
def _socketReadBytes(socket, bytesCount):
res = ''
while len(res) < bytesCount:
data = socket.recv(bytesCount - len(res))
assert len(data) > 0
res += data
return res
def _socketSend(socket, obj):
#print 'Sending response... ',
data = json.dumps(obj)
prefix = struct.pack('<I', len(data))
zdata = zlib.compress(data)
zdata = zdata[2:-4] # strip zlib header and tailer
if len(zdata) < len(data):
prefix += struct.pack('<I', len(zdata))
socket.sendall(prefix + zdata)
else:
prefix += struct.pack('<I', 0)
socket.sendall(prefix + data)
#print 'Sent'
def _socketRead(socket):
dataSize, = struct.unpack('<I', _socketReadBytes(socket, 4))
zdataSize, = struct.unpack('<I', _socketReadBytes(socket, 4))
if zdataSize > 0:
zdata = _socketReadBytes(socket, zdataSize)
data = zlib.decompress(zdata, -15)
assert len(data) == dataSize
else:
data = _socketReadBytes(socket, dataSize)
request = json.loads(data)
#print 'Parsed request: %s' % request
return request
class Connection:
def __init__(self, inSocket, outSocket, handlers = {}, debugName = ''):
self._inSocket = inSocket
self._outSocket = outSocket
#self._handlers = {'ping' : lambda x : None}
self._handlers = {}
self._handlers.update(handlers)
self._dead = False
self._lock = threading.Lock()
threading.Thread(target=self._threadFn, name='Connection reading thread: "%s"' % debugName).start()
def isAlive(self):
return not self._dead
def isDead(self):
return self._dead
def close(self):
if not self._dead:
self._dead = True
self._inSocket.close()
self._outSocket.close()
def send(self, cmd, args):
with self._lock:
try:
_socketSend(self._outSocket, {'cmd' : cmd, 'args' : args, 'type' : 'syncRequest'} )
return _socketRead(self._outSocket)
except Exception, e:
print 'Outgoing conenction broken', type(e), e
self._dead = True
self._inSocket.close()
self._outSocket.close()
return None
def addHandler(self, cmd, handler):
assert not cmd in self._handlers, 'Handler already installed'
self._handlers[cmd] = handler
def _handleRequest(self, cmd, args):
if not cmd in self._handlers:
print 'Unable to handle request: %s' % args
return
try:
return self._handlers[cmd](args)
except:
print 'Error handling request: %s,' % args
traceback.print_exc()
def _threadFn(self):
try:
while True:
request = _socketRead(self._inSocket)
cmd = request['cmd']
args = request['args']
response = self._handleRequest(cmd, args)
_socketSend(self._inSocket, response)
except Exception, e:
print 'Incoming connection broken', type(e), e
self._dead = True
self._inSocket.close()
self._outSocket.close()
class Listener:
def __init__(self, address):
self._pendingConnections = {}
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self._socket.bind(address)
self._socket.listen(16)
self._handshakedConnetionsData = []
self._handshakedConnetionsDataCondition = threading.Condition()
threading.Thread(target=self._listeningThread, name='Listening thread: %s:%s' % address).start()
print 'Server started at %s:%s' % address
def _handshakeThread(self, clientSocket, address):
print 'Received connection from %s:%s' % address
host, port = address
remoteVersionMarker, = struct.unpack('<I', _socketReadBytes(clientSocket, 4))
assert remoteVersionMarker in versionMarkers, 'Version mismatch, expected one of: [%X], got %X' % (', '.join(map(hex, versionMarkers)), remoteVersionMarker)
clientSocket.sendall(struct.pack('<I', remoteVersionMarker))
login = _socketRead(clientSocket)
assert login['cmd'] == 'login'
name = login['name']
connectionType = login['connectionType']
connectionId = (name, host)
if connectionType == 'in': # in/out swapped here, what's 'in' for client is 'out' for server
if connectionId in self._pendingConnections:
self._pendingConnections[connectionId].close()
print 'Old pending connection erased, new incoming connection pending, waiting for second channel to connect'
self._pendingConnections[connectionId] = clientSocket
print 'Incoming connection pending, waiting for second channel to connect'
elif connectionType == 'out':
if connectionId in self._pendingConnections:
print 'Connection fully established: %s (%s)' % connectionId
previousSocket = self._pendingConnections[connectionId]
del self._pendingConnections[connectionId]
connectionData = clientSocket, previousSocket, address, name
with self._handshakedConnetionsDataCondition:
self._handshakedConnetionsData.append(connectionData)
self._handshakedConnetionsDataCondition.notifyAll()
else:
print 'Received out connection without in connection pending, discarded'
clientSocket.close()
else:
assert False
def _listeningThread(self):
while True:
clientSocket, address = self._socket.accept()
def threadFn():
self._handshakeThread(clientSocket, address)
threading.Thread(target=threadFn, name='Handshaking thread: %s:%s' % address).start()
def getConnection(self, handlers = {}):
while True:
with self._handshakedConnetionsDataCondition:
self._handshakedConnetionsDataCondition.wait()
if self._handshakedConnetionsData:
socket1, socket2, address, name = self._handshakedConnetionsData.pop()
c = Connection(socket1, socket2, handlers, 'in %s:%s' % address)
return address, name, c
def _createSocket(host, port, connectionType):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((host, port))
versionMarker = versionMarkers[0]
s.sendall(struct.pack('<I', versionMarker))
remoteVersionMarker, = struct.unpack('<I', _socketReadBytes(s, 4))
assert versionMarker == remoteVersionMarker, 'Version mismatch, expected %X, got %X' % (versionMarker, remoteVersionMarker)
loginMsg = {
'cmd' : 'login',
'name' : socket.gethostname(),
'connectionType' : connectionType,
}
_socketSend(s, loginMsg)
return s
def connect(host, port, handlers = {}):
print 'Connecting to %s:%s' % (host, port)
inSocket = _createSocket(host, port, 'in')
outSocket = _createSocket(host, port, 'out')
print 'Connected to %s:%s' % (host, port)
return Connection(inSocket, outSocket, handlers, 'out %s:%s' % (host, port))