"Fossies" - the Fresh Open Source Software Archive

Member "ospd-2.0.1/ospd/server.py" (12 May 2020, 9299 Bytes) of package /linux/misc/openvas/ospd-2.0.1.tar.gz:


As a special service "Fossies" has tried to format the requested source page into HTML format using (guessed) Python source code syntax highlighting (style: standard) with prefixed line numbers. Alternatively you can here view or download the uninterpreted source code file. For more information about "server.py" see the Fossies "Dox" file reference documentation and the latest Fossies "Diffs" side-by-side code changes report: 2.0.0_vs_2.0.1.

    1 # Copyright (C) 2019 Greenbone Networks GmbH
    2 #
    3 # SPDX-License-Identifier: GPL-2.0-or-later
    4 #
    5 # This program is free software; you can redistribute it and/or
    6 # modify it under the terms of the GNU General Public License
    7 # as published by the Free Software Foundation; either version 2
    8 # of the License, or (at your option) any later version.
    9 #
   10 # This program is distributed in the hope that it will be useful,
   11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
   12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   13 # GNU General Public License for more details.
   14 #
   15 # You should have received a copy of the GNU General Public License
   16 # along with this program; if not, write to the Free Software
   17 # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
   18 """
   19 Module for serving and streaming data
   20 """
   21 
   22 import logging
   23 import socket
   24 import ssl
   25 import time
   26 import os
   27 import threading
   28 import socketserver
   29 
   30 from abc import ABC, abstractmethod
   31 from pathlib import Path
   32 from typing import Callable, Optional, Tuple, Union
   33 
   34 from ospd.errors import OspdError
   35 
   36 logger = logging.getLogger(__name__)
   37 
   38 DEFAULT_BUFSIZE = 1024
   39 
   40 
   41 class Stream:
   42     def __init__(self, sock: socket.socket, stream_timeout: int):
   43         self.socket = sock
   44         self.socket.settimeout(stream_timeout)
   45 
   46     def close(self):
   47         """ Close the stream
   48         """
   49         try:
   50             self.socket.shutdown(socket.SHUT_RDWR)
   51         except OSError as e:
   52             logger.debug(
   53                 "Ignoring error while shutting down the connection. %s", e
   54             )
   55 
   56         self.socket.close()
   57 
   58     def read(self, bufsize: Optional[int] = DEFAULT_BUFSIZE) -> bytes:
   59         """ Read at maximum bufsize data from the stream
   60         """
   61         data = self.socket.recv(bufsize)
   62 
   63         if not data:
   64             logger.debug('Client closed the connection')
   65 
   66         return data
   67 
   68     def write(self, data: bytes):
   69         """ Send data in chunks of DEFAULT_BUFSIZE to the client
   70         """
   71         b_start = 0
   72         b_end = DEFAULT_BUFSIZE
   73 
   74         while True:
   75             if b_end > len(data):
   76                 try:
   77                     self.socket.send(data[b_start:])
   78                 except (socket.error, BrokenPipeError) as e:
   79                     logger.error("Error sending data to the client. %s", e)
   80                 finally:
   81                     return
   82 
   83             try:
   84                 b_sent = self.socket.send(data[b_start:b_end])
   85             except (socket.error, BrokenPipeError) as e:
   86                 logger.error("Error sending data to the client. %s", e)
   87                 return
   88             b_start = b_end
   89             b_end += b_sent
   90 
   91 
   92 StreamCallbackType = Callable[[Stream], None]
   93 
   94 InetAddress = Tuple[str, int]
   95 
   96 
   97 def validate_cacert_file(cacert: str):
   98     """ Check if provided file is a valid CA Certificate """
   99     try:
  100         context = ssl.create_default_context(cafile=cacert)
  101     except AttributeError:
  102         # Python version < 2.7.9
  103         return
  104     except IOError:
  105         raise OspdError('CA Certificate not found')
  106 
  107     try:
  108         not_after = context.get_ca_certs()[0]['notAfter']
  109         not_after = ssl.cert_time_to_seconds(not_after)
  110         not_before = context.get_ca_certs()[0]['notBefore']
  111         not_before = ssl.cert_time_to_seconds(not_before)
  112     except (KeyError, IndexError):
  113         raise OspdError('CA Certificate is erroneous')
  114 
  115     now = int(time.time())
  116     if not_after < now:
  117         raise OspdError('CA Certificate expired')
  118 
  119     if not_before > now:
  120         raise OspdError('CA Certificate not active yet')
  121 
  122 
  123 class RequestHandler(socketserver.BaseRequestHandler):
  124     """ Class to handle the request."""
  125 
  126     def handle(self):
  127         self.server.handle_request(self.request, self.client_address)
  128 
  129 
  130 class BaseServer(ABC):
  131     def __init__(self, stream_timeout: int):
  132         self.server = None
  133         self.stream_timeout = stream_timeout
  134 
  135     @abstractmethod
  136     def start(self, stream_callback: StreamCallbackType):
  137         """ Starts a server with capabilities to handle multiple client
  138         connections simultaneously.
  139         If a new client connects the stream_callback is called with a Stream
  140 
  141         Arguments:
  142             stream_callback (function): Callback function to be called when
  143                 a stream is ready
  144         """
  145 
  146     def close(self):
  147         """ Shutdown the server"""
  148         self.server.shutdown()
  149         self.server.server_close()
  150 
  151     @abstractmethod
  152     def handle_request(self, request, client_address):
  153         """ Handle an incoming client request"""
  154 
  155     def _start_threading_server(self):
  156         server_thread = threading.Thread(target=self.server.serve_forever)
  157         server_thread.daemon = True
  158         server_thread.start()
  159 
  160 
  161 class SocketServerMixin:
  162     def __init__(self, server: BaseServer, address: Union[str, InetAddress]):
  163         self.server = server
  164         super().__init__(address, RequestHandler, bind_and_activate=True)
  165 
  166     def handle_request(self, request, client_address):
  167         self.server.handle_request(request, client_address)
  168 
  169 
  170 class ThreadedUnixSocketServer(
  171     SocketServerMixin,
  172     socketserver.ThreadingMixIn,
  173     socketserver.UnixStreamServer,
  174 ):
  175     pass
  176 
  177 
  178 class ThreadedTlsSocketServer(
  179     SocketServerMixin, socketserver.ThreadingMixIn, socketserver.TCPServer
  180 ):
  181     pass
  182 
  183 
  184 class UnixSocketServer(BaseServer):
  185     """ Server for accepting connections via a Unix domain socket
  186     """
  187 
  188     def __init__(self, socket_path: str, socket_mode: str, stream_timeout: int):
  189         super().__init__(stream_timeout)
  190         self.socket_path = Path(socket_path)
  191         self.socket_mode = int(socket_mode, 8)
  192 
  193     def _cleanup_socket(self):
  194         if self.socket_path.exists():
  195             self.socket_path.unlink()
  196 
  197     def _create_parent_dirs(self):
  198         # create all parent directories for the socket path
  199         parent = self.socket_path.parent
  200         parent.mkdir(parents=True, exist_ok=True)
  201 
  202     def start(self, stream_callback: StreamCallbackType):
  203         self._cleanup_socket()
  204         self._create_parent_dirs()
  205 
  206         try:
  207             self.stream_callback = stream_callback
  208             self.server = ThreadedUnixSocketServer(self, str(self.socket_path))
  209             self._start_threading_server()
  210         except OSError as e:
  211             logger.error("Couldn't bind socket on %s", str(self.socket_path))
  212             raise OspdError(
  213                 "Couldn't bind socket on {}. {}".format(
  214                     str(self.socket_path), e
  215                 )
  216             )
  217 
  218         if self.socket_path.exists():
  219             self.socket_path.chmod(self.socket_mode)
  220 
  221     def close(self):
  222         super().close()
  223         self._cleanup_socket()
  224 
  225     def handle_request(self, request, client_address):
  226         logger.debug("New connection from %s", str(self.socket_path))
  227 
  228         stream = Stream(request, self.stream_timeout)
  229         self.stream_callback(stream)
  230 
  231 
  232 class TlsServer(BaseServer):
  233     """ Server for accepting TLS encrypted connections via a TCP socket
  234     """
  235 
  236     def __init__(
  237         self,
  238         address: str,
  239         port: int,
  240         cert_file: str,
  241         key_file: str,
  242         ca_file: str,
  243         stream_timeout: int,
  244     ):
  245         super().__init__(stream_timeout)
  246         self.socket = (address, port)
  247 
  248         if not Path(cert_file).exists():
  249             raise OspdError('cert file {} not found'.format(cert_file))
  250 
  251         if not Path(key_file).exists():
  252             raise OspdError('key file {} not found'.format(key_file))
  253 
  254         if not Path(ca_file).exists():
  255             raise OspdError('CA file {} not found'.format(ca_file))
  256 
  257         validate_cacert_file(ca_file)
  258 
  259         # Despite the name, ssl.PROTOCOL_SSLv23 selects the highest
  260         # protocol version that both the client and server support. In modern
  261         # Python versions (>= 3.4) it supports TLS >= 1.0 with SSLv2 and SSLv3
  262         # being disabled. For Python > 3.5, PROTOCOL_SSLv23 is an alias for
  263         # PROTOCOL_TLS which should be used once compatibility with Python 3.5
  264         # is no longer desired.
  265 
  266         if hasattr(ssl, 'PROTOCOL_TLS'):
  267             protocol = ssl.PROTOCOL_TLS
  268         else:
  269             protocol = ssl.PROTOCOL_SSLv23
  270 
  271         self.tls_context = ssl.SSLContext(protocol)
  272         self.tls_context.verify_mode = ssl.CERT_REQUIRED
  273 
  274         self.tls_context.load_cert_chain(cert_file, keyfile=key_file)
  275         self.tls_context.load_verify_locations(ca_file)
  276 
  277     def start(self, stream_callback: StreamCallbackType):
  278         try:
  279             self.stream_callback = stream_callback
  280             self.server = ThreadedTlsSocketServer(self, self.socket)
  281             self._start_threading_server()
  282         except OSError as e:
  283             logger.error(
  284                 "Couldn't bind socket on %s:%s", self.socket[0], self.socket[1]
  285             )
  286             raise OspdError(
  287                 "Couldn't bind socket on {}:{}. {}".format(
  288                     self.socket[0], str(self.socket[1]), e
  289                 )
  290             )
  291 
  292     def handle_request(self, request, client_address):
  293         logger.debug("New connection from %s", client_address)
  294 
  295         req_socket = self.tls_context.wrap_socket(request, server_side=True)
  296 
  297         stream = Stream(req_socket, self.stream_timeout)
  298         self.stream_callback(stream)