# Copyright 2020, Google Inc. # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are # met: # # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above # copyright notice, this list of conditions and the following disclaimer # in the documentation and/or other materials provided with the # distribution. # * Neither the name of Google Inc. nor the names of its # contributors may be used to endorse or promote products derived from # this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Standalone WebsocketServer This file deals with the main module of standalone server. Although it is fine to import this file directly to use WebSocketServer, it is strongly recommended to use standalone.py, since it is intended to act as a skeleton of this module. """ from __future__ import absolute_import from six.moves import BaseHTTPServer from six.moves import socketserver import logging import re import select import socket import ssl import threading import traceback from mod_pywebsocket import dispatch from mod_pywebsocket import util from mod_pywebsocket.request_handler import WebSocketRequestHandler def _alias_handlers(dispatcher, websock_handlers_map_file): """Set aliases specified in websock_handler_map_file in dispatcher. Args: dispatcher: dispatch.Dispatcher instance websock_handler_map_file: alias map file """ with open(websock_handlers_map_file) as f: for line in f: if line[0] == '#' or line.isspace(): continue m = re.match('(\S+)\s+(\S+)$', line) if not m: logging.warning('Wrong format in map file:' + line) continue try: dispatcher.add_resource_path_alias(m.group(1), m.group(2)) except dispatch.DispatchException as e: logging.error(str(e)) class WebSocketServer(socketserver.ThreadingMixIn, BaseHTTPServer.HTTPServer): """HTTPServer specialized for WebSocket.""" # Overrides SocketServer.ThreadingMixIn.daemon_threads daemon_threads = True # Overrides BaseHTTPServer.HTTPServer.allow_reuse_address allow_reuse_address = True def __init__(self, options): """Override SocketServer.TCPServer.__init__ to set SSL enabled socket object to self.socket before server_bind and server_activate, if necessary. """ # Share a Dispatcher among request handlers to save time for # instantiation. Dispatcher can be shared because it is thread-safe. options.dispatcher = dispatch.Dispatcher( options.websock_handlers, options.scan_dir, options.allow_handlers_outside_root_dir) if options.websock_handlers_map_file: _alias_handlers(options.dispatcher, options.websock_handlers_map_file) warnings = options.dispatcher.source_warnings() if warnings: for warning in warnings: logging.warning('Warning in source loading: %s' % warning) self._logger = util.get_class_logger(self) self.request_queue_size = options.request_queue_size self.__ws_is_shut_down = threading.Event() self.__ws_serving = False socketserver.BaseServer.__init__(self, (options.server_host, options.port), WebSocketRequestHandler) # Expose the options object to allow handler objects access it. We name # it with websocket_ prefix to avoid conflict. self.websocket_server_options = options self._create_sockets() self.server_bind() self.server_activate() def _create_sockets(self): self.server_name, self.server_port = self.server_address self._sockets = [] if not self.server_name: # On platforms that doesn't support IPv6, the first bind fails. # On platforms that supports IPv6 # - If it binds both IPv4 and IPv6 on call with AF_INET6, the # first bind succeeds and the second fails (we'll see 'Address # already in use' error). # - If it binds only IPv6 on call with AF_INET6, both call are # expected to succeed to listen both protocol. addrinfo_array = [(socket.AF_INET6, socket.SOCK_STREAM, '', '', ''), (socket.AF_INET, socket.SOCK_STREAM, '', '', '')] else: addrinfo_array = socket.getaddrinfo(self.server_name, self.server_port, socket.AF_UNSPEC, socket.SOCK_STREAM, socket.IPPROTO_TCP) for addrinfo in addrinfo_array: self._logger.info('Create socket on: %r', addrinfo) family, socktype, proto, canonname, sockaddr = addrinfo try: socket_ = socket.socket(family, socktype) except Exception as e: self._logger.info('Skip by failure: %r', e) continue server_options = self.websocket_server_options if server_options.use_tls: if server_options.tls_client_auth: if server_options.tls_client_cert_optional: client_cert_ = ssl.CERT_OPTIONAL else: client_cert_ = ssl.CERT_REQUIRED else: client_cert_ = ssl.CERT_NONE socket_ = ssl.wrap_socket( socket_, keyfile=server_options.private_key, certfile=server_options.certificate, ca_certs=server_options.tls_client_ca, cert_reqs=client_cert_) self._sockets.append((socket_, addrinfo)) def server_bind(self): """Override SocketServer.TCPServer.server_bind to enable multiple sockets bind. """ failed_sockets = [] for socketinfo in self._sockets: socket_, addrinfo = socketinfo self._logger.info('Bind on: %r', addrinfo) if self.allow_reuse_address: socket_.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) try: socket_.bind(self.server_address) except Exception as e: self._logger.info('Skip by failure: %r', e) socket_.close() failed_sockets.append(socketinfo) if self.server_address[1] == 0: # The operating system assigns the actual port number for port # number 0. This case, the second and later sockets should use # the same port number. Also self.server_port is rewritten # because it is exported, and will be used by external code. self.server_address = (self.server_name, socket_.getsockname()[1]) self.server_port = self.server_address[1] self._logger.info('Port %r is assigned', self.server_port) for socketinfo in failed_sockets: self._sockets.remove(socketinfo) def server_activate(self): """Override SocketServer.TCPServer.server_activate to enable multiple sockets listen. """ failed_sockets = [] for socketinfo in self._sockets: socket_, addrinfo = socketinfo self._logger.info('Listen on: %r', addrinfo) try: socket_.listen(self.request_queue_size) except Exception as e: self._logger.info('Skip by failure: %r', e) socket_.close() failed_sockets.append(socketinfo) for socketinfo in failed_sockets: self._sockets.remove(socketinfo) if len(self._sockets) == 0: self._logger.critical( 'No sockets activated. Use info log level to see the reason.') def server_close(self): """Override SocketServer.TCPServer.server_close to enable multiple sockets close. """ for socketinfo in self._sockets: socket_, addrinfo = socketinfo self._logger.info('Close on: %r', addrinfo) socket_.close() def fileno(self): """Override SocketServer.TCPServer.fileno.""" self._logger.critical('Not supported: fileno') return self._sockets[0][0].fileno() def handle_error(self, request, client_address): """Override SocketServer.handle_error.""" self._logger.error('Exception in processing request from: %r\n%s', client_address, traceback.format_exc()) # Note: client_address is a tuple. def get_request(self): """Override TCPServer.get_request.""" accepted_socket, client_address = self.socket.accept() server_options = self.websocket_server_options if server_options.use_tls: # Print cipher in use. Handshake is done on accept. self._logger.debug('Cipher: %s', accepted_socket.cipher()) self._logger.debug('Client cert: %r', accepted_socket.getpeercert()) return accepted_socket, client_address def serve_forever(self, poll_interval=0.5): """Override SocketServer.BaseServer.serve_forever.""" self.__ws_serving = True self.__ws_is_shut_down.clear() handle_request = self.handle_request if hasattr(self, '_handle_request_noblock'): handle_request = self._handle_request_noblock else: self._logger.warning('Fallback to blocking request handler') try: while self.__ws_serving: r, w, e = select.select( [socket_[0] for socket_ in self._sockets], [], [], poll_interval) for socket_ in r: self.socket = socket_ handle_request() self.socket = None finally: self.__ws_is_shut_down.set() def shutdown(self): """Override SocketServer.BaseServer.shutdown.""" self.__ws_serving = False self.__ws_is_shut_down.wait() # vi:sts=4 sw=4 et