diff options
Diffstat (limited to '')
40 files changed, 7833 insertions, 0 deletions
diff --git a/suricata/update/__init__.py b/suricata/update/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/suricata/update/__init__.py diff --git a/suricata/update/commands/__init__.py b/suricata/update/commands/__init__.py new file mode 100644 index 0000000..e75c80a --- /dev/null +++ b/suricata/update/commands/__init__.py @@ -0,0 +1,23 @@ +# Copyright (C) 2017 Open Information Security Foundation +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +from suricata.update.commands import addsource +from suricata.update.commands import listsources +from suricata.update.commands import updatesources +from suricata.update.commands import enablesource +from suricata.update.commands import disablesource +from suricata.update.commands import removesource +from suricata.update.commands import checkversions diff --git a/suricata/update/commands/addsource.py b/suricata/update/commands/addsource.py new file mode 100644 index 0000000..a87095c --- /dev/null +++ b/suricata/update/commands/addsource.py @@ -0,0 +1,72 @@ +# Copyright (C) 2017 Open Information Security Foundation +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +from __future__ import print_function + +import logging + +from suricata.update import config +from suricata.update import sources + +try: + input = raw_input +except: + pass + +logger = logging.getLogger() + + +def register(parser): + parser.add_argument("name", metavar="<name>", nargs="?", + help="Name of source") + parser.add_argument("url", metavar="<url>", nargs="?", help="Source URL") + parser.add_argument("--http-header", metavar="<http-header>", + help="Additional HTTP header to add to requests") + parser.add_argument("--no-checksum", action="store_false", + help="Skips downloading the checksum URL") + parser.set_defaults(func=add_source) + + +def add_source(): + args = config.args() + + if args.name: + name = args.name + else: + while True: + name = input("Name of source: ").strip() + if name: + break + + if sources.source_name_exists(name): + logger.error("A source with name %s already exists.", name) + return 1 + + if args.url: + url = args.url + else: + while True: + url = input("URL: ").strip() + if url: + break + + checksum = args.no_checksum + + header = args.http_header if args.http_header else None + + source_config = sources.SourceConfiguration( + name, header=header, url=url, checksum=checksum) + sources.save_source_config(source_config) diff --git a/suricata/update/commands/checkversions.py b/suricata/update/commands/checkversions.py new file mode 100644 index 0000000..3492317 --- /dev/null +++ b/suricata/update/commands/checkversions.py @@ -0,0 +1,83 @@ +# Copyright (C) 2019 Open Information Security Foundation +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +import os.path +import logging +from suricata.update import sources, engine + +logger = logging.getLogger() + + +def is_gt(v1, v2): + if v1.full == v2.full: + return False + + if v1.major < v2.major: + return False + elif v1.major > v2.major: + return True + + if v1.minor < v2.minor: + return False + elif v1.minor > v2.minor: + return True + + if v1.patch < v2.patch: + return False + + return True + + +def register(parser): + parser.set_defaults(func=check_version) + + +def check_version(suricata_version): + if "dev" in suricata_version.full: + logger.warning("Development version of Suricata found: %s. " + "Skipping version check.", suricata_version.full) + return + + index_filename = sources.get_index_filename() + if not os.path.exists(index_filename): + logger.warning("No index exists, will use bundled index.") + logger.warning("Please run suricata-update update-sources.") + index = sources.Index(index_filename) + version = index.get_versions() + recommended = engine.parse_version(version["suricata"]["recommended"]) + if not recommended: + logger.error("Recommended version was not parsed properly") + sys.exit(1) + # In case index is out of date + if is_gt(suricata_version, recommended): + return + # Evaluate if the installed version is present in index + upgrade_version = version["suricata"].get(suricata_version.short) + if not upgrade_version: + logger.warning("Suricata version %s has reached EOL. Please upgrade to %s.", + suricata_version.full, recommended.full) + return + if suricata_version.full == upgrade_version: + logger.info("Suricata version %s is up to date", suricata_version.full) + elif upgrade_version == recommended.full: + logger.warning( + "Suricata version %s is outdated. Please upgrade to %s.", + suricata_version.full, recommended.full) + else: + logger.warning( + "Suricata version %s is outdated. Please upgrade to %s or %s.", + suricata_version.full, upgrade_version, recommended.full) + diff --git a/suricata/update/commands/disablesource.py b/suricata/update/commands/disablesource.py new file mode 100644 index 0000000..6a64a7b --- /dev/null +++ b/suricata/update/commands/disablesource.py @@ -0,0 +1,40 @@ +# Copyright (C) 2017 Open Information Security Foundation +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +from __future__ import print_function + +import os +import logging + +from suricata.update import config +from suricata.update import sources + +logger = logging.getLogger() + +def register(parser): + parser.add_argument("name") + parser.set_defaults(func=disable_source) + +def disable_source(): + name = config.args().name + filename = sources.get_enabled_source_filename(name) + if not os.path.exists(filename): + logger.debug("Filename %s does not exist.", filename) + logger.warning("Source %s is not enabled.", name) + return 0 + logger.debug("Renaming %s to %s.disabled.", filename, filename) + os.rename(filename, "%s.disabled" % (filename)) + logger.info("Source %s has been disabled", name) diff --git a/suricata/update/commands/enablesource.py b/suricata/update/commands/enablesource.py new file mode 100644 index 0000000..53bb68a --- /dev/null +++ b/suricata/update/commands/enablesource.py @@ -0,0 +1,162 @@ +# Copyright (C) 2017 Open Information Security Foundation +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +from __future__ import print_function + +import os +import logging + +import yaml + +from suricata.update import config +from suricata.update import sources + +try: + input = raw_input +except: + pass + +logger = logging.getLogger() + +default_source = "et/open" + +def register(parser): + parser.add_argument("name") + parser.add_argument("params", nargs="*", metavar="param=val") + parser.set_defaults(func=enable_source) + +def enable_source(): + name = config.args().name + update_params = False + + # Check if source is already enabled. + enabled_source_filename = sources.get_enabled_source_filename(name) + if os.path.exists(enabled_source_filename): + logger.warning("The source %s is already enabled.", name) + update_params = True + + # First check if this source was previous disabled and then just + # re-enable it. + disabled_source_filename = sources.get_disabled_source_filename(name) + if os.path.exists(disabled_source_filename): + logger.info("Re-enabling previously disabled source for %s.", name) + os.rename(disabled_source_filename, enabled_source_filename) + update_params = True + + if not os.path.exists(sources.get_index_filename()): + logger.warning("Source index does not exist, will use bundled one.") + logger.warning("Please run suricata-update update-sources.") + + source_index = sources.load_source_index(config) + + if not name in source_index.get_sources() and not name in sources.get_sources_from_dir(): + logger.error("Unknown source: %s", name) + return 1 + + # Parse key=val options. + opts = {} + for param in config.args().params: + key, val = param.split("=", 1) + opts[key] = val + + params = {} + if update_params: + source = yaml.safe_load(open(sources.get_enabled_source_filename(name), "rb")) + else: + source = source_index.get_sources()[name] + + if "params" in source: + params = source["params"] + for old_param in source["params"]: + if old_param in opts and source["params"][old_param] != opts[old_param]: + logger.info("Updating source parameter '%s': '%s' -> '%s'." % ( + old_param, source["params"][old_param], opts[old_param])) + params[old_param] = opts[old_param] + + if "subscribe-url" in source: + print("The source %s requires a subscription. Subscribe here:" % (name)) + print(" %s" % source["subscribe-url"]) + + if "parameters" in source: + for param in source["parameters"]: + if param in opts: + params[param] = opts[param] + else: + prompt = source["parameters"][param]["prompt"] + while True: + r = input("%s (%s): " % (prompt, param)) + r = r.strip() + if r: + break + params[param] = r.strip() + + if "checksum" in source: + checksum = source["checksum"] + else: + checksum = source.get("checksum", True) + + new_source = sources.SourceConfiguration( + name, params=params, checksum=checksum) + + # If the source directory does not exist, create it. Also create + # the default rule-source of et/open, unless the source being + # enabled replaces it. + source_directory = sources.get_source_directory() + if not os.path.exists(source_directory): + try: + logger.info("Creating directory %s", source_directory) + os.makedirs(source_directory) + except Exception as err: + logger.error( + "Failed to create directory %s: %s", source_directory, err) + return 1 + + if "replaces" in source and default_source in source["replaces"]: + logger.debug( + "Not enabling default source as selected source replaces it") + elif new_source.name == default_source: + logger.debug( + "Not enabling default source as selected source is the default") + else: + logger.info("Enabling default source %s", default_source) + if not source_index.get_source_by_name(default_source): + logger.error("Default source %s not in index", default_source) + else: + default_source_config = sources.SourceConfiguration( + default_source) + write_source_config(default_source_config, True) + + write_source_config(new_source, True) + logger.info("Source %s enabled", new_source.name) + + if "replaces" in source: + for replaces in source["replaces"]: + filename = sources.get_enabled_source_filename(replaces) + if os.path.exists(filename): + logger.info( + "Removing source %s as its replaced by %s", replaces, + new_source.name) + logger.debug("Deleting %s", filename) + os.unlink(filename) + +def write_source_config(config, enabled): + if enabled: + filename = sources.get_enabled_source_filename(config.name) + else: + filename = sources.get_disabled_source_filename(config.name) + with open(filename, "w") as fileobj: + logger.debug("Writing %s", filename) + fileobj.write(yaml.safe_dump(config.dict(), default_flow_style=False)) diff --git a/suricata/update/commands/listsources.py b/suricata/update/commands/listsources.py new file mode 100644 index 0000000..d35c3cd --- /dev/null +++ b/suricata/update/commands/listsources.py @@ -0,0 +1,116 @@ +# Copyright (C) 2017 Open Information Security Foundation +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +from __future__ import print_function + +import logging + +from suricata.update import config +from suricata.update import sources +from suricata.update import util +from suricata.update import exceptions + +logger = logging.getLogger() + +def register(parser): + parser.add_argument("--free", action="store_true", + default=False, help="List all freely available sources") + parser.add_argument("--enabled", action="store_true", + help="List all enabled sources") + parser.add_argument("--all", action="store_true", + help="List all sources (including deprecated and obsolete)") + parser.set_defaults(func=list_sources) + +def list_sources(): + enabled = config.args().enabled or \ + config.args().subcommand == "list-enabled-sources" + + if enabled: + found = False + + # First list sources from the main config. + config_sources = config.get("sources") + if config_sources: + found = True + print("From %s:" % (config.filename)) + for source in config_sources: + print(" - %s" % (source)) + + # And local files. + local = config.get("local") + if local: + found = True + print("Local files/directories:") + for filename in local: + print(" - %s" % (filename)) + + enabled_sources = sources.get_enabled_sources() + if enabled_sources: + found = True + print("Enabled sources:") + for source in enabled_sources.values(): + print(" - %s" % (source["source"])) + + # If no enabled sources were found, log it. + if not found: + logger.warning("No enabled sources.") + return 0 + + free_only = config.args().free + if not sources.source_index_exists(config): + logger.warning("Source index does not exist, will use bundled one.") + logger.warning("Please run suricata-update update-sources.") + + index = sources.load_source_index(config) + for name, source in index.get_sources().items(): + is_not_free = source.get("subscribe-url") + if free_only and is_not_free: + continue + if not config.args().all: + if source.get("deprecated") is not None or \ + source.get("obsolete") is not None: + continue + print("%s: %s" % (util.bright_cyan("Name"), util.bright_magenta(name))) + print(" %s: %s" % ( + util.bright_cyan("Vendor"), util.bright_magenta(source["vendor"]))) + print(" %s: %s" % ( + util.bright_cyan("Summary"), util.bright_magenta(source["summary"]))) + print(" %s: %s" % ( + util.bright_cyan("License"), util.bright_magenta(source["license"]))) + if "tags" in source: + print(" %s: %s" % ( + util.bright_cyan("Tags"), + util.bright_magenta(", ".join(source["tags"])))) + if "replaces" in source: + print(" %s: %s" % ( + util.bright_cyan("Replaces"), + util.bright_magenta(", ".join(source["replaces"])))) + if "parameters" in source: + print(" %s: %s" % ( + util.bright_cyan("Parameters"), + util.bright_magenta(", ".join(source["parameters"])))) + if "subscribe-url" in source: + print(" %s: %s" % ( + util.bright_cyan("Subscription"), + util.bright_magenta(source["subscribe-url"]))) + if "deprecated" in source: + print(" %s: %s" % ( + util.orange("Deprecated"), + util.bright_magenta(source["deprecated"]))) + if "obsolete" in source: + print(" %s: %s" % ( + util.orange("Obsolete"), + util.bright_magenta(source["obsolete"]))) diff --git a/suricata/update/commands/removesource.py b/suricata/update/commands/removesource.py new file mode 100644 index 0000000..f75d5ca --- /dev/null +++ b/suricata/update/commands/removesource.py @@ -0,0 +1,49 @@ +# Copyright (C) 2017 Open Information Security Foundation +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +from __future__ import print_function + +import os +import logging + +from suricata.update import config +from suricata.update import sources + +logger = logging.getLogger() + +def register(parser): + parser.add_argument("name") + parser.set_defaults(func=remove_source) + +def remove_source(): + name = config.args().name + + enabled_source_filename = sources.get_enabled_source_filename(name) + if os.path.exists(enabled_source_filename): + logger.debug("Deleting file %s.", enabled_source_filename) + os.remove(enabled_source_filename) + logger.info("Source %s removed, previously enabled.", name) + return 0 + + disabled_source_filename = sources.get_disabled_source_filename(name) + if os.path.exists(disabled_source_filename): + logger.debug("Deleting file %s.", disabled_source_filename) + os.remove(disabled_source_filename) + logger.info("Source %s removed, previously disabled.", name) + return 0 + + logger.warning("Source %s does not exist.", name) + return 1 diff --git a/suricata/update/commands/updatesources.py b/suricata/update/commands/updatesources.py new file mode 100644 index 0000000..06a0d11 --- /dev/null +++ b/suricata/update/commands/updatesources.py @@ -0,0 +1,105 @@ +# Copyright (C) 2017 Open Information Security Foundation +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +from __future__ import print_function + +import io +import logging +import os + +import yaml +from suricata.update import config, exceptions, net, sources + +logger = logging.getLogger() + + +def register(parser): + parser.set_defaults(func=update_sources) + + +def get_initial_content(): + initial_content = None + if os.path.exists(local_index_filename): + with open(local_index_filename, "r") as stream: + initial_content = yaml.safe_load(stream) + return initial_content + + +def get_sources(before, after): + all_sources = {source: after[source] + for source in after if source not in before} + return all_sources + + +def log_sources(sources_map): + for name, all_sources in sources_map.items(): + if not all_sources: + continue + for source in all_sources: + logger.info("Source %s was %s", source, name) + + +def compare_sources(initial_content, final_content): + if not initial_content: + logger.info("Adding all sources") + return + if initial_content == final_content: + logger.info("No change in sources") + return + initial_sources = initial_content.get("sources") + final_sources = final_content.get("sources") + added_sources = get_sources(before=initial_sources, after=final_sources) + removed_sources = get_sources(before=final_sources, after=initial_sources) + log_sources(sources_map={"added": added_sources, + "removed": removed_sources}) + for source in set(initial_sources) & set(final_sources): + if initial_sources[source] != final_sources[source]: + logger.info("Source %s was changed", source) + + +def write_and_compare(initial_content, fileobj): + try: + with open(local_index_filename, "wb") as outobj: + outobj.write(fileobj.getvalue()) + except IOError as ioe: + logger.error("Failed to open directory: %s", ioe) + return 1 + with open(local_index_filename, "rb") as stream: + final_content = yaml.safe_load(stream) + compare_sources(initial_content, final_content) + logger.info("Saved %s", local_index_filename) + + +def update_sources(): + global local_index_filename + local_index_filename = sources.get_index_filename() + initial_content = get_initial_content() + with io.BytesIO() as fileobj: + url = sources.get_source_index_url() + logger.info("Downloading %s", url) + try: + net.get(url, fileobj) + except Exception as err: + raise exceptions.ApplicationError( + "Failed to download index: %s: %s" % (url, err)) + if not os.path.exists(config.get_cache_dir()): + try: + os.makedirs(config.get_cache_dir()) + except Exception as err: + logger.error("Failed to create directory %s: %s", + config.get_cache_dir(), err) + return 1 + write_and_compare(initial_content=initial_content, fileobj=fileobj) diff --git a/suricata/update/compat/__init__.py b/suricata/update/compat/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/suricata/update/compat/__init__.py diff --git a/suricata/update/compat/argparse/LICENSE.txt b/suricata/update/compat/argparse/LICENSE.txt new file mode 100644 index 0000000..640bc78 --- /dev/null +++ b/suricata/update/compat/argparse/LICENSE.txt @@ -0,0 +1,20 @@ +argparse is (c) 2006-2009 Steven J. Bethard <steven.bethard@gmail.com>. + +The argparse module was contributed to Python as of Python 2.7 and thus +was licensed under the Python license. Same license applies to all files in +the argparse package project. + +For details about the Python License, please see doc/Python-License.txt. + +History +------- + +Before (and including) argparse 1.1, the argparse package was licensed under +Apache License v2.0. + +After argparse 1.1, all project files from the argparse project were deleted +due to license compatibility issues between Apache License 2.0 and GNU GPL v2. + +The project repository then had a clean start with some files taken from +Python 2.7.1, so definitely all files are under Python License now. + diff --git a/suricata/update/compat/argparse/__init__.py b/suricata/update/compat/argparse/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/suricata/update/compat/argparse/__init__.py diff --git a/suricata/update/compat/argparse/argparse.py b/suricata/update/compat/argparse/argparse.py new file mode 100644 index 0000000..5a68b70 --- /dev/null +++ b/suricata/update/compat/argparse/argparse.py @@ -0,0 +1,2378 @@ +# Author: Steven J. Bethard <steven.bethard@gmail.com>. + +"""Command-line parsing library + +This module is an optparse-inspired command-line parsing library that: + + - handles both optional and positional arguments + - produces highly informative usage messages + - supports parsers that dispatch to sub-parsers + +The following is a simple usage example that sums integers from the +command-line and writes the result to a file:: + + parser = argparse.ArgumentParser( + description='sum the integers at the command line') + parser.add_argument( + 'integers', metavar='int', nargs='+', type=int, + help='an integer to be summed') + parser.add_argument( + '--log', default=sys.stdout, type=argparse.FileType('w'), + help='the file where the sum should be written') + args = parser.parse_args() + args.log.write('%s' % sum(args.integers)) + args.log.close() + +The module contains the following public classes: + + - ArgumentParser -- The main entry point for command-line parsing. As the + example above shows, the add_argument() method is used to populate + the parser with actions for optional and positional arguments. Then + the parse_args() method is invoked to convert the args at the + command-line into an object with attributes. + + - ArgumentError -- The exception raised by ArgumentParser objects when + there are errors with the parser's actions. Errors raised while + parsing the command-line are caught by ArgumentParser and emitted + as command-line messages. + + - FileType -- A factory for defining types of files to be created. As the + example above shows, instances of FileType are typically passed as + the type= argument of add_argument() calls. + + - Action -- The base class for parser actions. Typically actions are + selected by passing strings like 'store_true' or 'append_const' to + the action= argument of add_argument(). However, for greater + customization of ArgumentParser actions, subclasses of Action may + be defined and passed as the action= argument. + + - HelpFormatter, RawDescriptionHelpFormatter, RawTextHelpFormatter, + ArgumentDefaultsHelpFormatter -- Formatter classes which + may be passed as the formatter_class= argument to the + ArgumentParser constructor. HelpFormatter is the default, + RawDescriptionHelpFormatter and RawTextHelpFormatter tell the parser + not to change the formatting for help text, and + ArgumentDefaultsHelpFormatter adds information about argument defaults + to the help. + +All other classes in this module are considered implementation details. +(Also note that HelpFormatter and RawDescriptionHelpFormatter are only +considered public as object names -- the API of the formatter objects is +still considered an implementation detail.) +""" + +__version__ = '1.3.0' # we use our own version number independant of the + # one in stdlib and we release this on pypi. + +__external_lib__ = True # to make sure the tests really test THIS lib, + # not the builtin one in Python stdlib + +__all__ = [ + 'ArgumentParser', + 'ArgumentError', + 'ArgumentTypeError', + 'FileType', + 'HelpFormatter', + 'ArgumentDefaultsHelpFormatter', + 'RawDescriptionHelpFormatter', + 'RawTextHelpFormatter', + 'Namespace', + 'Action', + 'ONE_OR_MORE', + 'OPTIONAL', + 'PARSER', + 'REMAINDER', + 'SUPPRESS', + 'ZERO_OR_MORE', +] + + +import copy as _copy +import os as _os +import re as _re +import sys as _sys +import textwrap as _textwrap + +from gettext import gettext as _ + +try: + set +except NameError: + # for python < 2.4 compatibility (sets module is there since 2.3): + from sets import Set as set + +try: + basestring +except NameError: + basestring = str + +try: + sorted +except NameError: + # for python < 2.4 compatibility: + def sorted(iterable, reverse=False): + result = list(iterable) + result.sort() + if reverse: + result.reverse() + return result + + +def _callable(obj): + return hasattr(obj, '__call__') or hasattr(obj, '__bases__') + + +SUPPRESS = '==SUPPRESS==' + +OPTIONAL = '?' +ZERO_OR_MORE = '*' +ONE_OR_MORE = '+' +PARSER = 'A...' +REMAINDER = '...' +_UNRECOGNIZED_ARGS_ATTR = '_unrecognized_args' + +# ============================= +# Utility functions and classes +# ============================= + +class _AttributeHolder(object): + """Abstract base class that provides __repr__. + + The __repr__ method returns a string in the format:: + ClassName(attr=name, attr=name, ...) + The attributes are determined either by a class-level attribute, + '_kwarg_names', or by inspecting the instance __dict__. + """ + + def __repr__(self): + type_name = type(self).__name__ + arg_strings = [] + for arg in self._get_args(): + arg_strings.append(repr(arg)) + for name, value in self._get_kwargs(): + arg_strings.append('%s=%r' % (name, value)) + return '%s(%s)' % (type_name, ', '.join(arg_strings)) + + def _get_kwargs(self): + return sorted(self.__dict__.items()) + + def _get_args(self): + return [] + + +def _ensure_value(namespace, name, value): + if getattr(namespace, name, None) is None: + setattr(namespace, name, value) + return getattr(namespace, name) + + +# =============== +# Formatting Help +# =============== + +class HelpFormatter(object): + """Formatter for generating usage messages and argument help strings. + + Only the name of this class is considered a public API. All the methods + provided by the class are considered an implementation detail. + """ + + def __init__(self, + prog, + indent_increment=2, + max_help_position=24, + width=None): + + # default setting for width + if width is None: + try: + width = int(_os.environ['COLUMNS']) + except (KeyError, ValueError): + width = 80 + width -= 2 + + self._prog = prog + self._indent_increment = indent_increment + self._max_help_position = max_help_position + self._width = width + + self._current_indent = 0 + self._level = 0 + self._action_max_length = 0 + + self._root_section = self._Section(self, None) + self._current_section = self._root_section + + self._whitespace_matcher = _re.compile(r'\s+') + self._long_break_matcher = _re.compile(r'\n\n\n+') + + # =============================== + # Section and indentation methods + # =============================== + def _indent(self): + self._current_indent += self._indent_increment + self._level += 1 + + def _dedent(self): + self._current_indent -= self._indent_increment + assert self._current_indent >= 0, 'Indent decreased below 0.' + self._level -= 1 + + class _Section(object): + + def __init__(self, formatter, parent, heading=None): + self.formatter = formatter + self.parent = parent + self.heading = heading + self.items = [] + + def format_help(self): + # format the indented section + if self.parent is not None: + self.formatter._indent() + join = self.formatter._join_parts + for func, args in self.items: + func(*args) + item_help = join([func(*args) for func, args in self.items]) + if self.parent is not None: + self.formatter._dedent() + + # return nothing if the section was empty + if not item_help: + return '' + + # add the heading if the section was non-empty + if self.heading is not SUPPRESS and self.heading is not None: + current_indent = self.formatter._current_indent + heading = '%*s%s:\n' % (current_indent, '', self.heading) + else: + heading = '' + + # join the section-initial newline, the heading and the help + return join(['\n', heading, item_help, '\n']) + + def _add_item(self, func, args): + self._current_section.items.append((func, args)) + + # ======================== + # Message building methods + # ======================== + def start_section(self, heading): + self._indent() + section = self._Section(self, self._current_section, heading) + self._add_item(section.format_help, []) + self._current_section = section + + def end_section(self): + self._current_section = self._current_section.parent + self._dedent() + + def add_text(self, text): + if text is not SUPPRESS and text is not None: + self._add_item(self._format_text, [text]) + + def add_usage(self, usage, actions, groups, prefix=None): + if usage is not SUPPRESS: + args = usage, actions, groups, prefix + self._add_item(self._format_usage, args) + + def add_argument(self, action): + if action.help is not SUPPRESS: + + # find all invocations + get_invocation = self._format_action_invocation + invocations = [get_invocation(action)] + for subaction in self._iter_indented_subactions(action): + invocations.append(get_invocation(subaction)) + + # update the maximum item length + invocation_length = max([len(s) for s in invocations]) + action_length = invocation_length + self._current_indent + self._action_max_length = max(self._action_max_length, + action_length) + + # add the item to the list + self._add_item(self._format_action, [action]) + + def add_arguments(self, actions): + for action in actions: + self.add_argument(action) + + # ======================= + # Help-formatting methods + # ======================= + def format_help(self): + help = self._root_section.format_help() + if help: + help = self._long_break_matcher.sub('\n\n', help) + help = help.strip('\n') + '\n' + return help + + def _join_parts(self, part_strings): + return ''.join([part + for part in part_strings + if part and part is not SUPPRESS]) + + def _format_usage(self, usage, actions, groups, prefix): + if prefix is None: + prefix = _('usage: ') + + # if usage is specified, use that + if usage is not None: + usage = usage % dict(prog=self._prog) + + # if no optionals or positionals are available, usage is just prog + elif usage is None and not actions: + usage = '%(prog)s' % dict(prog=self._prog) + + # if optionals and positionals are available, calculate usage + elif usage is None: + prog = '%(prog)s' % dict(prog=self._prog) + + # split optionals from positionals + optionals = [] + positionals = [] + for action in actions: + if action.option_strings: + optionals.append(action) + else: + positionals.append(action) + + # build full usage string + format = self._format_actions_usage + action_usage = format(optionals + positionals, groups) + usage = ' '.join([s for s in [prog, action_usage] if s]) + + # wrap the usage parts if it's too long + text_width = self._width - self._current_indent + if len(prefix) + len(usage) > text_width: + + # break usage into wrappable parts + part_regexp = r'\(.*?\)+|\[.*?\]+|\S+' + opt_usage = format(optionals, groups) + pos_usage = format(positionals, groups) + opt_parts = _re.findall(part_regexp, opt_usage) + pos_parts = _re.findall(part_regexp, pos_usage) + assert ' '.join(opt_parts) == opt_usage + assert ' '.join(pos_parts) == pos_usage + + # helper for wrapping lines + def get_lines(parts, indent, prefix=None): + lines = [] + line = [] + if prefix is not None: + line_len = len(prefix) - 1 + else: + line_len = len(indent) - 1 + for part in parts: + if line_len + 1 + len(part) > text_width: + lines.append(indent + ' '.join(line)) + line = [] + line_len = len(indent) - 1 + line.append(part) + line_len += len(part) + 1 + if line: + lines.append(indent + ' '.join(line)) + if prefix is not None: + lines[0] = lines[0][len(indent):] + return lines + + # if prog is short, follow it with optionals or positionals + if len(prefix) + len(prog) <= 0.75 * text_width: + indent = ' ' * (len(prefix) + len(prog) + 1) + if opt_parts: + lines = get_lines([prog] + opt_parts, indent, prefix) + lines.extend(get_lines(pos_parts, indent)) + elif pos_parts: + lines = get_lines([prog] + pos_parts, indent, prefix) + else: + lines = [prog] + + # if prog is long, put it on its own line + else: + indent = ' ' * len(prefix) + parts = opt_parts + pos_parts + lines = get_lines(parts, indent) + if len(lines) > 1: + lines = [] + lines.extend(get_lines(opt_parts, indent)) + lines.extend(get_lines(pos_parts, indent)) + lines = [prog] + lines + + # join lines into usage + usage = '\n'.join(lines) + + # prefix with 'usage:' + return '%s%s\n\n' % (prefix, usage) + + def _format_actions_usage(self, actions, groups): + # find group indices and identify actions in groups + group_actions = set() + inserts = {} + for group in groups: + try: + start = actions.index(group._group_actions[0]) + except ValueError: + continue + else: + end = start + len(group._group_actions) + if actions[start:end] == group._group_actions: + for action in group._group_actions: + group_actions.add(action) + if not group.required: + if start in inserts: + inserts[start] += ' [' + else: + inserts[start] = '[' + inserts[end] = ']' + else: + if start in inserts: + inserts[start] += ' (' + else: + inserts[start] = '(' + inserts[end] = ')' + for i in range(start + 1, end): + inserts[i] = '|' + + # collect all actions format strings + parts = [] + for i, action in enumerate(actions): + + # suppressed arguments are marked with None + # remove | separators for suppressed arguments + if action.help is SUPPRESS: + parts.append(None) + if inserts.get(i) == '|': + inserts.pop(i) + elif inserts.get(i + 1) == '|': + inserts.pop(i + 1) + + # produce all arg strings + elif not action.option_strings: + part = self._format_args(action, action.dest) + + # if it's in a group, strip the outer [] + if action in group_actions: + if part[0] == '[' and part[-1] == ']': + part = part[1:-1] + + # add the action string to the list + parts.append(part) + + # produce the first way to invoke the option in brackets + else: + option_string = action.option_strings[0] + + # if the Optional doesn't take a value, format is: + # -s or --long + if action.nargs == 0: + part = '%s' % option_string + + # if the Optional takes a value, format is: + # -s ARGS or --long ARGS + else: + default = action.dest.upper() + args_string = self._format_args(action, default) + part = '%s %s' % (option_string, args_string) + + # make it look optional if it's not required or in a group + if not action.required and action not in group_actions: + part = '[%s]' % part + + # add the action string to the list + parts.append(part) + + # insert things at the necessary indices + for i in sorted(inserts, reverse=True): + parts[i:i] = [inserts[i]] + + # join all the action items with spaces + text = ' '.join([item for item in parts if item is not None]) + + # clean up separators for mutually exclusive groups + open = r'[\[(]' + close = r'[\])]' + text = _re.sub(r'(%s) ' % open, r'\1', text) + text = _re.sub(r' (%s)' % close, r'\1', text) + text = _re.sub(r'%s *%s' % (open, close), r'', text) + text = _re.sub(r'\(([^|]*)\)', r'\1', text) + text = text.strip() + + # return the text + return text + + def _format_text(self, text): + if '%(prog)' in text: + text = text % dict(prog=self._prog) + text_width = self._width - self._current_indent + indent = ' ' * self._current_indent + return self._fill_text(text, text_width, indent) + '\n\n' + + def _format_action(self, action): + # determine the required width and the entry label + help_position = min(self._action_max_length + 2, + self._max_help_position) + help_width = self._width - help_position + action_width = help_position - self._current_indent - 2 + action_header = self._format_action_invocation(action) + + # ho nelp; start on same line and add a final newline + if not action.help: + tup = self._current_indent, '', action_header + action_header = '%*s%s\n' % tup + + # short action name; start on the same line and pad two spaces + elif len(action_header) <= action_width: + tup = self._current_indent, '', action_width, action_header + action_header = '%*s%-*s ' % tup + indent_first = 0 + + # long action name; start on the next line + else: + tup = self._current_indent, '', action_header + action_header = '%*s%s\n' % tup + indent_first = help_position + + # collect the pieces of the action help + parts = [action_header] + + # if there was help for the action, add lines of help text + if action.help: + help_text = self._expand_help(action) + help_lines = self._split_lines(help_text, help_width) + parts.append('%*s%s\n' % (indent_first, '', help_lines[0])) + for line in help_lines[1:]: + parts.append('%*s%s\n' % (help_position, '', line)) + + # or add a newline if the description doesn't end with one + elif not action_header.endswith('\n'): + parts.append('\n') + + # if there are any sub-actions, add their help as well + for subaction in self._iter_indented_subactions(action): + parts.append(self._format_action(subaction)) + + # return a single string + return self._join_parts(parts) + + def _format_action_invocation(self, action): + if not action.option_strings: + metavar, = self._metavar_formatter(action, action.dest)(1) + return metavar + + else: + parts = [] + + # if the Optional doesn't take a value, format is: + # -s, --long + if action.nargs == 0: + parts.extend(action.option_strings) + + # if the Optional takes a value, format is: + # -s ARGS, --long ARGS + else: + default = action.dest.upper() + args_string = self._format_args(action, default) + for option_string in action.option_strings: + parts.append('%s %s' % (option_string, args_string)) + + return ', '.join(parts) + + def _metavar_formatter(self, action, default_metavar): + if action.metavar is not None: + result = action.metavar + elif action.choices is not None: + choice_strs = [str(choice) for choice in action.choices] + result = '{%s}' % ','.join(choice_strs) + else: + result = default_metavar + + def format(tuple_size): + if isinstance(result, tuple): + return result + else: + return (result, ) * tuple_size + return format + + def _format_args(self, action, default_metavar): + get_metavar = self._metavar_formatter(action, default_metavar) + if action.nargs is None: + result = '%s' % get_metavar(1) + elif action.nargs == OPTIONAL: + result = '[%s]' % get_metavar(1) + elif action.nargs == ZERO_OR_MORE: + result = '[%s [%s ...]]' % get_metavar(2) + elif action.nargs == ONE_OR_MORE: + result = '%s [%s ...]' % get_metavar(2) + elif action.nargs == REMAINDER: + result = '...' + elif action.nargs == PARSER: + result = '%s ...' % get_metavar(1) + else: + formats = ['%s' for _ in range(action.nargs)] + result = ' '.join(formats) % get_metavar(action.nargs) + return result + + def _expand_help(self, action): + params = dict(vars(action), prog=self._prog) + for name in list(params): + if params[name] is SUPPRESS: + del params[name] + for name in list(params): + if hasattr(params[name], '__name__'): + params[name] = params[name].__name__ + if params.get('choices') is not None: + choices_str = ', '.join([str(c) for c in params['choices']]) + params['choices'] = choices_str + return self._get_help_string(action) % params + + def _iter_indented_subactions(self, action): + try: + get_subactions = action._get_subactions + except AttributeError: + pass + else: + self._indent() + for subaction in get_subactions(): + yield subaction + self._dedent() + + def _split_lines(self, text, width): + text = self._whitespace_matcher.sub(' ', text).strip() + return _textwrap.wrap(text, width) + + def _fill_text(self, text, width, indent): + text = self._whitespace_matcher.sub(' ', text).strip() + return _textwrap.fill(text, width, initial_indent=indent, + subsequent_indent=indent) + + def _get_help_string(self, action): + return action.help + + +class RawDescriptionHelpFormatter(HelpFormatter): + """Help message formatter which retains any formatting in descriptions. + + Only the name of this class is considered a public API. All the methods + provided by the class are considered an implementation detail. + """ + + def _fill_text(self, text, width, indent): + return ''.join([indent + line for line in text.splitlines(True)]) + + +class RawTextHelpFormatter(RawDescriptionHelpFormatter): + """Help message formatter which retains formatting of all help text. + + Only the name of this class is considered a public API. All the methods + provided by the class are considered an implementation detail. + """ + + def _split_lines(self, text, width): + return text.splitlines() + + +class ArgumentDefaultsHelpFormatter(HelpFormatter): + """Help message formatter which adds default values to argument help. + + Only the name of this class is considered a public API. All the methods + provided by the class are considered an implementation detail. + """ + + def _get_help_string(self, action): + help = action.help + if '%(default)' not in action.help: + if action.default is not SUPPRESS: + defaulting_nargs = [OPTIONAL, ZERO_OR_MORE] + if action.option_strings or action.nargs in defaulting_nargs: + help += ' (default: %(default)s)' + return help + + +# ===================== +# Options and Arguments +# ===================== + +def _get_action_name(argument): + if argument is None: + return None + elif argument.option_strings: + return '/'.join(argument.option_strings) + elif argument.metavar not in (None, SUPPRESS): + return argument.metavar + elif argument.dest not in (None, SUPPRESS): + return argument.dest + else: + return None + + +class ArgumentError(Exception): + """An error from creating or using an argument (optional or positional). + + The string value of this exception is the message, augmented with + information about the argument that caused it. + """ + + def __init__(self, argument, message): + self.argument_name = _get_action_name(argument) + self.message = message + + def __str__(self): + if self.argument_name is None: + format = '%(message)s' + else: + format = 'argument %(argument_name)s: %(message)s' + return format % dict(message=self.message, + argument_name=self.argument_name) + + +class ArgumentTypeError(Exception): + """An error from trying to convert a command line string to a type.""" + pass + + +# ============== +# Action classes +# ============== + +class Action(_AttributeHolder): + """Information about how to convert command line strings to Python objects. + + Action objects are used by an ArgumentParser to represent the information + needed to parse a single argument from one or more strings from the + command line. The keyword arguments to the Action constructor are also + all attributes of Action instances. + + Keyword Arguments: + + - option_strings -- A list of command-line option strings which + should be associated with this action. + + - dest -- The name of the attribute to hold the created object(s) + + - nargs -- The number of command-line arguments that should be + consumed. By default, one argument will be consumed and a single + value will be produced. Other values include: + - N (an integer) consumes N arguments (and produces a list) + - '?' consumes zero or one arguments + - '*' consumes zero or more arguments (and produces a list) + - '+' consumes one or more arguments (and produces a list) + Note that the difference between the default and nargs=1 is that + with the default, a single value will be produced, while with + nargs=1, a list containing a single value will be produced. + + - const -- The value to be produced if the option is specified and the + option uses an action that takes no values. + + - default -- The value to be produced if the option is not specified. + + - type -- The type which the command-line arguments should be converted + to, should be one of 'string', 'int', 'float', 'complex' or a + callable object that accepts a single string argument. If None, + 'string' is assumed. + + - choices -- A container of values that should be allowed. If not None, + after a command-line argument has been converted to the appropriate + type, an exception will be raised if it is not a member of this + collection. + + - required -- True if the action must always be specified at the + command line. This is only meaningful for optional command-line + arguments. + + - help -- The help string describing the argument. + + - metavar -- The name to be used for the option's argument with the + help string. If None, the 'dest' value will be used as the name. + """ + + def __init__(self, + option_strings, + dest, + nargs=None, + const=None, + default=None, + type=None, + choices=None, + required=False, + help=None, + metavar=None): + self.option_strings = option_strings + self.dest = dest + self.nargs = nargs + self.const = const + self.default = default + self.type = type + self.choices = choices + self.required = required + self.help = help + self.metavar = metavar + + def _get_kwargs(self): + names = [ + 'option_strings', + 'dest', + 'nargs', + 'const', + 'default', + 'type', + 'choices', + 'help', + 'metavar', + ] + return [(name, getattr(self, name)) for name in names] + + def __call__(self, parser, namespace, values, option_string=None): + raise NotImplementedError(_('.__call__() not defined')) + + +class _StoreAction(Action): + + def __init__(self, + option_strings, + dest, + nargs=None, + const=None, + default=None, + type=None, + choices=None, + required=False, + help=None, + metavar=None): + if nargs == 0: + raise ValueError('nargs for store actions must be > 0; if you ' + 'have nothing to store, actions such as store ' + 'true or store const may be more appropriate') + if const is not None and nargs != OPTIONAL: + raise ValueError('nargs must be %r to supply const' % OPTIONAL) + super(_StoreAction, self).__init__( + option_strings=option_strings, + dest=dest, + nargs=nargs, + const=const, + default=default, + type=type, + choices=choices, + required=required, + help=help, + metavar=metavar) + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, values) + + +class _StoreConstAction(Action): + + def __init__(self, + option_strings, + dest, + const, + default=None, + required=False, + help=None, + metavar=None): + super(_StoreConstAction, self).__init__( + option_strings=option_strings, + dest=dest, + nargs=0, + const=const, + default=default, + required=required, + help=help) + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, self.const) + + +class _StoreTrueAction(_StoreConstAction): + + def __init__(self, + option_strings, + dest, + default=False, + required=False, + help=None): + super(_StoreTrueAction, self).__init__( + option_strings=option_strings, + dest=dest, + const=True, + default=default, + required=required, + help=help) + + +class _StoreFalseAction(_StoreConstAction): + + def __init__(self, + option_strings, + dest, + default=True, + required=False, + help=None): + super(_StoreFalseAction, self).__init__( + option_strings=option_strings, + dest=dest, + const=False, + default=default, + required=required, + help=help) + + +class _AppendAction(Action): + + def __init__(self, + option_strings, + dest, + nargs=None, + const=None, + default=None, + type=None, + choices=None, + required=False, + help=None, + metavar=None): + if nargs == 0: + raise ValueError('nargs for append actions must be > 0; if arg ' + 'strings are not supplying the value to append, ' + 'the append const action may be more appropriate') + if const is not None and nargs != OPTIONAL: + raise ValueError('nargs must be %r to supply const' % OPTIONAL) + super(_AppendAction, self).__init__( + option_strings=option_strings, + dest=dest, + nargs=nargs, + const=const, + default=default, + type=type, + choices=choices, + required=required, + help=help, + metavar=metavar) + + def __call__(self, parser, namespace, values, option_string=None): + items = _copy.copy(_ensure_value(namespace, self.dest, [])) + items.append(values) + setattr(namespace, self.dest, items) + + +class _AppendConstAction(Action): + + def __init__(self, + option_strings, + dest, + const, + default=None, + required=False, + help=None, + metavar=None): + super(_AppendConstAction, self).__init__( + option_strings=option_strings, + dest=dest, + nargs=0, + const=const, + default=default, + required=required, + help=help, + metavar=metavar) + + def __call__(self, parser, namespace, values, option_string=None): + items = _copy.copy(_ensure_value(namespace, self.dest, [])) + items.append(self.const) + setattr(namespace, self.dest, items) + + +class _CountAction(Action): + + def __init__(self, + option_strings, + dest, + default=None, + required=False, + help=None): + super(_CountAction, self).__init__( + option_strings=option_strings, + dest=dest, + nargs=0, + default=default, + required=required, + help=help) + + def __call__(self, parser, namespace, values, option_string=None): + new_count = _ensure_value(namespace, self.dest, 0) + 1 + setattr(namespace, self.dest, new_count) + + +class _HelpAction(Action): + + def __init__(self, + option_strings, + dest=SUPPRESS, + default=SUPPRESS, + help=None): + super(_HelpAction, self).__init__( + option_strings=option_strings, + dest=dest, + default=default, + nargs=0, + help=help) + + def __call__(self, parser, namespace, values, option_string=None): + parser.print_help() + parser.exit() + + +class _VersionAction(Action): + + def __init__(self, + option_strings, + version=None, + dest=SUPPRESS, + default=SUPPRESS, + help="show program's version number and exit"): + super(_VersionAction, self).__init__( + option_strings=option_strings, + dest=dest, + default=default, + nargs=0, + help=help) + self.version = version + + def __call__(self, parser, namespace, values, option_string=None): + version = self.version + if version is None: + version = parser.version + formatter = parser._get_formatter() + formatter.add_text(version) + parser.exit(message=formatter.format_help()) + + +class _SubParsersAction(Action): + + class _ChoicesPseudoAction(Action): + + def __init__(self, name, aliases, help): + metavar = dest = name + if aliases: + metavar += ' (%s)' % ', '.join(aliases) + sup = super(_SubParsersAction._ChoicesPseudoAction, self) + sup.__init__(option_strings=[], dest=dest, help=help, + metavar=metavar) + + def __init__(self, + option_strings, + prog, + parser_class, + dest=SUPPRESS, + help=None, + metavar=None): + + self._prog_prefix = prog + self._parser_class = parser_class + self._name_parser_map = {} + self._choices_actions = [] + + super(_SubParsersAction, self).__init__( + option_strings=option_strings, + dest=dest, + nargs=PARSER, + choices=self._name_parser_map, + help=help, + metavar=metavar) + + def add_parser(self, name, **kwargs): + # set prog from the existing prefix + if kwargs.get('prog') is None: + kwargs['prog'] = '%s %s' % (self._prog_prefix, name) + + aliases = kwargs.pop('aliases', ()) + + # create a pseudo-action to hold the choice help + if 'help' in kwargs: + help = kwargs.pop('help') + choice_action = self._ChoicesPseudoAction(name, aliases, help) + self._choices_actions.append(choice_action) + + # create the parser and add it to the map + parser = self._parser_class(**kwargs) + self._name_parser_map[name] = parser + + # make parser available under aliases also + for alias in aliases: + self._name_parser_map[alias] = parser + + return parser + + def _get_subactions(self): + return self._choices_actions + + def __call__(self, parser, namespace, values, option_string=None): + parser_name = values[0] + arg_strings = values[1:] + + # set the parser name if requested + if self.dest is not SUPPRESS: + setattr(namespace, self.dest, parser_name) + + # select the parser + try: + parser = self._name_parser_map[parser_name] + except KeyError: + tup = parser_name, ', '.join(self._name_parser_map) + msg = _('unknown parser %r (choices: %s)' % tup) + raise ArgumentError(self, msg) + + # parse all the remaining options into the namespace + # store any unrecognized options on the object, so that the top + # level parser can decide what to do with them + namespace, arg_strings = parser.parse_known_args(arg_strings, namespace) + if arg_strings: + vars(namespace).setdefault(_UNRECOGNIZED_ARGS_ATTR, []) + getattr(namespace, _UNRECOGNIZED_ARGS_ATTR).extend(arg_strings) + + +# ============== +# Type classes +# ============== + +class FileType(object): + """Factory for creating file object types + + Instances of FileType are typically passed as type= arguments to the + ArgumentParser add_argument() method. + + Keyword Arguments: + - mode -- A string indicating how the file is to be opened. Accepts the + same values as the builtin open() function. + - bufsize -- The file's desired buffer size. Accepts the same values as + the builtin open() function. + """ + + def __init__(self, mode='r', bufsize=None): + self._mode = mode + self._bufsize = bufsize + + def __call__(self, string): + # the special argument "-" means sys.std{in,out} + if string == '-': + if 'r' in self._mode: + return _sys.stdin + elif 'w' in self._mode: + return _sys.stdout + else: + msg = _('argument "-" with mode %r' % self._mode) + raise ValueError(msg) + + # all other arguments are used as file names + if self._bufsize: + return open(string, self._mode, self._bufsize) + else: + return open(string, self._mode) + + def __repr__(self): + args = [self._mode, self._bufsize] + args_str = ', '.join([repr(arg) for arg in args if arg is not None]) + return '%s(%s)' % (type(self).__name__, args_str) + +# =========================== +# Optional and Positional Parsing +# =========================== + +class Namespace(_AttributeHolder): + """Simple object for storing attributes. + + Implements equality by attribute names and values, and provides a simple + string representation. + """ + + def __init__(self, **kwargs): + for name in kwargs: + setattr(self, name, kwargs[name]) + + __hash__ = None + + def __eq__(self, other): + return vars(self) == vars(other) + + def __ne__(self, other): + return not (self == other) + + def __contains__(self, key): + return key in self.__dict__ + + +class _ActionsContainer(object): + + def __init__(self, + description, + prefix_chars, + argument_default, + conflict_handler): + super(_ActionsContainer, self).__init__() + + self.description = description + self.argument_default = argument_default + self.prefix_chars = prefix_chars + self.conflict_handler = conflict_handler + + # set up registries + self._registries = {} + + # register actions + self.register('action', None, _StoreAction) + self.register('action', 'store', _StoreAction) + self.register('action', 'store_const', _StoreConstAction) + self.register('action', 'store_true', _StoreTrueAction) + self.register('action', 'store_false', _StoreFalseAction) + self.register('action', 'append', _AppendAction) + self.register('action', 'append_const', _AppendConstAction) + self.register('action', 'count', _CountAction) + self.register('action', 'help', _HelpAction) + self.register('action', 'version', _VersionAction) + self.register('action', 'parsers', _SubParsersAction) + + # raise an exception if the conflict handler is invalid + self._get_handler() + + # action storage + self._actions = [] + self._option_string_actions = {} + + # groups + self._action_groups = [] + self._mutually_exclusive_groups = [] + + # defaults storage + self._defaults = {} + + # determines whether an "option" looks like a negative number + self._negative_number_matcher = _re.compile(r'^-\d+$|^-\d*\.\d+$') + + # whether or not there are any optionals that look like negative + # numbers -- uses a list so it can be shared and edited + self._has_negative_number_optionals = [] + + # ==================== + # Registration methods + # ==================== + def register(self, registry_name, value, object): + registry = self._registries.setdefault(registry_name, {}) + registry[value] = object + + def _registry_get(self, registry_name, value, default=None): + return self._registries[registry_name].get(value, default) + + # ================================== + # Namespace default accessor methods + # ================================== + def set_defaults(self, **kwargs): + self._defaults.update(kwargs) + + # if these defaults match any existing arguments, replace + # the previous default on the object with the new one + for action in self._actions: + if action.dest in kwargs: + action.default = kwargs[action.dest] + + def get_default(self, dest): + for action in self._actions: + if action.dest == dest and action.default is not None: + return action.default + return self._defaults.get(dest, None) + + + # ======================= + # Adding argument actions + # ======================= + def add_argument(self, *args, **kwargs): + """ + add_argument(dest, ..., name=value, ...) + add_argument(option_string, option_string, ..., name=value, ...) + """ + + # if no positional args are supplied or only one is supplied and + # it doesn't look like an option string, parse a positional + # argument + chars = self.prefix_chars + if not args or len(args) == 1 and args[0][0] not in chars: + if args and 'dest' in kwargs: + raise ValueError('dest supplied twice for positional argument') + kwargs = self._get_positional_kwargs(*args, **kwargs) + + # otherwise, we're adding an optional argument + else: + kwargs = self._get_optional_kwargs(*args, **kwargs) + + # if no default was supplied, use the parser-level default + if 'default' not in kwargs: + dest = kwargs['dest'] + if dest in self._defaults: + kwargs['default'] = self._defaults[dest] + elif self.argument_default is not None: + kwargs['default'] = self.argument_default + + # create the action object, and add it to the parser + action_class = self._pop_action_class(kwargs) + if not _callable(action_class): + raise ValueError('unknown action "%s"' % action_class) + action = action_class(**kwargs) + + # raise an error if the action type is not callable + type_func = self._registry_get('type', action.type, action.type) + if not _callable(type_func): + raise ValueError('%r is not callable' % type_func) + + return self._add_action(action) + + def add_argument_group(self, *args, **kwargs): + group = _ArgumentGroup(self, *args, **kwargs) + self._action_groups.append(group) + return group + + def add_mutually_exclusive_group(self, **kwargs): + group = _MutuallyExclusiveGroup(self, **kwargs) + self._mutually_exclusive_groups.append(group) + return group + + def _add_action(self, action): + # resolve any conflicts + self._check_conflict(action) + + # add to actions list + self._actions.append(action) + action.container = self + + # index the action by any option strings it has + for option_string in action.option_strings: + self._option_string_actions[option_string] = action + + # set the flag if any option strings look like negative numbers + for option_string in action.option_strings: + if self._negative_number_matcher.match(option_string): + if not self._has_negative_number_optionals: + self._has_negative_number_optionals.append(True) + + # return the created action + return action + + def _remove_action(self, action): + self._actions.remove(action) + + def _add_container_actions(self, container): + # collect groups by titles + title_group_map = {} + for group in self._action_groups: + if group.title in title_group_map: + msg = _('cannot merge actions - two groups are named %r') + raise ValueError(msg % (group.title)) + title_group_map[group.title] = group + + # map each action to its group + group_map = {} + for group in container._action_groups: + + # if a group with the title exists, use that, otherwise + # create a new group matching the container's group + if group.title not in title_group_map: + title_group_map[group.title] = self.add_argument_group( + title=group.title, + description=group.description, + conflict_handler=group.conflict_handler) + + # map the actions to their new group + for action in group._group_actions: + group_map[action] = title_group_map[group.title] + + # add container's mutually exclusive groups + # NOTE: if add_mutually_exclusive_group ever gains title= and + # description= then this code will need to be expanded as above + for group in container._mutually_exclusive_groups: + mutex_group = self.add_mutually_exclusive_group( + required=group.required) + + # map the actions to their new mutex group + for action in group._group_actions: + group_map[action] = mutex_group + + # add all actions to this container or their group + for action in container._actions: + group_map.get(action, self)._add_action(action) + + def _get_positional_kwargs(self, dest, **kwargs): + # make sure required is not specified + if 'required' in kwargs: + msg = _("'required' is an invalid argument for positionals") + raise TypeError(msg) + + # mark positional arguments as required if at least one is + # always required + if kwargs.get('nargs') not in [OPTIONAL, ZERO_OR_MORE]: + kwargs['required'] = True + if kwargs.get('nargs') == ZERO_OR_MORE and 'default' not in kwargs: + kwargs['required'] = True + + # return the keyword arguments with no option strings + return dict(kwargs, dest=dest, option_strings=[]) + + def _get_optional_kwargs(self, *args, **kwargs): + # determine short and long option strings + option_strings = [] + long_option_strings = [] + for option_string in args: + # error on strings that don't start with an appropriate prefix + if not option_string[0] in self.prefix_chars: + msg = _('invalid option string %r: ' + 'must start with a character %r') + tup = option_string, self.prefix_chars + raise ValueError(msg % tup) + + # strings starting with two prefix characters are long options + option_strings.append(option_string) + if option_string[0] in self.prefix_chars: + if len(option_string) > 1: + if option_string[1] in self.prefix_chars: + long_option_strings.append(option_string) + + # infer destination, '--foo-bar' -> 'foo_bar' and '-x' -> 'x' + dest = kwargs.pop('dest', None) + if dest is None: + if long_option_strings: + dest_option_string = long_option_strings[0] + else: + dest_option_string = option_strings[0] + dest = dest_option_string.lstrip(self.prefix_chars) + if not dest: + msg = _('dest= is required for options like %r') + raise ValueError(msg % option_string) + dest = dest.replace('-', '_') + + # return the updated keyword arguments + return dict(kwargs, dest=dest, option_strings=option_strings) + + def _pop_action_class(self, kwargs, default=None): + action = kwargs.pop('action', default) + return self._registry_get('action', action, action) + + def _get_handler(self): + # determine function from conflict handler string + handler_func_name = '_handle_conflict_%s' % self.conflict_handler + try: + return getattr(self, handler_func_name) + except AttributeError: + msg = _('invalid conflict_resolution value: %r') + raise ValueError(msg % self.conflict_handler) + + def _check_conflict(self, action): + + # find all options that conflict with this option + confl_optionals = [] + for option_string in action.option_strings: + if option_string in self._option_string_actions: + confl_optional = self._option_string_actions[option_string] + confl_optionals.append((option_string, confl_optional)) + + # resolve any conflicts + if confl_optionals: + conflict_handler = self._get_handler() + conflict_handler(action, confl_optionals) + + def _handle_conflict_error(self, action, conflicting_actions): + message = _('conflicting option string(s): %s') + conflict_string = ', '.join([option_string + for option_string, action + in conflicting_actions]) + raise ArgumentError(action, message % conflict_string) + + def _handle_conflict_resolve(self, action, conflicting_actions): + + # remove all conflicting options + for option_string, action in conflicting_actions: + + # remove the conflicting option + action.option_strings.remove(option_string) + self._option_string_actions.pop(option_string, None) + + # if the option now has no option string, remove it from the + # container holding it + if not action.option_strings: + action.container._remove_action(action) + + +class _ArgumentGroup(_ActionsContainer): + + def __init__(self, container, title=None, description=None, **kwargs): + # add any missing keyword arguments by checking the container + update = kwargs.setdefault + update('conflict_handler', container.conflict_handler) + update('prefix_chars', container.prefix_chars) + update('argument_default', container.argument_default) + super_init = super(_ArgumentGroup, self).__init__ + super_init(description=description, **kwargs) + + # group attributes + self.title = title + self._group_actions = [] + + # share most attributes with the container + self._registries = container._registries + self._actions = container._actions + self._option_string_actions = container._option_string_actions + self._defaults = container._defaults + self._has_negative_number_optionals = \ + container._has_negative_number_optionals + + def _add_action(self, action): + action = super(_ArgumentGroup, self)._add_action(action) + self._group_actions.append(action) + return action + + def _remove_action(self, action): + super(_ArgumentGroup, self)._remove_action(action) + self._group_actions.remove(action) + + +class _MutuallyExclusiveGroup(_ArgumentGroup): + + def __init__(self, container, required=False): + super(_MutuallyExclusiveGroup, self).__init__(container) + self.required = required + self._container = container + + def _add_action(self, action): + if action.required: + msg = _('mutually exclusive arguments must be optional') + raise ValueError(msg) + action = self._container._add_action(action) + self._group_actions.append(action) + return action + + def _remove_action(self, action): + self._container._remove_action(action) + self._group_actions.remove(action) + + +class ArgumentParser(_AttributeHolder, _ActionsContainer): + """Object for parsing command line strings into Python objects. + + Keyword Arguments: + - prog -- The name of the program (default: sys.argv[0]) + - usage -- A usage message (default: auto-generated from arguments) + - description -- A description of what the program does + - epilog -- Text following the argument descriptions + - parents -- Parsers whose arguments should be copied into this one + - formatter_class -- HelpFormatter class for printing help messages + - prefix_chars -- Characters that prefix optional arguments + - fromfile_prefix_chars -- Characters that prefix files containing + additional arguments + - argument_default -- The default value for all arguments + - conflict_handler -- String indicating how to handle conflicts + - add_help -- Add a -h/-help option + """ + + def __init__(self, + prog=None, + usage=None, + description=None, + epilog=None, + version=None, + parents=[], + formatter_class=HelpFormatter, + prefix_chars='-', + fromfile_prefix_chars=None, + argument_default=None, + conflict_handler='error', + add_help=True): + + if version is not None: + import warnings + warnings.warn( + """The "version" argument to ArgumentParser is deprecated. """ + """Please use """ + """"add_argument(..., action='version', version="N", ...)" """ + """instead""", DeprecationWarning) + + superinit = super(ArgumentParser, self).__init__ + superinit(description=description, + prefix_chars=prefix_chars, + argument_default=argument_default, + conflict_handler=conflict_handler) + + # default setting for prog + if prog is None: + prog = _os.path.basename(_sys.argv[0]) + + self.prog = prog + self.usage = usage + self.epilog = epilog + self.version = version + self.formatter_class = formatter_class + self.fromfile_prefix_chars = fromfile_prefix_chars + self.add_help = add_help + + add_group = self.add_argument_group + self._positionals = add_group(_('positional arguments')) + self._optionals = add_group(_('optional arguments')) + self._subparsers = None + + # register types + def identity(string): + return string + self.register('type', None, identity) + + # add help and version arguments if necessary + # (using explicit default to override global argument_default) + if '-' in prefix_chars: + default_prefix = '-' + else: + default_prefix = prefix_chars[0] + if self.add_help: + self.add_argument( + default_prefix+'h', default_prefix*2+'help', + action='help', default=SUPPRESS, + help=_('show this help message and exit')) + if self.version: + self.add_argument( + default_prefix+'v', default_prefix*2+'version', + action='version', default=SUPPRESS, + version=self.version, + help=_("show program's version number and exit")) + + # add parent arguments and defaults + for parent in parents: + self._add_container_actions(parent) + try: + defaults = parent._defaults + except AttributeError: + pass + else: + self._defaults.update(defaults) + + # ======================= + # Pretty __repr__ methods + # ======================= + def _get_kwargs(self): + names = [ + 'prog', + 'usage', + 'description', + 'version', + 'formatter_class', + 'conflict_handler', + 'add_help', + ] + return [(name, getattr(self, name)) for name in names] + + # ================================== + # Optional/Positional adding methods + # ================================== + def add_subparsers(self, **kwargs): + if self._subparsers is not None: + self.error(_('cannot have multiple subparser arguments')) + + # add the parser class to the arguments if it's not present + kwargs.setdefault('parser_class', type(self)) + + if 'title' in kwargs or 'description' in kwargs: + title = _(kwargs.pop('title', 'subcommands')) + description = _(kwargs.pop('description', None)) + self._subparsers = self.add_argument_group(title, description) + else: + self._subparsers = self._positionals + + # prog defaults to the usage message of this parser, skipping + # optional arguments and with no "usage:" prefix + if kwargs.get('prog') is None: + formatter = self._get_formatter() + positionals = self._get_positional_actions() + groups = self._mutually_exclusive_groups + formatter.add_usage(self.usage, positionals, groups, '') + kwargs['prog'] = formatter.format_help().strip() + + # create the parsers action and add it to the positionals list + parsers_class = self._pop_action_class(kwargs, 'parsers') + action = parsers_class(option_strings=[], **kwargs) + self._subparsers._add_action(action) + + # return the created parsers action + return action + + def _add_action(self, action): + if action.option_strings: + self._optionals._add_action(action) + else: + self._positionals._add_action(action) + return action + + def _get_optional_actions(self): + return [action + for action in self._actions + if action.option_strings] + + def _get_positional_actions(self): + return [action + for action in self._actions + if not action.option_strings] + + # ===================================== + # Command line argument parsing methods + # ===================================== + def parse_args(self, args=None, namespace=None): + args, argv = self.parse_known_args(args, namespace) + if argv: + msg = _('unrecognized arguments: %s') + self.error(msg % ' '.join(argv)) + return args + + def parse_known_args(self, args=None, namespace=None): + # args default to the system args + if args is None: + args = _sys.argv[1:] + + # default Namespace built from parser defaults + if namespace is None: + namespace = Namespace() + + # add any action defaults that aren't present + for action in self._actions: + if action.dest is not SUPPRESS: + if not hasattr(namespace, action.dest): + if action.default is not SUPPRESS: + default = action.default + if isinstance(action.default, basestring): + default = self._get_value(action, default) + setattr(namespace, action.dest, default) + + # add any parser defaults that aren't present + for dest in self._defaults: + if not hasattr(namespace, dest): + setattr(namespace, dest, self._defaults[dest]) + + # parse the arguments and exit if there are any errors + try: + namespace, args = self._parse_known_args(args, namespace) + if hasattr(namespace, _UNRECOGNIZED_ARGS_ATTR): + args.extend(getattr(namespace, _UNRECOGNIZED_ARGS_ATTR)) + delattr(namespace, _UNRECOGNIZED_ARGS_ATTR) + return namespace, args + except ArgumentError: + err = _sys.exc_info()[1] + self.error(str(err)) + + def _parse_known_args(self, arg_strings, namespace): + # replace arg strings that are file references + if self.fromfile_prefix_chars is not None: + arg_strings = self._read_args_from_files(arg_strings) + + # map all mutually exclusive arguments to the other arguments + # they can't occur with + action_conflicts = {} + for mutex_group in self._mutually_exclusive_groups: + group_actions = mutex_group._group_actions + for i, mutex_action in enumerate(mutex_group._group_actions): + conflicts = action_conflicts.setdefault(mutex_action, []) + conflicts.extend(group_actions[:i]) + conflicts.extend(group_actions[i + 1:]) + + # find all option indices, and determine the arg_string_pattern + # which has an 'O' if there is an option at an index, + # an 'A' if there is an argument, or a '-' if there is a '--' + option_string_indices = {} + arg_string_pattern_parts = [] + arg_strings_iter = iter(arg_strings) + for i, arg_string in enumerate(arg_strings_iter): + + # all args after -- are non-options + if arg_string == '--': + arg_string_pattern_parts.append('-') + for arg_string in arg_strings_iter: + arg_string_pattern_parts.append('A') + + # otherwise, add the arg to the arg strings + # and note the index if it was an option + else: + option_tuple = self._parse_optional(arg_string) + if option_tuple is None: + pattern = 'A' + else: + option_string_indices[i] = option_tuple + pattern = 'O' + arg_string_pattern_parts.append(pattern) + + # join the pieces together to form the pattern + arg_strings_pattern = ''.join(arg_string_pattern_parts) + + # converts arg strings to the appropriate and then takes the action + seen_actions = set() + seen_non_default_actions = set() + + def take_action(action, argument_strings, option_string=None): + seen_actions.add(action) + argument_values = self._get_values(action, argument_strings) + + # error if this argument is not allowed with other previously + # seen arguments, assuming that actions that use the default + # value don't really count as "present" + if argument_values is not action.default: + seen_non_default_actions.add(action) + for conflict_action in action_conflicts.get(action, []): + if conflict_action in seen_non_default_actions: + msg = _('not allowed with argument %s') + action_name = _get_action_name(conflict_action) + raise ArgumentError(action, msg % action_name) + + # take the action if we didn't receive a SUPPRESS value + # (e.g. from a default) + if argument_values is not SUPPRESS: + action(self, namespace, argument_values, option_string) + + # function to convert arg_strings into an optional action + def consume_optional(start_index): + + # get the optional identified at this index + option_tuple = option_string_indices[start_index] + action, option_string, explicit_arg = option_tuple + + # identify additional optionals in the same arg string + # (e.g. -xyz is the same as -x -y -z if no args are required) + match_argument = self._match_argument + action_tuples = [] + while True: + + # if we found no optional action, skip it + if action is None: + extras.append(arg_strings[start_index]) + return start_index + 1 + + # if there is an explicit argument, try to match the + # optional's string arguments to only this + if explicit_arg is not None: + arg_count = match_argument(action, 'A') + + # if the action is a single-dash option and takes no + # arguments, try to parse more single-dash options out + # of the tail of the option string + chars = self.prefix_chars + if arg_count == 0 and option_string[1] not in chars: + action_tuples.append((action, [], option_string)) + char = option_string[0] + option_string = char + explicit_arg[0] + new_explicit_arg = explicit_arg[1:] or None + optionals_map = self._option_string_actions + if option_string in optionals_map: + action = optionals_map[option_string] + explicit_arg = new_explicit_arg + else: + msg = _('ignored explicit argument %r') + raise ArgumentError(action, msg % explicit_arg) + + # if the action expect exactly one argument, we've + # successfully matched the option; exit the loop + elif arg_count == 1: + stop = start_index + 1 + args = [explicit_arg] + action_tuples.append((action, args, option_string)) + break + + # error if a double-dash option did not use the + # explicit argument + else: + msg = _('ignored explicit argument %r') + raise ArgumentError(action, msg % explicit_arg) + + # if there is no explicit argument, try to match the + # optional's string arguments with the following strings + # if successful, exit the loop + else: + start = start_index + 1 + selected_patterns = arg_strings_pattern[start:] + arg_count = match_argument(action, selected_patterns) + stop = start + arg_count + args = arg_strings[start:stop] + action_tuples.append((action, args, option_string)) + break + + # add the Optional to the list and return the index at which + # the Optional's string args stopped + assert action_tuples + for action, args, option_string in action_tuples: + take_action(action, args, option_string) + return stop + + # the list of Positionals left to be parsed; this is modified + # by consume_positionals() + positionals = self._get_positional_actions() + + # function to convert arg_strings into positional actions + def consume_positionals(start_index): + # match as many Positionals as possible + match_partial = self._match_arguments_partial + selected_pattern = arg_strings_pattern[start_index:] + arg_counts = match_partial(positionals, selected_pattern) + + # slice off the appropriate arg strings for each Positional + # and add the Positional and its args to the list + for action, arg_count in zip(positionals, arg_counts): + args = arg_strings[start_index: start_index + arg_count] + start_index += arg_count + take_action(action, args) + + # slice off the Positionals that we just parsed and return the + # index at which the Positionals' string args stopped + positionals[:] = positionals[len(arg_counts):] + return start_index + + # consume Positionals and Optionals alternately, until we have + # passed the last option string + extras = [] + start_index = 0 + if option_string_indices: + max_option_string_index = max(option_string_indices) + else: + max_option_string_index = -1 + while start_index <= max_option_string_index: + + # consume any Positionals preceding the next option + next_option_string_index = min([ + index + for index in option_string_indices + if index >= start_index]) + if start_index != next_option_string_index: + positionals_end_index = consume_positionals(start_index) + + # only try to parse the next optional if we didn't consume + # the option string during the positionals parsing + if positionals_end_index > start_index: + start_index = positionals_end_index + continue + else: + start_index = positionals_end_index + + # if we consumed all the positionals we could and we're not + # at the index of an option string, there were extra arguments + if start_index not in option_string_indices: + strings = arg_strings[start_index:next_option_string_index] + extras.extend(strings) + start_index = next_option_string_index + + # consume the next optional and any arguments for it + start_index = consume_optional(start_index) + + # consume any positionals following the last Optional + stop_index = consume_positionals(start_index) + + # if we didn't consume all the argument strings, there were extras + extras.extend(arg_strings[stop_index:]) + + # if we didn't use all the Positional objects, there were too few + # arg strings supplied. + if positionals: + self.error(_('too few arguments')) + + # make sure all required actions were present + for action in self._actions: + if action.required: + if action not in seen_actions: + name = _get_action_name(action) + self.error(_('argument %s is required') % name) + + # make sure all required groups had one option present + for group in self._mutually_exclusive_groups: + if group.required: + for action in group._group_actions: + if action in seen_non_default_actions: + break + + # if no actions were used, report the error + else: + names = [_get_action_name(action) + for action in group._group_actions + if action.help is not SUPPRESS] + msg = _('one of the arguments %s is required') + self.error(msg % ' '.join(names)) + + # return the updated namespace and the extra arguments + return namespace, extras + + def _read_args_from_files(self, arg_strings): + # expand arguments referencing files + new_arg_strings = [] + for arg_string in arg_strings: + + # for regular arguments, just add them back into the list + if arg_string[0] not in self.fromfile_prefix_chars: + new_arg_strings.append(arg_string) + + # replace arguments referencing files with the file content + else: + try: + args_file = open(arg_string[1:]) + try: + arg_strings = [] + for arg_line in args_file.read().splitlines(): + for arg in self.convert_arg_line_to_args(arg_line): + arg_strings.append(arg) + arg_strings = self._read_args_from_files(arg_strings) + new_arg_strings.extend(arg_strings) + finally: + args_file.close() + except IOError: + err = _sys.exc_info()[1] + self.error(str(err)) + + # return the modified argument list + return new_arg_strings + + def convert_arg_line_to_args(self, arg_line): + return [arg_line] + + def _match_argument(self, action, arg_strings_pattern): + # match the pattern for this action to the arg strings + nargs_pattern = self._get_nargs_pattern(action) + match = _re.match(nargs_pattern, arg_strings_pattern) + + # raise an exception if we weren't able to find a match + if match is None: + nargs_errors = { + None: _('expected one argument'), + OPTIONAL: _('expected at most one argument'), + ONE_OR_MORE: _('expected at least one argument'), + } + default = _('expected %s argument(s)') % action.nargs + msg = nargs_errors.get(action.nargs, default) + raise ArgumentError(action, msg) + + # return the number of arguments matched + return len(match.group(1)) + + def _match_arguments_partial(self, actions, arg_strings_pattern): + # progressively shorten the actions list by slicing off the + # final actions until we find a match + result = [] + for i in range(len(actions), 0, -1): + actions_slice = actions[:i] + pattern = ''.join([self._get_nargs_pattern(action) + for action in actions_slice]) + match = _re.match(pattern, arg_strings_pattern) + if match is not None: + result.extend([len(string) for string in match.groups()]) + break + + # return the list of arg string counts + return result + + def _parse_optional(self, arg_string): + # if it's an empty string, it was meant to be a positional + if not arg_string: + return None + + # if it doesn't start with a prefix, it was meant to be positional + if not arg_string[0] in self.prefix_chars: + return None + + # if the option string is present in the parser, return the action + if arg_string in self._option_string_actions: + action = self._option_string_actions[arg_string] + return action, arg_string, None + + # if it's just a single character, it was meant to be positional + if len(arg_string) == 1: + return None + + # if the option string before the "=" is present, return the action + if '=' in arg_string: + option_string, explicit_arg = arg_string.split('=', 1) + if option_string in self._option_string_actions: + action = self._option_string_actions[option_string] + return action, option_string, explicit_arg + + # search through all possible prefixes of the option string + # and all actions in the parser for possible interpretations + option_tuples = self._get_option_tuples(arg_string) + + # if multiple actions match, the option string was ambiguous + if len(option_tuples) > 1: + options = ', '.join([option_string + for action, option_string, explicit_arg in option_tuples]) + tup = arg_string, options + self.error(_('ambiguous option: %s could match %s') % tup) + + # if exactly one action matched, this segmentation is good, + # so return the parsed action + elif len(option_tuples) == 1: + option_tuple, = option_tuples + return option_tuple + + # if it was not found as an option, but it looks like a negative + # number, it was meant to be positional + # unless there are negative-number-like options + if self._negative_number_matcher.match(arg_string): + if not self._has_negative_number_optionals: + return None + + # if it contains a space, it was meant to be a positional + if ' ' in arg_string: + return None + + # it was meant to be an optional but there is no such option + # in this parser (though it might be a valid option in a subparser) + return None, arg_string, None + + def _get_option_tuples(self, option_string): + result = [] + + # option strings starting with two prefix characters are only + # split at the '=' + chars = self.prefix_chars + if option_string[0] in chars and option_string[1] in chars: + if '=' in option_string: + option_prefix, explicit_arg = option_string.split('=', 1) + else: + option_prefix = option_string + explicit_arg = None + for option_string in self._option_string_actions: + if option_string.startswith(option_prefix): + action = self._option_string_actions[option_string] + tup = action, option_string, explicit_arg + result.append(tup) + + # single character options can be concatenated with their arguments + # but multiple character options always have to have their argument + # separate + elif option_string[0] in chars and option_string[1] not in chars: + option_prefix = option_string + explicit_arg = None + short_option_prefix = option_string[:2] + short_explicit_arg = option_string[2:] + + for option_string in self._option_string_actions: + if option_string == short_option_prefix: + action = self._option_string_actions[option_string] + tup = action, option_string, short_explicit_arg + result.append(tup) + elif option_string.startswith(option_prefix): + action = self._option_string_actions[option_string] + tup = action, option_string, explicit_arg + result.append(tup) + + # shouldn't ever get here + else: + self.error(_('unexpected option string: %s') % option_string) + + # return the collected option tuples + return result + + def _get_nargs_pattern(self, action): + # in all examples below, we have to allow for '--' args + # which are represented as '-' in the pattern + nargs = action.nargs + + # the default (None) is assumed to be a single argument + if nargs is None: + nargs_pattern = '(-*A-*)' + + # allow zero or one arguments + elif nargs == OPTIONAL: + nargs_pattern = '(-*A?-*)' + + # allow zero or more arguments + elif nargs == ZERO_OR_MORE: + nargs_pattern = '(-*[A-]*)' + + # allow one or more arguments + elif nargs == ONE_OR_MORE: + nargs_pattern = '(-*A[A-]*)' + + # allow any number of options or arguments + elif nargs == REMAINDER: + nargs_pattern = '([-AO]*)' + + # allow one argument followed by any number of options or arguments + elif nargs == PARSER: + nargs_pattern = '(-*A[-AO]*)' + + # all others should be integers + else: + nargs_pattern = '(-*%s-*)' % '-*'.join('A' * nargs) + + # if this is an optional action, -- is not allowed + if action.option_strings: + nargs_pattern = nargs_pattern.replace('-*', '') + nargs_pattern = nargs_pattern.replace('-', '') + + # return the pattern + return nargs_pattern + + # ======================== + # Value conversion methods + # ======================== + def _get_values(self, action, arg_strings): + # for everything but PARSER args, strip out '--' + if action.nargs not in [PARSER, REMAINDER]: + arg_strings = [s for s in arg_strings if s != '--'] + + # optional argument produces a default when not present + if not arg_strings and action.nargs == OPTIONAL: + if action.option_strings: + value = action.const + else: + value = action.default + if isinstance(value, basestring): + value = self._get_value(action, value) + self._check_value(action, value) + + # when nargs='*' on a positional, if there were no command-line + # args, use the default if it is anything other than None + elif (not arg_strings and action.nargs == ZERO_OR_MORE and + not action.option_strings): + if action.default is not None: + value = action.default + else: + value = arg_strings + self._check_value(action, value) + + # single argument or optional argument produces a single value + elif len(arg_strings) == 1 and action.nargs in [None, OPTIONAL]: + arg_string, = arg_strings + value = self._get_value(action, arg_string) + self._check_value(action, value) + + # REMAINDER arguments convert all values, checking none + elif action.nargs == REMAINDER: + value = [self._get_value(action, v) for v in arg_strings] + + # PARSER arguments convert all values, but check only the first + elif action.nargs == PARSER: + value = [self._get_value(action, v) for v in arg_strings] + self._check_value(action, value[0]) + + # all other types of nargs produce a list + else: + value = [self._get_value(action, v) for v in arg_strings] + for v in value: + self._check_value(action, v) + + # return the converted value + return value + + def _get_value(self, action, arg_string): + type_func = self._registry_get('type', action.type, action.type) + if not _callable(type_func): + msg = _('%r is not callable') + raise ArgumentError(action, msg % type_func) + + # convert the value to the appropriate type + try: + result = type_func(arg_string) + + # ArgumentTypeErrors indicate errors + except ArgumentTypeError: + name = getattr(action.type, '__name__', repr(action.type)) + msg = str(_sys.exc_info()[1]) + raise ArgumentError(action, msg) + + # TypeErrors or ValueErrors also indicate errors + except (TypeError, ValueError): + name = getattr(action.type, '__name__', repr(action.type)) + msg = _('invalid %s value: %r') + raise ArgumentError(action, msg % (name, arg_string)) + + # return the converted value + return result + + def _check_value(self, action, value): + # converted value must be one of the choices (if specified) + if action.choices is not None and value not in action.choices: + tup = value, ', '.join(map(repr, action.choices)) + msg = _('invalid choice: %r (choose from %s)') % tup + raise ArgumentError(action, msg) + + # ======================= + # Help-formatting methods + # ======================= + def format_usage(self): + formatter = self._get_formatter() + formatter.add_usage(self.usage, self._actions, + self._mutually_exclusive_groups) + return formatter.format_help() + + def format_help(self): + formatter = self._get_formatter() + + # usage + formatter.add_usage(self.usage, self._actions, + self._mutually_exclusive_groups) + + # description + formatter.add_text(self.description) + + # positionals, optionals and user-defined groups + for action_group in self._action_groups: + formatter.start_section(action_group.title) + formatter.add_text(action_group.description) + formatter.add_arguments(action_group._group_actions) + formatter.end_section() + + # epilog + formatter.add_text(self.epilog) + + # determine help from format above + return formatter.format_help() + + def format_version(self): + import warnings + warnings.warn( + 'The format_version method is deprecated -- the "version" ' + 'argument to ArgumentParser is no longer supported.', + DeprecationWarning) + formatter = self._get_formatter() + formatter.add_text(self.version) + return formatter.format_help() + + def _get_formatter(self): + return self.formatter_class(prog=self.prog) + + # ===================== + # Help-printing methods + # ===================== + def print_usage(self, file=None): + if file is None: + file = _sys.stdout + self._print_message(self.format_usage(), file) + + def print_help(self, file=None): + if file is None: + file = _sys.stdout + self._print_message(self.format_help(), file) + + def print_version(self, file=None): + import warnings + warnings.warn( + 'The print_version method is deprecated -- the "version" ' + 'argument to ArgumentParser is no longer supported.', + DeprecationWarning) + self._print_message(self.format_version(), file) + + def _print_message(self, message, file=None): + if message: + if file is None: + file = _sys.stderr + file.write(message) + + # =============== + # Exiting methods + # =============== + def exit(self, status=0, message=None): + if message: + self._print_message(message, _sys.stderr) + _sys.exit(status) + + def error(self, message): + """error(message: string) + + Prints a usage message incorporating the message to stderr and + exits. + + If you override this in a subclass, it should not return -- it + should either exit or raise an exception. + """ + self.print_usage(_sys.stderr) + self.exit(2, _('%s: error: %s\n') % (self.prog, message)) diff --git a/suricata/update/compat/ordereddict.py b/suricata/update/compat/ordereddict.py new file mode 100644 index 0000000..5b0303f --- /dev/null +++ b/suricata/update/compat/ordereddict.py @@ -0,0 +1,127 @@ +# Copyright (c) 2009 Raymond Hettinger
+#
+# Permission is hereby granted, free of charge, to any person
+# obtaining a copy of this software and associated documentation files
+# (the "Software"), to deal in the Software without restriction,
+# including without limitation the rights to use, copy, modify, merge,
+# publish, distribute, sublicense, and/or sell copies of the Software,
+# and to permit persons to whom the Software is furnished to do so,
+# subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be
+# included in all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
+# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
+# HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
+# WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
+# OTHER DEALINGS IN THE SOFTWARE.
+
+from UserDict import DictMixin
+
+class OrderedDict(dict, DictMixin):
+
+ def __init__(self, *args, **kwds):
+ if len(args) > 1:
+ raise TypeError('expected at most 1 arguments, got %d' % len(args))
+ try:
+ self.__end
+ except AttributeError:
+ self.clear()
+ self.update(*args, **kwds)
+
+ def clear(self):
+ self.__end = end = []
+ end += [None, end, end] # sentinel node for doubly linked list
+ self.__map = {} # key --> [key, prev, next]
+ dict.clear(self)
+
+ def __setitem__(self, key, value):
+ if key not in self:
+ end = self.__end
+ curr = end[1]
+ curr[2] = end[1] = self.__map[key] = [key, curr, end]
+ dict.__setitem__(self, key, value)
+
+ def __delitem__(self, key):
+ dict.__delitem__(self, key)
+ key, prev, next = self.__map.pop(key)
+ prev[2] = next
+ next[1] = prev
+
+ def __iter__(self):
+ end = self.__end
+ curr = end[2]
+ while curr is not end:
+ yield curr[0]
+ curr = curr[2]
+
+ def __reversed__(self):
+ end = self.__end
+ curr = end[1]
+ while curr is not end:
+ yield curr[0]
+ curr = curr[1]
+
+ def popitem(self, last=True):
+ if not self:
+ raise KeyError('dictionary is empty')
+ if last:
+ key = reversed(self).next()
+ else:
+ key = iter(self).next()
+ value = self.pop(key)
+ return key, value
+
+ def __reduce__(self):
+ items = [[k, self[k]] for k in self]
+ tmp = self.__map, self.__end
+ del self.__map, self.__end
+ inst_dict = vars(self).copy()
+ self.__map, self.__end = tmp
+ if inst_dict:
+ return (self.__class__, (items,), inst_dict)
+ return self.__class__, (items,)
+
+ def keys(self):
+ return list(self)
+
+ setdefault = DictMixin.setdefault
+ update = DictMixin.update
+ pop = DictMixin.pop
+ values = DictMixin.values
+ items = DictMixin.items
+ iterkeys = DictMixin.iterkeys
+ itervalues = DictMixin.itervalues
+ iteritems = DictMixin.iteritems
+
+ def __repr__(self):
+ if not self:
+ return '%s()' % (self.__class__.__name__,)
+ return '%s(%r)' % (self.__class__.__name__, self.items())
+
+ def copy(self):
+ return self.__class__(self)
+
+ @classmethod
+ def fromkeys(cls, iterable, value=None):
+ d = cls()
+ for key in iterable:
+ d[key] = value
+ return d
+
+ def __eq__(self, other):
+ if isinstance(other, OrderedDict):
+ if len(self) != len(other):
+ return False
+ for p, q in zip(self.items(), other.items()):
+ if p != q:
+ return False
+ return True
+ return dict.__eq__(self, other)
+
+ def __ne__(self, other):
+ return not self == other
diff --git a/suricata/update/config.py b/suricata/update/config.py new file mode 100644 index 0000000..ad95996 --- /dev/null +++ b/suricata/update/config.py @@ -0,0 +1,266 @@ +# Copyright (C) 2017 Open Information Security Foundation +# Copyright (c) 2015-2017 Jason Ish +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +import os.path +import logging + +import yaml + +import suricata.update.engine +from suricata.update.exceptions import ApplicationError + +try: + from suricata.config import defaults + has_defaults = True +except: + has_defaults = False + +logger = logging.getLogger() + +DEFAULT_DATA_DIRECTORY = "/var/lib/suricata" + +# Cache directory - relative to the data directory. +CACHE_DIRECTORY = os.path.join("update", "cache") + +# Source directory - relative to the data directory. +SOURCE_DIRECTORY = os.path.join("update", "sources") + +# Configuration keys. +DATA_DIRECTORY_KEY = "data-directory" +CACHE_DIRECTORY_KEY = "cache-directory" +IGNORE_KEY = "ignore" +DISABLE_CONF_KEY = "disable-conf" +ENABLE_CONF_KEY = "enable-conf" +MODIFY_CONF_KEY = "modify-conf" +DROP_CONF_KEY = "drop-conf" +LOCAL_CONF_KEY = "local" +OUTPUT_KEY = "output" +DIST_RULE_DIRECTORY_KEY = "dist-rule-directory" + +if has_defaults: + DEFAULT_UPDATE_YAML_PATH = os.path.join(defaults.sysconfdir, "update.yaml") +else: + DEFAULT_UPDATE_YAML_PATH = "/etc/suricata/update.yaml" + +DEFAULT_SURICATA_YAML_PATH = [ + "/etc/suricata/suricata.yaml", + "/usr/local/etc/suricata/suricata.yaml", + "/etc/suricata/suricata-debian.yaml" +] + +if has_defaults: + DEFAULT_DIST_RULE_PATH = [ + defaults.datarulesdir, + "/etc/suricata/rules", + ] +else: + DEFAULT_DIST_RULE_PATH = [ + "/etc/suricata/rules", + ] + +DEFAULT_CONFIG = { + "sources": [], + LOCAL_CONF_KEY: [], + + # The default file patterns to ignore. + "ignore": [ + "*deleted.rules", + ], +} + +_args = None +_config = {} + +# The filename the config was read from, if any. +filename = None + +def has(key): + """Return true if a configuration key exists.""" + return key in _config + +def set(key, value): + """Set a configuration value.""" + _config[key] = value + +def get(key): + """Get a configuration value.""" + if key in _config: + return _config[key] + return None + +def set_state_dir(directory): + _config[DATA_DIRECTORY_KEY] = directory + +def get_state_dir(): + """Get the data directory. This is more of the Suricata state + directory than a specific Suricata-Update directory, and is used + as the root directory for Suricata-Update data. + """ + if os.getenv("DATA_DIRECTORY"): + return os.getenv("DATA_DIRECTORY") + if DATA_DIRECTORY_KEY in _config: + return _config[DATA_DIRECTORY_KEY] + return DEFAULT_DATA_DIRECTORY + +def set_cache_dir(directory): + """Set an alternate cache directory.""" + _config[CACHE_DIRECTORY_KEY] = directory + +def get_cache_dir(): + """Get the cache directory.""" + if CACHE_DIRECTORY_KEY in _config: + return _config[CACHE_DIRECTORY_KEY] + return os.path.join(get_state_dir(), CACHE_DIRECTORY) + +def get_output_dir(): + """Get the rule output directory.""" + if OUTPUT_KEY in _config: + return _config[OUTPUT_KEY] + return os.path.join(get_state_dir(), "rules") + +def args(): + """Return sthe parsed argument object.""" + return _args + +def get_arg(key): + key = key.replace("-", "_") + if hasattr(_args, key): + val = getattr(_args, key) + if val not in [[], None]: + return val + return None + +def init(args): + global _args + global filename + + _args = args + _config.update(DEFAULT_CONFIG) + + if args.config: + logger.info("Loading %s", args.config) + with open(args.config, "rb") as fileobj: + config = yaml.safe_load(fileobj) + if config: + _config.update(config) + filename = args.config + elif os.path.exists(DEFAULT_UPDATE_YAML_PATH): + logger.info("Loading %s", DEFAULT_UPDATE_YAML_PATH) + with open(DEFAULT_UPDATE_YAML_PATH, "rb") as fileobj: + config = yaml.safe_load(fileobj) + if config: + _config.update(config) + filename = DEFAULT_UPDATE_YAML_PATH + + # Apply command line arguments to the config. + for arg in vars(args): + if arg == "local": + for local in args.local: + logger.debug("Adding local ruleset to config: %s", local) + _config[LOCAL_CONF_KEY].append(local) + elif arg == "data_dir" and args.data_dir: + logger.debug("Setting data directory to %s", args.data_dir) + _config[DATA_DIRECTORY_KEY] = args.data_dir + elif getattr(args, arg) is not None: + key = arg.replace("_", "-") + val = getattr(args, arg) + logger.debug("Setting configuration value %s -> %s", key, val) + _config[key] = val + + # Find and set the path to suricata if not provided. + if "suricata" in _config: + if not os.path.exists(_config["suricata"]): + raise ApplicationError( + "Configured path to suricata does not exist: %s" % ( + _config["suricata"])) + else: + suricata_path = suricata.update.engine.get_path() + if not suricata_path: + logger.warning("No suricata application binary found on path.") + else: + _config["suricata"] = suricata_path + + if "suricata" in _config: + build_info = suricata.update.engine.get_build_info(_config["suricata"]) + + # Set the first suricata.yaml to check for to the one in the + # --sysconfdir provided by build-info. + if not "suricata_conf" in _config and "sysconfdir" in build_info: + DEFAULT_SURICATA_YAML_PATH.insert( + 0, os.path.join( + build_info["sysconfdir"], "suricata/suricata.yaml")) + + # Amend the path to look for Suricata provided rules based on + # the build info. As we are inserting at the front, put the + # highest priority path last. + if "sysconfdir" in build_info: + DEFAULT_DIST_RULE_PATH.insert( + 0, os.path.join(build_info["sysconfdir"], "suricata/rules")) + if "datarootdir" in build_info: + DEFAULT_DIST_RULE_PATH.insert( + 0, os.path.join(build_info["datarootdir"], "suricata/rules")) + + # Set the data-directory prefix to that of the --localstatedir + # found in the build-info. + if not DATA_DIRECTORY_KEY in _config and "localstatedir" in build_info: + data_directory = os.path.join( + build_info["localstatedir"], "lib/suricata") + logger.info("Using data-directory %s.", data_directory) + _config[DATA_DIRECTORY_KEY] = data_directory + + # Fixup the default locations for Suricata-Update configuration files, but only if + # they exist, otherwise keep the defaults. + conf_search_path = ["/etc"] + if "sysconfdir" in build_info: + sysconfdir = build_info["sysconfdir"] + if not sysconfdir in conf_search_path: + conf_search_path.insert(0, sysconfdir) + configs = ( + ("disable-conf", "disable.conf"), + ("enable-conf", "enable.conf"), + ("drop-conf", "drop.conf"), + ("modify-conf", "modify.conf"), + ) + for key, filename in configs: + if getattr(args, key.replace("-", "_"), None) is not None: + continue + if _config.get(key) is not None: + continue + for conf_dir in conf_search_path: + config_path = os.path.join(conf_dir, "suricata", filename) + logger.debug("Looking for {}".format(config_path)) + if os.path.exists(config_path): + logger.debug("Found {}".format(config_path)) + logger.debug("Using {} for {}".format(config_path, key)) + _config[key] = config_path + break + + # If suricata-conf not provided on the command line or in the + # configuration file, look for it. + if not "suricata-conf" in _config: + for conf in DEFAULT_SURICATA_YAML_PATH: + if os.path.exists(conf): + logger.info("Using Suricata configuration %s" % (conf)) + _config["suricata-conf"] = conf + break + + if not DIST_RULE_DIRECTORY_KEY in _config: + for path in DEFAULT_DIST_RULE_PATH: + if os.path.exists(path): + logger.info("Using %s for Suricata provided rules.", path) + _config[DIST_RULE_DIRECTORY_KEY] = path + break diff --git a/suricata/update/configs/__init__.py b/suricata/update/configs/__init__.py new file mode 100644 index 0000000..e136c7a --- /dev/null +++ b/suricata/update/configs/__init__.py @@ -0,0 +1,31 @@ +# Copyright (C) 2017 Open Information Security Foundation +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +import os.path + +# The list of sample config files provided here, for use when asked to +# dump them. +filenames = [ + "update.yaml", + "enable.conf", + "disable.conf", + "modify.conf", + "drop.conf", + "threshold.in", +] + +directory = os.path.dirname(__file__) + diff --git a/suricata/update/configs/disable.conf b/suricata/update/configs/disable.conf new file mode 100644 index 0000000..59d0e18 --- /dev/null +++ b/suricata/update/configs/disable.conf @@ -0,0 +1,19 @@ +# suricata-update - disable.conf + +# Example of disabling a rule by signature ID (gid is optional). +# 1:2019401 +# 2019401 + +# Example of disabling a rule by regular expression. +# - All regular expression matches are case insensitive. +# re:heartbleed +# re:MS(0[7-9]|10)-\d+ + +# Examples of disabling a group of rules. +# group:emerging-icmp.rules +# group:emerging-dos +# group:emerging* + +# Disable all rules with a metadata of "deployment perimeter". Note that metadata +# matches are case insensitive. +# metadata: deployment perimeter
\ No newline at end of file diff --git a/suricata/update/configs/drop.conf b/suricata/update/configs/drop.conf new file mode 100644 index 0000000..a93268d --- /dev/null +++ b/suricata/update/configs/drop.conf @@ -0,0 +1,11 @@ +# suricata-update - drop.conf +# +# Rules matching specifiers in this file will be converted to drop rules. +# +# Examples: +# +# 1:2019401 +# 2019401 +# +# re:heartbleed +# re:MS(0[7-9]|10)-\d+ diff --git a/suricata/update/configs/enable.conf b/suricata/update/configs/enable.conf new file mode 100644 index 0000000..ad7b4e2 --- /dev/null +++ b/suricata/update/configs/enable.conf @@ -0,0 +1,19 @@ +# suricata-update - enable.conf + +# Example of enabling a rule by signature ID (gid is optional). +# 1:2019401 +# 2019401 + +# Example of enabling a rule by regular expression. +# - All regular expression matches are case insensitive. +# re:heartbleed +# re:MS(0[7-9]|10)-\d+ + +# Examples of enabling a group of rules. +# group:emerging-icmp.rules +# group:emerging-dos +# group:emerging* + +# Enable all rules with a metadata of "deployment perimeter". Note that metadata +# matches are case insensitive. +# metadata: deployment perimeter
\ No newline at end of file diff --git a/suricata/update/configs/modify.conf b/suricata/update/configs/modify.conf new file mode 100644 index 0000000..70bfb3e --- /dev/null +++ b/suricata/update/configs/modify.conf @@ -0,0 +1,24 @@ +# suricata-update - modify.conf + +# Format: <sid> "<from>" "<to>" + +# Example changing the seconds for rule 2019401 to 3600. +# 2019401 "seconds \d+" "seconds 3600" +# +# Example converting all alert rules to drop: +# re:. ^alert drop +# +# Example converting all drop rules with noalert back to alert: +# re:. "^drop(.*)noalert(.*)" "alert\\1noalert\\2" + +# Change all trojan-activity rules to drop. Its better to setup a +# drop.conf for this, but this does show the use of back references. +# re:classtype:trojan-activity "(alert)(.*)" "drop\\2" + +# For compatibility, most Oinkmaster modifysid lines should work as +# well. +# modifysid * "^drop(.*)noalert(.*)" | "alert${1}noalert${2}" + +# Add metadata. +#metadata-add re:"SURICATA STREAM" "evebox-action" "archive" +#metadata-add 2010646 "evebox-action" "archive"
\ No newline at end of file diff --git a/suricata/update/configs/threshold.in b/suricata/update/configs/threshold.in new file mode 100644 index 0000000..377417d --- /dev/null +++ b/suricata/update/configs/threshold.in @@ -0,0 +1,22 @@ +# suricata-update - threshold.in + +# This file contains thresholding configurations that will be turned into +# a Suricata compatible threshold.conf file. + +# This file can contain standard threshold.conf configurations: +# +# suppress gen_id <gid>, sig_id <sid> +# suppress gen_id <gid>, sig_id <sid>, track <by_src|by_dst>, ip <ip|subnet> +# threshold gen_id 0, sig_id 0, type threshold, track by_src, count 10, seconds 10 +# suppress gen_id 1, sig_id 2009557, track by_src, ip 217.110.97.128/25 + +# Or ones that will be preprocessed... + +# Suppress all rules containing "java". +# +# suppress re:java +# suppress re:java, track by_src, ip 217.110.97.128/25 + +# Threshold all rules containing "java". +# +# threshold re:java, type threshold, track by_dst, count 1, seconds 10 diff --git a/suricata/update/configs/update.yaml b/suricata/update/configs/update.yaml new file mode 100644 index 0000000..358e869 --- /dev/null +++ b/suricata/update/configs/update.yaml @@ -0,0 +1,58 @@ +# Configuration with disable filters. +# - Overrided by --disable-conf +# - Default: /etc/suricata/disable.conf +disable-conf: /etc/suricata/disable.conf + +# Configuration with enable filters. +# - Overrided by --enable-conf +# - Default: /etc/suricata/enable.conf +enable-conf: /etc/suricata/enable.conf + +# Configuration with drop filters. +# - Overrided by --drop-conf +# - Default: /etc/suricata/drop.conf +drop-conf: /etc/suricata/drop.conf + +# Configuration with modify filters. +# - Overrided by --modify-conf +# - Default: /etc/suricata/modify.conf +modify-conf: /etc/suricata/modify.conf + +# List of files to ignore. Overrided by the --ignore command line option. +ignore: + - "*deleted.rules" + +# Override the user-agent string. +#user-agent: "Suricata-Update" + +# Provide an alternate command to the default test command. +# +# The following environment variables can be used. +# SURICATA_PATH - The path to the discovered suricata program. +# OUTPUT_DIR - The directory the rules are written to. +# OUTPUT_FILENAME - The name of the rule file. Will be empty if the rules +# were not merged. +#test-command: ${SURICATA_PATH} -T -S ${OUTPUT_FILENAME} -l /tmp + +# Provide a command to reload the Suricata rules. +# May be overrided by the --reload-command command line option. +# See the documentation of --reload-command for the different options +# to reload Suricata rules. +#reload-command: sudo systemctl reload suricata + +# Remote rule sources. Simply a list of URLs. +sources: + # Emerging Threats Open with the Suricata version dynamically replaced. + - https://rules.emergingthreats.net/open/suricata-%(__version__)s/emerging.rules.tar.gz + # The SSL blacklist, which is just a standalone rule file. + - https://sslbl.abuse.ch/blacklist/sslblacklist.rules + +# A list of local rule sources. Each entry can be a rule file, a +# directory or a wild card specification. +local: + # A directory of rules. + - /etc/suricata/rules + # A single rule file. + - /etc/suricata/rules/app-layer-events.rules + # A wildcard. + - /etc/suricata/rules/*.rules diff --git a/suricata/update/data/__init__.py b/suricata/update/data/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/suricata/update/data/__init__.py diff --git a/suricata/update/data/index.py b/suricata/update/data/index.py new file mode 100644 index 0000000..02a9c4f --- /dev/null +++ b/suricata/update/data/index.py @@ -0,0 +1,476 @@ +index = { 'sources': { 'et/open': { 'description': 'Proofpoint ET Open is a ' + 'timely and accurate rule set ' + 'for detecting and blocking ' + 'advanced threats\n', + 'license': 'MIT', + 'summary': 'Emerging Threats Open Ruleset', + 'url': 'https://rules.emergingthreats.net/open/suricata-%(__version__)s/emerging.rules.tar.gz', + 'vendor': 'Proofpoint'}, + 'et/pro': { 'checksum': False, + 'description': 'Proofpoint ET Pro is a timely ' + 'and accurate rule set for ' + 'detecting and blocking ' + 'advanced threats\n', + 'license': 'Commercial', + 'parameters': { 'secret-code': { 'prompt': 'Emerging ' + 'Threats ' + 'Pro ' + 'access ' + 'code'}}, + 'replaces': ['et/open'], + 'subscribe-url': 'https://www.proofpoint.com/us/threat-insight/et-pro-ruleset', + 'summary': 'Emerging Threats Pro Ruleset', + 'url': 'https://rules.emergingthreatspro.com/%(secret-code)s/suricata-%(__version__)s/etpro.rules.tar.gz', + 'vendor': 'Proofpoint'}, + 'etnetera/aggressive': { 'checksum': False, + 'license': 'MIT', + 'min-version': '4.0.0', + 'summary': 'Etnetera aggressive ' + 'IP blacklist', + 'url': 'https://security.etnetera.cz/feeds/etn_aggressive.rules', + 'vendor': 'Etnetera a.s.'}, + 'malsilo/win-malware': { 'checksum': True, + 'description': 'TCP/UDP, DNS and ' + 'HTTP Windows ' + 'threats ' + 'artifacts ' + 'observed at ' + 'runtime.\n', + 'homepage': 'https://raw-data.gitlab.io/post/malsilo_2.1/', + 'license': 'MIT', + 'min-version': '4.1.0', + 'summary': 'Commodity malware ' + 'rules', + 'url': 'https://malsilo.gitlab.io/feeds/dumps/malsilo.rules.tar.gz', + 'vendor': 'malsilo'}, + 'oisf/trafficid': { 'checksum': False, + 'license': 'MIT', + 'min-version': '4.0.0', + 'summary': 'Suricata Traffic ID ' + 'ruleset', + 'support-url': 'https://redmine.openinfosecfoundation.org/', + 'url': 'https://openinfosecfoundation.org/rules/trafficid/trafficid.rules', + 'vendor': 'OISF'}, + 'pawpatrules': { 'checksum': False, + 'description': 'PAW Patrules ruleset ' + 'permit to detect many ' + 'events on\n' + 'network. Suspicious ' + 'flow, malicious tool, ' + 'unsuported and\n' + 'vulnerable system, known ' + 'threat actors with ' + 'various IOCs,\n' + 'lateral movement, bad ' + 'practice, shadow IT... ' + 'Rules are\n' + 'frequently updated.\n', + 'homepage': 'https://pawpatrules.fr/', + 'license': 'CC-BY-SA-4.0', + 'min-version': '6.0.0', + 'summary': 'PAW Patrules is a collection ' + 'of rules for IDPS / NSM ' + 'Suricata engine', + 'url': 'https://rules.pawpatrules.fr/suricata/paw-patrules.tar.gz', + 'vendor': 'pawpatrules'}, + 'ptresearch/attackdetection': { 'description': 'The ' + 'Attack ' + 'Detection ' + 'Team ' + 'searches ' + 'for new ' + 'vulnerabilities ' + 'and ' + '0-days, ' + 'reproduces ' + 'it and ' + 'creates ' + 'PoC ' + 'exploits ' + 'to ' + 'understand ' + 'how these ' + 'security ' + 'flaws ' + 'work and ' + 'how ' + 'related ' + 'attacks ' + 'can be ' + 'detected ' + 'on the ' + 'network ' + 'layer. ' + 'Additionally, ' + 'we are ' + 'interested ' + 'in ' + 'malware ' + 'and ' + "hackers' " + 'TTPs, so ' + 'we ' + 'develop ' + 'Suricata ' + 'rules for ' + 'detecting ' + 'all sorts ' + 'of such ' + 'activities.\n', + 'license': 'Custom', + 'license-url': 'https://raw.githubusercontent.com/ptresearch/AttackDetection/master/LICENSE', + 'obsolete': 'no longer ' + 'exists', + 'summary': 'Positive ' + 'Technologies ' + 'Attack ' + 'Detection ' + 'Team ruleset', + 'url': 'https://raw.githubusercontent.com/ptresearch/AttackDetection/master/pt.rules.tar.gz', + 'vendor': 'Positive ' + 'Technologies'}, + 'scwx/enhanced': { 'description': 'Broad ruleset composed ' + 'of malware rules and ' + 'other security-related ' + 'countermeasures, and ' + 'curated by the ' + 'Secureworks Counter ' + 'Threat Unit research ' + 'team. This ruleset ' + 'has been enhanced with ' + 'comprehensive and ' + 'fully ' + 'standard-compliant ' + 'BETTER metadata ' + '(https://better-schema.readthedocs.io/).\n', + 'license': 'Commercial', + 'min-version': '3.0.0', + 'parameters': { 'secret-code': { 'prompt': 'Secureworks ' + 'Threat ' + 'Intelligence ' + 'Authentication ' + 'Token'}}, + 'subscribe-url': 'https://www.secureworks.com/contact/ ' + '(Please reference ' + 'CTU Countermeasures)', + 'summary': 'Secureworks ' + 'suricata-enhanced ruleset', + 'url': 'https://ws.secureworks.com/ti/ruleset/%(secret-code)s/Suricata_suricata-enhanced_latest.tgz', + 'vendor': 'Secureworks'}, + 'scwx/malware': { 'description': 'High-fidelity, ' + 'high-priority ruleset ' + 'composed mainly of ' + 'malware-related ' + 'countermeasures and ' + 'curated by the ' + 'Secureworks Counter ' + 'Threat Unit research ' + 'team.\n', + 'license': 'Commercial', + 'min-version': '3.0.0', + 'parameters': { 'secret-code': { 'prompt': 'Secureworks ' + 'Threat ' + 'Intelligence ' + 'Authentication ' + 'Token'}}, + 'subscribe-url': 'https://www.secureworks.com/contact/ ' + '(Please reference CTU ' + 'Countermeasures)', + 'summary': 'Secureworks ' + 'suricata-malware ruleset', + 'url': 'https://ws.secureworks.com/ti/ruleset/%(secret-code)s/Suricata_suricata-malware_latest.tgz', + 'vendor': 'Secureworks'}, + 'scwx/security': { 'description': 'Broad ruleset composed ' + 'of malware rules and ' + 'other security-related ' + 'countermeasures, and ' + 'curated by the ' + 'Secureworks Counter ' + 'Threat Unit research ' + 'team.\n', + 'license': 'Commercial', + 'min-version': '3.0.0', + 'parameters': { 'secret-code': { 'prompt': 'Secureworks ' + 'Threat ' + 'Intelligence ' + 'Authentication ' + 'Token'}}, + 'subscribe-url': 'https://www.secureworks.com/contact/ ' + '(Please reference ' + 'CTU Countermeasures)', + 'summary': 'Secureworks ' + 'suricata-security ruleset', + 'url': 'https://ws.secureworks.com/ti/ruleset/%(secret-code)s/Suricata_suricata-security_latest.tgz', + 'vendor': 'Secureworks'}, + 'sslbl/ja3-fingerprints': { 'checksum': False, + 'description': 'If you are ' + 'running ' + 'Suricata, you ' + 'can use the ' + "SSLBL's " + 'Suricata JA3 ' + 'FingerprintRuleset ' + 'to detect ' + 'and/or block ' + 'malicious SSL ' + 'connections ' + 'in your ' + 'network based ' + 'on the JA3 ' + 'fingerprint. ' + 'Please note ' + 'that your ' + 'need Suricata ' + '4.1.0 or ' + 'newer in ' + 'order to use ' + 'the JA3 ' + 'fingerprint ' + 'ruleset.\n', + 'license': 'Non-Commercial', + 'min-version': '4.1.0', + 'summary': 'Abuse.ch Suricata ' + 'JA3 Fingerprint ' + 'Ruleset', + 'url': 'https://sslbl.abuse.ch/blacklist/ja3_fingerprints.rules', + 'vendor': 'Abuse.ch'}, + 'sslbl/ssl-fp-blacklist': { 'checksum': False, + 'description': 'The SSL ' + 'Blacklist ' + '(SSLBL) is a ' + 'project of ' + 'abuse.ch with ' + 'the goal of ' + 'detecting ' + 'malicious SSL ' + 'connections, ' + 'by ' + 'identifying ' + 'and ' + 'blacklisting ' + 'SSL ' + 'certificates ' + 'used by ' + 'botnet C&C ' + 'servers. In ' + 'addition, ' + 'SSLBL ' + 'identifies ' + 'JA3 ' + 'fingerprints ' + 'that helps ' + 'you to detect ' + '& block ' + 'malware ' + 'botnet C&C ' + 'communication ' + 'on the TCP ' + 'layer.\n', + 'license': 'Non-Commercial', + 'summary': 'Abuse.ch SSL ' + 'Blacklist', + 'url': 'https://sslbl.abuse.ch/blacklist/sslblacklist.rules', + 'vendor': 'Abuse.ch'}, + 'stamus/lateral': { 'description': 'Suricata ruleset ' + 'specifically focused ' + 'on detecting lateral\n' + 'movement in Microsoft ' + 'Windows environments ' + 'by Stamus Networks\n', + 'license': 'GPL-3.0-only', + 'min-version': '6.0.6', + 'summary': 'Lateral movement rules', + 'support-url': 'https://discord.com/channels/911231224448712714/911238451842666546', + 'url': 'https://ti.stamus-networks.io/open/stamus-lateral-rules.tar.gz', + 'vendor': 'Stamus Networks'}, + 'stamus/nrd-14-open': { 'description': 'Newly Registered ' + 'Domains list ' + '(last 14 days) to ' + 'match on DNS, TLS ' + 'and HTTP ' + 'communication.\n' + 'Produced by ' + 'Stamus Labs ' + 'research team.\n', + 'license': 'Commercial', + 'min-version': '6.0.0', + 'parameters': { 'secret-code': { 'prompt': 'Stamus ' + 'Networks ' + 'License ' + 'code'}}, + 'subscribe-url': 'https://www.stamus-networks.com/stamus-labs/subscribe-to-threat-intel-feed', + 'summary': 'Newly Registered ' + 'Domains Open only - ' + '14 day list, complete', + 'url': 'https://ti.stamus-networks.io/%(secret-code)s/sti-domains-nrd-14.tar.gz', + 'vendor': 'Stamus Networks'}, + 'stamus/nrd-30-open': { 'description': 'Newly Registered ' + 'Domains list ' + '(last 30 days) to ' + 'match on DNS, TLS ' + 'and HTTP ' + 'communication.\n' + 'Produced by ' + 'Stamus Labs ' + 'research team.\n', + 'license': 'Commercial', + 'min-version': '6.0.0', + 'parameters': { 'secret-code': { 'prompt': 'Stamus ' + 'Networks ' + 'License ' + 'code'}}, + 'subscribe-url': 'https://www.stamus-networks.com/stamus-labs/subscribe-to-threat-intel-feed', + 'summary': 'Newly Registered ' + 'Domains Open only - ' + '30 day list, complete', + 'url': 'https://ti.stamus-networks.io/%(secret-code)s/sti-domains-nrd-30.tar.gz', + 'vendor': 'Stamus Networks'}, + 'stamus/nrd-entropy-14-open': { 'description': 'Suspicious ' + 'Newly ' + 'Registered ' + 'Domains ' + 'list with ' + 'high ' + 'entropy ' + '(last 14 ' + 'days) to ' + 'match on ' + 'DNS, TLS ' + 'and HTTP ' + 'communication.\n' + 'Produced ' + 'by Stamus ' + 'Labs ' + 'research ' + 'team.\n', + 'license': 'Commercial', + 'min-version': '6.0.0', + 'parameters': { 'secret-code': { 'prompt': 'Stamus ' + 'Networks ' + 'License ' + 'code'}}, + 'subscribe-url': 'https://www.stamus-networks.com/stamus-labs/subscribe-to-threat-intel-feed', + 'summary': 'Newly ' + 'Registered ' + 'Domains Open ' + 'only - 14 day ' + 'list, high ' + 'entropy', + 'url': 'https://ti.stamus-networks.io/%(secret-code)s/sti-domains-entropy-14.tar.gz', + 'vendor': 'Stamus ' + 'Networks'}, + 'stamus/nrd-entropy-30-open': { 'description': 'Suspicious ' + 'Newly ' + 'Registered ' + 'Domains ' + 'list with ' + 'high ' + 'entropy ' + '(last 30 ' + 'days) to ' + 'match on ' + 'DNS, TLS ' + 'and HTTP ' + 'communication.\n' + 'Produced ' + 'by Stamus ' + 'Labs ' + 'research ' + 'team.\n', + 'license': 'Commercial', + 'min-version': '6.0.0', + 'parameters': { 'secret-code': { 'prompt': 'Stamus ' + 'Networks ' + 'License ' + 'code'}}, + 'subscribe-url': 'https://www.stamus-networks.com/stamus-labs/subscribe-to-threat-intel-feed', + 'summary': 'Newly ' + 'Registered ' + 'Domains Open ' + 'only - 30 day ' + 'list, high ' + 'entropy', + 'url': 'https://ti.stamus-networks.io/%(secret-code)s/sti-domains-entropy-30.tar.gz', + 'vendor': 'Stamus ' + 'Networks'}, + 'stamus/nrd-phishing-14-open': { 'description': 'Suspicious ' + 'Newly ' + 'Registered ' + 'Domains ' + 'Phishing ' + 'list ' + '(last 14 ' + 'days) to ' + 'match on ' + 'DNS, TLS ' + 'and HTTP ' + 'communication.\n' + 'Produced ' + 'by ' + 'Stamus ' + 'Labs ' + 'research ' + 'team.\n', + 'license': 'Commercial', + 'min-version': '6.0.0', + 'parameters': { 'secret-code': { 'prompt': 'Stamus ' + 'Networks ' + 'License ' + 'code'}}, + 'subscribe-url': 'https://www.stamus-networks.com/stamus-labs/subscribe-to-threat-intel-feed', + 'summary': 'Newly ' + 'Registered ' + 'Domains Open ' + 'only - 14 ' + 'day list, ' + 'phishing', + 'url': 'https://ti.stamus-networks.io/%(secret-code)s/sti-domains-phishing-14.tar.gz', + 'vendor': 'Stamus ' + 'Networks'}, + 'stamus/nrd-phishing-30-open': { 'description': 'Suspicious ' + 'Newly ' + 'Registered ' + 'Domains ' + 'Phishing ' + 'list ' + '(last 30 ' + 'days) to ' + 'match on ' + 'DNS, TLS ' + 'and HTTP ' + 'communication.\n' + 'Produced ' + 'by ' + 'Stamus ' + 'Labs ' + 'research ' + 'team.\n', + 'license': 'Commercial', + 'min-version': '6.0.0', + 'parameters': { 'secret-code': { 'prompt': 'Stamus ' + 'Networks ' + 'License ' + 'code'}}, + 'subscribe-url': 'https://www.stamus-networks.com/stamus-labs/subscribe-to-threat-intel-feed', + 'summary': 'Newly ' + 'Registered ' + 'Domains Open ' + 'only - 30 ' + 'day list, ' + 'phishing', + 'url': 'https://ti.stamus-networks.io/%(secret-code)s/sti-domains-phishing-30.tar.gz', + 'vendor': 'Stamus ' + 'Networks'}, + 'tgreen/hunting': { 'checksum': False, + 'description': 'Heuristic ruleset for ' + 'hunting. Focus on ' + 'anomaly detection and ' + 'showcasing latest ' + 'engine features, not ' + 'performance.\n', + 'license': 'GPLv3', + 'min-version': '4.1.0', + 'summary': 'Threat hunting rules', + 'url': 'https://raw.githubusercontent.com/travisbgreen/hunting-rules/master/hunting.rules', + 'vendor': 'tgreen'}}, + 'version': 1}
\ No newline at end of file diff --git a/suricata/update/data/update.py b/suricata/update/data/update.py new file mode 100644 index 0000000..8b34c40 --- /dev/null +++ b/suricata/update/data/update.py @@ -0,0 +1,53 @@ +# Copyright (C) 2018-2022 Open Information Security Foundation +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +import os.path +import sys +import pprint + +try: + from urllib2 import urlopen +except: + from urllib.request import urlopen + +import yaml + +DEFAULT_URL = "https://raw.githubusercontent.com/oisf/suricata-intel-index/master/index.yaml" + +def embed_index(): + """Embed a copy of the index as a Python source file. We can't use a + datafile yet as there is no easy way to do with distutils.""" + if len(sys.argv) > 1: + url = sys.argv[1] + else: + url = DEFAULT_URL + dist_filename = os.path.join(os.path.dirname(__file__), "index.py") + response = urlopen(url) + index = yaml.safe_load(response.read()) + + # Delete the version info to prevent the issue of the version info being out of + # date around a new release of Suricata where the index has not been updated + # to the latest recommended version. The user will be asked to update their + # sources to run the version check. + del(index["versions"]) + + pp = pprint.PrettyPrinter(indent=4) + + with open(dist_filename, "w") as fileobj: + fileobj.write("index = {}".format(pp.pformat(index))) + +if __name__ == "__main__": + embed_index() diff --git a/suricata/update/engine.py b/suricata/update/engine.py new file mode 100644 index 0000000..22ad9b3 --- /dev/null +++ b/suricata/update/engine.py @@ -0,0 +1,196 @@ +# Copyright (C) 2017 Open Information Security Foundation +# Copyright (c) 2015 Jason Ish +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +# This module contains functions for interacting with the Suricata +# application (aka the engine). + +from __future__ import print_function + +import sys +import os +import os.path +import subprocess +import re +import logging +import shutil +import yaml +import tempfile +from collections import namedtuple + +logger = logging.getLogger() + +SuricataVersion = namedtuple( + "SuricataVersion", ["major", "minor", "patch", "full", "short", "raw"]) + +def get_build_info(suricata): + build_info = { + "features": [], + } + build_info_output = subprocess.check_output([suricata, "--build-info"]) + for line in build_info_output.decode("utf-8").split("\n"): + line = line.strip() + if line.startswith("--prefix"): + build_info["prefix"] = line.split()[-1].strip() + elif line.startswith("--sysconfdir"): + build_info["sysconfdir"] = line.split()[-1].strip() + elif line.startswith("--localstatedir"): + build_info["localstatedir"] = line.split()[-1].strip() + elif line.startswith("--datarootdir"): + build_info["datarootdir"] = line.split()[-1].strip() + elif line.startswith("Features:"): + build_info["features"] = line.split()[1:] + elif line.startswith("This is Suricata version"): + build_info["version"] = parse_version(line) + + if not "prefix" in build_info: + logger.warning("--prefix not found in build-info.") + if not "sysconfdir" in build_info: + logger.warning("--sysconfdir not found in build-info.") + if not "localstatedir" in build_info: + logger.warning("--localstatedir not found in build-info.") + + return build_info + +class Configuration: + """An abstraction over the Suricata configuration file.""" + + def __init__(self, conf, build_info = {}): + self.conf = conf + self.build_info = build_info + + def keys(self): + return self.conf.keys() + + def has_key(self, key): + return key in self.conf + + def get(self, key): + return self.conf.get(key, None) + + def is_true(self, key, truthy=[]): + if not key in self.conf: + logger.warning( + "Suricata configuration key does not exist: %s" % (key)) + return False + if key in self.conf: + val = self.conf[key] + if val.lower() in ["1", "yes", "true"] + truthy: + return True + return False + + @classmethod + def load(cls, config_filename, suricata_path=None): + env = build_env() + env["SC_LOG_LEVEL"] = "Error" + if not suricata_path: + suricata_path = get_path() + if not suricata_path: + raise Exception("Suricata program could not be found.") + if not os.path.exists(suricata_path): + raise Exception("Suricata program %s does not exist.", suricata_path) + configuration_dump = subprocess.check_output( + [suricata_path, "-c", config_filename, "--dump-config"], + env=env) + conf = {} + for line in configuration_dump.splitlines(): + try: + key, val = line.decode().split(" = ") + conf[key] = val + except: + logger.warning("Failed to parse: %s", line) + build_info = get_build_info(suricata_path) + return cls(conf, build_info) + +def get_path(program="suricata"): + """Find Suricata in the shell path.""" + # First look for Suricata relative to suricata-update. + relative_path = os.path.join(os.path.dirname(sys.argv[0]), "suricata") + if os.path.exists(relative_path): + logger.debug("Found suricata at %s" % (relative_path)) + return relative_path + + # Otherwise look for it in the path. + for path in os.environ["PATH"].split(os.pathsep): + if not path: + continue + suricata_path = os.path.join(path, program) + logger.debug("Looking for %s in %s" % (program, path)) + if os.path.exists(suricata_path): + logger.debug("Found %s." % (suricata_path)) + return suricata_path + return None + +def parse_version(buf): + m = re.search(r"((\d+)\.(\d+)(\.(\d+))?([\w\-]+)?)", str(buf).strip()) + if m: + full = m.group(1) + major = int(m.group(2)) + minor = int(m.group(3)) + if not m.group(5): + patch = 0 + else: + patch = int(m.group(5)) + short = "%s.%s" % (major, minor) + return SuricataVersion( + major=major, minor=minor, patch=patch, short=short, full=full, + raw=buf) + return None + +def get_version(path): + """Get a SuricataVersion named tuple describing the version. + + If no path argument is found, the envionment PATH will be + searched. + """ + if not path: + return None + output = subprocess.check_output([path, "-V"]) + if output: + return parse_version(output) + return None + +def test_configuration(suricata_path, suricata_conf=None, rule_filename=None): + """Test the Suricata configuration with -T.""" + tempdir = tempfile.mkdtemp() + test_command = [ + suricata_path, + "-T", + "-l", tempdir, + ] + if suricata_conf: + test_command += ["-c", suricata_conf] + if rule_filename: + test_command += ["-S", rule_filename] + + env = build_env() + env["SC_LOG_LEVEL"] = "Warning" + + logger.debug("Running %s; env=%s", " ".join(test_command), str(env)) + rc = subprocess.Popen(test_command, env=env).wait() + ret = True if rc == 0 else False + + # Cleanup the temp dir + shutil.rmtree(tempdir) + + return ret + +def build_env(): + env = os.environ.copy() + env["SC_LOG_FORMAT"] = "%t - <%d> -- " + env["SC_LOG_LEVEL"] = "Error" + env["ASAN_OPTIONS"] = "detect_leaks=0" + return env diff --git a/suricata/update/exceptions.py b/suricata/update/exceptions.py new file mode 100644 index 0000000..1f2c547 --- /dev/null +++ b/suricata/update/exceptions.py @@ -0,0 +1,21 @@ +# Copyright (C) 2017 Open Information Security Foundation +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +class ApplicationError(Exception): + pass + +class InvalidConfigurationError(ApplicationError): + pass diff --git a/suricata/update/extract.py b/suricata/update/extract.py new file mode 100644 index 0000000..20e4156 --- /dev/null +++ b/suricata/update/extract.py @@ -0,0 +1,68 @@ +# Copyright (C) 2017 Open Information Security Foundation +# Copyright (c) 2017 Jason Ish +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +from __future__ import print_function + +import tarfile +from zipfile import ZipFile + +def extract_tar(filename): + files = {} + + tf = tarfile.open(filename, mode="r:*") + + try: + while True: + member = tf.next() + if member is None: + break + if not member.isfile(): + continue + fileobj = tf.extractfile(member) + if fileobj: + # Remove leading /. + member_name = member.name.lstrip("/") + files[member_name] = fileobj.read() + finally: + tf.close() + + return files + +def extract_zip(filename): + files = {} + + with ZipFile(filename) as reader: + for name in reader.namelist(): + if name.endswith("/"): + continue + fixed_name = name.lstrip("/") + files[fixed_name] = reader.read(name) + + return files + +def try_extract(filename): + try: + return extract_tar(filename) + except: + pass + + try: + return extract_zip(filename) + except: + pass + + return None diff --git a/suricata/update/loghandler.py b/suricata/update/loghandler.py new file mode 100644 index 0000000..dc10504 --- /dev/null +++ b/suricata/update/loghandler.py @@ -0,0 +1,115 @@ +# Copyright (C) 2017 Open Information Security Foundation +# Copyright (c) 2016 Jason Ish +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +import sys +import os +import logging +import time + +# A list of secrets that will be replaced in the log output. +secrets = {} + + +def add_secret(secret, replacement): + """Register a secret to be masked. The secret will be replaced with: + <replacement> + """ + secrets[str(secret)] = str(replacement) + + +class SuriColourLogHandler(logging.StreamHandler): + """An alternative stream log handler that logs with Suricata inspired + log colours.""" + + GREEN = "\x1b[32m" + BLUE = "\x1b[34m" + REDB = "\x1b[1;31m" + YELLOW = "\x1b[33m" + RED = "\x1b[31m" + YELLOWB = "\x1b[1;33m" + ORANGE = "\x1b[38;5;208m" + RESET = "\x1b[0m" + + def formatTime(self, record): + lt = time.localtime(record.created) + t = "%d/%d/%d -- %02d:%02d:%02d" % (lt.tm_mday, + lt.tm_mon, + lt.tm_year, + lt.tm_hour, + lt.tm_min, + lt.tm_sec) + return "%s" % (t) + + def emit(self, record): + + if record.levelname == "ERROR": + level_prefix = self.REDB + message_prefix = self.REDB + elif record.levelname == "WARNING": + level_prefix = self.ORANGE + message_prefix = self.ORANGE + else: + level_prefix = self.YELLOW + message_prefix = "" + + if os.isatty(self.stream.fileno()): + self.stream.write("%s%s%s - <%s%s%s> -- %s%s%s\n" % ( + self.GREEN, + self.formatTime(record), + self.RESET, + level_prefix, + record.levelname.title(), + self.RESET, + message_prefix, + self.mask_secrets(record.getMessage()), + self.RESET)) + else: + self.stream.write("%s - <%s> -- %s\n" % ( + self.formatTime(record), + record.levelname.title(), + self.mask_secrets(record.getMessage()))) + + def mask_secrets(self, msg): + for secret in secrets: + msg = msg.replace(secret, "<%s>" % secrets[secret]) + return msg + + +class LessThanFilter(logging.Filter): + def __init__(self, exclusive_maximum, name=""): + super(LessThanFilter, self).__init__(name) + self.max_level = exclusive_maximum + + def filter(self, record): + return 1 if record.levelno < self.max_level else 0 + + +def configure_logging(): + if os.fstat(sys.stdout.fileno()) == os.fstat(sys.stderr.fileno()): + filter_stdout = True + else: + filter_stdout = False + logger = logging.getLogger() + logger.setLevel(logging.NOTSET) + logging_handler_out = SuriColourLogHandler(sys.stdout) + logging_handler_out.setLevel(logging.DEBUG) + if filter_stdout: + logging_handler_out.addFilter(LessThanFilter(logging.WARNING)) + logger.addHandler(logging_handler_out) + logging_handler_err = SuriColourLogHandler(sys.stderr) + logging_handler_err.setLevel(logging.WARNING) + logger.addHandler(logging_handler_err) diff --git a/suricata/update/main.py b/suricata/update/main.py new file mode 100644 index 0000000..18af7a8 --- /dev/null +++ b/suricata/update/main.py @@ -0,0 +1,1404 @@ +# Copyright (C) 2017-2022 Open Information Security Foundation +# Copyright (c) 2015-2017 Jason Ish +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +from __future__ import print_function + +import sys +import re +import os.path +import logging +import argparse +import time +import hashlib +import fnmatch +import subprocess +import shutil +import glob +import io +import tempfile +import signal +import errno +from collections import namedtuple + +try: + # Python 3. + from urllib.error import URLError +except ImportError: + # Python 2.7. + from urllib2 import URLError + +try: + import yaml +except: + print("error: pyyaml is required") + sys.exit(1) + +from suricata.update import ( + commands, + config, + configs, + engine, + exceptions, + extract, + loghandler, + net, + notes, + parsers, + rule as rule_mod, + sources, + util, + matchers as matchers_mod +) + +from suricata.update.version import version +try: + from suricata.update.revision import revision +except: + revision = None + +SourceFile = namedtuple("SourceFile", ["filename", "content"]) + +if sys.argv[0] == __file__: + sys.path.insert( + 0, os.path.abspath(os.path.join(__file__, "..", "..", ".."))) + +# Initialize logging, use colour if on a tty. +if len(logging.root.handlers) == 0: + logger = logging.getLogger() + loghandler.configure_logging() + logger.setLevel(level=logging.INFO) +else: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - <%(levelname)s> - %(message)s") + logger = logging.getLogger() + +# If Suricata is not found, default to this version. +DEFAULT_SURICATA_VERSION = "6.0.0" + +# The default filename to use for the output rule file. This is a +# single file concatenating all input rule files together. +DEFAULT_OUTPUT_RULE_FILENAME = "suricata.rules" + +INDEX_EXPIRATION_TIME = 60 * 60 * 24 * 14 + +# Rule keywords that come with files +file_kw = ["filemd5", "filesha1", "filesha256", "dataset"] + +def strict_error(msg): + logger.error(msg) + if config.args().fail: + sys.exit(1) + +class Fetch: + + def __init__(self): + self.istty = os.isatty(sys.stdout.fileno()) + + def check_checksum(self, tmp_filename, url, checksum_url=None): + try: + if not isinstance(checksum_url, str): + checksum_url = url[0] + ".md5" + net_arg=(checksum_url,url[1]) + local_checksum = hashlib.md5( + open(tmp_filename, "rb").read()).hexdigest().strip() + remote_checksum_buf = io.BytesIO() + logger.info("Checking %s." % (checksum_url)) + net.get(net_arg, remote_checksum_buf) + remote_checksum = remote_checksum_buf.getvalue().decode().strip() + logger.debug("Local checksum=|%s|; remote checksum=|%s|" % ( + local_checksum, remote_checksum)) + if local_checksum == remote_checksum: + os.utime(tmp_filename, None) + return True + except Exception as err: + logger.warning("Failed to check remote checksum: %s" % err) + return False + + def progress_hook(self, content_length, bytes_read): + if config.args().quiet or not self.istty: + return + if not content_length or content_length == 0: + percent = 0 + else: + percent = int((bytes_read / float(content_length)) * 100) + buf = " %3d%% - %-30s" % ( + percent, "%d/%d" % (bytes_read, content_length)) + sys.stdout.write(buf) + sys.stdout.flush() + sys.stdout.write("\b" * 38) + + def progress_hook_finish(self): + if config.args().quiet or not self.istty: + return + sys.stdout.write("\n") + sys.stdout.flush() + + def url_basename(self, url): + """ Return the base filename of the URL. """ + filename = os.path.basename(url).split("?", 1)[0] + return filename + + def get_tmp_filename(self, url): + url_hash = hashlib.md5(url.encode("utf-8")).hexdigest() + return os.path.join( + config.get_cache_dir(), + "%s-%s" % (url_hash, self.url_basename(url))) + + def fetch(self, url): + net_arg = url + checksum = url[2] + url = url[0] + tmp_filename = self.get_tmp_filename(url) + if config.args().offline: + if config.args().force: + logger.warning("Running offline, skipping download of %s", url) + logger.info("Using latest cached version of rule file: %s", url) + if not os.path.exists(tmp_filename): + logger.error("Can't proceed offline, " + "source {} has not yet been downloaded.".format(url)) + sys.exit(1) + return self.extract_files(tmp_filename) + if not config.args().force and os.path.exists(tmp_filename): + if not config.args().now and \ + time.time() - os.stat(tmp_filename).st_mtime < (60 * 15): + logger.info( + "Last download less than 15 minutes ago. Not downloading %s.", + url) + return self.extract_files(tmp_filename) + if checksum: + if self.check_checksum(tmp_filename, net_arg, checksum): + logger.info("Remote checksum has not changed. " + "Not fetching.") + return self.extract_files(tmp_filename) + if not os.path.exists(config.get_cache_dir()): + os.makedirs(config.get_cache_dir(), mode=0o770) + logger.info("Fetching %s." % (url)) + try: + tmp_fileobj = tempfile.NamedTemporaryFile() + net.get( + net_arg, + tmp_fileobj, + progress_hook=self.progress_hook) + shutil.copyfile(tmp_fileobj.name, tmp_filename) + tmp_fileobj.close() + except URLError as err: + if os.path.exists(tmp_filename): + if config.args().fail: + strict_error("Failed to fetch {}: {}".format(url, err)) + else: + logger.error("Failed to fetch {}, will use latest cached version: {}".format(url, err)) + return self.extract_files(tmp_filename) + raise err + except IOError as err: + self.progress_hook_finish() + logger.error("Failed to copy file: {}".format(err)) + sys.exit(1) + except Exception as err: + raise err + self.progress_hook_finish() + logger.info("Done.") + return self.extract_files(tmp_filename) + + def run(self, url=None): + files = {} + if url: + try: + fetched = self.fetch(url) + files.update(fetched) + except URLError as err: + url = url[0] if isinstance(url, tuple) else url + strict_error("Failed to fetch {}: {}".format(url, err)) + else: + for url in self.args.url: + files.update(self.fetch(url)) + return files + + def extract_files(self, filename): + files = extract.try_extract(filename) + if files: + return files + + # The file is not an archive, treat it as an individual file. + basename = os.path.basename(filename).split("-", 1)[1] + if not basename.endswith(".rules"): + basename = "{}.rules".format(basename) + files = {} + files[basename] = open(filename, "rb").read() + return files + +def load_filters(filename): + + filters = [] + + with open(filename) as fileobj: + for line in fileobj: + line = line.strip() + if not line or line.startswith("#"): + continue + line = line.rsplit(" #")[0].strip() + + try: + if line.startswith("metadata-add"): + rule_filter = matchers_mod.AddMetadataFilter.parse(line) + filters.append(rule_filter) + else: + line = re.sub(r'\\\$', '$', line) # needed to escape $ in pp + rule_filter = matchers_mod.ModifyRuleFilter.parse(line) + filters.append(rule_filter) + except Exception as err: + raise exceptions.ApplicationError( + "Failed to parse modify filter: {}".format(line)) + + return filters + +def load_drop_filters(filename): + matchers = load_matchers(filename) + filters = [] + + for matcher in matchers: + filters.append(matchers_mod.DropRuleFilter(matcher)) + + return filters + +def parse_matchers(fileobj): + matchers = [] + + for line in fileobj: + line = line.strip() + if not line or line.startswith("#"): + continue + line = line.rsplit(" #")[0] + matcher = matchers_mod.parse_rule_match(line) + if not matcher: + logger.warn("Failed to parse: \"%s\"" % (line)) + else: + matchers.append(matcher) + + return matchers + +def load_matchers(filename): + with open(filename) as fileobj: + return parse_matchers(fileobj) + +def load_local(local, files): + + """Load local files into the files dict.""" + if os.path.isdir(local): + for dirpath, dirnames, filenames in os.walk(local): + for filename in filenames: + if filename.endswith(".rules"): + path = os.path.join(local, filename) + load_local(path, files) + else: + local_files = glob.glob(local) + if len(local_files) == 0: + local_files.append(local) + for filename in local_files: + filename = os.path.realpath(filename) + logger.info("Loading local file %s" % (filename)) + if filename in files: + logger.warn( + "Local file %s overrides existing file of same name." % ( + filename)) + try: + with open(filename, "rb") as fileobj: + files.append(SourceFile(filename, fileobj.read())) + except Exception as err: + logger.error("Failed to open {}: {}".format(filename, err)) + +def load_dist_rules(files): + """Load the rule files provided by the Suricata distribution.""" + + # In the future hopefully we can just pull in all files from + # /usr/share/suricata/rules, but for now pull in the set of files + # known to have been provided by the Suricata source. + filenames = [ + "app-layer-events.rules", + "decoder-events.rules", + "dhcp-events.rules", + "dnp3-events.rules", + "dns-events.rules", + "files.rules", + "http-events.rules", + "ipsec-events.rules", + "kerberos-events.rules", + "modbus-events.rules", + "nfs-events.rules", + "ntp-events.rules", + "smb-events.rules", + "smtp-events.rules", + "stream-events.rules", + "tls-events.rules", + ] + + dist_rule_path = config.get(config.DIST_RULE_DIRECTORY_KEY) + if not dist_rule_path: + logger.warning("No distribution rule directory found.") + return + + if not os.path.exists(dist_rule_path): + logger.warning("Distribution rule directory not found: %s", + dist_rule_path) + return + + if os.path.exists(dist_rule_path): + if not os.access(dist_rule_path, os.R_OK): + logger.warning("Distribution rule path not readable: %s", + dist_rule_path) + return + for filename in filenames: + path = os.path.join(dist_rule_path, filename) + if not os.path.exists(path): + continue + if not os.access(path, os.R_OK): + logger.warning("Distribution rule file not readable: %s", + path) + continue + logger.info("Loading distribution rule file %s", path) + try: + with open(path, "rb") as fileobj: + files.append(SourceFile(path, fileobj.read())) + except Exception as err: + logger.error("Failed to open {}: {}".format(path, err)) + sys.exit(1) + +def load_classification(suriconf, files): + filename = os.path.join("suricata", "classification.config") + dirs = [] + classification_dict = {} + if "sysconfdir" in suriconf.build_info: + dirs.append(os.path.join(suriconf.build_info["sysconfdir"], filename)) + if "datarootdir" in suriconf.build_info: + dirs.append(os.path.join(suriconf.build_info["datarootdir"], filename)) + + for path in dirs: + if os.path.exists(path): + logger.debug("Loading {}".format(path)) + with open(path) as fp: + for line in fp: + if line.startswith("#") or not line.strip(): + continue + config_classification = line.split(":")[1].strip() + key, desc, priority = config_classification.split(",") + if key in classification_dict: + if classification_dict[key][1] >= priority: + continue + classification_dict[key] = [desc, priority, line.strip()] + + # Handle files from the sources + for filep in files: + logger.debug("Loading {}".format(filep[0])) + lines = filep[1].decode().split('\n') + for line in lines: + if line.startswith("#") or not line.strip(): + continue + config_classification = line.split(":")[1].strip() + key, desc, priority = config_classification.split(",") + if key in classification_dict: + if classification_dict[key][1] >= priority: + if classification_dict[key][1] > priority: + logger.warning("Found classification with same shortname \"{}\"," + " keeping the one with higher priority ({})".format( + key, classification_dict[key][1])) + continue + classification_dict[key] = [desc, priority, line.strip()] + + return classification_dict + +def manage_classification(suriconf, files): + if suriconf is None: + # Can't continue without a valid Suricata configuration + # object. + return + classification_dict = load_classification(suriconf, files) + path = os.path.join(config.get_output_dir(), "classification.config") + try: + logger.info("Writing {}".format(path)) + with open(path, "w+") as fp: + fp.writelines("{}\n".format(v[2]) for k, v in classification_dict.items()) + except (OSError, IOError) as err: + logger.error(err) + +def handle_dataset_files(rule, dep_files): + if not rule.enabled: + return + dataset_load = [el for el in (el.strip() for el in rule.dataset.split(",")) if el.startswith("load")] + if not dataset_load: + # No dataset load found. + return + dataset_filename = dataset_load[0].split(maxsplit=1)[1].strip() + + # Get the directory name the rule is from. + prefix = os.path.dirname(rule.group) + + # Construct the source filename. + source_filename = os.path.join(prefix, dataset_filename) + + # If a source filename starts with a "/", look for it on the filesystem. The archive + # unpackers will take care of removing a leading / so this shouldn't happen for + # downloaded rulesets. + if source_filename.startswith("/"): + if not os.path.exists(source_filename): + logger.warn("Local dataset file '{}' was not found for rule {}, rule will be disabled".format(source_filename, rule.idstr)) + rule.enabled = False + return + dataset_contents = open(source_filename, "rb").read() + else: + if not source_filename in dep_files: + logger.warn("Dataset file '{}' was not found for rule {}, rule will be disabled".format(dataset_filename, rule.idstr)) + rule.enabled = False + return + dataset_contents = dep_files[source_filename] + + source_filename_hash = hashlib.md5(source_filename.encode()).hexdigest() + new_rule = re.sub(r"(dataset.*?load\s+){}".format(dataset_filename), r"\g<1>datasets/{}".format(source_filename_hash), rule.format()) + dest_filename = os.path.join(config.get_output_dir(), "datasets", source_filename_hash) + dest_dir = os.path.dirname(dest_filename) + logger.debug("Copying dataset file {} to {}".format(dataset_filename, dest_filename)) + try: + os.makedirs(dest_dir, exist_ok=True) + except Exception as err: + logger.error("Failed to create directory {}: {}".format(dest_dir, err)) + return + with open(dest_filename, "w") as fp: + fp.write(dataset_contents.decode("utf-8")) + return new_rule + +def handle_filehash_files(rule, dep_files, fhash): + if not rule.enabled: + return + filehash_fname = rule.get(fhash) + + # Get the directory name the rule is from. + prefix = os.path.dirname(rule.group) + + source_filename = os.path.join(prefix, filehash_fname) + dest_filename = source_filename[len(prefix) + len(os.path.sep):] + logger.debug("dest_filename={}".format(dest_filename)) + + if source_filename not in dep_files: + logger.error("{} file {} was not found".format(fhash, filehash_fname)) + else: + logger.debug("Copying %s file %s to output directory" % (fhash, filehash_fname)) + filepath = os.path.join(config.get_output_dir(), os.path.dirname(dest_filename)) + logger.debug("filepath: %s" % filepath) + try: + os.makedirs(filepath) + except OSError as oserr: + if oserr.errno != errno.EEXIST: + logger.error(oserr) + sys.exit(1) + output_filename = os.path.join(filepath, os.path.basename(filehash_fname)) + logger.debug("output fname: %s" % output_filename) + with open(output_filename, "w") as fp: + fp.write(dep_files[source_filename].decode("utf-8")) + +def write_merged(filename, rulemap, dep_files): + + if not args.quiet: + # List of rule IDs that have been added. + added = [] + # List of rule objects that have been removed. + removed = [] + # List of rule IDs that have been modified. + modified = [] + + oldset = {} + if os.path.exists(filename): + for rule in rule_mod.parse_file(filename): + oldset[rule.id] = True + if not rule.id in rulemap: + removed.append(rule) + elif rule.format() != rulemap[rule.id].format(): + modified.append(rulemap[rule.id]) + + for key in rulemap: + if not key in oldset: + added.append(key) + + enabled = len([rule for rule in rulemap.values() if rule.enabled]) + logger.info("Writing rules to %s: total: %d; enabled: %d; " + "added: %d; removed %d; modified: %d" % ( + filename, + len(rulemap), + enabled, + len(added), + len(removed), + len(modified))) + tmp_filename = ".".join([filename, "tmp"]) + with io.open(tmp_filename, encoding="utf-8", mode="w") as fileobj: + for sid in rulemap: + rule = rulemap[sid] + reformatted = None + for kw in file_kw: + if kw in rule: + if "dataset" == kw: + reformatted = handle_dataset_files(rule, dep_files) + else: + handle_filehash_files(rule, dep_files, kw) + if reformatted: + print(reformatted, file=fileobj) + else: + print(rule.format(), file=fileobj) + os.rename(tmp_filename, filename) + +def write_to_directory(directory, files, rulemap, dep_files): + # List of rule IDs that have been added. + added = [] + # List of rule objects that have been removed. + removed = [] + # List of rule IDs that have been modified. + modified = [] + + oldset = {} + if not args.quiet: + for file in files: + outpath = os.path.join( + directory, os.path.basename(file.filename)) + + if os.path.exists(outpath): + for rule in rule_mod.parse_file(outpath): + oldset[rule.id] = True + if not rule.id in rulemap: + removed.append(rule) + elif rule.format() != rulemap[rule.id].format(): + modified.append(rule.id) + for key in rulemap: + if not key in oldset: + added.append(key) + + enabled = len([rule for rule in rulemap.values() if rule.enabled]) + logger.info("Writing rule files to directory %s: total: %d; " + "enabled: %d; added: %d; removed %d; modified: %d" % ( + directory, + len(rulemap), + enabled, + len(added), + len(removed), + len(modified))) + + for file in sorted(files): + outpath = os.path.join( + directory, os.path.basename(file.filename)) + logger.debug("Writing %s." % outpath) + if not file.filename.endswith(".rules"): + open(outpath, "wb").write(file.content) + else: + content = [] + for line in io.StringIO(file.content.decode("utf-8")): + rule = rule_mod.parse(line) + if not rule or rule.id not in rulemap: + content.append(line.strip()) + else: + reformatted = None + for kw in file_kw: + if kw in rule: + if "dataset" == kw: + reformatted = handle_dataset_files(rulemap[rule.id], dep_files) + else: + handle_filehash_files(rulemap[rule.id], dep_files, kw) + if reformatted: + content.append(reformatted) + else: + content.append(rulemap[rule.id].format()) + tmp_filename = ".".join([outpath, "tmp"]) + io.open(tmp_filename, encoding="utf-8", mode="w").write( + u"\n".join(content)) + os.rename(tmp_filename, outpath) + +def write_yaml_fragment(filename, files): + logger.info( + "Writing YAML configuration fragment: %s" % (filename)) + with open(filename, "w") as fileobj: + print("%YAML 1.1", file=fileobj) + print("---", file=fileobj) + print("rule-files:", file=fileobj) + for fn in sorted(files): + if fn.endswith(".rules"): + print(" - %s" % os.path.basename(fn), file=fileobj) + +def write_sid_msg_map(filename, rulemap, version=1): + logger.info("Writing %s." % (filename)) + with io.open(filename, encoding="utf-8", mode="w") as fileobj: + for key in rulemap: + rule = rulemap[key] + if version == 2: + formatted = rule_mod.format_sidmsgmap_v2(rule) + if formatted: + print(formatted, file=fileobj) + else: + formatted = rule_mod.format_sidmsgmap(rule) + if formatted: + print(formatted, file=fileobj) + +def build_rule_map(rules): + """Turn a list of rules into a mapping of rules. + + In case of gid:sid conflict, the rule with the higher revision + number will be used. + """ + rulemap = {} + + for rule in rules: + if rule.id not in rulemap: + rulemap[rule.id] = rule + else: + if rule["rev"] == rulemap[rule.id]["rev"]: + logger.warning( + "Found duplicate rule SID {} with same revision, " + "keeping the first rule seen.".format(rule.sid)) + if rule["rev"] > rulemap[rule.id]["rev"]: + logger.warning( + "Found duplicate rule SID {}, " + "keeping the rule with greater revision.".format(rule.sid)) + rulemap[rule.id] = rule + + return rulemap + +def dump_sample_configs(): + + for filename in configs.filenames: + if os.path.exists(filename): + logger.info("File already exists, not dumping %s." % (filename)) + else: + logger.info("Creating %s." % (filename)) + shutil.copy(os.path.join(configs.directory, filename), filename) + +def resolve_flowbits(rulemap, disabled_rules): + flowbit_resolver = rule_mod.FlowbitResolver() + flowbit_enabled = set() + pass_ = 1 + while True: + logger.debug("Checking flowbits for pass %d of rules.", pass_) + flowbits = flowbit_resolver.get_required_flowbits(rulemap) + logger.debug("Found %d required flowbits.", len(flowbits)) + required_rules = flowbit_resolver.get_required_rules(rulemap, flowbits) + logger.debug( + "Found %d rules to enable for flowbit requirements (pass %d)", + len(required_rules), pass_) + if not required_rules: + logger.debug("All required rules enabled.") + break + for rule in required_rules: + if not rule.enabled and rule in disabled_rules: + logger.debug( + "Enabling previously disabled rule for flowbits: %s" % ( + rule.brief())) + rule.enabled = True + rule.noalert = True + flowbit_enabled.add(rule) + pass_ = pass_ + 1 + logger.info("Enabled %d rules for flowbit dependencies." % ( + len(flowbit_enabled))) + +class ThresholdProcessor: + + patterns = [ + re.compile(r"\s+(re:\"(.*)\")"), + re.compile(r"\s+(re:(.*?)),.*"), + re.compile(r"\s+(re:(.*))"), + ] + + def extract_regex(self, buf): + for pattern in self.patterns: + m = pattern.search(buf) + if m: + return m.group(2) + + def extract_pattern(self, buf): + regex = self.extract_regex(buf) + if regex: + return re.compile(regex, re.I) + + def replace(self, threshold, rule): + for pattern in self.patterns: + m = pattern.search(threshold) + if m: + return threshold.replace( + m.group(1), "gen_id %d, sig_id %d" % (rule.gid, rule.sid)) + return threshold + + def process(self, filein, fileout, rulemap): + count = 0 + for line in filein: + line = line.rstrip() + if not line or line.startswith("#"): + print(line, file=fileout) + continue + pattern = self.extract_pattern(line) + if not pattern: + print(line, file=fileout) + else: + for rule in rulemap.values(): + if rule.enabled: + if pattern.search(rule.format()): + count += 1 + print("# %s" % (rule.brief()), file=fileout) + print(self.replace(line, rule), file=fileout) + print("", file=fileout) + logger.info("Generated %d thresholds to %s." % (count, fileout.name)) + +class FileTracker: + """Used to check if files are modified. + + Usage: Add files with add(filename) prior to modification. Test + with any_modified() which will return True if any of the checksums + have been modified. + + """ + + def __init__(self): + self.hashes = {} + + def add(self, filename): + checksum = self.md5(filename) + if not checksum: + logger.debug("Recording new file %s" % (filename)) + else: + logger.debug("Recording existing file %s with hash '%s'.", + filename, checksum) + self.hashes[filename] = checksum + + def md5(self, filename): + if not os.path.exists(filename): + return "" + else: + return hashlib.md5(open(filename, "rb").read()).hexdigest() + + def any_modified(self): + for filename in self.hashes: + if self.md5(filename) != self.hashes[filename]: + return True + return False + +def ignore_file(ignore_files, filename): + if not ignore_files: + return False + for pattern in ignore_files: + if fnmatch.fnmatch(os.path.basename(filename), pattern): + return True + return False + +def check_vars(suriconf, rulemap): + """Check that all vars referenced by a rule exist. If a var is not + found, disable the rule. + """ + if suriconf is None: + # Can't continue without a valid Suricata configuration + # object. + return + for rule_id in rulemap: + rule = rulemap[rule_id] + disable = False + for var in rule_mod.parse_var_names(rule["source_addr"]): + if not suriconf.has_key("vars.address-groups.%s" % (var)): + logger.warning( + "Rule has unknown source address var and will be disabled: %s: %s" % ( + var, rule.brief())) + notes.address_group_vars.add(var) + disable = True + for var in rule_mod.parse_var_names(rule["dest_addr"]): + if not suriconf.has_key("vars.address-groups.%s" % (var)): + logger.warning( + "Rule has unknown dest address var and will be disabled: %s: %s" % ( + var, rule.brief())) + notes.address_group_vars.add(var) + disable = True + for var in rule_mod.parse_var_names(rule["source_port"]): + if not suriconf.has_key("vars.port-groups.%s" % (var)): + logger.warning( + "Rule has unknown source port var and will be disabled: %s: %s" % ( + var, rule.brief())) + notes.port_group_vars.add(var) + disable = True + for var in rule_mod.parse_var_names(rule["dest_port"]): + if not suriconf.has_key("vars.port-groups.%s" % (var)): + logger.warning( + "Rule has unknown dest port var and will be disabled: %s: %s" % ( + var, rule.brief())) + notes.port_group_vars.add(var) + disable = True + + if disable: + rule.enabled = False + +def test_suricata(suricata_path): + if not suricata_path: + logger.info("No suricata application binary found, skipping test.") + return True + + if config.get("no-test"): + logger.info("Skipping test, disabled by configuration.") + return True + + if config.get("test-command"): + test_command = config.get("test-command") + logger.info("Testing Suricata configuration with: %s" % ( + test_command)) + env = { + "SURICATA_PATH": suricata_path, + "OUTPUT_DIR": config.get_output_dir(), + } + if not config.get("no-merge"): + env["OUTPUT_FILENAME"] = os.path.join( + config.get_output_dir(), DEFAULT_OUTPUT_RULE_FILENAME) + rc = subprocess.Popen(test_command, shell=True, env=env).wait() + if rc != 0: + return False + else: + logger.info("Testing with suricata -T.") + suricata_conf = config.get("suricata-conf") + if not config.get("no-merge"): + if not engine.test_configuration( + suricata_path, suricata_conf, + os.path.join( + config.get_output_dir(), + DEFAULT_OUTPUT_RULE_FILENAME)): + return False + else: + if not engine.test_configuration(suricata_path, suricata_conf): + return False + + return True + +def copytree(src, dst): + """A shutil.copytree like function that will copy the files from one + tree to another even if the path exists. + + """ + + for dirpath, dirnames, filenames in os.walk(src): + for filename in filenames: + src_path = os.path.join(dirpath, filename) + dst_path = os.path.join(dst, src_path[len(src) + 1:]) + if not os.path.exists(os.path.dirname(dst_path)): + os.makedirs(os.path.dirname(dst_path), mode=0o770) + shutil.copyfile(src_path, dst_path) + + # Also attempt to copy the stat bits, but this may fail + # if the owner of the file is not the same as the user + # running the program. + try: + shutil.copystat(src_path, dst_path) + except OSError as err: + logger.debug( + "Failed to copy stat info from %s to %s", src_path, + dst_path) + +def load_sources(suricata_version): + urls = [] + + http_header = None + checksum = True + + # Add any URLs added with the --url command line parameter. + if config.args().url: + for url in config.args().url: + urls.append((url, http_header, checksum)) + + # Get the new style sources. + enabled_sources = sources.get_enabled_sources() + + # Convert the Suricata version to a version string. + version_string = "%d.%d.%d" % ( + suricata_version.major, suricata_version.minor, + suricata_version.patch) + + # Construct the URL replacement parameters that are internal to + # suricata-update. + internal_params = {"__version__": version_string} + + # If we have new sources, we also need to load the index. + if enabled_sources: + index_filename = sources.get_index_filename() + if not os.path.exists(index_filename): + logger.warning("No index exists, will use bundled index.") + logger.warning("Please run suricata-update update-sources.") + if os.path.exists(index_filename) and time.time() - \ + os.stat(index_filename).st_mtime > INDEX_EXPIRATION_TIME: + logger.warning( + "Source index is older than 2 weeks. " + "Please update with suricata-update update-sources.") + index = sources.Index(index_filename) + + for (name, source) in enabled_sources.items(): + params = source["params"] if "params" in source else {} + params.update(internal_params) + if "url" in source: + # No need to go off to the index. + http_header = source.get("http-header") + checksum = source.get("checksum") + url = (source["url"] % params, http_header, checksum) + logger.debug("Resolved source %s to URL %s.", name, url[0]) + else: + if not index: + raise exceptions.ApplicationError( + "Source index is required for source %s; " + "run suricata-update update-sources" % (source["source"])) + source_config = index.get_source_by_name(name) + if source_config is None: + logger.warn("Source no longer exists in index and will not be fetched: {}".format(name)) + continue + try: + checksum = source_config["checksum"] + except: + checksum = True + url = (index.resolve_url(name, params), http_header, + checksum) + if "deprecated" in source_config: + logger.warn("Source has been deprecated: %s: %s" % ( + name, source_config["deprecated"])) + if "obsolete" in source_config: + logger.warn("Source is obsolete and will not be fetched: %s: %s" % ( + name, source_config["obsolete"])) + continue + logger.debug("Resolved source %s to URL %s.", name, url[0]) + urls.append(url) + + if config.get("sources"): + for url in config.get("sources"): + if not isinstance(url, str): + raise exceptions.InvalidConfigurationError( + "Invalid datatype for source URL: %s" % (str(url))) + url = (url % internal_params, http_header, checksum) + logger.debug("Adding source %s.", url) + urls.append(url) + + # If --etopen is on the command line, make sure its added. Or if + # there are no URLs, default to ET/Open. + if config.get("etopen") or not urls: + if not config.args().offline and not urls: + logger.info("No sources configured, will use Emerging Threats Open") + urls.append((sources.get_etopen_url(internal_params), http_header, + checksum)) + + # Converting the URLs to a set removed dupes. + urls = set(urls) + + # Now download each URL. + files = [] + for url in urls: + + # To de-duplicate filenames, add a prefix that is a hash of the URL. + prefix = hashlib.md5(url[0].encode()).hexdigest() + source_files = Fetch().run(url) + for key in source_files: + content = source_files[key] + key = os.path.join(prefix, key) + files.append(SourceFile(key, content)) + + # Now load local rules. + if config.get("local") is not None: + for local in config.get("local"): + load_local(local, files) + + return files + +def copytree_ignore_backup(src, names): + """ Returns files to ignore when doing a backup of the rules. """ + return [".cache"] + +def check_output_directory(output_dir): + """ Check that the output directory exists, creating it if it doesn't. """ + if not os.path.exists(output_dir): + logger.info("Creating directory %s." % (output_dir)) + try: + os.makedirs(output_dir, mode=0o770) + except Exception as err: + raise exceptions.ApplicationError( + "Failed to create directory %s: %s" % ( + output_dir, err)) + +# Check and disable ja3 rules if needed. +# +# Note: This is a bit of a quick fixup job for 5.0, but we should look +# at making feature handling more generic. +def disable_ja3(suriconf, rulemap, disabled_rules): + if suriconf and suriconf.build_info: + enabled = False + reason = None + logged = False + if "HAVE_NSS" not in suriconf.build_info["features"]: + reason = "Disabling ja3 rules as Suricata is built without libnss." + else: + # Check if disabled. Must be explicitly disabled, + # otherwise we'll keep ja3 rules enabled. + val = suriconf.get("app-layer.protocols.tls.ja3-fingerprints") + + # Prior to Suricata 5, leaving ja3-fingerprints undefined + # in the configuration disabled the feature. With 5.0, + # having it undefined will enable it as needed. + if not val: + if suriconf.build_info["version"].major < 5: + val = "no" + else: + val = "auto" + + if val and val.lower() not in ["1", "yes", "true", "auto"]: + reason = "Disabling ja3 rules as ja3 fingerprints are not enabled." + else: + enabled = True + + count = 0 + if not enabled: + for key, rule in rulemap.items(): + if "ja3" in rule["features"]: + if not logged: + logger.warn(reason) + logged = True + rule.enabled = False + disabled_rules.append(rule) + count += 1 + if count: + logger.info("%d ja3_hash rules disabled." % (count)) + +def _main(): + global args + args = parsers.parse_arg() + + # Go verbose or quiet sooner than later. + if args.verbose: + logger.setLevel(logging.DEBUG) + if args.quiet: + logger.setLevel(logging.WARNING) + + logger.debug("This is suricata-update version %s (rev: %s); Python: %s" % ( + version, revision, sys.version.replace("\n", "- "))) + + config.init(args) + + # Error out if any reserved/unimplemented arguments were set. + unimplemented_args = [ + "disable", + "enable", + "modify", + "drop", + ] + for arg in unimplemented_args: + if hasattr(args, arg) and getattr(args, arg): + logger.error("--{} not implemented".format(arg)) + return 1 + + suricata_path = config.get("suricata") + + # Now parse the Suricata version. If provided on the command line, + # use that, otherwise attempt to get it from Suricata. + if args.suricata_version: + # The Suricata version was passed on the command line, parse it. + suricata_version = engine.parse_version(args.suricata_version) + if not suricata_version: + logger.error("Failed to parse provided Suricata version: {}".format( + args.suricata_version)) + return 1 + logger.info("Forcing Suricata version to %s." % (suricata_version.full)) + elif suricata_path: + suricata_version = engine.get_version(suricata_path) + if suricata_version: + logger.info("Found Suricata version %s at %s." % ( + str(suricata_version.full), suricata_path)) + else: + logger.error("Failed to get Suricata version.") + return 1 + else: + logger.info( + "Using default Suricata version of %s", DEFAULT_SURICATA_VERSION) + suricata_version = engine.parse_version(DEFAULT_SURICATA_VERSION) + + # Provide the Suricata version to the net module to add to the + # User-Agent. + net.set_user_agent_suricata_version(suricata_version.full) + + if args.subcommand: + if args.subcommand == "check-versions" and hasattr(args, "func"): + return args.func(suricata_version) + elif hasattr(args, "func"): + return args.func() + elif args.subcommand != "update": + logger.error("Unknown command: {}".format(args.subcommand)) + return 1 + + if args.dump_sample_configs: + return dump_sample_configs() + + # If --no-ignore was provided, clear any ignores provided in the + # config. + if args.no_ignore: + config.set(config.IGNORE_KEY, []) + + file_tracker = FileTracker() + + disable_matchers = [] + enable_matchers = [] + modify_filters = [] + drop_filters = [] + + # Load user provided disable filters. + disable_conf_filename = config.get("disable-conf") + if disable_conf_filename: + if os.path.exists(disable_conf_filename): + logger.info("Loading %s.", disable_conf_filename) + disable_matchers += load_matchers(disable_conf_filename) + else: + logger.warn("disable-conf file does not exist: {}".format(disable_conf_filename)) + + # Load user provided enable filters. + enable_conf_filename = config.get("enable-conf") + if enable_conf_filename: + if os.path.exists(enable_conf_filename): + logger.info("Loading %s.", enable_conf_filename) + enable_matchers += load_matchers(enable_conf_filename) + else: + logger.warn("enable-conf file does not exist: {}".format(enable_conf_filename)) + + # Load user provided modify filters. + modify_conf_filename = config.get("modify-conf") + if modify_conf_filename: + if os.path.exists(modify_conf_filename): + logger.info("Loading %s.", modify_conf_filename) + modify_filters += load_filters(modify_conf_filename) + else: + logger.warn("modify-conf file does not exist: {}".format(modify_conf_filename)) + + # Load user provided drop filters. + drop_conf_filename = config.get("drop-conf") + if drop_conf_filename: + if os.path.exists(drop_conf_filename): + logger.info("Loading %s.", drop_conf_filename) + drop_filters += load_drop_filters(drop_conf_filename) + else: + logger.warn("drop-conf file does not exist: {}".format(drop_conf_filename)) + + # Load the Suricata configuration if we can. + suriconf = None + if config.get("suricata-conf") and \ + os.path.exists(config.get("suricata-conf")) and \ + suricata_path and os.path.exists(suricata_path): + logger.info("Loading %s",config.get("suricata-conf")) + try: + suriconf = engine.Configuration.load( + config.get("suricata-conf"), suricata_path=suricata_path) + except subprocess.CalledProcessError: + return 1 + + # Disable rule that are for app-layers that are not enabled. + if suriconf: + for key in suriconf.keys(): + m = re.match(r"app-layer\.protocols\.([^\.]+)\.enabled", key) + if m: + proto = m.group(1) + if not suriconf.is_true(key, ["detection-only"]): + logger.info("Disabling rules for protocol %s", proto) + disable_matchers.append(matchers_mod.ProtoRuleMatcher(proto)) + elif proto == "smb" and suriconf.build_info: + # Special case for SMB rules. For versions less + # than 5, disable smb rules if Rust is not + # available. + if suriconf.build_info["version"].major < 5: + if not "RUST" in suriconf.build_info["features"]: + logger.info("Disabling rules for protocol {}".format(proto)) + disable_matchers.append(matchers_mod.ProtoRuleMatcher(proto)) + + # Check that the cache directory exists and is writable. + if not os.path.exists(config.get_cache_dir()): + try: + os.makedirs(config.get_cache_dir(), mode=0o770) + except Exception as err: + logger.warning( + "Cache directory does not exist and could not be created. " + "/var/tmp will be used instead.") + config.set_cache_dir("/var/tmp") + + files = load_sources(suricata_version) + + load_dist_rules(files) + + rules = [] + classification_files = [] + dep_files = {} + for entry in sorted(files, key = lambda e: e.filename): + if "classification.config" in entry.filename: + classification_files.append((entry.filename, entry.content)) + continue + if not entry.filename.endswith(".rules"): + dep_files.update({entry.filename: entry.content}) + continue + if ignore_file(config.get("ignore"), entry.filename): + logger.info("Ignoring file {}".format(entry.filename)) + continue + logger.debug("Parsing {}".format(entry.filename)) + rules += rule_mod.parse_fileobj(io.BytesIO(entry.content), entry.filename) + + rulemap = build_rule_map(rules) + logger.info("Loaded %d rules." % (len(rules))) + + # Counts of user enabled and modified rules. + enable_count = 0 + modify_count = 0 + drop_count = 0 + + # List of rules disabled by user. Used for counting, and to log + # rules that are re-enabled to meet flowbit requirements. + disabled_rules = [] + + for key, rule in rulemap.items(): + + # To avoid duplicate counts when a rule has more than one modification + # to it, we track the actions here then update the counts at the end. + enabled = False + modified = False + dropped = False + + for matcher in disable_matchers: + if rule.enabled and matcher.match(rule): + logger.debug("Disabling: %s" % (rule.brief())) + rule.enabled = False + disabled_rules.append(rule) + + for matcher in enable_matchers: + if not rule.enabled and matcher.match(rule): + logger.debug("Enabling: %s" % (rule.brief())) + rule.enabled = True + enabled = True + + for fltr in drop_filters: + if fltr.match(rule): + rule = fltr.run(rule) + dropped = True + + for fltr in modify_filters: + if fltr.match(rule): + rule = fltr.run(rule) + modified = True + + if enabled: + enable_count += 1 + if modified: + modify_count += 1 + if dropped: + drop_count += 1 + + rulemap[key] = rule + + # Check if we should disable ja3 rules. + try: + disable_ja3(suriconf, rulemap, disabled_rules) + except Exception as err: + logger.error("Failed to dynamically disable ja3 rules: {}".format(err)) + + # Check rule vars, disabling rules that use unknown vars. + check_vars(suriconf, rulemap) + + logger.info("Disabled %d rules." % (len(disabled_rules))) + logger.info("Enabled %d rules." % (enable_count)) + logger.info("Modified %d rules." % (modify_count)) + logger.info("Dropped %d rules." % (drop_count)) + + # Fixup flowbits. + resolve_flowbits(rulemap, disabled_rules) + + # Check that output directory exists, creating it if needed. + check_output_directory(config.get_output_dir()) + + # Check that output directory is writable. + if not os.access(config.get_output_dir(), os.W_OK): + logger.error( + "Output directory is not writable: {}".format(config.get_output_dir())) + return 1 + + # Backup the output directory. + logger.info("Backing up current rules.") + backup_directory = util.mktempdir() + shutil.copytree(config.get_output_dir(), os.path.join( + backup_directory, "backup"), ignore=copytree_ignore_backup) + + if not args.no_merge: + # The default, write out a merged file. + output_filename = os.path.join( + config.get_output_dir(), DEFAULT_OUTPUT_RULE_FILENAME) + file_tracker.add(output_filename) + write_merged(os.path.join(output_filename), rulemap, dep_files) + else: + for file in files: + file_tracker.add( + os.path.join( + config.get_output_dir(), os.path.basename(file.filename))) + write_to_directory(config.get_output_dir(), files, rulemap, dep_files) + + manage_classification(suriconf, classification_files) + + if args.yaml_fragment: + file_tracker.add(args.yaml_fragment) + write_yaml_fragment(args.yaml_fragment, files) + + if args.sid_msg_map: + write_sid_msg_map(args.sid_msg_map, rulemap, version=1) + if args.sid_msg_map_2: + write_sid_msg_map(args.sid_msg_map_2, rulemap, version=2) + + if args.threshold_in and args.threshold_out: + file_tracker.add(args.threshold_out) + threshold_processor = ThresholdProcessor() + threshold_processor.process( + open(args.threshold_in), open(args.threshold_out, "w"), rulemap) + + if not args.force and not file_tracker.any_modified(): + logger.info("No changes detected, exiting.") + notes.dump_notes() + return 0 + + # Set these containers to None to fee the memory before testing Suricata which + # may consume a lot of memory by itself. Ideally we should refactor this large + # function into multiple methods so these go out of scope and get removed + # automatically. + rulemap = None + rules = None + files = None + + if not test_suricata(suricata_path): + logger.error("Suricata test failed, aborting.") + logger.error("Restoring previous rules.") + copytree( + os.path.join(backup_directory, "backup"), config.get_output_dir()) + return 1 + + if not config.args().no_reload and config.get("reload-command"): + logger.info("Running %s." % (config.get("reload-command"))) + rc = subprocess.Popen(config.get("reload-command"), shell=True).wait() + if rc != 0: + logger.error("Reload command exited with error: {}".format(rc)) + + logger.info("Done.") + + notes.dump_notes() + + return 0 + +def signal_handler(signal, frame): + print('Program interrupted. Aborting...') + sys.exit(1) + +def main(): + signal.signal(signal.SIGINT, signal_handler) + try: + sys.exit(_main()) + except exceptions.ApplicationError as err: + logger.error(err) + sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/suricata/update/maps.py b/suricata/update/maps.py new file mode 100644 index 0000000..8a34f27 --- /dev/null +++ b/suricata/update/maps.py @@ -0,0 +1,215 @@ +# Copyright (C) 2017 Open Information Security Foundation +# Copyright (c) 2013 Jason Ish +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +"""Provide mappings from ID's to descriptions. + +Includes mapping classes for event ID messages and classification +information. +""" + +from __future__ import print_function + +import re + +class SignatureMap(object): + """SignatureMap maps signature IDs to a signature info dict. + + The signature map can be build up from classification.config, + gen-msg.map, and new and old-style sid-msg.map files. + + The dict's in the map will have at a minimum the following + fields: + + * gid *(int)* + * sid *(int)* + * msg *(string)* + * refs *(list of strings)* + + Signatures loaded from a new style sid-msg.map file will also have + *rev*, *classification* and *priority* fields. + + Example:: + + >>> from idstools import maps + >>> sigmap = maps.SignatureMap() + >>> sigmap.load_generator_map(open("tests/gen-msg.map")) + >>> sigmap.load_signature_map(open("tests/sid-msg-v2.map")) + >>> print(sigmap.get(1, 2495)) + {'classification': 'misc-attack', 'rev': 8, 'priority': 0, 'gid': 1, + 'sid': 2495, + 'msg': 'GPL NETBIOS SMB DCEPRC ORPCThis request flood attempt', + 'ref': ['bugtraq,8811', 'cve,2003-0813', 'nessus,12206', + 'url,www.microsoft.com/technet/security/bulletin/MS04-011.mspx']} + + """ + + def __init__(self): + self.map = {} + + def size(self): + return len(self.map) + + def get(self, generator_id, signature_id): + """Get signature info by generator_id and signature_id. + + :param generator_id: The generator id of the signature to lookup. + :param signature_id: The signature id of the signature to lookup. + + For convenience, if the generator_id is 3 and the signature is + not found, a second lookup will be done using a generator_id + of 1. + + """ + + key = (generator_id, signature_id) + sig = self.map.get(key) + if sig is None and generator_id == 3: + return self.get(1, signature_id) + return sig + + def load_generator_map(self, fileobj): + """Load the generator message map (gen-msg.map) from a + file-like object. + + """ + for line in fileobj: + line = line.strip() + if not line or line.startswith("#"): + continue + gid, sid, msg = [part.strip() for part in line.split("||")] + entry = { + "gid": int(gid), + "sid": int(sid), + "msg": msg, + "refs": [], + } + self.map[(entry["gid"], entry["sid"])] = entry + + def load_signature_map(self, fileobj, defaultgid=1): + """Load signature message map (sid-msg.map) from a file-like + object. + + """ + + for line in fileobj: + line = line.strip() + if not line or line.startswith("#"): + continue + parts = [p.strip() for p in line.split("||")] + + # If we have at least 6 parts, attempt to parse as a v2 + # signature map file. + try: + entry = { + "gid": int(parts[0]), + "sid": int(parts[1]), + "rev": int(parts[2]), + "classification": parts[3], + "priority": int(parts[4]), + "msg": parts[5], + "ref": parts[6:], + } + except: + entry = { + "gid": defaultgid, + "sid": int(parts[0]), + "msg": parts[1], + "ref": parts[2:], + } + self.map[(entry["gid"], entry["sid"])] = entry + +class ClassificationMap(object): + """ClassificationMap maps classification IDs and names to a dict + object describing a classification. + + :param fileobj: (Optional) A file like object to load + classifications from on initialization. + + The classification dicts stored in the map have the following + fields: + + * name *(string)* + * description *(string)* + * priority *(int)* + + Example:: + + >>> from idstools import maps + >>> classmap = maps.ClassificationMap() + >>> classmap.load_from_file(open("tests/classification.config")) + + >>> classmap.get(3) + {'priority': 2, 'name': 'bad-unknown', 'description': 'Potentially Bad Traffic'} + >>> classmap.get_by_name("bad-unknown") + {'priority': 2, 'name': 'bad-unknown', 'description': 'Potentially Bad Traffic'} + + """ + + def __init__(self, fileobj=None): + self.id_map = [] + self.name_map = {} + + if fileobj: + self.load_from_file(fileobj) + + def size(self): + return len(self.id_map) + + def add(self, classification): + """Add a classification to the map.""" + self.id_map.append(classification) + self.name_map[classification["name"]] = classification + + def get(self, class_id): + """Get a classification by ID. + + :param class_id: The classification ID to get. + + :returns: A dict describing the classification or None. + + """ + if 0 < class_id <= len(self.id_map): + return self.id_map[class_id - 1] + else: + return None + + def get_by_name(self, name): + """Get a classification by name. + + :param name: The name of the classification + + :returns: A dict describing the classification or None. + + """ + if name in self.name_map: + return self.name_map[name] + else: + return None + + def load_from_file(self, fileobj): + """Load classifications from a Snort style + classification.config file object. + + """ + pattern = "config classification: ([^,]+),([^,]+),([^,]+)" + for line in fileobj: + m = re.match(pattern, line.strip()) + if m: + self.add({ + "name": m.group(1), + "description": m.group(2), + "priority": int(m.group(3))}) diff --git a/suricata/update/matchers.py b/suricata/update/matchers.py new file mode 100644 index 0000000..56a9e29 --- /dev/null +++ b/suricata/update/matchers.py @@ -0,0 +1,331 @@ +# Copyright (C) 2017 Open Information Security Foundation +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +# This module contains functions for matching rules for disabling, +# enabling, converting to drop or modification. + +import re +import os.path +import logging +import shlex +import fnmatch +import suricata.update.rule + + +logger = logging.getLogger() + + +class AllRuleMatcher(object): + """Matcher object to match all rules. """ + + def match(self, rule): + return True + + @classmethod + def parse(cls, buf): + if buf.strip() == "*": + return cls() + return None + + +class ProtoRuleMatcher: + """A rule matcher that matches on the protocol of a rule.""" + + def __init__(self, proto): + self.proto = proto + + def match(self, rule): + return rule.proto == self.proto + + +class IdRuleMatcher(object): + """Matcher object to match an idstools rule object by its signature + ID.""" + + def __init__(self, generatorId=None, signatureId=None): + self.signatureIds = [] + if generatorId and signatureId: + self.signatureIds.append((generatorId, signatureId)) + + def match(self, rule): + for (generatorId, signatureId) in self.signatureIds: + if generatorId == rule.gid and signatureId == rule.sid: + return True + return False + + @classmethod + def parse(cls, buf): + matcher = cls() + + for entry in buf.split(","): + entry = entry.strip() + + parts = entry.split(":", 1) + if not parts: + return None + if len(parts) == 1: + try: + signatureId = int(parts[0]) + matcher.signatureIds.append((1, signatureId)) + except: + return None + else: + try: + generatorId = int(parts[0]) + signatureId = int(parts[1]) + matcher.signatureIds.append((generatorId, signatureId)) + except: + return None + + return matcher + + +class FilenameMatcher(object): + """Matcher object to match a rule by its filename. This is similar to + a group but has no specifier prefix. + """ + + def __init__(self, pattern): + self.pattern = pattern + + def match(self, rule): + if hasattr(rule, "group") and rule.group is not None: + return fnmatch.fnmatch(rule.group, self.pattern) + return False + + @classmethod + def parse(cls, buf): + if buf.startswith("filename:"): + try: + group = buf.split(":", 1)[1] + return cls(group.strip()) + except: + pass + return None + + +class GroupMatcher(object): + """Matcher object to match an idstools rule object by its group (ie: + filename). + + The group is just the basename of the rule file with or without + extension. + + Examples: + - emerging-shellcode + - emerging-trojan.rules + + """ + + def __init__(self, pattern): + self.pattern = pattern + + def match(self, rule): + if hasattr(rule, "group") and rule.group is not None: + if fnmatch.fnmatch(os.path.basename(rule.group), self.pattern): + return True + # Try matching against the rule group without the file + # extension. + if fnmatch.fnmatch( + os.path.splitext( + os.path.basename(rule.group))[0], self.pattern): + return True + return False + + @classmethod + def parse(cls, buf): + if buf.startswith("group:"): + try: + logger.debug("Parsing group matcher: %s" % (buf)) + group = buf.split(":", 1)[1] + return cls(group.strip()) + except: + pass + if buf.endswith(".rules"): + return cls(buf.strip()) + return None + + +class ReRuleMatcher(object): + """Matcher object to match an idstools rule object by regular + expression.""" + + def __init__(self, pattern): + self.pattern = pattern + + def match(self, rule): + if self.pattern.search(rule.raw): + return True + return False + + @classmethod + def parse(cls, buf): + if buf.startswith("re:"): + try: + logger.debug("Parsing regex matcher: %s" % (buf)) + patternstr = buf.split(":", 1)[1].strip() + pattern = re.compile(patternstr, re.I) + return cls(pattern) + except: + pass + return None + + +class MetadataRuleMatch(object): + """ Matcher that matches on key/value style metadata fields. Case insensitive. """ + + def __init__(self, key, value): + self.key = key + self.value = value + + def match(self, rule): + for entry in rule.metadata: + parts = entry.strip().split(" ", 1) + if parts[0].strip().lower() == self.key and parts[1].strip().lower() == self.value: + print(rule) + return True + return False + + @classmethod + def parse(cls, buf): + print(buf) + if buf.startswith("metadata:"): + buf = buf.split(":", 1)[1].strip() + parts = buf.split(" ", 1) + if len(parts) == 2: + key = parts[0].strip().lower() + val = parts[1].strip().lower() + return cls(key, val) + return None + + +class ModifyRuleFilter(object): + """Filter to modify an idstools rule object. + + Important note: This filter does not modify the rule inplace, but + instead returns a new rule object with the modification. + """ + + def __init__(self, matcher, pattern, repl): + self.matcher = matcher + self.pattern = pattern + self.repl = repl + + def match(self, rule): + return self.matcher.match(rule) + + def run(self, rule): + modified_rule = self.pattern.sub(self.repl, rule.format()) + parsed = suricata.update.rule.parse(modified_rule, rule.group) + if parsed is None: + logger.error("Modification of rule %s results in invalid rule: %s", + rule.idstr, modified_rule) + return rule + return parsed + + @classmethod + def parse(cls, buf): + tokens = shlex.split(buf) + if len(tokens) == 3: + matchstring, a, b = tokens + elif len(tokens) > 3 and tokens[0] == "modifysid": + matchstring, a, b = tokens[1], tokens[2], tokens[4] + else: + raise Exception("Bad number of arguments.") + matcher = parse_rule_match(matchstring) + if not matcher: + raise Exception("Bad match string: %s" % (matchstring)) + pattern = re.compile(a) + + # Convert Oinkmaster backticks to Python. + b = re.sub(r"\$\{(\d+)\}", "\\\\\\1", b) + + return cls(matcher, pattern, b) + + +class DropRuleFilter(object): + """ Filter to modify an idstools rule object to a drop rule. """ + + def __init__(self, matcher): + self.matcher = matcher + + def match(self, rule): + if rule["noalert"]: + return False + return self.matcher.match(rule) + + def run(self, rule): + drop_rule = suricata.update.rule.parse(re.sub( + r"^\w+", "drop", rule.raw)) + drop_rule.enabled = rule.enabled + return drop_rule + +class AddMetadataFilter(object): + + def __init__(self, matcher, key, val): + self.matcher = matcher + self.key = key + self.val = val + + def match(self, rule): + return self.matcher.match(rule) + + def run(self, rule): + new_rule_string = re.sub(r";\s*\)$", "; metadata: {} {};)".format(self.key, self.val), rule.format()) + new_rule = suricata.update.rule.parse(new_rule_string, rule.group) + if not new_rule: + logger.error("Rule is not valid after adding metadata: [{}]: {}".format(rule.idstr, new_rule_string)) + return rule + return new_rule + + @classmethod + def parse(cls, buf): + try: + command, match_string, key, val = shlex.split(buf) + except: + raise Exception("metadata-add: invalid number of arguments") + matcher = parse_rule_match(match_string) + if not matcher: + raise Exception("Bad match string: %s" % (matchstring)) + return cls(matcher, key, val) + + +def parse_rule_match(match): + matcher = AllRuleMatcher.parse(match) + if matcher: + return matcher + + matcher = IdRuleMatcher.parse(match) + if matcher: + return matcher + + matcher = ReRuleMatcher.parse(match) + if matcher: + return matcher + + matcher = FilenameMatcher.parse(match) + if matcher: + return matcher + + matcher = GroupMatcher.parse(match) + if matcher: + return matcher + + matcher = MetadataRuleMatch.parse(match) + if matcher: + return matcher + + return None diff --git a/suricata/update/net.py b/suricata/update/net.py new file mode 100644 index 0000000..eac060e --- /dev/null +++ b/suricata/update/net.py @@ -0,0 +1,175 @@ +# Copyright (C) 2017 Open Information Security Foundation +# Copyright (c) 2013 Jason Ish +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +""" Module for network related operations. """ + +import platform +import logging +import ssl +import re + +try: + # Python 3.3... + from urllib.request import urlopen, build_opener + from urllib.error import HTTPError + from urllib.request import HTTPSHandler +except ImportError: + # Python 2.6, 2.7. + from urllib2 import urlopen, build_opener + from urllib2 import HTTPError + from urllib2 import HTTPSHandler + +from suricata.update.version import version +from suricata.update import config +from suricata.update import osinfo + +logger = logging.getLogger() + +# Number of bytes to read at a time in a GET request. +GET_BLOCK_SIZE = 8192 + +user_agent_suricata_verison = "Unknown" +custom_user_agent = None + +def set_custom_user_agent(ua): + global custom_user_agent + custom_user_agent = ua + +def set_user_agent_suricata_version(version): + global user_agent_suricata_verison + user_agent_suricata_verison = version + +def build_user_agent(): + params = [] + has_custom_user_agent = config.has("user-agent") + if has_custom_user_agent: + user_agent = config.get("user-agent") + if user_agent is None or len(user_agent.strip()) == 0: + logger.debug("Suppressing HTTP User-Agent header") + return None + return user_agent + + params = [] + try: + params.append("OS: {}".format(platform.system())) + except Exception as err: + logger.error("Failed to set user-agent OS: {}".format(str(err))) + try: + params.append("CPU: {}".format(osinfo.arch())) + except Exception as err: + logger.error("Failed to set user-agent architecture: {}".format(str(err))) + try: + params.append("Python: {}".format(platform.python_version())) + except Exception as err: + logger.error("Failed to set user-agent python version: {}".format(str(err))) + try: + params.append("Dist: {}".format(osinfo.dist())) + except Exception as err: + logger.error("Failed to set user-agent distribution: {}".format(str(err))) + + params.append("Suricata: %s" % (user_agent_suricata_verison)) + + return "Suricata-Update/%s (%s)" % ( + version, "; ".join(params)) + + +def is_header_clean(header): + if len(header) != 2: + return False + name, val = header[0].strip(), header[1].strip() + if re.match( r"^[\w-]+$", name) and re.match(r"^[\w\s -~]+$", val): + return True + return False + + +def get(url, fileobj, progress_hook=None): + """ Perform a GET request against a URL writing the contents into + the provided file-like object. + + :param url: The URL to fetch + :param fileobj: The fileobj to write the content to + :param progress_hook: The function to call with progress updates + + :returns: Returns a tuple containing the number of bytes read and + the result of the info() function from urllib2.urlopen(). + + :raises: Exceptions from urllib2.urlopen() and writing to the + provided fileobj may occur. + """ + + user_agent = build_user_agent() + + try: + # Wrap in a try as Python versions prior to 2.7.9 don't have + # create_default_context, but some distros have backported it. + ssl_context = ssl.create_default_context() + if config.get("no-check-certificate"): + logger.debug("Disabling SSL/TLS certificate verification.") + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + opener = build_opener(HTTPSHandler(context=ssl_context)) + except: + opener = build_opener() + + if user_agent: + logger.debug("Setting HTTP User-Agent to %s", user_agent) + http_headers = [("User-Agent", user_agent)] + else: + http_headers = [(header, value) for header, + value in opener.addheaders if header.lower() != "user-agent"] + if isinstance(url, tuple): + header = url[1].split(":") if url[1] is not None else None + if header and is_header_clean(header=header): + name, val = header[0].strip(), header[1].strip() + logger.debug("Setting HTTP header %s to %s", name, val) + http_headers.append((name, val)) + elif header: + logger.error("Header not set as it does not meet the criteria") + url = url[0] + opener.addheaders = http_headers + + try: + remote = opener.open(url, timeout=30) + except ValueError as ve: + logger.error(ve) + else: + info = remote.info() + content_length = info.get("content-length") + content_length = int(content_length) if content_length else 0 + bytes_read = 0 + while True: + buf = remote.read(GET_BLOCK_SIZE) + if not buf: + # EOF + break + bytes_read += len(buf) + fileobj.write(buf) + if progress_hook: + progress_hook(content_length, bytes_read) + remote.close() + fileobj.flush() + return bytes_read, info + + +if __name__ == "__main__": + + import sys + + try: + get(sys.argv[1], sys.stdout) + except Exception as err: + print("ERROR: %s" % (err)) diff --git a/suricata/update/notes.py b/suricata/update/notes.py new file mode 100644 index 0000000..6288781 --- /dev/null +++ b/suricata/update/notes.py @@ -0,0 +1,60 @@ +# Copyright (C) 2018 Open Information Security Foundation +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +from __future__ import print_function + +import textwrap + +# Address group notes. +address_group_vars = set() + +# Port group notes. +port_group_vars = set() + +# Template for missing address-group variable. +missing_address_group_var_template = """ +A rule has been disabled due to the unknown address-group variable +%(var)s being used. You may want to add this variable to your Suricata +configuration file. +""" + +# Template for missing port-group variable. +missing_port_group_var_template = """ +A rule has been disabled due to the unknown port-group variable +%(var)s being used. You may want to add this variable to your Suricata +configuration file. +""" + +def render_note(note): + lines = textwrap.wrap(note.strip().replace("\n", " ")) + print("* %s" % (lines[0])) + for line in lines[1:]: + print(" %s" % (line)) + +def dump_notes(): + notes = [] + + for var in address_group_vars: + notes.append(missing_address_group_var_template % {"var": var}) + + for var in port_group_vars: + notes.append(missing_port_group_var_template % {"var": var}) + + if notes: + print("\nNotes:\n") + for note in notes: + render_note(note) + print("") diff --git a/suricata/update/osinfo.py b/suricata/update/osinfo.py new file mode 100644 index 0000000..c3e417b --- /dev/null +++ b/suricata/update/osinfo.py @@ -0,0 +1,75 @@ +# Copyright (C) 2020 Open Information Security Foundation +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +import re +import os.path +import platform + +def parse_os_release(filename="/etc/os-release"): + os_release={} + + if not os.path.exists(filename): + return os_release + + with open(filename) as fileobj: + for line in fileobj: + line = line.strip() + m = re.match(r"^(\w+)=\"?(.*?)\"?$", line) + if m: + os_release[m.group(1)] = m.group(2) + return os_release + +def dist(): + os_release = parse_os_release() + if "NAME" in os_release: + version_fields = ["VERSION_ID", "BUILD_ID"] + for vf in version_fields: + if vf in os_release: + return "{}/{}".format(os_release["NAME"], os_release[vf]) + return os_release["NAME"] + + # Arch may or may not have /etc/os-release, but its easy to + # detect. + if os.path.exists("/etc/arch-release"): + return "Arch Linux" + + # Uname fallback. + uname = platform.uname() + return "{}/{}".format(uname[0], uname[2]) + +normalized_arch = { + "amd64": "x86_64", +} + +def arch(): + """Return the machine architecture. """ + machine = platform.machine() + return normalized_arch.get(machine, machine) + +if __name__ == "__main__": + # Build a user agent string. Something like: + # Suricata-Update/1.2.0dev0 (OS: Linux; \ + # CPU: x86_64; \ + # Python: 3.7.7; \ + # Dist: Fedora/31; \ + # Suricata: 4.0.0) + parts = [] + parts.append("OS: {}".format(platform.system())) + parts.append("CPU: {}".format(arch())) + parts.append("Python: {}".format(platform.python_version())) + parts.append("Dist: {}".format(dist())) + + print("Suricata-Update/1.2.0dev0 ({})".format("; ".join(parts))) diff --git a/suricata/update/parsers.py b/suricata/update/parsers.py new file mode 100644 index 0000000..185205c --- /dev/null +++ b/suricata/update/parsers.py @@ -0,0 +1,268 @@ +# Copyright (C) 2017 Open Information Security Foundation +# Copyright (c) 2015-2017 Jason Ish +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +# This module contains functions for command line parsers for +# suricata-update + +import argparse +import sys +from suricata.update import commands, config + +from suricata.update.version import version + +try: + from suricata.update.revision import revision +except: + revision = None + +default_update_yaml = config.DEFAULT_UPDATE_YAML_PATH + +show_advanced = False + +if "-s" in sys.argv or "--show-advanced" in sys.argv: + show_advanced = True + +# Global arguments - command line options for suricata-update +global_arg = [ + (("-v", "--verbose"), + {'action': 'store_true', 'default': None, + 'help': "Be more verbose"}), + (("-q", "--quiet"), + {'action': 'store_true', 'default': None, + 'help': "Be quiet, warning and error messages only"}), + (("-D", "--data-dir"), + {'metavar': '<directory>', 'dest': 'data_dir', + 'help': "Data directory (default: /var/lib/suricata)"}), + (("-c", "--config"), + {'metavar': '<filename>', + 'help': "configuration file (default: %s)" % (default_update_yaml)}), + (("--suricata-conf",), + {'metavar': '<filename>', + 'help': "configuration file (default: /etc/suricata/suricata.yaml)"}), + (("--suricata",), + {'metavar': '<path>', + 'help': "Path to Suricata program"}), + (("--suricata-version",), + {'metavar': '<version>', + 'help': "Override Suricata version"}), + (("--user-agent",), + {'metavar': '<user-agent>', + 'help': "Set custom user-agent string" + if show_advanced else argparse.SUPPRESS}), + (("--no-check-certificate",), + {'action': 'store_true', 'default': None, + 'help': "Disable server SSL/TLS certificate verification" + if show_advanced else argparse.SUPPRESS}), + (("-V", "--version"), + {'action': 'store_true', 'default': False, + 'help': "Display version"}), + (("-s","--show-advanced"), + {'action': 'store_true', + 'help': "Show advanced options"}), +] + +# Update arguments - command line options for suricata-update +update_arg = [ + (("-o", "--output"), + {'metavar': '<directory>', 'dest': 'output', + 'help': "Directory to write rules to"}), + (("-f", "--force"), + {'action': 'store_true', 'default': False, + 'help': "Force operations that might otherwise be skipped"}), + (("--yaml-fragment",), + {'metavar': '<filename>', + 'help': "Output YAML fragment for rule inclusion" + if show_advanced else argparse.SUPPRESS}), + (("--url",), + {'metavar': '<url>', 'action': 'append', 'default': [], + 'help': "URL to use instead of auto-generating one " + "(can be specified multiple times)" + if show_advanced else argparse.SUPPRESS}), + (("--local",), + {'metavar': '<path>', 'action': 'append', 'default': [], + 'help': "Local rule files or directories " + "(can be specified multiple times)" + if show_advanced else argparse.SUPPRESS}), + (("--sid-msg-map",), + {'metavar': '<filename>', + 'help': "Generate a sid-msg.map file" + if show_advanced else argparse.SUPPRESS}), + (("--sid-msg-map-2",), + {'metavar': '<filename>', + 'help': "Generate a v2 sid-msg.map file" + if show_advanced else argparse.SUPPRESS}), + + (("--disable-conf",), + {'metavar': '<filename>', + 'help': "Filename of rule disable filters"}), + (("--enable-conf",), + {'metavar': '<filename>', + 'help': "Filename of rule enable filters"}), + (("--modify-conf",), + {'metavar': '<filename>', + 'help': "Filename of rule modification filters"}), + (("--drop-conf",), + {'metavar': '<filename>', + 'help': "Filename of drop rule filters"}), + + (("--ignore",), + {'metavar': '<pattern>', 'action': 'append', 'default': None, + 'help': "Filenames to ignore " + "(can be specified multiple times; default: *deleted.rules)" + if show_advanced else argparse.SUPPRESS}), + (("--no-ignore",), + {'action': 'store_true', 'default': False, + 'help': "Disables the ignore option." + if show_advanced else argparse.SUPPRESS}), + (("--threshold-in",), + {'metavar': '<filename>', + 'help': "Filename of rule thresholding configuration" + if show_advanced else argparse.SUPPRESS}), + (("--threshold-out",), + {'metavar': '<filename>', + 'help': "Output of processed threshold configuration" + if show_advanced else argparse.SUPPRESS}), + (("--dump-sample-configs",), + {'action': 'store_true', 'default': False, + 'help': "Dump sample config files to current directory" + if show_advanced else argparse.SUPPRESS}), + (("--etopen",), + {'action': 'store_true', + 'help': "Use ET-Open rules (default)" + if show_advanced else argparse.SUPPRESS}), + (("--reload-command",), + {'metavar': '<command>', + 'help': "Command to run after update if modified" + if show_advanced else argparse.SUPPRESS}), + (("--no-reload",), + {'action': 'store_true', 'default': False, + 'help': "Disable reload"}), + (("-T", "--test-command"), + {'metavar': '<command>', + 'help': "Command to test Suricata configuration" + if show_advanced else argparse.SUPPRESS}), + (("--no-test",), + {'action': 'store_true', 'default': None, + 'help': "Disable testing rules with Suricata"}), + (("--no-merge",), + {'action': 'store_true', 'default': False, + 'help': "Do not merge the rules into a single file" + if show_advanced else argparse.SUPPRESS}), + (("--offline",), + {'action': 'store_true', + 'help': "Run offline using most recent cached rules"}), + (("--fail",), + {'action': 'store_true', + 'help': "Strictly fail and exit in case of an error"}), + + # Hidden argument, --now to bypass the timebased bypass of + # updating a ruleset. + (("--now",), + {'default': False, 'action': 'store_true', 'help': argparse.SUPPRESS}), + + # The Python 2.7 argparse module does prefix matching which can be + # undesirable. Reserve some names here that would match existing + # options to prevent prefix matching. + (("--disable",), + {'default': False, 'help': argparse.SUPPRESS}), + (("--enable",), + {'default': False, 'help': argparse.SUPPRESS}), + (("--modify",), + {'default': False, 'help': argparse.SUPPRESS}), + (("--drop",), + {'default': False, 'help': argparse.SUPPRESS}) +] + + +def parse_global(): + global_parser = argparse.ArgumentParser(add_help=False) + + for arg, opts in global_arg: + global_parser.add_argument(*arg, **opts) + + return global_parser + + +def parse_update(subparsers, global_parser): + # The "update" (default) sub-command parser. + update_parser = subparsers.add_parser( + "update", add_help=True, parents=[global_parser], + formatter_class=argparse.RawDescriptionHelpFormatter) + + for arg, opts in update_arg: + update_parser.add_argument(*arg, **opts) + + return update_parser + + +def parse_commands(subparsers, global_parser): + commands.listsources.register(subparsers.add_parser( + "list-sources", parents=[global_parser])) + commands.listsources.register(subparsers.add_parser( + "list-enabled-sources", parents=[global_parser])) + commands.addsource.register(subparsers.add_parser( + "add-source", parents=[global_parser])) + commands.updatesources.register(subparsers.add_parser( + "update-sources", parents=[global_parser])) + commands.enablesource.register(subparsers.add_parser( + "enable-source", parents=[global_parser])) + commands.disablesource.register(subparsers.add_parser( + "disable-source", parents=[global_parser])) + commands.removesource.register(subparsers.add_parser( + "remove-source", parents=[global_parser])) + commands.checkversions.register(subparsers.add_parser( + "check-versions", parents=[global_parser])) + + +def parse_arg(): + global_parser = parse_global() + global_args, rem = global_parser.parse_known_args() + + if global_args.version: + revision_string = " (rev: %s)" % (revision) if revision else "" + print("suricata-update version {}{}".format(version, revision_string)) + sys.exit(0) + + if not rem or rem[0].startswith("-"): + rem.insert(0, "update") + + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(dest="subcommand", metavar="<command>") + update_parser = parse_update(subparsers, global_parser) + + update_parser.epilog = r"""other commands: + update-sources Update the source index + list-sources List available sources + enable-source Enable a source from the index + disable-source Disable an enabled source + remove-source Remove an enabled or disabled source + add-source Add a new source by URL + check-versions Check version of suricata-update +""" + + parse_commands(subparsers, global_parser) + + args = parser.parse_args(rem) + + # Merge global args into args. + for arg in vars(global_args): + if not hasattr(args, arg): + setattr(args, arg, getattr(global_args, arg)) + elif hasattr(args, arg) and getattr(args, arg) is None: + setattr(args, arg, getattr(global_args, arg)) + + return args diff --git a/suricata/update/rule.py b/suricata/update/rule.py new file mode 100644 index 0000000..169af6c --- /dev/null +++ b/suricata/update/rule.py @@ -0,0 +1,439 @@ +# Copyright (C) 2017-2019 Open Information Security Foundation +# Copyright (c) 2011 Jason Ish +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +""" Module for parsing Snort-like rules. + +Parsing is done using regular expressions and the job of this module +is to do its best at parsing out fields of interest from the rule +rather than perform a sanity check. + +The methods that parse multiple rules for a provided input +(parse_file, parse_fileobj) return a list of rules instead of dict +keyed by ID as its not the job of this module to detect or deal with +duplicate signature IDs. +""" + +from __future__ import print_function + +import sys +import re +import logging +import io + +logger = logging.getLogger(__name__) + +# Compile an re pattern for basic rule matching. +rule_pattern = re.compile(r"^(?P<enabled>#)*[\s#]*" + r"(?P<raw>" + r"(?P<header>[^()]+)" + r"\((?P<options>.*)\)" + r"$)") + +# Rule actions we expect to see. +actions = ( + "alert", "log", "pass", "activate", "dynamic", "drop", "reject", "sdrop") + +class NoEndOfOptionError(Exception): + """Exception raised when the end of option terminator (semicolon) is + missing.""" + pass + +class Rule(dict): + """Class representing a rule. + + The Rule class is a class that also acts like a dictionary. + + Dictionary fields: + + - **group**: The group the rule belongs to, typically the filename. + - **enabled**: True if rule is enabled (uncommented), False is + disabled (commented) + - **action**: The action of the rule (alert, pass, etc) as a + string + - **proto**: The protocol of the rule. + - **direction**: The direction string of the rule. + - **gid**: The gid of the rule as an integer + - **sid**: The sid of the rule as an integer + - **rev**: The revision of the rule as an integer + - **msg**: The rule message as a string + - **flowbits**: List of flowbit options in the rule + - **metadata**: Metadata values as a list + - **references**: References as a list + - **classtype**: The classification type + - **priority**: The rule priority, 0 if not provided + - **noalert**: Is the rule a noalert rule + - **features**: Features required by this rule + - **raw**: The raw rule as read from the file or buffer + + :param enabled: Optional parameter to set the enabled state of the rule + :param action: Optional parameter to set the action of the rule + :param group: Optional parameter to set the group (filename) of the rule + + """ + + def __init__(self, enabled=None, action=None, group=None): + dict.__init__(self) + self["enabled"] = enabled + self["action"] = action + self["proto"] = None + self["source_addr"] = None + self["source_port"] = None + self["direction"] = None + self["dest_addr"] = None + self["dest_port"] = None + self["group"] = group + self["gid"] = 1 + self["sid"] = None + self["rev"] = 0 + self["msg"] = None + self["flowbits"] = [] + self["metadata"] = [] + self["references"] = [] + self["classtype"] = None + self["priority"] = 0 + self["noalert"] = False + + self["features"] = [] + + self["raw"] = None + + def __getattr__(self, name): + return self[name] + + @property + def id(self): + """ The ID of the rule. + + :returns: A tuple (gid, sid) representing the ID of the rule + :rtype: A tuple of 2 ints + """ + return (int(self.gid), int(self.sid)) + + @property + def idstr(self): + """Return the gid and sid of the rule as a string formatted like: + '[GID:SID]'""" + return "[%s:%s]" % (str(self.gid), str(self.sid)) + + def brief(self): + """ A brief description of the rule. + + :returns: A brief description of the rule + :rtype: string + """ + return "%s[%d:%d] %s" % ( + "" if self.enabled else "# ", self.gid, self.sid, self.msg) + + def __hash__(self): + return self["raw"].__hash__() + + def __str__(self): + """ The string representation of the rule. + + If the rule is disabled it will be returned as commented out. + """ + return self.format() + + def format(self): + if self.noalert and not "noalert;" in self.raw: + self.raw = re.sub(r'( *sid\: *[0-9]+\;)', r' noalert;\1', self.raw) + return u"%s%s" % (u"" if self.enabled else u"# ", self.raw) + +def find_opt_end(options): + """ Find the end of an option (;) handling escapes. """ + offset = 0 + + while True: + i = options[offset:].find(";") + if options[offset + i - 1] == "\\": + offset += 2 + else: + return offset + i + +class BadSidError(Exception): + """Raises exception when sid is of type null""" + +def parse(buf, group=None): + """ Parse a single rule for a string buffer. + + :param buf: A string buffer containing a single Snort-like rule + + :returns: An instance of of :py:class:`.Rule` representing the parsed rule + """ + + if type(buf) == type(b""): + buf = buf.decode("utf-8") + buf = buf.strip() + + m = rule_pattern.match(buf) + if not m: + return None + + if m.group("enabled") == "#": + enabled = False + else: + enabled = True + + header = m.group("header").strip() + + rule = Rule(enabled=enabled, group=group) + + # If a decoder rule, the header will be one word. + if len(header.split(" ")) == 1: + action = header + direction = None + else: + states = ["action", + "proto", + "source_addr", + "source_port", + "direction", + "dest_addr", + "dest_port", + ] + state = 0 + + rem = header + while state < len(states): + if not rem: + return None + if rem[0] == "[": + end = rem.find("]") + if end < 0: + return + end += 1 + token = rem[:end].strip() + rem = rem[end:].strip() + else: + end = rem.find(" ") + if end < 0: + token = rem + rem = "" + else: + token = rem[:end].strip() + rem = rem[end:].strip() + + if states[state] == "action": + action = token + elif states[state] == "proto": + rule["proto"] = token + elif states[state] == "source_addr": + rule["source_addr"] = token + elif states[state] == "source_port": + rule["source_port"] = token + elif states[state] == "direction": + direction = token + elif states[state] == "dest_addr": + rule["dest_addr"] = token + elif states[state] == "dest_port": + rule["dest_port"] = token + + state += 1 + + if action not in actions: + return None + + rule["action"] = action + rule["direction"] = direction + rule["header"] = header + + options = m.group("options") + + while True: + if not options: + break + index = find_opt_end(options) + if index < 0: + raise NoEndOfOptionError("no end of option") + option = options[:index].strip() + options = options[index + 1:].strip() + + if option.find(":") > -1: + name, val = [x.strip() for x in option.split(":", 1)] + else: + name = option + val = None + + if name in ["gid", "sid", "rev"]: + rule[name] = int(val) + elif name == "metadata": + if not name in rule: + rule[name] = [] + rule[name] += [v.strip() for v in val.split(",")] + elif name == "flowbits": + rule.flowbits.append(val) + if val and val.find("noalert") > -1: + rule["noalert"] = True + elif name == "noalert": + rule["noalert"] = True + elif name == "reference": + rule.references.append(val) + elif name == "msg": + if val and val.startswith('"') and val.endswith('"'): + val = val[1:-1] + rule[name] = val + else: + rule[name] = val + + if name.startswith("ja3"): + rule["features"].append("ja3") + + if rule["msg"] is None: + rule["msg"] = "" + + if not rule["sid"]: + raise BadSidError("Sid cannot be of type null") + + rule["raw"] = m.group("raw").strip() + + return rule + +def parse_fileobj(fileobj, group=None): + """ Parse multiple rules from a file like object. + + Note: At this time rules must exist on one line. + + :param fileobj: A file like object to parse rules from. + + :returns: A list of :py:class:`.Rule` instances, one for each rule parsed + """ + rules = [] + buf = "" + for line in fileobj: + try: + if type(line) == type(b""): + line = line.decode() + except: + pass + if line.rstrip().endswith("\\"): + buf = "%s%s " % (buf, line.rstrip()[0:-1]) + continue + buf = buf + line + try: + rule = parse(buf, group) + if rule: + rules.append(rule) + except Exception as err: + logger.error("Failed to parse rule: %s: %s", buf.rstrip(), err) + buf = "" + return rules + +def parse_file(filename, group=None): + """ Parse multiple rules from the provided filename. + + :param filename: Name of file to parse rules from + + :returns: A list of :py:class:`.Rule` instances, one for each rule parsed + """ + with io.open(filename, encoding="utf-8") as fileobj: + return parse_fileobj(fileobj, group) + +class FlowbitResolver(object): + + setters = ["set", "setx", "unset", "toggle"] + getters = ["isset", "isnotset"] + + def __init__(self): + self.enabled = [] + + def resolve(self, rules): + required = self.get_required_flowbits(rules) + enabled = self.set_required_flowbits(rules, required) + if enabled: + self.enabled += enabled + return self.resolve(rules) + return self.enabled + + def set_required_flowbits(self, rules, required): + enabled = [] + for rule in [rule for rule in rules.values() if not rule.enabled]: + for option, value in map(self.parse_flowbit, rule.flowbits): + if option in self.setters and value in required: + rule.enabled = True + enabled.append(rule) + return enabled + + def get_required_rules(self, rulemap, flowbits, include_enabled=False): + """Returns a list of rules that need to be enabled in order to satisfy + the list of required flowbits. + + """ + required = [] + + for rule in [rule for rule in rulemap.values()]: + if not rule: + continue + for option, value in map(self.parse_flowbit, rule.flowbits): + if option in self.setters and value in flowbits: + if rule.enabled and not include_enabled: + continue + required.append(rule) + + return required + + def get_required_flowbits(self, rules): + required_flowbits = set() + for rule in [rule for rule in rules.values() if rule and rule.enabled]: + for option, value in map(self.parse_flowbit, rule.flowbits): + if option in self.getters: + required_flowbits.add(value) + return required_flowbits + + def parse_flowbit(self, flowbit): + tokens = flowbit.split(",", 1) + if len(tokens) == 1: + return tokens[0], None + elif len(tokens) == 2: + return tokens[0], tokens[1] + else: + raise Exception("Flowbit parse error on %s" % (flowbit)) + +def enable_flowbit_dependencies(rulemap): + """Helper function to resolve flowbits, wrapping the FlowbitResolver + class. """ + resolver = FlowbitResolver() + return resolver.resolve(rulemap) + +def format_sidmsgmap(rule): + """ Format a rule as a sid-msg.map entry. """ + try: + return " || ".join([str(rule.sid), rule.msg] + rule.references) + except: + logger.error("Failed to format rule as sid-msg.map: %s" % (str(rule))) + return None + +def format_sidmsgmap_v2(rule): + """ Format a rule as a v2 sid-msg.map entry. + + eg: + gid || sid || rev || classification || priority || msg || ref0 || refN + """ + try: + return " || ".join([ + str(rule.gid), str(rule.sid), str(rule.rev), + "NOCLASS" if rule.classtype is None else rule.classtype, + str(rule.priority), rule.msg] + rule.references) + except: + logger.error("Failed to format rule as sid-msg-v2.map: %s" % ( + str(rule))) + return None + +def parse_var_names(var): + """ Parse out the variable names from a string. """ + if var is None: + return [] + return re.findall(r"\$([\w_]+)", var) diff --git a/suricata/update/sources.py b/suricata/update/sources.py new file mode 100644 index 0000000..a5bc673 --- /dev/null +++ b/suricata/update/sources.py @@ -0,0 +1,207 @@ +# Copyright (C) 2017 Open Information Security Foundation +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +from __future__ import print_function + +import sys +import os +import logging +import io +import argparse + +import yaml + +from suricata.update import config +from suricata.update import net +from suricata.update import util +from suricata.update import loghandler +from suricata.update.data.index import index as bundled_index + +logger = logging.getLogger() + +DEFAULT_SOURCE_INDEX_URL = "https://www.openinfosecfoundation.org/rules/index.yaml" +SOURCE_INDEX_FILENAME = "index.yaml" + +DEFAULT_ETOPEN_URL = "https://rules.emergingthreats.net/open/suricata-%(__version__)s/emerging.rules.tar.gz" + +def get_source_directory(): + """Return the directory where source configuration files are kept.""" + return os.path.join(config.get_state_dir(), config.SOURCE_DIRECTORY) + +def get_index_filename(): + return os.path.join(config.get_cache_dir(), SOURCE_INDEX_FILENAME) + +def get_sources_from_dir(): + """Return names of all files existing in the sources dir""" + source_dir = get_source_directory() + source_names = [] + (_, _, fnames) = next(os.walk(source_dir)) + source_names = [".".join(fname.split('.')[:-1]) for fname in fnames] + return source_names + +def get_enabled_source_filename(name): + return os.path.join(get_source_directory(), "%s.yaml" % ( + safe_filename(name))) + +def get_disabled_source_filename(name): + return os.path.join(get_source_directory(), "%s.yaml.disabled" % ( + safe_filename(name))) + +def source_name_exists(name): + """Return True if a source already exists with name.""" + if os.path.exists(get_enabled_source_filename(name)) or \ + os.path.exists(get_disabled_source_filename(name)): + return True + return False + +def source_index_exists(config): + """Return True if the source index file exists.""" + return os.path.exists(get_index_filename()) + +def get_source_index_url(): + if os.getenv("SOURCE_INDEX_URL"): + return os.getenv("SOURCE_INDEX_URL") + return DEFAULT_SOURCE_INDEX_URL + +def save_source_config(source_config): + if not os.path.exists(get_source_directory()): + logger.info("Creating directory %s", get_source_directory()) + os.makedirs(get_source_directory()) + with open(get_enabled_source_filename(source_config.name), "w") as fileobj: + fileobj.write(yaml.safe_dump( + source_config.dict(), default_flow_style=False)) + +class SourceConfiguration: + + def __init__(self, name, header=None, url=None, + params={}, checksum=True): + self.name = name + self.url = url + self.params = params + self.header = header + self.checksum = checksum + + def dict(self): + d = { + "source": self.name, + } + if self.url: + d["url"] = self.url + if self.params: + d["params"] = self.params + if self.header: + d["http-header"] = self.header + if self.checksum: + d["checksum"] = self.checksum + return d + +class Index: + + def __init__(self, filename): + self.filename = filename + self.index = {} + self.load() + + def load(self): + if os.path.exists(self.filename): + index = yaml.safe_load(open(self.filename, "rb")) + self.index = index + else: + self.index = bundled_index + + def resolve_url(self, name, params={}): + if not name in self.index["sources"]: + raise Exception("Source name not in index: %s" % (name)) + source = self.index["sources"][name] + try: + return source["url"] % params + except KeyError as err: + raise Exception("Missing URL parameter: %s" % (str(err.args[0]))) + + def get_sources(self): + return self.index["sources"] + + def get_source_by_name(self, name): + if name in self.index["sources"]: + return self.index["sources"][name] + return None + + def get_versions(self): + try: + return self.index["versions"] + except KeyError: + logger.error("Version information not in index. Please update with suricata-update update-sources.") + sys.exit(1) + +def load_source_index(config): + return Index(get_index_filename()) + +def get_enabled_sources(): + """Return a map of enabled sources, keyed by name.""" + if not os.path.exists(get_source_directory()): + return {} + sources = {} + for dirpath, dirnames, filenames in os.walk(get_source_directory()): + for filename in filenames: + if filename.endswith(".yaml"): + path = os.path.join(dirpath, filename) + logger.debug("Loading source specification file {}".format(path)) + source = yaml.safe_load(open(path, "rb")) + + if not "source" in source: + logger.error("Source specification file missing field \"source\": filename: {}".format( + path)) + continue + + sources[source["source"]] = source + + if "params" in source: + for param in source["params"]: + if param.startswith("secret"): + loghandler.add_secret(source["params"][param], param) + + return sources + +def remove_source(config): + name = config.args.name + + enabled_source_filename = get_enabled_source_filename(name) + if os.path.exists(enabled_source_filename): + logger.debug("Deleting file %s.", enabled_source_filename) + os.remove(enabled_source_filename) + logger.info("Source %s removed, previously enabled.", name) + return 0 + + disabled_source_filename = get_disabled_source_filename(name) + if os.path.exists(disabled_source_filename): + logger.debug("Deleting file %s.", disabled_source_filename) + os.remove(disabled_source_filename) + logger.info("Source %s removed, previously disabled.", name) + return 0 + + logger.warning("Source %s does not exist.", name) + return 1 + +def safe_filename(name): + """Utility function to make a source short-name safe as a + filename.""" + name = name.replace("/", "-") + return name + +def get_etopen_url(params): + if os.getenv("ETOPEN_URL"): + return os.getenv("ETOPEN_URL") % params + return DEFAULT_ETOPEN_URL % params diff --git a/suricata/update/util.py b/suricata/update/util.py new file mode 100644 index 0000000..50788d8 --- /dev/null +++ b/suricata/update/util.py @@ -0,0 +1,98 @@ +# Copyright (C) 2017 Open Information Security Foundation +# Copyright (c) 2013 Jason Ish +# +# You can copy, redistribute or modify this Program under the terms of +# the GNU General Public License version 2 as published by the Free +# Software Foundation. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# version 2 along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA +# 02110-1301, USA. + +""" Module for utility functions that don't really fit anywhere else. """ + +import hashlib +import tempfile +import atexit +import shutil +import zipfile + +def md5_hexdigest(filename): + """ Compute the MD5 checksum for the contents of the provided filename. + + :param filename: Filename to computer MD5 checksum of. + + :returns: A string representing the hex value of the computed MD5. + """ + return hashlib.md5(open(filename).read().encode()).hexdigest() + +def mktempdir(delete_on_exit=True): + """ Create a temporary directory that is removed on exit. """ + tmpdir = tempfile.mkdtemp("suricata-update") + if delete_on_exit: + atexit.register(shutil.rmtree, tmpdir, ignore_errors=True) + return tmpdir + +class ZipArchiveReader: + + def __init__(self, zipfile): + self.zipfile = zipfile + self.names = self.zipfile.namelist() + + def __iter__(self): + return self + + def __enter__(self): + return self + + def __exit__(self, type, value, traceback): + self.zipfile.close() + + def next(self): + if self.names: + name = self.names.pop(0) + if name.endswith("/"): + # Is a directory, ignore + return self.next() + return name + raise StopIteration + + def open(self, name): + return self.zipfile.open(name) + + def read(self, name): + return self.zipfile.read(name) + + @classmethod + def from_fileobj(cls, fileobj): + zf = zipfile.ZipFile(fileobj) + return cls(zf) + +GREEN = "\x1b[32m" +BLUE = "\x1b[34m" +REDB = "\x1b[1;31m" +YELLOW = "\x1b[33m" +RED = "\x1b[31m" +YELLOWB = "\x1b[1;33m" +ORANGE = "\x1b[38;5;208m" +BRIGHT_MAGENTA = "\x1b[1;35m" +BRIGHT_CYAN = "\x1b[1;36m" +RESET = "\x1b[0m" + +def blue(msg): + return "%s%s%s" % (BLUE, msg, RESET) + +def bright_magenta(msg): + return "%s%s%s" % (BRIGHT_MAGENTA, msg, RESET) + +def bright_cyan(msg): + return "%s%s%s" % (BRIGHT_CYAN, msg, RESET) + +def orange(msg): + return "%s%s%s" % (ORANGE, msg, RESET) diff --git a/suricata/update/version.py b/suricata/update/version.py new file mode 100644 index 0000000..75d1205 --- /dev/null +++ b/suricata/update/version.py @@ -0,0 +1,7 @@ +# Version format: +# Release: 1.0.0 +# Beta: 1.0.0b1 +# Alpha: 1.0.0a1 +# Development: 1.0.0dev0 +# Release candidate: 1.0.0rc1 +version = "1.3.2" |