import uuid from socketio import packet from socketio.pubsub_manager import PubSubManager from werkzeug.test import EnvironBuilder class SocketIOTestClient(object): """ This class is useful for testing a Flask-SocketIO server. It works in a similar way to the Flask Test Client, but adapted to the Socket.IO server. :param app: The Flask application instance. :param socketio: The application's ``SocketIO`` instance. :param namespace: The namespace for the client. If not provided, the client connects to the server on the global namespace. :param query_string: A string with custom query string arguments. :param headers: A dictionary with custom HTTP headers. :param auth: Optional authentication data, given as a dictionary. :param flask_test_client: The instance of the Flask test client currently in use. Passing the Flask test client is optional, but is necessary if you want the Flask user session and any other cookies set in HTTP routes accessible from Socket.IO events. """ clients = {} def __init__(self, app, socketio, namespace=None, query_string=None, headers=None, auth=None, flask_test_client=None): def _mock_send_packet(eio_sid, pkt): # make sure the packet can be encoded and decoded epkt = pkt.encode() if not isinstance(epkt, list): pkt = packet.Packet(encoded_packet=epkt) else: pkt = packet.Packet(encoded_packet=epkt[0]) for att in epkt[1:]: pkt.add_attachment(att) client = self.clients.get(eio_sid) if not client: return if pkt.packet_type == packet.EVENT or \ pkt.packet_type == packet.BINARY_EVENT: if pkt.data[0] == 'message' or pkt.data[0] == 'json': client.queue.append({ 'name': pkt.data[0], 'args': pkt.data[1], 'namespace': pkt.namespace or '/'}) else: client.queue.append({ 'name': pkt.data[0], 'args': pkt.data[1:], 'namespace': pkt.namespace or '/'}) elif pkt.packet_type == packet.ACK or \ pkt.packet_type == packet.BINARY_ACK: client.acks = {'args': pkt.data, 'namespace': pkt.namespace or '/'} elif pkt.packet_type in [packet.DISCONNECT, packet.CONNECT_ERROR]: client.connected[pkt.namespace or '/'] = False _current_packet = None def _mock_send_eio_packet(eio_sid, eio_pkt): nonlocal _current_packet if _current_packet is not None: _current_packet.add_attachment(eio_pkt.data) if _current_packet.attachment_count == \ len(_current_packet.attachments): _mock_send_packet(eio_sid, _current_packet) _current_packet = None else: pkt = packet.Packet(encoded_packet=eio_pkt.data) if pkt.attachment_count == 0: _mock_send_packet(eio_sid, pkt) else: _current_packet = pkt self.app = app self.flask_test_client = flask_test_client self.eio_sid = uuid.uuid4().hex self.clients[self.eio_sid] = self self.callback_counter = 0 self.socketio = socketio self.connected = {} self.queue = [] self.acks = None socketio.server._send_packet = _mock_send_packet socketio.server._send_eio_packet = _mock_send_eio_packet socketio.server.environ[self.eio_sid] = {} socketio.server.async_handlers = False # easier to test when socketio.server.eio.async_handlers = False # events are sync if isinstance(socketio.server.manager, PubSubManager): raise RuntimeError('Test client cannot be used with a message ' 'queue. Disable the queue on your test ' 'configuration.') socketio.server.manager.initialize() self.connect(namespace=namespace, query_string=query_string, headers=headers, auth=auth) def is_connected(self, namespace=None): """Check if a namespace is connected. :param namespace: The namespace to check. The global namespace is assumed if this argument is not provided. """ return self.connected.get(namespace or '/', False) def connect(self, namespace=None, query_string=None, headers=None, auth=None): """Connect the client. :param namespace: The namespace for the client. If not provided, the client connects to the server on the global namespace. :param query_string: A string with custom query string arguments. :param headers: A dictionary with custom HTTP headers. :param auth: Optional authentication data, given as a dictionary. Note that it is usually not necessary to explicitly call this method, since a connection is automatically established when an instance of this class is created. An example where it this method would be useful is when the application accepts multiple namespace connections. """ url = '/socket.io' namespace = namespace or '/' if query_string: if query_string[0] != '?': query_string = '?' + query_string url += query_string environ = EnvironBuilder(url, headers=headers).get_environ() environ['flask.app'] = self.app if self.flask_test_client: # inject cookies from Flask if hasattr(self.flask_test_client, '_add_cookies_to_wsgi'): # flask >= 2.3 self.flask_test_client._add_cookies_to_wsgi(environ) else: # pragma: no cover # flask < 2.3 self.flask_test_client.cookie_jar.inject_wsgi(environ) self.socketio.server._handle_eio_connect(self.eio_sid, environ) pkt = packet.Packet(packet.CONNECT, auth, namespace=namespace) self.socketio.server._handle_eio_message(self.eio_sid, pkt.encode()) sid = self.socketio.server.manager.sid_from_eio_sid(self.eio_sid, namespace) if sid: self.connected[namespace] = True def disconnect(self, namespace=None): """Disconnect the client. :param namespace: The namespace to disconnect. The global namespace is assumed if this argument is not provided. """ if not self.is_connected(namespace): raise RuntimeError('not connected') pkt = packet.Packet(packet.DISCONNECT, namespace=namespace) self.socketio.server._handle_eio_message(self.eio_sid, pkt.encode()) del self.connected[namespace or '/'] def emit(self, event, *args, **kwargs): """Emit an event to the server. :param event: The event name. :param *args: The event arguments. :param callback: ``True`` if the client requests a callback, ``False`` if not. Note that client-side callbacks are not implemented, a callback request will just tell the server to provide the arguments to invoke the callback, but no callback is invoked. Instead, the arguments that the server provided for the callback are returned by this function. :param namespace: The namespace of the event. The global namespace is assumed if this argument is not provided. """ namespace = kwargs.pop('namespace', None) if not self.is_connected(namespace): raise RuntimeError('not connected') callback = kwargs.pop('callback', False) id = None if callback: self.callback_counter += 1 id = self.callback_counter pkt = packet.Packet(packet.EVENT, data=[event] + list(args), namespace=namespace, id=id) encoded_pkt = pkt.encode() if isinstance(encoded_pkt, list): for epkt in encoded_pkt: self.socketio.server._handle_eio_message(self.eio_sid, epkt) else: self.socketio.server._handle_eio_message(self.eio_sid, encoded_pkt) if self.acks is not None: ack = self.acks self.acks = None return ack['args'][0] if len(ack['args']) == 1 \ else ack['args'] def send(self, data, json=False, callback=False, namespace=None): """Send a text or JSON message to the server. :param data: A string, dictionary or list to send to the server. :param json: ``True`` to send a JSON message, ``False`` to send a text message. :param callback: ``True`` if the client requests a callback, ``False`` if not. Note that client-side callbacks are not implemented, a callback request will just tell the server to provide the arguments to invoke the callback, but no callback is invoked. Instead, the arguments that the server provided for the callback are returned by this function. :param namespace: The namespace of the event. The global namespace is assumed if this argument is not provided. """ if json: msg = 'json' else: msg = 'message' return self.emit(msg, data, callback=callback, namespace=namespace) def get_received(self, namespace=None): """Return the list of messages received from the server. Since this is not a real client, any time the server emits an event, the event is simply stored. The test code can invoke this method to obtain the list of events that were received since the last call. :param namespace: The namespace to get events from. The global namespace is assumed if this argument is not provided. """ if not self.is_connected(namespace): raise RuntimeError('not connected') namespace = namespace or '/' r = [pkt for pkt in self.queue if pkt['namespace'] == namespace] self.queue = [pkt for pkt in self.queue if pkt not in r] return r