summaryrefslogtreecommitdiffstats
path: root/suricata/update/commands/updatesources.py
diff options
context:
space:
mode:
Diffstat (limited to 'suricata/update/commands/updatesources.py')
-rw-r--r--suricata/update/commands/updatesources.py105
1 files changed, 105 insertions, 0 deletions
diff --git a/suricata/update/commands/updatesources.py b/suricata/update/commands/updatesources.py
new file mode 100644
index 0000000..06a0d11
--- /dev/null
+++ b/suricata/update/commands/updatesources.py
@@ -0,0 +1,105 @@
+# Copyright (C) 2017 Open Information Security Foundation
+#
+# You can copy, redistribute or modify this Program under the terms of
+# the GNU General Public License version 2 as published by the Free
+# Software Foundation.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# version 2 along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
+# 02110-1301, USA.
+
+from __future__ import print_function
+
+import io
+import logging
+import os
+
+import yaml
+from suricata.update import config, exceptions, net, sources
+
+logger = logging.getLogger()
+
+
+def register(parser):
+ parser.set_defaults(func=update_sources)
+
+
+def get_initial_content():
+ initial_content = None
+ if os.path.exists(local_index_filename):
+ with open(local_index_filename, "r") as stream:
+ initial_content = yaml.safe_load(stream)
+ return initial_content
+
+
+def get_sources(before, after):
+ all_sources = {source: after[source]
+ for source in after if source not in before}
+ return all_sources
+
+
+def log_sources(sources_map):
+ for name, all_sources in sources_map.items():
+ if not all_sources:
+ continue
+ for source in all_sources:
+ logger.info("Source %s was %s", source, name)
+
+
+def compare_sources(initial_content, final_content):
+ if not initial_content:
+ logger.info("Adding all sources")
+ return
+ if initial_content == final_content:
+ logger.info("No change in sources")
+ return
+ initial_sources = initial_content.get("sources")
+ final_sources = final_content.get("sources")
+ added_sources = get_sources(before=initial_sources, after=final_sources)
+ removed_sources = get_sources(before=final_sources, after=initial_sources)
+ log_sources(sources_map={"added": added_sources,
+ "removed": removed_sources})
+ for source in set(initial_sources) & set(final_sources):
+ if initial_sources[source] != final_sources[source]:
+ logger.info("Source %s was changed", source)
+
+
+def write_and_compare(initial_content, fileobj):
+ try:
+ with open(local_index_filename, "wb") as outobj:
+ outobj.write(fileobj.getvalue())
+ except IOError as ioe:
+ logger.error("Failed to open directory: %s", ioe)
+ return 1
+ with open(local_index_filename, "rb") as stream:
+ final_content = yaml.safe_load(stream)
+ compare_sources(initial_content, final_content)
+ logger.info("Saved %s", local_index_filename)
+
+
+def update_sources():
+ global local_index_filename
+ local_index_filename = sources.get_index_filename()
+ initial_content = get_initial_content()
+ with io.BytesIO() as fileobj:
+ url = sources.get_source_index_url()
+ logger.info("Downloading %s", url)
+ try:
+ net.get(url, fileobj)
+ except Exception as err:
+ raise exceptions.ApplicationError(
+ "Failed to download index: %s: %s" % (url, err))
+ if not os.path.exists(config.get_cache_dir()):
+ try:
+ os.makedirs(config.get_cache_dir())
+ except Exception as err:
+ logger.error("Failed to create directory %s: %s",
+ config.get_cache_dir(), err)
+ return 1
+ write_and_compare(initial_content=initial_content, fileobj=fileobj)