summaryrefslogtreecommitdiffstats
path: root/eos_downloader/object_downloader.py
diff options
context:
space:
mode:
Diffstat (limited to 'eos_downloader/object_downloader.py')
-rw-r--r--eos_downloader/object_downloader.py513
1 files changed, 513 insertions, 0 deletions
diff --git a/eos_downloader/object_downloader.py b/eos_downloader/object_downloader.py
new file mode 100644
index 0000000..0420acb
--- /dev/null
+++ b/eos_downloader/object_downloader.py
@@ -0,0 +1,513 @@
+#!/usr/bin/python
+# coding: utf-8 -*-
+# flake8: noqa: F811
+# pylint: disable=too-many-instance-attributes
+# pylint: disable=too-many-arguments
+
+"""
+eos_downloader class definition
+"""
+
+from __future__ import (absolute_import, division, print_function,
+ unicode_literals, annotations)
+
+import base64
+import glob
+import hashlib
+import json
+import os
+import sys
+import xml.etree.ElementTree as ET
+from typing import Union
+
+import requests
+import rich
+from loguru import logger
+from rich import console
+from tqdm import tqdm
+
+from eos_downloader import (ARISTA_DOWNLOAD_URL, ARISTA_GET_SESSION,
+ ARISTA_SOFTWARE_FOLDER_TREE, EVE_QEMU_FOLDER_PATH,
+ MSG_INVALID_DATA, MSG_TOKEN_EXPIRED)
+from eos_downloader.data import DATA_MAPPING
+from eos_downloader.download import DownloadProgressBar
+
+# logger = logging.getLogger(__name__)
+
+console = rich.get_console()
+
+
+class ObjectDownloader():
+ """
+ ObjectDownloader Generic Object to download from Arista.com
+ """
+ def __init__(self, image: str, version: str, token: str, software: str = 'EOS', hash_method: str = 'md5sum'):
+ """
+ __init__ Class constructor
+
+ generic class constructor
+
+ Parameters
+ ----------
+ image : str
+ Type of image to download
+ version : str
+ Version of the package to download
+ token : str
+ Arista API token
+ software : str, optional
+ Package name to download (vEOS-lab, cEOS, EOS, ...), by default 'EOS'
+ hash_method : str, optional
+ Hash protocol to use to check download, by default 'md5sum'
+ """
+ self.software = software
+ self.image = image
+ self._version = version
+ self.token = token
+ self.folder_level = 0
+ self.session_id = None
+ self.filename = self._build_filename()
+ self.hash_method = hash_method
+ self.timeout = 5
+ # Logging
+ logger.debug(f'Filename built by _build_filename is {self.filename}')
+
+ def __str__(self) -> str:
+ return f'{self.software} - {self.image} - {self.version}'
+
+ # def __repr__(self):
+ # return str(self.__dict__)
+
+ @property
+ def version(self) -> str:
+ """Get version."""
+ return self._version
+
+ @version.setter
+ def version(self, value: str) -> None:
+ """Set version."""
+ self._version = value
+ self.filename = self._build_filename()
+
+ # ------------------------------------------------------------------------ #
+ # Internal METHODS
+ # ------------------------------------------------------------------------ #
+
+ def _build_filename(self) -> str:
+ """
+ _build_filename Helper to build filename to search on arista.com
+
+ Returns
+ -------
+ str:
+ Filename to search for on Arista.com
+ """
+ logger.info('start build')
+ if self.software in DATA_MAPPING:
+ logger.info(f'software in data mapping: {self.software}')
+ if self.image in DATA_MAPPING[self.software]:
+ logger.info(f'image in data mapping: {self.image}')
+ return f"{DATA_MAPPING[self.software][self.image]['prepend']}-{self.version}{DATA_MAPPING[self.software][self.image]['extension']}"
+ return f"{DATA_MAPPING[self.software]['default']['prepend']}-{self.version}{DATA_MAPPING[self.software]['default']['extension']}"
+ raise ValueError(f'Incorrect value for software {self.software}')
+
+ def _parse_xml_for_path(self, root_xml: ET.ElementTree, xpath: str, search_file: str) -> str:
+ # sourcery skip: remove-unnecessary-cast
+ """
+ _parse_xml Read and extract data from XML using XPATH
+
+ Get all interested nodes using XPATH and then get node that match search_file
+
+ Parameters
+ ----------
+ root_xml : ET.ElementTree
+ XML document
+ xpath : str
+ XPATH expression to filter XML
+ search_file : str
+ Filename to search for
+
+ Returns
+ -------
+ str
+ File Path on Arista server side
+ """
+ logger.debug(f'Using xpath {xpath}')
+ logger.debug(f'Search for file {search_file}')
+ console.print(f'🔎 Searching file {search_file}')
+ for node in root_xml.findall(xpath):
+ # logger.debug('Found {}', node.text)
+ if str(node.text).lower() == search_file.lower():
+ path = node.get('path')
+ console.print(f' -> Found file at {path}')
+ logger.info(f'Found {node.text} at {node.get("path")}')
+ return str(node.get('path')) if node.get('path') is not None else ''
+ logger.error(f'Requested file ({self.filename}) not found !')
+ return ''
+
+ def _get_hash(self, file_path: str) -> str:
+ """
+ _get_hash Download HASH file from Arista server
+
+ Parameters
+ ----------
+ file_path : str
+ Path of the HASH file
+
+ Returns
+ -------
+ str
+ Hash string read from HASH file downloaded from Arista.com
+ """
+ remote_hash_file = self._get_remote_hashpath(hash_method=self.hash_method)
+ hash_url = self._get_url(remote_file_path=remote_hash_file)
+ # hash_downloaded = self._download_file_raw(url=hash_url, file_path=file_path + "/" + os.path.basename(remote_hash_file))
+ dl_rich_progress_bar = DownloadProgressBar()
+ dl_rich_progress_bar.download(urls=[hash_url], dest_dir=file_path)
+ hash_downloaded = f"{file_path}/{os.path.basename(remote_hash_file)}"
+ hash_content = 'unset'
+ with open(hash_downloaded, 'r', encoding='utf-8') as f:
+ hash_content = f.read()
+ return hash_content.split(' ')[0]
+
+ @staticmethod
+ def _compute_hash_md5sum(file: str, hash_expected: str) -> bool:
+ """
+ _compute_hash_md5sum Compare MD5 sum
+
+ Do comparison between local md5 of the file and value provided by arista.com
+
+ Parameters
+ ----------
+ file : str
+ Local file to use for MD5 sum
+ hash_expected : str
+ MD5 from arista.com
+
+ Returns
+ -------
+ bool
+ True if both are equal, False if not
+ """
+ hash_md5 = hashlib.md5()
+ with open(file, "rb") as f:
+ for chunk in iter(lambda: f.read(4096), b""):
+ hash_md5.update(chunk)
+ if hash_md5.hexdigest() == hash_expected:
+ return True
+ logger.warning(f'Downloaded file is corrupt: local md5 ({hash_md5.hexdigest()}) is different to md5 from arista ({hash_expected})')
+ return False
+
+ @staticmethod
+ def _compute_hash_sh512sum(file: str, hash_expected: str) -> bool:
+ """
+ _compute_hash_sh512sum Compare SHA512 sum
+
+ Do comparison between local sha512 of the file and value provided by arista.com
+
+ Parameters
+ ----------
+ file : str
+ Local file to use for MD5 sum
+ hash_expected : str
+ SHA512 from arista.com
+
+ Returns
+ -------
+ bool
+ True if both are equal, False if not
+ """
+ hash_sha512 = hashlib.sha512()
+ with open(file, "rb") as f:
+ for chunk in iter(lambda: f.read(4096), b""):
+ hash_sha512.update(chunk)
+ if hash_sha512.hexdigest() == hash_expected:
+ return True
+ logger.warning(f'Downloaded file is corrupt: local sha512 ({hash_sha512.hexdigest()}) is different to sha512 from arista ({hash_expected})')
+ return False
+
+ def _get_folder_tree(self) -> ET.ElementTree:
+ """
+ _get_folder_tree Download XML tree from Arista server
+
+ Returns
+ -------
+ ET.ElementTree
+ XML document
+ """
+ if self.session_id is None:
+ self.authenticate()
+ jsonpost = {'sessionCode': self.session_id}
+ result = requests.post(ARISTA_SOFTWARE_FOLDER_TREE, data=json.dumps(jsonpost), timeout=self.timeout)
+ try:
+ folder_tree = result.json()["data"]["xml"]
+ return ET.ElementTree(ET.fromstring(folder_tree))
+ except KeyError as error:
+ logger.error(MSG_INVALID_DATA)
+ logger.error(f'Server returned: {error}')
+ console.print(f'❌ {MSG_INVALID_DATA}', style="bold red")
+ sys.exit(1)
+
+ def _get_remote_filepath(self) -> str:
+ """
+ _get_remote_filepath Helper to get path of the file to download
+
+ Set XPATH and return result of _parse_xml for the file to download
+
+ Returns
+ -------
+ str
+ Remote path of the file to download
+ """
+ root = self._get_folder_tree()
+ logger.debug("GET XML content from ARISTA.com")
+ xpath = f'.//dir[@label="{self.software}"]//file'
+ return self._parse_xml_for_path(root_xml=root, xpath=xpath, search_file=self.filename)
+
+ def _get_remote_hashpath(self, hash_method: str = 'md5sum') -> str:
+ """
+ _get_remote_hashpath Helper to get path of the hash's file to download
+
+ Set XPATH and return result of _parse_xml for the file to download
+
+ Returns
+ -------
+ str
+ Remote path of the hash's file to download
+ """
+ root = self._get_folder_tree()
+ logger.debug("GET XML content from ARISTA.com")
+ xpath = f'.//dir[@label="{self.software}"]//file'
+ return self._parse_xml_for_path(
+ root_xml=root,
+ xpath=xpath,
+ search_file=f'{self.filename}.{hash_method}',
+ )
+
+ def _get_url(self, remote_file_path: str) -> str:
+ """
+ _get_url Get URL to use for downloading file from Arista server
+
+ Send remote_file_path to get correct URL to use for download
+
+ Parameters
+ ----------
+ remote_file_path : str
+ Filepath from XML to use to get correct download link
+
+ Returns
+ -------
+ str
+ URL link to use for download
+ """
+ if self.session_id is None:
+ self.authenticate()
+ jsonpost = {'sessionCode': self.session_id, 'filePath': remote_file_path}
+ result = requests.post(ARISTA_DOWNLOAD_URL, data=json.dumps(jsonpost), timeout=self.timeout)
+ if 'data' in result.json() and 'url' in result.json()['data']:
+ # logger.debug('URL to download file is: {}', result.json())
+ return result.json()["data"]["url"]
+ logger.critical(f'Server returns following message: {result.json()}')
+ return ''
+
+ @staticmethod
+ def _download_file_raw(url: str, file_path: str) -> str:
+ """
+ _download_file Helper to download file from Arista.com
+
+ [extended_summary]
+
+ Parameters
+ ----------
+ url : str
+ URL provided by server for remote_file_path
+ file_path : str
+ Location where to save local file
+
+ Returns
+ -------
+ str
+ File path
+ """
+ chunkSize = 1024
+ r = requests.get(url, stream=True, timeout=5)
+ with open(file_path, 'wb') as f:
+ pbar = tqdm(unit="B", total=int(r.headers['Content-Length']), unit_scale=True, unit_divisor=1024)
+ for chunk in r.iter_content(chunk_size=chunkSize):
+ if chunk:
+ pbar.update(len(chunk))
+ f.write(chunk)
+ return file_path
+
+ def _download_file(self, file_path: str, filename: str, rich_interface: bool = True) -> Union[None, str]:
+ remote_file_path = self._get_remote_filepath()
+ logger.info(f'File found on arista server: {remote_file_path}')
+ file_url = self._get_url(remote_file_path=remote_file_path)
+ if file_url is not False:
+ if not rich_interface:
+ return self._download_file_raw(url=file_url, file_path=os.path.join(file_path, filename))
+ rich_downloader = DownloadProgressBar()
+ rich_downloader.download(urls=[file_url], dest_dir=file_path)
+ return os.path.join(file_path, filename)
+ logger.error(f'Cannot download file {file_path}')
+ return None
+
+ @staticmethod
+ def _create_destination_folder(path: str) -> None:
+ # os.makedirs(path, mode, exist_ok=True)
+ os.system(f'mkdir -p {path}')
+
+ @staticmethod
+ def _disable_ztp(file_path: str) -> None:
+ pass
+
+ # ------------------------------------------------------------------------ #
+ # Public METHODS
+ # ------------------------------------------------------------------------ #
+
+ def authenticate(self) -> bool:
+ """
+ authenticate Authenticate user on Arista.com server
+
+ Send API token and get a session-id from remote server.
+ Session-id will be used by all other functions.
+
+ Returns
+ -------
+ bool
+ True if authentication succeeds=, False in all other situations.
+ """
+ credentials = (base64.b64encode(self.token.encode())).decode("utf-8")
+ session_code_url = ARISTA_GET_SESSION
+ jsonpost = {'accessToken': credentials}
+
+ result = requests.post(session_code_url, data=json.dumps(jsonpost), timeout=self.timeout)
+
+ if result.json()["status"]["message"] in[ 'Access token expired', 'Invalid access token']:
+ console.print(f'❌ {MSG_TOKEN_EXPIRED}', style="bold red")
+ logger.error(MSG_TOKEN_EXPIRED)
+ return False
+
+ try:
+ if 'data' in result.json():
+ self.session_id = result.json()["data"]["session_code"]
+ logger.info('Authenticated on arista.com')
+ return True
+ logger.debug(f'{result.json()}')
+ return False
+ except KeyError as error_arista:
+ logger.error(f'Error: {error_arista}')
+ sys.exit(1)
+
+ def download_local(self, file_path: str, checksum: bool = False) -> bool:
+ # sourcery skip: move-assign
+ """
+ download_local Entrypoint for local download feature
+
+ Do local downnload feature:
+ - Get remote file path
+ - Get URL from Arista.com
+ - Download file
+ - Do HASH comparison (optional)
+
+ Parameters
+ ----------
+ file_path : str
+ Local path to save downloaded file
+ checksum : bool, optional
+ Execute checksum or not, by default False
+
+ Returns
+ -------
+ bool
+ True if everything went well, False if any problem appears
+ """
+ file_downloaded = str(self._download_file(file_path=file_path, filename=self.filename))
+
+ # Check file HASH
+ hash_result = False
+ if checksum:
+ logger.info('🚀 Running checksum validation')
+ console.print('🚀 Running checksum validation')
+ if self.hash_method == 'md5sum':
+ hash_expected = self._get_hash(file_path=file_path)
+ hash_result = self._compute_hash_md5sum(file=file_downloaded, hash_expected=hash_expected)
+ elif self.hash_method == 'sha512sum':
+ hash_expected = self._get_hash(file_path=file_path)
+ hash_result = self._compute_hash_sh512sum(file=file_downloaded, hash_expected=hash_expected)
+ if not hash_result:
+ logger.error('Downloaded file is corrupted, please check your connection')
+ console.print('❌ Downloaded file is corrupted, please check your connection')
+ return False
+ logger.info('Downloaded file is correct.')
+ console.print('✅ Downloaded file is correct.')
+ return True
+
+ def provision_eve(self, noztp: bool = False, checksum: bool = True) -> None:
+ # pylint: disable=unused-argument
+ """
+ provision_eve Entrypoint for EVE-NG download and provisioning
+
+ Do following actions:
+ - Get remote file path
+ - Get URL from file path
+ - Download file
+ - Convert file to qcow2 format
+ - Create new version to EVE-NG
+ - Disable ZTP (optional)
+
+ Parameters
+ ----------
+ noztp : bool, optional
+ Flag to deactivate ZTP in EOS image, by default False
+ checksum : bool, optional
+ Flag to ask for hash validation, by default True
+ """
+ # Build image name to use in folder path
+ eos_image_name = self.filename.rstrip(".vmdk").lower()
+ if noztp:
+ eos_image_name = f'{eos_image_name}-noztp'
+ # Create full path for EVE-NG
+ file_path = os.path.join(EVE_QEMU_FOLDER_PATH, eos_image_name.rstrip())
+ # Create folders in filesystem
+ self._create_destination_folder(path=file_path)
+
+ # Download file to local destination
+ file_downloaded = self._download_file(
+ file_path=file_path, filename=self.filename)
+
+ # Convert to QCOW2 format
+ file_qcow2 = os.path.join(file_path, "hda.qcow2")
+ logger.info('Converting VMDK to QCOW2 format')
+ console.print('🚀 Converting VMDK to QCOW2 format...')
+
+ os.system(f'$(which qemu-img) convert -f vmdk -O qcow2 {file_downloaded} {file_qcow2}')
+
+ logger.info('Applying unl_wrapper to fix permissions')
+ console.print('Applying unl_wrapper to fix permissions')
+
+ os.system('/opt/unetlab/wrappers/unl_wrapper -a fixpermissions')
+ os.system(f'rm -f {file_downloaded}')
+
+ if noztp:
+ self._disable_ztp(file_path=file_path)
+
+ def docker_import(self, image_name: str = "arista/ceos") -> None:
+ """
+ Import docker container to your docker server.
+
+ Import downloaded container to your local docker engine.
+
+ Args:
+ version (str):
+ image_name (str, optional): Image name to use. Defaults to "arista/ceos".
+ """
+ docker_image = f'{image_name}:{self.version}'
+ logger.info(f'Importing image {self.filename} to {docker_image}')
+ console.print(f'🚀 Importing image {self.filename} to {docker_image}')
+ os.system(f'$(which docker) import {self.filename} {docker_image}')
+ for filename in glob.glob(f'{self.filename}*'):
+ try:
+ os.remove(filename)
+ except FileNotFoundError:
+ console.print(f'File not found: {filename}')