"Fossies" - the Fresh Open Source Software Archive  

Source code changes of the file "ospd/ospd.py" between
ospd-2.0.0.tar.gz and ospd-2.0.1.tar.gz

About: OSPd is a base class for scanner wrappers which share the same communication protocol: OSP (OpenVAS Scanner Protocol).

ospd.py  (ospd-2.0.0):ospd.py  (ospd-2.0.1)
skipping to change at line 48 skipping to change at line 48
from xml.etree.ElementTree import Element, SubElement from xml.etree.ElementTree import Element, SubElement
import defusedxml.ElementTree as secET import defusedxml.ElementTree as secET
from ospd import __version__ from ospd import __version__
from ospd.errors import OspdCommandError, OspdError from ospd.errors import OspdCommandError, OspdError
from ospd.misc import ScanCollection, ResultType, ScanStatus, valid_uuid from ospd.misc import ScanCollection, ResultType, ScanStatus, valid_uuid
from ospd.network import resolve_hostname, target_str_to_list from ospd.network import resolve_hostname, target_str_to_list
from ospd.server import BaseServer from ospd.server import BaseServer
from ospd.vtfilter import VtsFilter from ospd.vtfilter import VtsFilter
from ospd.xml import simple_response_str, get_result_xml from ospd.xml import simple_response_str, get_result_xml, XmlStringHelper
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PROTOCOL_VERSION = "1.2" PROTOCOL_VERSION = "1.2"
SCHEDULER_CHECK_PERIOD = 5 # in seconds SCHEDULER_CHECK_PERIOD = 5 # in seconds
GVMCG_TITLES = [ GVMCG_TITLES = [
'cpu-*', 'cpu-*',
'proc', 'proc',
skipping to change at line 195 skipping to change at line 195
self.daemon_info['version'] = __version__ self.daemon_info['version'] = __version__
self.daemon_info['description'] = "No description" self.daemon_info['description'] = "No description"
self.scanner_info = dict() self.scanner_info = dict()
self.scanner_info['name'] = 'No name' self.scanner_info['name'] = 'No name'
self.scanner_info['version'] = 'No version' self.scanner_info['version'] = 'No version'
self.scanner_info['description'] = 'No description' self.scanner_info['description'] = 'No description'
self.server_version = None # Set by the subclass. self.server_version = None # Set by the subclass.
self.scaninfo_store_time = kwargs.get('scaninfo_store_time')
self.protocol_version = PROTOCOL_VERSION self.protocol_version = PROTOCOL_VERSION
self.commands = COMMANDS_TABLE self.commands = COMMANDS_TABLE
self.scanner_params = dict() self.scanner_params = dict()
for name, param in BASE_SCANNER_PARAMS.items(): for name, param in BASE_SCANNER_PARAMS.items():
self.add_scanner_param(name, param) self.add_scanner_param(name, param)
self.vts = dict() self.vts = dict()
self.vt_id_pattern = re.compile("[0-9a-zA-Z_\\-:.]{1,80}") self.vt_id_pattern = re.compile("[0-9a-zA-Z_\\-:.]{1,80}")
self.vts_version = None self.vts_version = None
if customvtfilter: if customvtfilter:
self.vts_filter = customvtfilter self.vts_filter = customvtfilter
else: else:
self.vts_filter = VtsFilter() self.vts_filter = VtsFilter()
self.is_cache_available = False
def init(self): def init(self):
""" Should be overridden by a subclass if the initialization is costly. """ Should be overridden by a subclass if the initialization is costly.
Will be called before check. Will be called before check.
""" """
self.is_cache_available = True
def set_command_attributes(self, name, attributes): def set_command_attributes(self, name, attributes):
""" Sets the xml attributes of a specified command. """ """ Sets the xml attributes of a specified command. """
if self.command_exists(name): if self.command_exists(name):
command = self.commands.get(name) command = self.commands.get(name)
command['attributes'] = attributes command['attributes'] = attributes
def add_scanner_param(self, name, scanner_param): def add_scanner_param(self, name, scanner_param):
""" Add a scanner parameter. """ """ Add a scanner parameter. """
skipping to change at line 540 skipping to change at line 545
['192.168.0.0/24', '22', {'smb': {'type': type, ['192.168.0.0/24', '22', {'smb': {'type': type,
'port': port, 'port': port,
'username': username, 'username': username,
'password': pass, 'password': pass,
}}], ''] }}], '']
""" """
target_list = [] target_list = []
for target in scanner_target: for target in scanner_target:
exclude_hosts = '' exclude_hosts = ''
finished_hosts = ''
ports = '' ports = ''
credentials = {} credentials = {}
for child in target: for child in target:
if child.tag == 'hosts': if child.tag == 'hosts':
hosts = child.text hosts = child.text
if child.tag == 'exclude_hosts': if child.tag == 'exclude_hosts':
exclude_hosts = child.text exclude_hosts = child.text
if child.tag == 'finished_hosts':
finished_hosts = child.text
if child.tag == 'ports': if child.tag == 'ports':
ports = child.text ports = child.text
if child.tag == 'credentials': if child.tag == 'credentials':
credentials = cls.process_credentials_elements(child) credentials = cls.process_credentials_elements(child)
if hosts: if hosts:
target_list.append([hosts, ports, credentials, exclude_hosts]) target_list.append(
[hosts, ports, credentials, exclude_hosts, finished_hosts]
)
else: else:
raise OspdCommandError('No target to scan', 'start_scan') raise OspdCommandError('No target to scan', 'start_scan')
return target_list return target_list
def handle_start_scan_command(self, scan_et): def handle_start_scan_command(self, scan_et):
""" Handles <start_scan> command. """ Handles <start_scan> command.
@return: Response string for <start_scan> command. @return: Response string for <start_scan> command.
""" """
skipping to change at line 577 skipping to change at line 587
# <targets> element is ignored. # <targets> element is ignored.
if target_str is None or ports_str is None: if target_str is None or ports_str is None:
target_list = scan_et.find('targets') target_list = scan_et.find('targets')
if target_list is None or len(target_list) == 0: if target_list is None or len(target_list) == 0:
raise OspdCommandError('No targets or ports', 'start_scan') raise OspdCommandError('No targets or ports', 'start_scan')
else: else:
scan_targets = self.process_targets_element(target_list) scan_targets = self.process_targets_element(target_list)
else: else:
scan_targets = [] scan_targets = []
for single_target in target_str_to_list(target_str): for single_target in target_str_to_list(target_str):
scan_targets.append([single_target, ports_str, '', '']) scan_targets.append([single_target, ports_str, '', '', ''])
scan_id = scan_et.attrib.get('scan_id') scan_id = scan_et.attrib.get('scan_id')
if scan_id is not None and scan_id != '' and not valid_uuid(scan_id): if scan_id is not None and scan_id != '' and not valid_uuid(scan_id):
raise OspdCommandError('Invalid scan_id UUID', 'start_scan') raise OspdCommandError('Invalid scan_id UUID', 'start_scan')
try: try:
parallel = int(scan_et.attrib.get('parallel', '1')) parallel = int(scan_et.attrib.get('parallel', '1'))
if parallel < 1 or parallel > 20: if parallel < 1 or parallel > 20:
parallel = 1 parallel = 1
except ValueError: except ValueError:
skipping to change at line 616 skipping to change at line 626
vt_selection = self.process_vts_params(scanner_vts) vt_selection = self.process_vts_params(scanner_vts)
# Dry run case. # Dry run case.
if 'dry_run' in params and int(params['dry_run']): if 'dry_run' in params and int(params['dry_run']):
scan_func = self.dry_run_scan scan_func = self.dry_run_scan
scan_params = None scan_params = None
else: else:
scan_func = self.start_scan scan_func = self.start_scan
scan_params = self.process_scan_params(params) scan_params = self.process_scan_params(params)
scan_id_aux = scan_id
scan_id = self.create_scan( scan_id = self.create_scan(
scan_id, scan_targets, scan_params, vt_selection scan_id, scan_targets, scan_params, vt_selection
) )
if not scan_id:
id_ = Element('id')
id_.text = scan_id_aux
return simple_response_str('start_scan', 100, 'Continue', id_)
scan_process = multiprocessing.Process( scan_process = multiprocessing.Process(
target=scan_func, args=(scan_id, scan_targets, parallel) target=scan_func, args=(scan_id, scan_targets, parallel)
) )
self.scan_processes[scan_id] = scan_process self.scan_processes[scan_id] = scan_process
scan_process.start() scan_process.start()
id_ = Element('id') id_ = Element('id')
id_.text = scan_id id_.text = scan_id
return simple_response_str('start_scan', 200, 'OK', id_) return simple_response_str('start_scan', 200, 'OK', id_)
def handle_stop_scan_command(self, scan_et): def handle_stop_scan_command(self, scan_et):
skipping to change at line 660 skipping to change at line 676
) )
self.set_scan_status(scan_id, ScanStatus.STOPPED) self.set_scan_status(scan_id, ScanStatus.STOPPED)
logger.info('%s: Scan stopping %s.', scan_id, scan_process.ident) logger.info('%s: Scan stopping %s.', scan_id, scan_process.ident)
self.stop_scan_cleanup(scan_id) self.stop_scan_cleanup(scan_id)
try: try:
scan_process.terminate() scan_process.terminate()
except AttributeError: except AttributeError:
logger.debug('%s: The scanner task stopped unexpectedly.', scan_id) logger.debug('%s: The scanner task stopped unexpectedly.', scan_id)
os.killpg(os.getpgid(scan_process.ident), 15) try:
os.killpg(os.getpgid(scan_process.ident), 15)
except ProcessLookupError as e:
logger.info(
'%s: Scan already stopped %s.', scan_id, scan_process.ident
)
if scan_process.ident != os.getpid(): if scan_process.ident != os.getpid():
scan_process.join() scan_process.join()
logger.info('%s: Scan stopped.', scan_id) logger.info('%s: Scan stopped.', scan_id)
@staticmethod @staticmethod
def stop_scan_cleanup(scan_id): def stop_scan_cleanup(scan_id):
""" Should be implemented by subclass in case of a clean up before """ Should be implemented by subclass in case of a clean up before
terminating is needed. """ terminating is needed. """
@staticmethod @staticmethod
skipping to change at line 759 skipping to change at line 781
except (ssl.SSLError) as exception: except (ssl.SSLError) as exception:
logger.debug('Error: %s', exception) logger.debug('Error: %s', exception)
break break
except (socket.timeout) as exception: except (socket.timeout) as exception:
break break
if len(data) <= 0: if len(data) <= 0:
logger.debug("Empty client stream") logger.debug("Empty client stream")
return return
response = None
try: try:
response = self.handle_command(data) self.handle_command(data, stream)
except OspdCommandError as exception: except OspdCommandError as exception:
response = exception.as_xml() response = exception.as_xml()
logger.debug('Command error: %s', exception.message) logger.debug('Command error: %s', exception.message)
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
logger.exception('While handling client command:') logger.exception('While handling client command:')
exception = OspdCommandError('Fatal error', 'error') exception = OspdCommandError('Fatal error', 'error')
response = exception.as_xml() response = exception.as_xml()
if response:
stream.write(response) stream.write(response)
stream.close() stream.close()
def parallel_scan(self, scan_id, target): def parallel_scan(self, scan_id, target):
""" Starts the scan with scan_id. """ """ Starts the scan with scan_id. """
try: try:
ret = self.exec_scan(scan_id, target) ret = self.exec_scan(scan_id, target)
if ret == 0: if ret == 0:
logger.info("%s: Host scan dead.", target) logger.info("%s: Host scan dead.", target)
elif ret == 1: elif ret == 1:
logger.info("%s: Host scan alived.", target) logger.info("%s: Host scan alived.", target)
skipping to change at line 839 skipping to change at line 862
def calculate_progress(self, scan_id): def calculate_progress(self, scan_id):
""" Calculate the total scan progress from the """ Calculate the total scan progress from the
partial target progress. """ partial target progress. """
t_prog = dict() t_prog = dict()
for target in self.get_scan_target(scan_id): for target in self.get_scan_target(scan_id):
t_prog[target] = self.get_scan_target_progress(scan_id, target) t_prog[target] = self.get_scan_target_progress(scan_id, target)
return sum(t_prog.values()) / len(t_prog) return sum(t_prog.values()) / len(t_prog)
def process_exclude_hosts(self, scan_id, target_list): def process_exclude_hosts(self, scan_id, target_list):
""" Process the exclude hosts before launching the scans. """ Process the exclude hosts before launching the scans."""
Set exclude hosts as finished with 100% to calculate
the scan progress."""
for target, _, _, exclude_hosts in target_list: for target, _, _, exclude_hosts, _ in target_list:
exc_hosts_list = '' exc_hosts_list = ''
if not exclude_hosts: if not exclude_hosts:
continue continue
exc_hosts_list = target_str_to_list(exclude_hosts) exc_hosts_list = target_str_to_list(exclude_hosts)
self.remove_scan_hosts_from_target_progress(
scan_id, target, exc_hosts_list
)
def process_finished_hosts(self, scan_id, target_list):
""" Process the finished hosts before launching the scans.
Set finished hosts as finished with 100% to calculate
the scan progress."""
for target, _, _, _, finished_hosts in target_list:
exc_hosts_list = ''
if not finished_hosts:
continue
exc_hosts_list = target_str_to_list(finished_hosts)
for host in exc_hosts_list: for host in exc_hosts_list:
self.set_scan_host_finished(scan_id, target, host) self.set_scan_host_finished(scan_id, target, host)
self.set_scan_host_progress(scan_id, target, host, 100) self.set_scan_host_progress(scan_id, target, host, 100)
def start_scan(self, scan_id, targets, parallel=1): def start_scan(self, scan_id, targets, parallel=1):
""" Handle N parallel scans if 'parallel' is greater than 1. """ """ Handle N parallel scans if 'parallel' is greater than 1. """
os.setsid() os.setsid()
multiscan_proc = [] multiscan_proc = []
logger.info("%s: Scan started.", scan_id) logger.info("%s: Scan started.", scan_id)
target_list = targets target_list = targets
if target_list is None or not target_list: if target_list is None or not target_list:
raise OspdCommandError('Erroneous targets list', 'start_scan') raise OspdCommandError('Erroneous targets list', 'start_scan')
self.process_exclude_hosts(scan_id, target_list) self.process_exclude_hosts(scan_id, target_list)
self.process_finished_hosts(scan_id, target_list)
for _index, target in enumerate(target_list): for _index, target in enumerate(target_list):
while len(multiscan_proc) >= parallel: while len(multiscan_proc) >= parallel:
progress = self.calculate_progress(scan_id) progress = self.calculate_progress(scan_id)
self.set_scan_progress(scan_id, progress) self.set_scan_progress(scan_id, progress)
multiscan_proc = self.check_pending_target( multiscan_proc = self.check_pending_target(
scan_id, multiscan_proc scan_id, multiscan_proc
) )
time.sleep(1) time.sleep(1)
# If the scan status is stopped, does not launch anymore target # If the scan status is stopped, does not launch anymore target
# scans # scans
if self.get_scan_status(scan_id) == ScanStatus.STOPPED: if self.get_scan_status(scan_id) == ScanStatus.STOPPED:
return return
logger.info( logger.debug(
"%s: Host scan started on ports %s.", target[0], target[1] "%s: Host scan started on ports %s.", target[0], target[1]
) )
scan_process = multiprocessing.Process( scan_process = multiprocessing.Process(
target=self.parallel_scan, args=(scan_id, target[0]) target=self.parallel_scan, args=(scan_id, target[0])
) )
multiscan_proc.append((scan_process, target[0])) multiscan_proc.append((scan_process, target[0]))
scan_process.start() scan_process.start()
self.set_scan_status(scan_id, ScanStatus.RUNNING) self.set_scan_status(scan_id, ScanStatus.RUNNING)
# Wait until all single target were scanned # Wait until all single target were scanned
skipping to change at line 925 skipping to change at line 962
def handle_timeout(self, scan_id, host): def handle_timeout(self, scan_id, host):
""" Handles scanner reaching timeout error. """ """ Handles scanner reaching timeout error. """
self.add_scan_error( self.add_scan_error(
scan_id, scan_id,
host=host, host=host,
name="Timeout", name="Timeout",
value="{0} exec timeout.".format(self.get_scanner_name()), value="{0} exec timeout.".format(self.get_scanner_name()),
) )
def remove_scan_hosts_from_target_progress(
self, scan_id, target, exc_hosts_list
):
""" Remove a list of hosts from the main scan progress table."""
self.scan_collection.remove_hosts_from_target_progress(
scan_id, target, exc_hosts_list
)
def set_scan_host_finished(self, scan_id, target, host): def set_scan_host_finished(self, scan_id, target, host):
""" Add the host in a list of finished hosts """ """ Add the host in a list of finished hosts """
self.scan_collection.set_host_finished(scan_id, target, host) self.scan_collection.set_host_finished(scan_id, target, host)
def set_scan_progress(self, scan_id, progress): def set_scan_progress(self, scan_id, progress):
""" Sets scan_id scan's progress which is a number """ Sets scan_id scan's progress which is a number
between 0 and 100. """ between 0 and 100. """
self.scan_collection.set_progress(scan_id, progress) self.scan_collection.set_progress(scan_id, progress)
def set_scan_host_progress(self, scan_id, target, host, progress): def set_scan_host_progress(self, scan_id, target, host, progress):
""" Sets host's progress which is part of target. """ """ Sets host's progress which is part of target. """
self.scan_collection.set_host_progress( self.scan_collection.set_host_progress(scan_id, target, host, progress)
scan_id, target, host, progress
)
def set_scan_status(self, scan_id, status): def set_scan_status(self, scan_id, status):
""" Set the scan's status.""" """ Set the scan's status."""
self.scan_collection.set_status(scan_id, status) self.scan_collection.set_status(scan_id, status)
def get_scan_status(self, scan_id): def get_scan_status(self, scan_id):
""" Get scan_id scans's status.""" """ Get scan_id scans's status."""
return self.scan_collection.get_status(scan_id) return self.scan_collection.get_status(scan_id)
def scan_exists(self, scan_id): def scan_exists(self, scan_id):
skipping to change at line 990 skipping to change at line 1033
return simple_response_str('get_scans', 404, text) return simple_response_str('get_scans', 404, text)
else: else:
for scan_id in self.scan_collection.ids_iterator(): for scan_id in self.scan_collection.ids_iterator():
self.check_scan_process(scan_id) self.check_scan_process(scan_id)
scan = self.get_scan_xml(scan_id, details, pop_res) scan = self.get_scan_xml(scan_id, details, pop_res)
responses.append(scan) responses.append(scan)
return simple_response_str('get_scans', 200, 'OK', responses) return simple_response_str('get_scans', 200, 'OK', responses)
def handle_get_vts_command(self, vt_et): def handle_get_vts_command(self, vt_et):
""" Handles <get_vts> command. """ Handles <get_vts> command.
The <get_vts> element accept two optional arguments.
vt_id argument receives a single vt id.
filter argument receives a filter selecting a sub set of vts.
If both arguments are given, the vts which match with the filter
are return.
@return: Response string for <get_vts> command. @return: Response string for <get_vts> command.
""" """
if not self.is_cache_available:
try:
yield simple_response_str(
'get_vts',
409,
'Conflict',
'A vts update is being performed.',
)
finally:
return
self.is_cache_available = False
xml_helper = XmlStringHelper()
vt_id = vt_et.attrib.get('vt_id') vt_id = vt_et.attrib.get('vt_id')
vt_filter = vt_et.attrib.get('filter') vt_filter = vt_et.attrib.get('filter')
if vt_id and vt_id not in self.vts: if vt_id and vt_id not in self.vts:
text = "Failed to find vulnerability test '{0}'".format(vt_id) try:
return simple_response_str('get_vts', 404, text) text = "Failed to find vulnerability test '{0}'".format(vt_id)
yield simple_response_str('get_vts', 404, text)
finally:
self.is_cache_available = True
return
filtered_vts = None filtered_vts = None
if vt_filter: if not vt_id and vt_filter:
filtered_vts = self.vts_filter.get_filtered_vts_list( try:
self.vts, vt_filter filtered_vts = self.vts_filter.get_filtered_vts_list(
) self.vts, vt_filter
)
except OspdCommandError as filter_error:
self.is_cache_available = True
raise OspdCommandError(filter_error)
elif vt_id:
filtered_vts = vt_id
else:
filtered_vts = self.vts.keys()
responses = [] yield xml_helper.create_response('get_vts')
yield xml_helper.create_element('vts')
vts_xml = self.get_vts_xml(vt_id, filtered_vts) for vt in self.get_vt_iterator():
vt_id, _ = vt
if vt_id not in filtered_vts:
continue
yield xml_helper.add_element(self.get_vt_xml(vt))
responses.append(vts_xml) yield xml_helper.create_element('vts', end=True)
yield xml_helper.create_response('get_vts', end=True)
return simple_response_str('get_vts', 200, 'OK', responses) self.is_cache_available = True
def handle_get_performance(self, scan_et): def handle_get_performance(self, scan_et):
""" Handles <get_performance> command. """ Handles <get_performance> command.
@return: Response string for <get_performance> command. @return: Response string for <get_performance> command.
""" """
start = scan_et.attrib.get('start') start = scan_et.attrib.get('start')
end = scan_et.attrib.get('end') end = scan_et.attrib.get('end')
titles = scan_et.attrib.get('titles') titles = scan_et.attrib.get('titles')
cmd = ['gvmcg'] cmd = ['gvmcg']
if start: if start:
try: try:
int(start) int(start)
except ValueError: except ValueError:
raise OspdCommandError( raise OspdCommandError(
'Start argument must be integer.', 'Start argument must be integer.', 'get_performance'
'get_performance' )
)
cmd.append(start) cmd.append(start)
if end: if end:
try: try:
int(end) int(end)
except ValueError: except ValueError:
raise OspdCommandError( raise OspdCommandError(
'End argument must be integer.', 'End argument must be integer.', 'get_performance'
'get_performance'
) )
cmd.append(end) cmd.append(end)
if titles: if titles:
combined = "(" + ")|(".join(GVMCG_TITLES) + ")" combined = "(" + ")|(".join(GVMCG_TITLES) + ")"
forbidden = "^[^|&;]+$" forbidden = "^[^|&;]+$"
if re.match(combined, titles) and re.match(forbidden, titles): if re.match(combined, titles) and re.match(forbidden, titles):
cmd.append(titles) cmd.append(titles)
else: else:
raise OspdCommandError( raise OspdCommandError(
'Arguments not allowed', 'Arguments not allowed', 'get_performance'
'get_performance'
) )
try: try:
output = subprocess.check_output(cmd) output = subprocess.check_output(cmd)
except ( except (
subprocess.CalledProcessError, subprocess.CalledProcessError,
PermissionError, PermissionError,
FileNotFoundError, FileNotFoundError,
) as e: ) as e:
raise OspdCommandError( raise OspdCommandError(
'Bogus get_performance format. %s' % e, 'Bogus get_performance format. %s' % e, 'get_performance'
'get_performance'
) )
return simple_response_str('get_performance', 200, 'OK', output.decode() return simple_response_str(
) 'get_performance', 200, 'OK', output.decode()
)
def handle_help_command(self, scan_et): def handle_help_command(self, scan_et):
""" Handles <help> command. """ Handles <help> command.
@return: Response string for <help> command. @return: Response string for <help> command.
""" """
help_format = scan_et.attrib.get('format') help_format = scan_et.attrib.get('format')
if help_format is None or help_format == "text": if help_format is None or help_format == "text":
# Default help format is text. # Default help format is text.
return simple_response_str('help', 200, 'OK', self.get_help_text()) return simple_response_str('help', 200, 'OK', self.get_help_text())
skipping to change at line 1166 skipping to change at line 1244
def get_scan_results_xml(self, scan_id, pop_res): def get_scan_results_xml(self, scan_id, pop_res):
""" Gets scan_id scan's results in XML format. """ Gets scan_id scan's results in XML format.
@return: String of scan results in xml. @return: String of scan results in xml.
""" """
results = Element('results') results = Element('results')
for result in self.scan_collection.results_iterator(scan_id, pop_res): for result in self.scan_collection.results_iterator(scan_id, pop_res):
results.append(get_result_xml(result)) results.append(get_result_xml(result))
logger.info('Returning %d results', len(results)) logger.debug('Returning %d results', len(results))
return results return results
def get_xml_str(self, data): def get_xml_str(self, data):
""" Creates a string in XML Format using the provided data structure. """ Creates a string in XML Format using the provided data structure.
@param: Dictionary of xml tags and their elements. @param: Dictionary of xml tags and their elements.
@return: String of data in xml format. @return: String of data in xml format.
""" """
skipping to change at line 1425 skipping to change at line 1503
This needs to be implemented by each ospd wrapper, in case This needs to be implemented by each ospd wrapper, in case
severities elements for VTs are used. severities elements for VTs are used.
The severities XML objects which are returned will be embedded The severities XML objects which are returned will be embedded
into a <severities></severities> element. into a <severities></severities> element.
@return: XML object as string for severities data. @return: XML object as string for severities data.
""" """
return '' return ''
def get_vt_xml(self, vt_id): def get_vt_iterator(self):
for vt_id, val in self.vts.items():
yield (vt_id, val)
def get_vt_xml(self, single_vt):
""" Gets a single vulnerability test information in XML format. """ Gets a single vulnerability test information in XML format.
@return: String of single vulnerability test information in XML format. @return: String of single vulnerability test information in XML format.
""" """
if not vt_id: if not single_vt:
return Element('vt') return Element('vt')
vt = self.vts.get(vt_id) vt_id, vt = single_vt
name = vt.get('name') name = vt.get('name')
vt_xml = Element('vt') vt_xml = Element('vt')
vt_xml.set('id', vt_id) vt_xml.set('id', vt_id)
for name, value in [('name', name)]: for name, value in [('name', name)]:
elem = SubElement(vt_xml, name) elem = SubElement(vt_xml, name)
elem.text = str(value) elem.text = str(value)
if vt.get('vt_params'): if vt.get('vt_params'):
params_xml_str = self.get_params_vt_as_xml_str( params_xml_str = self.get_params_vt_as_xml_str(
skipping to change at line 1521 skipping to change at line 1602
vt_xml.append(secET.fromstring(severities_xml_str)) vt_xml.append(secET.fromstring(severities_xml_str))
if vt.get('custom'): if vt.get('custom'):
custom_xml_str = self.get_custom_vt_as_xml_str( custom_xml_str = self.get_custom_vt_as_xml_str(
vt_id, vt.get('custom') vt_id, vt.get('custom')
) )
vt_xml.append(secET.fromstring(custom_xml_str)) vt_xml.append(secET.fromstring(custom_xml_str))
return vt_xml return vt_xml
def get_vts_xml(self, vt_id=None, filtered_vts=None):
""" Gets collection of vulnerability test information in XML format.
If vt_id is specified, the collection will contain only this vt, if
found.
If no vt_id is specified, the collection will contain all vts or those
passed in filtered_vts.
Arguments:
vt_id (vt_id, optional): ID of the vt to get.
filtered_vts (dict, optional): Filtered VTs collection.
Return:
String of collection of vulnerability test information in
XML format.
"""
vts_xml = Element('vts')
if vt_id:
vts_xml.append(self.get_vt_xml(vt_id))
elif filtered_vts:
for vt_id in filtered_vts:
vts_xml.append(self.get_vt_xml(vt_id))
else:
for vt_id in self.vts:
vts_xml.append(self.get_vt_xml(vt_id))
return vts_xml
def handle_get_scanner_details(self): def handle_get_scanner_details(self):
""" Handles <get_scanner_details> command. """ Handles <get_scanner_details> command.
@return: Response string for <get_scanner_details> command. @return: Response string for <get_scanner_details> command.
""" """
desc_xml = Element('description') desc_xml = Element('description')
desc_xml.text = self.get_scanner_description() desc_xml.text = self.get_scanner_description()
details = [desc_xml, self.get_scanner_params_xml()] details = [desc_xml, self.get_scanner_params_xml()]
return simple_response_str('get_scanner_details', 200, 'OK', details) return simple_response_str('get_scanner_details', 200, 'OK', details)
skipping to change at line 1599 skipping to change at line 1651
content = [protocol, daemon, scanner] content = [protocol, daemon, scanner]
if self.get_vts_version(): if self.get_vts_version():
vts = Element('vts') vts = Element('vts')
elem = SubElement(vts, 'version') elem = SubElement(vts, 'version')
elem.text = self.get_vts_version() elem.text = self.get_vts_version()
content.append(vts) content.append(vts)
return simple_response_str('get_version', 200, 'OK', content) return simple_response_str('get_version', 200, 'OK', content)
def handle_command(self, command): def handle_command(self, command, stream):
""" Handles an osp command in a string. """ Handles an osp command in a string.
@return: OSP Response to command. @return: OSP Response to command.
""" """
try: try:
tree = secET.fromstring(command) tree = secET.fromstring(command)
except secET.ParseError: except secET.ParseError:
logger.debug("Erroneous client input: %s", command) logger.debug("Erroneous client input: %s", command)
raise OspdCommandError('Invalid data') raise OspdCommandError('Invalid data')
if not self.command_exists(tree.tag) and tree.tag != "authenticate": if not self.command_exists(tree.tag) and tree.tag != "authenticate":
raise OspdCommandError('Bogus command name') raise OspdCommandError('Bogus command name')
if tree.tag == "get_version": if tree.tag == "get_version":
return self.handle_get_version_command() stream.write(self.handle_get_version_command())
elif tree.tag == "start_scan": elif tree.tag == "start_scan":
return self.handle_start_scan_command(tree) stream.write(self.handle_start_scan_command(tree))
elif tree.tag == "stop_scan": elif tree.tag == "stop_scan":
return self.handle_stop_scan_command(tree) stream.write(self.handle_stop_scan_command(tree))
elif tree.tag == "get_scans": elif tree.tag == "get_scans":
return self.handle_get_scans_command(tree) stream.write(self.handle_get_scans_command(tree))
elif tree.tag == "get_vts": elif tree.tag == "get_vts":
return self.handle_get_vts_command(tree) response = self.handle_get_vts_command(tree)
for data in response:
stream.write(data)
return
elif tree.tag == "delete_scan": elif tree.tag == "delete_scan":
return self.handle_delete_scan_command(tree) stream.write(self.handle_delete_scan_command(tree))
elif tree.tag == "help": elif tree.tag == "help":
return self.handle_help_command(tree) stream.write(self.handle_help_command(tree))
elif tree.tag == "get_scanner_details": elif tree.tag == "get_scanner_details":
return self.handle_get_scanner_details() stream.write(self.handle_get_scanner_details())
elif tree.tag == "get_performance": elif tree.tag == "get_performance":
return self.handle_get_performance(tree) stream.write(self.handle_get_performance(tree))
else: else:
assert False, "Unhandled command: {0}".format(tree.tag) assert False, "Unhandled command: {0}".format(tree.tag)
def check(self): def check(self):
""" Asserts to False. Should be implemented by subclass. """ """ Asserts to False. Should be implemented by subclass. """
raise NotImplementedError raise NotImplementedError
def run(self, server: BaseServer): def run(self, server: BaseServer):
""" Starts the Daemon, handling commands until interrupted. """ Starts the Daemon, handling commands until interrupted.
""" """
server.start(self.handle_client_stream) server.start(self.handle_client_stream)
try: try:
while True: while True:
time.sleep(10) time.sleep(10)
self.scheduler() self.scheduler()
self.clean_forgotten_scans()
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Received Ctrl-C shutting-down ...") logger.info("Received Ctrl-C shutting-down ...")
finally: finally:
logger.info("Shutting-down server ...") logger.info("Shutting-down server ...")
server.close() server.close()
def scheduler(self): def scheduler(self):
""" Should be implemented by subclass in case of need """ Should be implemented by subclass in case of need
to run tasks periodically. """ to run tasks periodically. """
def create_scan(self, scan_id, targets, options, vts): def create_scan(self, scan_id, targets, options, vts):
""" Creates a new scan. """ Creates a new scan.
@target: Target to scan. @target: Target to scan.
@options: Miscellaneous scan options. @options: Miscellaneous scan options.
@return: New scan's ID. @return: New scan's ID. None if the scan_id already exists and the
scan status is RUNNING or FINISHED.
""" """
if self.scan_exists(scan_id): status = None
logger.info("Scan %s exists. Resuming scan.", scan_id) scan_exists = self.scan_exists(scan_id)
if scan_id and scan_exists:
status = self.get_scan_status(scan_id)
if scan_exists and status == ScanStatus.STOPPED:
logger.info("Scan %s exists. Resuming scan.", scan_id)
elif scan_exists and (
status == ScanStatus.RUNNING or status == ScanStatus.FINISHED
):
logger.info(
"Scan %s exists with status %s.", scan_id, status.name.lower()
)
return
return self.scan_collection.create_scan(scan_id, targets, options, vts) return self.scan_collection.create_scan(scan_id, targets, options, vts)
def get_scan_options(self, scan_id): def get_scan_options(self, scan_id):
""" Gives a scan's list of options. """ """ Gives a scan's list of options. """
return self.scan_collection.get_options(scan_id) return self.scan_collection.get_options(scan_id)
def set_scan_option(self, scan_id, name, value): def set_scan_option(self, scan_id, name, value):
""" Sets a scan's option to a provided value. """ """ Sets a scan's option to a provided value. """
return self.scan_collection.set_option(scan_id, name, value) return self.scan_collection.set_option(scan_id, name, value)
def clean_forgotten_scans(self):
""" Check for old stopped or finished scans which have not been
deleted and delete them if the are older than the set value."""
if not self.scaninfo_store_time:
return
for scan_id in list(self.scan_collection.ids_iterator()):
end_time = int(self.get_scan_end_time(scan_id))
scan_status = self.get_scan_status(scan_id)
if (
scan_status == ScanStatus.STOPPED
or scan_status == ScanStatus.FINISHED
) and end_time:
stored_time = int(time.time()) - end_time
if stored_time > self.scaninfo_store_time * 3600:
logger.debug(
'Scan %s is older than %d hours and seems have been '
'forgotten. Scan info will be deleted from the '
'scan table',
scan_id,
self.scaninfo_store_time,
)
self.delete_scan(scan_id)
def check_scan_process(self, scan_id): def check_scan_process(self, scan_id):
""" Check the scan's process, and terminate the scan if not alive. """ """ Check the scan's process, and terminate the scan if not alive. """
scan_process = self.scan_processes[scan_id] scan_process = self.scan_processes[scan_id]
progress = self.get_scan_progress(scan_id) progress = self.get_scan_progress(scan_id)
if progress < 100 and not scan_process.is_alive(): if progress < 100 and not scan_process.is_alive():
if not (self.get_scan_status(scan_id) == ScanStatus.STOPPED): if not (self.get_scan_status(scan_id) == ScanStatus.STOPPED):
self.set_scan_status(scan_id, ScanStatus.STOPPED) self.set_scan_status(scan_id, ScanStatus.STOPPED)
self.add_scan_error( self.add_scan_error(
scan_id, name="", host="", value="Scan process failure." scan_id, name="", host="", value="Scan process failure."
) )
skipping to change at line 1765 skipping to change at line 1859
hostname, hostname,
name, name,
value, value,
port, port,
test_id, test_id,
0.0, 0.0,
qod, qod,
) )
def add_scan_error( def add_scan_error(
self, scan_id, host='', hostname='', name='', value='', port='' self,
scan_id,
host='',
hostname='',
name='',
value='',
port='',
test_id='',
): ):
""" Adds an error result to scan_id scan. """ """ Adds an error result to scan_id scan. """
self.scan_collection.add_result( self.scan_collection.add_result(
scan_id, ResultType.ERROR, host, hostname, name, value, port scan_id,
ResultType.ERROR,
host,
hostname,
name,
value,
port,
test_id,
) )
def add_scan_host_detail( def add_scan_host_detail(
self, scan_id, host='', hostname='', name='', value='' self, scan_id, host='', hostname='', name='', value=''
): ):
""" Adds a host detail result to scan_id scan. """ """ Adds a host detail result to scan_id scan. """
self.scan_collection.add_result( self.scan_collection.add_result(
scan_id, ResultType.HOST_DETAIL, host, hostname, name, value scan_id, ResultType.HOST_DETAIL, host, hostname, name, value
) )
 End of changes. 57 change blocks. 
88 lines changed or deleted 195 lines changed or added

Home  |  About  |  Features  |  All  |  Newest  |  Dox  |  Diffs  |  RSS Feeds  |  Screenshots  |  Comments  |  Imprint  |  Privacy  |  HTTP(S)