Newer
Older
Import / projects / Gameloft / bne_lib / tools / metaserver / connection.py
import socket
import threading
import struct
import json
import traceback
import zlib

# python implementation fully 
# supports version 4 and partially
# compatible with version 5
versionMarkers = [0xDEBC0004, 0xDEBC0005]

def _socketReadBytes(socket, bytesCount):
  res = ''
  while len(res) < bytesCount:
    data = socket.recv(bytesCount - len(res))
    assert len(data) > 0
    res += data
  return res

def _socketSend(socket, obj):
  #print 'Sending response... ',
  data = json.dumps(obj)
  prefix  = struct.pack('<I', len(data))
  zdata = zlib.compress(data)
  zdata = zdata[2:-4] # strip zlib header and tailer

  if len(zdata) < len(data):
    prefix += struct.pack('<I', len(zdata))
    socket.sendall(prefix + zdata)
  else:
    prefix += struct.pack('<I', 0)
    socket.sendall(prefix + data)
  
  #print 'Sent'

def _socketRead(socket):
  dataSize, = struct.unpack('<I', _socketReadBytes(socket, 4))
  zdataSize, = struct.unpack('<I', _socketReadBytes(socket, 4))
  
  if zdataSize > 0:
    zdata = _socketReadBytes(socket, zdataSize)
    data = zlib.decompress(zdata, -15)
    assert len(data) == dataSize
  else:
    data = _socketReadBytes(socket, dataSize)

  request = json.loads(data)
  #print 'Parsed request: %s' % request
  return request

class Connection:
  def __init__(self, inSocket, outSocket, handlers = {}, debugName = ''):
    self._inSocket = inSocket
    self._outSocket = outSocket
    #self._handlers = {'ping' : lambda x : None}
    self._handlers = {}
    self._handlers.update(handlers)
    self._dead = False
    self._lock = threading.Lock()

    threading.Thread(target=self._threadFn, name='Connection reading thread: "%s"' % debugName).start()

  def isAlive(self):
    return not self._dead

  def isDead(self):
    return self._dead

  def close(self):
    if not self._dead:
      self._dead = True
      self._inSocket.close()
      self._outSocket.close()

  def send(self, cmd, args):
    with self._lock:
      try:
        _socketSend(self._outSocket, {'cmd' : cmd, 'args' : args, 'type' : 'syncRequest'} )
        return _socketRead(self._outSocket)
      except Exception, e:
        print 'Outgoing conenction broken', type(e), e
        self._dead = True
        self._inSocket.close()
        self._outSocket.close()
        return None

  def addHandler(self, cmd, handler):
    assert not cmd in self._handlers, 'Handler already installed'
    self._handlers[cmd] = handler

  def _handleRequest(self, cmd, args):
    if not cmd in self._handlers:
      print 'Unable to handle request: %s' % args
      return

    try:
      return self._handlers[cmd](args)
    except:
      print 'Error handling request: %s,' % args
      traceback.print_exc()

  def _threadFn(self):
    try:
      while True:
        request = _socketRead(self._inSocket)
        cmd = request['cmd']
        args = request['args']
        response = self._handleRequest(cmd, args)
        _socketSend(self._inSocket, response)

    except Exception, e:
      print 'Incoming connection broken', type(e), e
      self._dead = True
      self._inSocket.close()
      self._outSocket.close()

class Listener:
  def __init__(self, address):
    self._pendingConnections = {}

    self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    self._socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    self._socket.bind(address)
    self._socket.listen(16)

    self._handshakedConnetionsData = []
    self._handshakedConnetionsDataCondition = threading.Condition()

    threading.Thread(target=self._listeningThread, name='Listening thread: %s:%s' % address).start()
    print 'Server started at %s:%s' % address    

  def _handshakeThread(self, clientSocket, address):
    print 'Received connection from %s:%s' % address
    host, port = address

    remoteVersionMarker, = struct.unpack('<I', _socketReadBytes(clientSocket, 4))
    assert remoteVersionMarker in versionMarkers, 'Version mismatch, expected one of: [%X], got %X' % (', '.join(map(hex, versionMarkers)), remoteVersionMarker)
    clientSocket.sendall(struct.pack('<I', remoteVersionMarker))
    

    login = _socketRead(clientSocket)
    assert login['cmd'] == 'login'
    name = login['name']
    connectionType = login['connectionType']
    connectionId = (name, host)

    if connectionType == 'in': # in/out swapped here, what's 'in' for client is 'out' for server
      if connectionId in self._pendingConnections:
        self._pendingConnections[connectionId].close()
        print 'Old pending connection erased, new incoming connection pending, waiting for second channel to connect'
      self._pendingConnections[connectionId] = clientSocket
      print 'Incoming connection pending, waiting for second channel to connect'
    elif connectionType == 'out':
      if connectionId in self._pendingConnections:
        print 'Connection fully established: %s (%s)' % connectionId
        previousSocket = self._pendingConnections[connectionId]
        del self._pendingConnections[connectionId]

        connectionData = clientSocket, previousSocket, address, name

        with self._handshakedConnetionsDataCondition:
          self._handshakedConnetionsData.append(connectionData)        
          self._handshakedConnetionsDataCondition.notifyAll()
      else:
        print 'Received out connection without in connection pending, discarded'
        clientSocket.close()
    else:
      assert False

  def _listeningThread(self):
    while True:
      clientSocket, address = self._socket.accept()
      def threadFn():
        self._handshakeThread(clientSocket, address)
      threading.Thread(target=threadFn, name='Handshaking thread: %s:%s' % address).start()

  def getConnection(self, handlers = {}):
    while True:
      with self._handshakedConnetionsDataCondition:
        self._handshakedConnetionsDataCondition.wait()
        if self._handshakedConnetionsData:
          socket1, socket2, address, name = self._handshakedConnetionsData.pop()
          c = Connection(socket1, socket2, handlers, 'in %s:%s' % address)
          return address, name, c

def _createSocket(host, port, connectionType):
  s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  s.connect((host, port))

  versionMarker = versionMarkers[0]
  s.sendall(struct.pack('<I', versionMarker))
  remoteVersionMarker, = struct.unpack('<I', _socketReadBytes(s, 4))
  assert versionMarker == remoteVersionMarker, 'Version mismatch, expected %X, got %X' % (versionMarker, remoteVersionMarker)

  loginMsg = {
    'cmd' : 'login',
    'name' : socket.gethostname(),
    'connectionType' : connectionType,
  }
  _socketSend(s, loginMsg)

  return s

def connect(host, port, handlers = {}):
  print 'Connecting to %s:%s' % (host, port)
  inSocket = _createSocket(host, port, 'in')
  outSocket = _createSocket(host, port, 'out')
  print 'Connected to %s:%s' % (host, port)
  return Connection(inSocket, outSocket, handlers, 'out %s:%s' % (host, port))