diff options
Diffstat (limited to '')
-rw-r--r-- | eos_downloader/__init__.py | 47 | ||||
-rw-r--r-- | eos_downloader/cli/__init__.py | 0 | ||||
-rw-r--r-- | eos_downloader/cli/cli.py | 76 | ||||
-rw-r--r-- | eos_downloader/cli/debug/__init__.py | 0 | ||||
-rw-r--r-- | eos_downloader/cli/debug/commands.py | 53 | ||||
-rw-r--r-- | eos_downloader/cli/get/__init__.py | 0 | ||||
-rw-r--r-- | eos_downloader/cli/get/commands.py | 137 | ||||
-rw-r--r-- | eos_downloader/cli/info/__init__.py | 0 | ||||
-rw-r--r-- | eos_downloader/cli/info/commands.py | 87 | ||||
-rw-r--r-- | eos_downloader/cvp.py | 276 | ||||
-rw-r--r-- | eos_downloader/data.py | 93 | ||||
-rw-r--r-- | eos_downloader/download.py | 77 | ||||
-rw-r--r-- | eos_downloader/eos.py | 177 | ||||
-rw-r--r-- | eos_downloader/models/__init__.py | 0 | ||||
-rw-r--r-- | eos_downloader/models/version.py | 272 | ||||
-rw-r--r-- | eos_downloader/object_downloader.py | 513 | ||||
-rw-r--r-- | eos_downloader/tools.py | 13 |
17 files changed, 1821 insertions, 0 deletions
diff --git a/eos_downloader/__init__.py b/eos_downloader/__init__.py new file mode 100644 index 0000000..345ccf7 --- /dev/null +++ b/eos_downloader/__init__.py @@ -0,0 +1,47 @@ +#!/usr/bin/python +# coding: utf-8 -*- + +""" +EOS Downloader module. +""" + +from __future__ import (absolute_import, division, + print_function, unicode_literals, annotations) +import dataclasses +from typing import Any +import json +import importlib.metadata + +__author__ = '@titom73' +__email__ = 'tom@inetsix.net' +__date__ = '2022-03-16' +__version__ = importlib.metadata.version("eos-downloader") + +# __all__ = ["CvpAuthenticationItem", "CvFeatureManager", "EOSDownloader", "ObjectDownloader", "reverse"] + +ARISTA_GET_SESSION = "https://www.arista.com/custom_data/api/cvp/getSessionCode/" + +ARISTA_SOFTWARE_FOLDER_TREE = "https://www.arista.com/custom_data/api/cvp/getFolderTree/" + +ARISTA_DOWNLOAD_URL = "https://www.arista.com/custom_data/api/cvp/getDownloadLink/" + +MSG_TOKEN_EXPIRED = """The API token has expired. Please visit arista.com, click on your profile and +select Regenerate Token then re-run the script with the new token. +""" + +MSG_TOKEN_INVALID = """The API token is incorrect. Please visit arista.com, click on your profile and +check the Access Token. Then re-run the script with the correct token. +""" + +MSG_INVALID_DATA = """Invalid data returned by server +""" + +EVE_QEMU_FOLDER_PATH = '/opt/unetlab/addons/qemu/' + + +class EnhancedJSONEncoder(json.JSONEncoder): + """Custom JSon encoder.""" + def default(self, o: Any) -> Any: + if dataclasses.is_dataclass(o): + return dataclasses.asdict(o) + return super().default(o) diff --git a/eos_downloader/cli/__init__.py b/eos_downloader/cli/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/eos_downloader/cli/__init__.py diff --git a/eos_downloader/cli/cli.py b/eos_downloader/cli/cli.py new file mode 100644 index 0000000..ddd0dea --- /dev/null +++ b/eos_downloader/cli/cli.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python +# coding: utf-8 -*- +# pylint: disable=no-value-for-parameter +# pylint: disable=cyclic-import +# pylint: disable=too-many-arguments +# pylint: disable=unused-argument + + +""" +ARDL CLI Baseline. +""" + +import click +from rich.console import Console +import eos_downloader +from eos_downloader.cli.get import commands as get_commands +from eos_downloader.cli.debug import commands as debug_commands +from eos_downloader.cli.info import commands as info_commands + + +@click.group() +@click.pass_context +@click.option('--token', show_envvar=True, default=None, help='Arista Token from your customer account') +def ardl(ctx: click.Context, token: str) -> None: + """Arista Network Download CLI""" + ctx.ensure_object(dict) + ctx.obj['token'] = token + + +@click.command() +def version() -> None: + """Display version of ardl""" + console = Console() + console.print(f'ardl is running version {eos_downloader.__version__}') + + +@ardl.group(no_args_is_help=True) +@click.pass_context +def get(ctx: click.Context) -> None: + # pylint: disable=redefined-builtin + """Download Arista from Arista website""" + + +@ardl.group(no_args_is_help=True) +@click.pass_context +def info(ctx: click.Context) -> None: + # pylint: disable=redefined-builtin + """List information from Arista website""" + + +@ardl.group(no_args_is_help=True) +@click.pass_context +def debug(ctx: click.Context) -> None: + # pylint: disable=redefined-builtin + """Debug commands to work with ardl""" + +# ANTA CLI Execution + + +def cli() -> None: + """Load ANTA CLI""" + # Load group commands + get.add_command(get_commands.eos) + get.add_command(get_commands.cvp) + info.add_command(info_commands.eos_versions) + debug.add_command(debug_commands.xml) + ardl.add_command(version) + # Load CLI + ardl( + obj={}, + auto_envvar_prefix='arista' + ) + + +if __name__ == '__main__': + cli() diff --git a/eos_downloader/cli/debug/__init__.py b/eos_downloader/cli/debug/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/eos_downloader/cli/debug/__init__.py diff --git a/eos_downloader/cli/debug/commands.py b/eos_downloader/cli/debug/commands.py new file mode 100644 index 0000000..107b8a0 --- /dev/null +++ b/eos_downloader/cli/debug/commands.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python +# coding: utf-8 -*- +# pylint: disable=no-value-for-parameter +# pylint: disable=too-many-arguments +# pylint: disable=line-too-long +# pylint: disable=duplicate-code +# flake8: noqa E501 + +""" +Commands for ARDL CLI to get data. +""" + +import xml.etree.ElementTree as ET +from xml.dom import minidom + +import click +from loguru import logger +from rich.console import Console + +import eos_downloader.eos + + +@click.command() +@click.pass_context +@click.option('--output', default=str('arista.xml'), help='Path to save XML file', type=click.Path(), show_default=True) +@click.option('--log-level', '--log', help='Logging level of the command', default=None, type=click.Choice(['debug', 'info', 'warning', 'error', 'critical'], case_sensitive=False)) +def xml(ctx: click.Context, output: str, log_level: str) -> None: + # sourcery skip: remove-unnecessary-cast + """Extract XML directory structure""" + console = Console() + # Get from Context + token = ctx.obj['token'] + + logger.remove() + if log_level is not None: + logger.add("eos-downloader.log", rotation="10 MB", level=log_level.upper()) + + my_download = eos_downloader.eos.EOSDownloader( + image='unset', + software='EOS', + version='unset', + token=token, + hash_method='sha512sum') + + my_download.authenticate() + xml_object: ET.ElementTree = my_download._get_folder_tree() # pylint: disable=protected-access + xml_content = xml_object.getroot() + + xmlstr = minidom.parseString(ET.tostring(xml_content)).toprettyxml(indent=" ", newl='') + with open(output, "w", encoding='utf-8') as f: + f.write(str(xmlstr)) + + console.print(f'XML file saved in: { output }') diff --git a/eos_downloader/cli/get/__init__.py b/eos_downloader/cli/get/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/eos_downloader/cli/get/__init__.py diff --git a/eos_downloader/cli/get/commands.py b/eos_downloader/cli/get/commands.py new file mode 100644 index 0000000..13a8eec --- /dev/null +++ b/eos_downloader/cli/get/commands.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python +# coding: utf-8 -*- +# pylint: disable=no-value-for-parameter +# pylint: disable=too-many-arguments +# pylint: disable=line-too-long +# pylint: disable=redefined-builtin +# flake8: noqa E501 + +""" +Commands for ARDL CLI to get data. +""" + +import os +import sys +from typing import Union + +import click +from loguru import logger +from rich.console import Console + +import eos_downloader.eos +from eos_downloader.models.version import BASE_VERSION_STR, RTYPE_FEATURE, RTYPES + +EOS_IMAGE_TYPE = ['64', 'INT', '2GB-INT', 'cEOS', 'cEOS64', 'vEOS', 'vEOS-lab', 'EOS-2GB', 'default'] +CVP_IMAGE_TYPE = ['ova', 'rpm', 'kvm', 'upgrade'] + +@click.command(no_args_is_help=True) +@click.pass_context +@click.option('--image-type', default='default', help='EOS Image type', type=click.Choice(EOS_IMAGE_TYPE), required=True) +@click.option('--version', default=None, help='EOS version', type=str, required=False) +@click.option('--latest', '-l', is_flag=True, type=click.BOOL, default=False, help='Get latest version in given branch. If --branch is not use, get the latest branch with specific release type') +@click.option('--release-type', '-rtype', type=click.Choice(RTYPES, case_sensitive=False), default=RTYPE_FEATURE, help='EOS release type to search') +@click.option('--branch', '-b', type=click.STRING, default=None, help='EOS Branch to list releases') +@click.option('--docker-name', default='arista/ceos', help='Docker image name (default: arista/ceos)', type=str, show_default=True) +@click.option('--output', default=str(os.path.relpath(os.getcwd(), start=os.curdir)), help='Path to save image', type=click.Path(),show_default=True) +# Debugging +@click.option('--log-level', '--log', help='Logging level of the command', default=None, type=click.Choice(['debug', 'info', 'warning', 'error', 'critical'], case_sensitive=False)) +# Boolean triggers +@click.option('--eve-ng', is_flag=True, help='Run EVE-NG vEOS provisioning (only if CLI runs on an EVE-NG server)', default=False) +@click.option('--disable-ztp', is_flag=True, help='Disable ZTP process in vEOS image (only available with --eve-ng)', default=False) +@click.option('--import-docker', is_flag=True, help='Import docker image (only available with --image_type cEOSlab)', default=False) +def eos( + ctx: click.Context, image_type: str, output: str, log_level: str, eve_ng: bool, disable_ztp: bool, + import_docker: bool, docker_name: str, version: Union[str, None] = None, release_type: str = RTYPE_FEATURE, + latest: bool = False, branch: Union[str,None] = None + ) -> int: + """Download EOS image from Arista website""" + console = Console() + # Get from Context + token = ctx.obj['token'] + if token is None or token == '': + console.print('❗ Token is unset ! Please configure ARISTA_TOKEN or use --token option', style="bold red") + sys.exit(1) + + logger.remove() + if log_level is not None: + logger.add("eos-downloader.log", rotation="10 MB", level=log_level.upper()) + + console.print("🪐 [bold blue]eos-downloader[/bold blue] is starting...", ) + console.print(f' - Image Type: {image_type}') + console.print(f' - Version: {version}') + + + if version is not None: + my_download = eos_downloader.eos.EOSDownloader( + image=image_type, + software='EOS', + version=version, + token=token, + hash_method='sha512sum') + my_download.authenticate() + + elif latest: + my_download = eos_downloader.eos.EOSDownloader( + image=image_type, + software='EOS', + version='unset', + token=token, + hash_method='sha512sum') + my_download.authenticate() + if branch is None: + branch = str(my_download.latest_branch(rtype=release_type).branch) + latest_version = my_download.latest_eos(branch, rtype=release_type) + if str(latest_version) == BASE_VERSION_STR: + console.print(f'[red]Error[/red], cannot find any version in {branch} for {release_type} release type') + sys.exit(1) + my_download.version = str(latest_version) + + if eve_ng: + my_download.provision_eve(noztp=disable_ztp, checksum=True) + else: + my_download.download_local(file_path=output, checksum=True) + + if import_docker: + my_download.docker_import( + image_name=docker_name + ) + console.print('✅ processing done !') + sys.exit(0) + + + +@click.command(no_args_is_help=True) +@click.pass_context +@click.option('--format', default='upgrade', help='CVP Image type', type=click.Choice(CVP_IMAGE_TYPE), required=True) +@click.option('--version', default=None, help='CVP version', type=str, required=True) +@click.option('--output', default=str(os.path.relpath(os.getcwd(), start=os.curdir)), help='Path to save image', type=click.Path(),show_default=True) +@click.option('--log-level', '--log', help='Logging level of the command', default=None, type=click.Choice(['debug', 'info', 'warning', 'error', 'critical'], case_sensitive=False)) +def cvp(ctx: click.Context, version: str, format: str, output: str, log_level: str) -> int: + """Download CVP image from Arista website""" + console = Console() + # Get from Context + token = ctx.obj['token'] + if token is None or token == '': + console.print('❗ Token is unset ! Please configure ARISTA_TOKEN or use --token option', style="bold red") + sys.exit(1) + + logger.remove() + if log_level is not None: + logger.add("eos-downloader.log", rotation="10 MB", level=log_level.upper()) + + console.print("🪐 [bold blue]eos-downloader[/bold blue] is starting...", ) + console.print(f' - Image Type: {format}') + console.print(f' - Version: {version}') + + my_download = eos_downloader.eos.EOSDownloader( + image=format, + software='CloudVision', + version=version, + token=token, + hash_method='md5sum') + + my_download.authenticate() + + my_download.download_local(file_path=output, checksum=False) + console.print('✅ processing done !') + sys.exit(0) diff --git a/eos_downloader/cli/info/__init__.py b/eos_downloader/cli/info/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/eos_downloader/cli/info/__init__.py diff --git a/eos_downloader/cli/info/commands.py b/eos_downloader/cli/info/commands.py new file mode 100644 index 0000000..b51003b --- /dev/null +++ b/eos_downloader/cli/info/commands.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python +# coding: utf-8 -*- +# pylint: disable=no-value-for-parameter +# pylint: disable=too-many-arguments +# pylint: disable=line-too-long +# pylint: disable=redefined-builtin +# flake8: noqa E501 + +""" +Commands for ARDL CLI to list data. +""" + +import sys +from typing import Union + +import click +from loguru import logger +from rich.console import Console +from rich.pretty import pprint + +import eos_downloader.eos +from eos_downloader.models.version import BASE_VERSION_STR, RTYPE_FEATURE, RTYPES + + +@click.command(no_args_is_help=True) +@click.pass_context +@click.option('--latest', '-l', is_flag=True, type=click.BOOL, default=False, help='Get latest version in given branch. If --branch is not use, get the latest branch with specific release type') +@click.option('--release-type', '-rtype', type=click.Choice(RTYPES, case_sensitive=False), default=RTYPE_FEATURE, help='EOS release type to search') +@click.option('--branch', '-b', type=click.STRING, default=None, help='EOS Branch to list releases') +@click.option('--verbose', '-v', is_flag=True, type=click.BOOL, default=False, help='Human readable output. Default is none to use output in script)') +@click.option('--log-level', '--log', help='Logging level of the command', default='warning', type=click.Choice(['debug', 'info', 'warning', 'error', 'critical'], case_sensitive=False)) +def eos_versions(ctx: click.Context, log_level: str, branch: Union[str,None] = None, release_type: str = RTYPE_FEATURE, latest: bool = False, verbose: bool = False) -> None: + # pylint: disable = too-many-branches + """ + List Available EOS version on Arista.com website. + + Comes with some filters to get latest release (F or M) as well as branch filtering + + - To get latest M release available (without any branch): ardl info eos-versions --latest -rtype m + + - To get latest F release available: ardl info eos-versions --latest -rtype F + """ + console = Console() + # Get from Context + token = ctx.obj['token'] + + logger.remove() + if log_level is not None: + logger.add("eos-downloader.log", rotation="10 MB", level=log_level.upper()) + + my_download = eos_downloader.eos.EOSDownloader( + image='unset', + software='EOS', + version='unset', + token=token, + hash_method='sha512sum') + + auth = my_download.authenticate() + if verbose and auth: + console.print('✅ Authenticated on arista.com') + + if release_type is not None: + release_type = release_type.upper() + + if latest: + if branch is None: + branch = str(my_download.latest_branch(rtype=release_type).branch) + latest_version = my_download.latest_eos(branch, rtype=release_type) + if str(latest_version) == BASE_VERSION_STR: + console.print(f'[red]Error[/red], cannot find any version in {branch} for {release_type} release type') + sys.exit(1) + if verbose: + console.print(f'Branch {branch} has been selected with release type {release_type}') + if branch is not None: + console.print(f'Latest release for {branch}: {latest_version}') + else: + console.print(f'Latest EOS release: {latest_version}') + else: + console.print(f'{ latest_version }') + else: + versions = my_download.get_eos_versions(branch=branch, rtype=release_type) + if verbose: + console.print(f'List of available versions for {branch if branch is not None else "all branches"}') + for version in versions: + console.print(f' → {str(version)}') + else: + pprint([str(version) for version in versions]) diff --git a/eos_downloader/cvp.py b/eos_downloader/cvp.py new file mode 100644 index 0000000..6f14eb0 --- /dev/null +++ b/eos_downloader/cvp.py @@ -0,0 +1,276 @@ +#!/usr/bin/python +# coding: utf-8 -*- + +""" +CVP Uploader content +""" + +import os +from typing import List, Optional, Any +from dataclasses import dataclass +from loguru import logger +from cvprac.cvp_client import CvpClient +from cvprac.cvp_client_errors import CvpLoginError + +# from eos_downloader.tools import exc_to_str + +# logger = logging.getLogger(__name__) + + +@dataclass +class CvpAuthenticationItem: + """ + Data structure to represent Cloudvision Authentication + """ + server: str + port: int = 443 + token: Optional[str] = None + timeout: int = 1200 + validate_cert: bool = False + + +class Filer(): + # pylint: disable=too-few-public-methods + """ + Filer Helper for file management + """ + def __init__(self, path: str) -> None: + self.file_exist = False + self.filename = '' + self.absolute_path = '' + self.relative_path = path + if os.path.exists(path): + self.file_exist = True + self.filename = os.path.basename(path) + self.absolute_path = os.path.realpath(path) + + def __repr__(self) -> str: + return self.absolute_path if self.file_exist else '' + + +class CvFeatureManager(): + """ + CvFeatureManager Object to interect with Cloudvision + """ + def __init__(self, authentication: CvpAuthenticationItem) -> None: + """ + __init__ Class Creator + + Parameters + ---------- + authentication : CvpAuthenticationItem + Authentication information to use to connect to Cloudvision + """ + self._authentication = authentication + # self._cv_instance = CvpClient() + self._cv_instance = self._connect(authentication=authentication) + self._cv_images = self.__get_images() + # self._cv_bundles = self.__get_bundles() + + def _connect(self, authentication: CvpAuthenticationItem) -> CvpClient: + """ + _connect Connection management + + Parameters + ---------- + authentication : CvpAuthenticationItem + Authentication information to use to connect to Cloudvision + + Returns + ------- + CvpClient + cvprac session to cloudvision + """ + client = CvpClient() + if authentication.token is not None: + try: + client.connect( + nodes=[authentication.server], + username='', + password='', + api_token=authentication.token, + is_cvaas=True, + port=authentication.port, + cert=authentication.validate_cert, + request_timeout=authentication.timeout + ) + except CvpLoginError as error_data: + logger.error(f'Cannot connect to Cloudvision server {authentication.server}') + logger.debug(f'Error message: {error_data}') + logger.info('connected to Cloudvision server') + logger.debug(f'Connection info: {authentication}') + return client + + def __get_images(self) -> List[Any]: + """ + __get_images Collect information about images on Cloudvision + + Returns + ------- + dict + Fact returned by Cloudvision + """ + images = [] + logger.debug(' -> Collecting images') + images = self._cv_instance.api.get_images()['data'] + return images if self.__check_api_result(images) else [] + + # def __get_bundles(self): + # """ + # __get_bundles [Not In use] Collect information about bundles on Cloudvision + + # Returns + # ------- + # dict + # Fact returned by Cloudvision + # """ + # bundles = [] + # logger.debug(' -> Collecting images bundles') + # bundles = self._cv_instance.api.get_image_bundles()['data'] + # # bundles = self._cv_instance.post(url='/cvpservice/image/getImageBundles.do?queryparam=&startIndex=0&endIndex=0')['data'] + # return bundles if self.__check_api_result(bundles) else None + + def __check_api_result(self, arg0: Any) -> bool: + """ + __check_api_result Check API calls return content + + Parameters + ---------- + arg0 : any + Element to test + + Returns + ------- + bool + True if data are correct False in other cases + """ + logger.debug(arg0) + return len(arg0) > 0 + + def _does_image_exist(self, image_name: str) -> bool: + """ + _does_image_exist Check if an image is referenced in Cloudvision facts + + Parameters + ---------- + image_name : str + Name of the image to search for + + Returns + ------- + bool + True if present + """ + return any(image_name == image['name'] for image in self._cv_images) if isinstance(self._cv_images, list) else False + + def _does_bundle_exist(self, bundle_name: str) -> bool: + # pylint: disable=unused-argument + """ + _does_bundle_exist Check if an image is referenced in Cloudvision facts + + Returns + ------- + bool + True if present + """ + # return any(bundle_name == bundle['name'] for bundle in self._cv_bundles) + return False + + def upload_image(self, image_path: str) -> bool: + """ + upload_image Upload an image to Cloudvision server + + Parameters + ---------- + image_path : str + Path to the local file to upload + + Returns + ------- + bool + True if succeeds + """ + image_item = Filer(path=image_path) + if image_item.file_exist is False: + logger.error(f'File not found: {image_item.relative_path}') + return False + logger.info(f'File path for image: {image_item}') + if self._does_image_exist(image_name=image_item.filename): + logger.error("Image found in Cloudvision , Please delete it before running this script") + return False + try: + upload_result = self._cv_instance.api.add_image(filepath=image_item.absolute_path) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error('An error occurred during upload, check CV connection') + logger.error(f'Exception message is: {e}') + return False + logger.debug(f'Upload Result is : {upload_result}') + return True + + def build_image_list(self, image_list: List[str]) -> List[Any]: + """ + Builds a list of the image data structures, for a given list of image names. + Parameters + ---------- + image_list : list + List of software image names + Returns + ------- + List: + Returns a list of images, with complete data or None in the event of failure + """ + internal_image_list = [] + image_data = None + success = True + + for entry in image_list: + for image in self._cv_images: + if image["imageFileName"] == entry: + image_data = image + + if image_data is not None: + internal_image_list.append(image_data) + image_data = None + else: + success = False + + return internal_image_list if success else [] + + def create_bundle(self, name: str, images_name: List[str]) -> bool: + """ + create_bundle Create a bundle with a list of images. + + Parameters + ---------- + name : str + Name of the bundle + images_name : List[str] + List of images available on Cloudvision + + Returns + ------- + bool + True if succeeds + """ + logger.debug(f'Init creation of an image bundle {name} with following images {images_name}') + all_images_present: List[bool] = [] + self._cv_images = self.__get_images() + all_images_present.extend( + self._does_image_exist(image_name=image_name) + for image_name in images_name + ) + # Bundle Create + if self._does_bundle_exist(bundle_name=name) is False: + logger.debug(f'Creating image bundle {name} with following images {images_name}') + images_data = self.build_image_list(image_list=images_name) + if images_data is not None: + logger.debug('Images information: {images_data}') + try: + data = self._cv_instance.api.save_image_bundle(name=name, images=images_data) + except Exception as e: # pylint: disable=broad-exception-caught + logger.critical(f'{e}') + else: + logger.debug(data) + return True + logger.critical('No data found for images') + return False diff --git a/eos_downloader/data.py b/eos_downloader/data.py new file mode 100644 index 0000000..74f2f8e --- /dev/null +++ b/eos_downloader/data.py @@ -0,0 +1,93 @@ +#!/usr/bin/python +# coding: utf-8 -*- + +""" +EOS Downloader Information to use in +eos_downloader.object_downloader.ObjectDownloader._build_filename. + +Data are built from content of Arista XML file +""" + + +# [platform][image][version] +DATA_MAPPING = { + "CloudVision": { + "ova": { + "extension": ".ova", + "prepend": "cvp", + "folder_level": 0 + }, + "rpm": { + "extension": "", + "prepend": "cvp-rpm-installer", + "folder_level": 0 + }, + "kvm": { + "extension": "-kvm.tgz", + "prepend": "cvp", + "folder_level": 0 + }, + "upgrade": { + "extension": ".tgz", + "prepend": "cvp-upgrade", + "folder_level": 0 + }, + }, + "EOS": { + "64": { + "extension": ".swi", + "prepend": "EOS64", + "folder_level": 0 + }, + "INT": { + "extension": "-INT.swi", + "prepend": "EOS", + "folder_level": 1 + }, + "2GB-INT": { + "extension": "-INT.swi", + "prepend": "EOS-2GB", + "folder_level": 1 + }, + "cEOS": { + "extension": ".tar.xz", + "prepend": "cEOS-lab", + "folder_level": 0 + }, + "cEOS64": { + "extension": ".tar.xz", + "prepend": "cEOS64-lab", + "folder_level": 0 + }, + "vEOS": { + "extension": ".vmdk", + "prepend": "vEOS", + "folder_level": 0 + }, + "vEOS-lab": { + "extension": ".vmdk", + "prepend": "vEOS-lab", + "folder_level": 0 + }, + "EOS-2GB": { + "extension": ".swi", + "prepend": "EOS-2GB", + "folder_level": 0 + }, + "RN": { + "extension": "-", + "prepend": "RN", + "folder_level": 0 + }, + "SOURCE": { + "extension": "-source.tar", + "prepend": "EOS", + "folder_level": 0 + }, + "default": { + "extension": ".swi", + "prepend": "EOS", + "folder_level": 0 + } + } +} diff --git a/eos_downloader/download.py b/eos_downloader/download.py new file mode 100644 index 0000000..2297b04 --- /dev/null +++ b/eos_downloader/download.py @@ -0,0 +1,77 @@ +# flake8: noqa: F811 +# pylint: disable=unused-argument +# pylint: disable=too-few-public-methods + +"""download module""" + +import os.path +import signal +from concurrent.futures import ThreadPoolExecutor +from threading import Event +from typing import Iterable, Any + +import requests +import rich +from rich import console +from rich.progress import (BarColumn, DownloadColumn, Progress, TaskID, + TextColumn, TimeElapsedColumn, TransferSpeedColumn) + +console = rich.get_console() +done_event = Event() + + +def handle_sigint(signum: Any, frame: Any) -> None: + """Progress bar handler""" + done_event.set() + + +signal.signal(signal.SIGINT, handle_sigint) + + +class DownloadProgressBar(): + """ + Object to manage Download process with Progress Bar from Rich + """ + + def __init__(self) -> None: + """ + Class Constructor + """ + self.progress = Progress( + TextColumn("💾 Downloading [bold blue]{task.fields[filename]}", justify="right"), + BarColumn(bar_width=None), + "[progress.percentage]{task.percentage:>3.1f}%", + "•", + TransferSpeedColumn(), + "•", + DownloadColumn(), + "•", + TimeElapsedColumn(), + "•", + console=console + ) + + def _copy_url(self, task_id: TaskID, url: str, path: str, block_size: int = 1024) -> bool: + """Copy data from a url to a local file.""" + response = requests.get(url, stream=True, timeout=5) + # This will break if the response doesn't contain content length + self.progress.update(task_id, total=int(response.headers['Content-Length'])) + with open(path, "wb") as dest_file: + self.progress.start_task(task_id) + for data in response.iter_content(chunk_size=block_size): + dest_file.write(data) + self.progress.update(task_id, advance=len(data)) + if done_event.is_set(): + return True + # console.print(f"Downloaded {path}") + return False + + def download(self, urls: Iterable[str], dest_dir: str) -> None: + """Download multuple files to the given directory.""" + with self.progress: + with ThreadPoolExecutor(max_workers=4) as pool: + for url in urls: + filename = url.split("/")[-1].split('?')[0] + dest_path = os.path.join(dest_dir, filename) + task_id = self.progress.add_task("download", filename=filename, start=False) + pool.submit(self._copy_url, task_id, url, dest_path) diff --git a/eos_downloader/eos.py b/eos_downloader/eos.py new file mode 100644 index 0000000..e5f3670 --- /dev/null +++ b/eos_downloader/eos.py @@ -0,0 +1,177 @@ +#!/usr/bin/python +# coding: utf-8 -*- +# flake8: noqa: F811 + +""" +Specific EOS inheritance from object_download +""" + +import os +import xml.etree.ElementTree as ET +from typing import List, Union + +import rich +from loguru import logger +from rich import console + +from eos_downloader.models.version import BASE_BRANCH_STR, BASE_VERSION_STR, REGEX_EOS_VERSION, RTYPE_FEATURE, EosVersion +from eos_downloader.object_downloader import ObjectDownloader + +# logger = logging.getLogger(__name__) + +console = rich.get_console() + +class EOSDownloader(ObjectDownloader): + """ + EOSDownloader Object to download EOS images from Arista.com website + + Supercharge ObjectDownloader to support EOS specific actions + + Parameters + ---------- + ObjectDownloader : ObjectDownloader + Base object + """ + + eos_versions: Union[List[EosVersion], None] = None + + @staticmethod + def _disable_ztp(file_path: str) -> None: + """ + _disable_ztp Method to disable ZTP in EOS image + + Create a file in the EOS image to disable ZTP process during initial boot + + Parameters + ---------- + file_path : str + Path where EOS image is located + """ + logger.info('Mounting volume to disable ZTP') + console.print('🚀 Mounting volume to disable ZTP') + raw_folder = os.path.join(file_path, "raw") + os.system(f"rm -rf {raw_folder}") + os.system(f"mkdir -p {raw_folder}") + os.system( + f'guestmount -a {os.path.join(file_path, "hda.qcow2")} -m /dev/sda2 {os.path.join(file_path, "raw")}') + ztp_file = os.path.join(file_path, 'raw/zerotouch-config') + with open(ztp_file, 'w', encoding='ascii') as zfile: + zfile.write('DISABLE=True') + logger.info(f'Unmounting volume in {file_path}') + os.system(f"guestunmount {os.path.join(file_path, 'raw')}") + os.system(f"rm -rf {os.path.join(file_path, 'raw')}") + logger.info(f"Volume has been successfully unmounted at {file_path}") + + def _parse_xml_for_version(self,root_xml: ET.ElementTree, xpath: str = './/dir[@label="Active Releases"]/dir/dir/[@label]') -> List[EosVersion]: + """ + Extract list of available EOS versions from Arista.com website + + Create a list of EosVersion object for all versions available on Arista.com + + Args: + root_xml (ET.ElementTree): XML file with all versions available + xpath (str, optional): XPATH to use to extract EOS version. Defaults to './/dir[@label="Active Releases"]/dir/dir/[@label]'. + + Returns: + List[EosVersion]: List of EosVersion representing all available EOS versions + """ + # XPATH: .//dir[@label="Active Releases"]/dir/dir/[@label] + if self.eos_versions is None: + logger.debug(f'Using xpath {xpath}') + eos_versions = [] + for node in root_xml.findall(xpath): + if 'label' in node.attrib and node.get('label') is not None: + label = node.get('label') + if label is not None and REGEX_EOS_VERSION.match(label): + eos_version = EosVersion.from_str(label) + eos_versions.append(eos_version) + logger.debug(f"Found {label} - {eos_version}") + logger.debug(f'List of versions found on arista.com is: {eos_versions}') + self.eos_versions = eos_versions + else: + logger.debug('receiving instruction to download versions, but already available') + return self.eos_versions + + def _get_branches(self, with_rtype: str = RTYPE_FEATURE) -> List[str]: + """ + Extract all EOS branches available from arista.com + + Call self._parse_xml_for_version and then build list of available branches + + Args: + rtype (str, optional): Release type to find. Can be M or F, default to F + + Returns: + List[str]: A lsit of string that represent all availables EOS branches + """ + root = self._get_folder_tree() + versions = self._parse_xml_for_version(root_xml=root) + return list({version.branch for version in versions if version.rtype == with_rtype}) + + def latest_branch(self, rtype: str = RTYPE_FEATURE) -> EosVersion: + """ + Get latest branch from semver standpoint + + Args: + rtype (str, optional): Release type to find. Can be M or F, default to F + + Returns: + EosVersion: Latest Branch object + """ + selected_branch = EosVersion.from_str(BASE_BRANCH_STR) + for branch in self._get_branches(with_rtype=rtype): + branch = EosVersion.from_str(branch) + if branch > selected_branch: + selected_branch = branch + return selected_branch + + def get_eos_versions(self, branch: Union[str,None] = None, rtype: Union[str,None] = None) -> List[EosVersion]: + """ + Get a list of available EOS version available on arista.com + + If a branch is provided, only version in this branch are listed. + Otherwise, all versions are provided. + + Args: + branch (str, optional): An EOS branch to filter. Defaults to None. + rtype (str, optional): Release type to find. Can be M or F, default to F + + Returns: + List[EosVersion]: A list of versions available + """ + root = self._get_folder_tree() + result = [] + for version in self._parse_xml_for_version(root_xml=root): + if branch is None and (version.rtype == rtype or rtype is None): + result.append(version) + elif branch is not None and version.is_in_branch(branch) and version.rtype == rtype: + result.append(version) + return result + + def latest_eos(self, branch: Union[str,None] = None, rtype: str = RTYPE_FEATURE) -> EosVersion: + """ + Get latest version of EOS + + If a branch is provided, only version in this branch are listed. + Otherwise, all versions are provided. + You can select what type of version to consider: M or F + + Args: + branch (str, optional): An EOS branch to filter. Defaults to None. + rtype (str, optional): An EOS version type to filter, Can be M or F. Defaults to None. + + Returns: + EosVersion: latest version selected + """ + selected_version = EosVersion.from_str(BASE_VERSION_STR) + if branch is None: + latest_branch = self.latest_branch(rtype=rtype) + else: + latest_branch = EosVersion.from_str(branch) + for version in self.get_eos_versions(branch=str(latest_branch.branch), rtype=rtype): + if version > selected_version: + if rtype is not None and version.rtype == rtype: + selected_version = version + if rtype is None: + selected_version = version + return selected_version diff --git a/eos_downloader/models/__init__.py b/eos_downloader/models/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/eos_downloader/models/__init__.py diff --git a/eos_downloader/models/version.py b/eos_downloader/models/version.py new file mode 100644 index 0000000..14448c1 --- /dev/null +++ b/eos_downloader/models/version.py @@ -0,0 +1,272 @@ +#!/usr/bin/python +# coding: utf-8 -*- + +"""Module for EOS version management""" + +from __future__ import annotations + +import re +import typing +from typing import Any, Optional + +from loguru import logger +from pydantic import BaseModel + +from eos_downloader.tools import exc_to_str + +# logger = logging.getLogger(__name__) + +BASE_VERSION_STR = '4.0.0F' +BASE_BRANCH_STR = '4.0' + +RTYPE_FEATURE = 'F' +RTYPE_MAINTENANCE = 'M' +RTYPES = [RTYPE_FEATURE, RTYPE_MAINTENANCE] + +# Regular Expression to capture multiple EOS version format +# 4.24 +# 4.23.0 +# 4.21.1M +# 4.28.10.F +# 4.28.6.1M +REGEX_EOS_VERSION = re.compile(r"^.*(?P<major>4)\.(?P<minor>\d{1,2})\.(?P<patch>\d{1,2})(?P<other>\.\d*)*(?P<rtype>[M,F])*$") +REGEX_EOS_BRANCH = re.compile(r"^.*(?P<major>4)\.(?P<minor>\d{1,2})(\.?P<patch>\d)*(\.\d)*(?P<rtype>[M,F])*$") + + +class EosVersion(BaseModel): + """ + EosVersion object to play with version management in code + + Since EOS is not using strictly semver approach, this class mimic some functions from semver lib for Arista EOS versions + It is based on Pydantic and provides helpers for comparison: + + Examples: + >>> eos_version_str = '4.23.2F' + >>> eos_version = EosVersion.from_str(eos_version_str) + >>> print(f'str representation is: {str(eos_version)}') + str representation is: 4.23.2F + + >>> other_version = EosVersion.from_str(other_version_str) + >>> print(f'eos_version < other_version: {eos_version < other_version}') + eos_version < other_version: True + + >>> print(f'Is eos_version match("<=4.23.3M"): {eos_version.match("<=4.23.3M")}') + Is eos_version match("<=4.23.3M"): True + + >>> print(f'Is eos_version in branch 4.23: {eos_version.is_in_branch("4.23.0")}') + Is eos_version in branch 4.23: True + + Args: + BaseModel (Pydantic): Pydantic Base Model + """ + major: int = 4 + minor: int = 0 + patch: int = 0 + rtype: Optional[str] = 'F' + other: Any + + @classmethod + def from_str(cls, eos_version: str) -> EosVersion: + """ + Class constructor from a string representing EOS version + + Use regular expresion to extract fields from string. + It supports following formats: + - 4.24 + - 4.23.0 + - 4.21.1M + - 4.28.10.F + - 4.28.6.1M + + Args: + eos_version (str): EOS version in str format + + Returns: + EosVersion object + """ + logger.debug(f'receiving version: {eos_version}') + if REGEX_EOS_VERSION.match(eos_version): + matches = REGEX_EOS_VERSION.match(eos_version) + # assert matches is not None + assert matches is not None + return cls(**matches.groupdict()) + if REGEX_EOS_BRANCH.match(eos_version): + matches = REGEX_EOS_BRANCH.match(eos_version) + # assert matches is not None + assert matches is not None + return cls(**matches.groupdict()) + logger.error(f'Error occured with {eos_version}') + return EosVersion() + + @property + def branch(self) -> str: + """ + Extract branch of version + + Returns: + str: branch from version + """ + return f'{self.major}.{self.minor}' + + def __str__(self) -> str: + """ + Standard str representation + + Return string for EOS version like 4.23.3M + + Returns: + str: A standard EOS version string representing <MAJOR>.<MINOR>.<PATCH><RTYPE> + """ + if self.other is None: + return f'{self.major}.{self.minor}.{self.patch}{self.rtype}' + return f'{self.major}.{self.minor}.{self.patch}{self.other}{self.rtype}' + + def _compare(self, other: EosVersion) -> float: + """ + An internal comparison function to compare 2 EosVersion objects + + Do a deep comparison from Major to Release Type + The return value is + - negative if ver1 < ver2, + - zero if ver1 == ver2 + - strictly positive if ver1 > ver2 + + Args: + other (EosVersion): An EosVersion to compare with this object + + Raises: + ValueError: Raise ValueError if input is incorrect type + + Returns: + float: -1 if ver1 < ver2, 0 if ver1 == ver2, 1 if ver1 > ver2 + """ + if not isinstance(other, EosVersion): + raise ValueError(f'could not compare {other} as it is not an EosVersion object') + comparison_flag: float = 0 + logger.warning(f'current version {self.__str__()} - other {str(other)}') # pylint: disable = unnecessary-dunder-call + for key, _ in self.dict().items(): + if comparison_flag == 0 and self.dict()[key] is None or other.dict()[key] is None: + logger.debug(f'{key}: local None - remote None') + logger.debug(f'{key}: local {self.dict()} - remote {other.dict()}') + return comparison_flag + logger.debug(f'{key}: local {self.dict()[key]} - remote {other.dict()[key]}') + if comparison_flag == 0 and self.dict()[key] < other.dict()[key]: + comparison_flag = -1 + if comparison_flag == 0 and self.dict()[key] > other.dict()[key]: + comparison_flag = 1 + if comparison_flag != 0: + logger.info(f'comparison result is {comparison_flag}') + return comparison_flag + logger.info(f'comparison result is {comparison_flag}') + return comparison_flag + + @typing.no_type_check + def __eq__(self, other): + """ Implement __eq__ function (==) """ + return self._compare(other) == 0 + + @typing.no_type_check + def __ne__(self, other): + # type: ignore + """ Implement __nw__ function (!=) """ + return self._compare(other) != 0 + + @typing.no_type_check + def __lt__(self, other): + # type: ignore + """ Implement __lt__ function (<) """ + return self._compare(other) < 0 + + @typing.no_type_check + def __le__(self, other): + # type: ignore + """ Implement __le__ function (<=) """ + return self._compare(other) <= 0 + + @typing.no_type_check + def __gt__(self, other): + # type: ignore + """ Implement __gt__ function (>) """ + return self._compare(other) > 0 + + @typing.no_type_check + def __ge__(self, other): + # type: ignore + """ Implement __ge__ function (>=) """ + return self._compare(other) >= 0 + + def match(self, match_expr: str) -> bool: + """ + Compare self to match a match expression. + + Example: + >>> eos_version.match("<=4.23.3M") + True + >>> eos_version.match("==4.23.3M") + False + + Args: + match_expr (str): optional operator and version; valid operators are + ``<`` smaller than + ``>`` greater than + ``>=`` greator or equal than + ``<=`` smaller or equal than + ``==`` equal + ``!=`` not equal + + Raises: + ValueError: If input has no match_expr nor match_ver + + Returns: + bool: True if the expression matches the version, otherwise False + """ + prefix = match_expr[:2] + if prefix in (">=", "<=", "==", "!="): + match_version = match_expr[2:] + elif prefix and prefix[0] in (">", "<"): + prefix = prefix[0] + match_version = match_expr[1:] + elif match_expr and match_expr[0] in "0123456789": + prefix = "==" + match_version = match_expr + else: + raise ValueError( + "match_expr parameter should be in format <op><ver>, " + "where <op> is one of " + "['<', '>', '==', '<=', '>=', '!=']. " + f"You provided: {match_expr}" + ) + logger.debug(f'work on comparison {prefix} with base release {match_version}') + possibilities_dict = { + ">": (1,), + "<": (-1,), + "==": (0,), + "!=": (-1, 1), + ">=": (0, 1), + "<=": (-1, 0), + } + possibilities = possibilities_dict[prefix] + cmp_res = self._compare(EosVersion.from_str(match_version)) + + return cmp_res in possibilities + + def is_in_branch(self, branch_str: str) -> bool: + """ + Check if current version is part of a branch version + + Comparison is done across MAJOR and MINOR + + Args: + branch_str (str): a string for EOS branch. It supports following formats 4.23 or 4.23.0 + + Returns: + bool: True if current version is in provided branch, otherwise False + """ + try: + logger.debug(f'reading branch str:{branch_str}') + branch = EosVersion.from_str(branch_str) + except Exception as error: # pylint: disable = broad-exception-caught + logger.error(exc_to_str(error)) + else: + return self.major == branch.major and self.minor == branch.minor + return False diff --git a/eos_downloader/object_downloader.py b/eos_downloader/object_downloader.py new file mode 100644 index 0000000..0420acb --- /dev/null +++ b/eos_downloader/object_downloader.py @@ -0,0 +1,513 @@ +#!/usr/bin/python +# coding: utf-8 -*- +# flake8: noqa: F811 +# pylint: disable=too-many-instance-attributes +# pylint: disable=too-many-arguments + +""" +eos_downloader class definition +""" + +from __future__ import (absolute_import, division, print_function, + unicode_literals, annotations) + +import base64 +import glob +import hashlib +import json +import os +import sys +import xml.etree.ElementTree as ET +from typing import Union + +import requests +import rich +from loguru import logger +from rich import console +from tqdm import tqdm + +from eos_downloader import (ARISTA_DOWNLOAD_URL, ARISTA_GET_SESSION, + ARISTA_SOFTWARE_FOLDER_TREE, EVE_QEMU_FOLDER_PATH, + MSG_INVALID_DATA, MSG_TOKEN_EXPIRED) +from eos_downloader.data import DATA_MAPPING +from eos_downloader.download import DownloadProgressBar + +# logger = logging.getLogger(__name__) + +console = rich.get_console() + + +class ObjectDownloader(): + """ + ObjectDownloader Generic Object to download from Arista.com + """ + def __init__(self, image: str, version: str, token: str, software: str = 'EOS', hash_method: str = 'md5sum'): + """ + __init__ Class constructor + + generic class constructor + + Parameters + ---------- + image : str + Type of image to download + version : str + Version of the package to download + token : str + Arista API token + software : str, optional + Package name to download (vEOS-lab, cEOS, EOS, ...), by default 'EOS' + hash_method : str, optional + Hash protocol to use to check download, by default 'md5sum' + """ + self.software = software + self.image = image + self._version = version + self.token = token + self.folder_level = 0 + self.session_id = None + self.filename = self._build_filename() + self.hash_method = hash_method + self.timeout = 5 + # Logging + logger.debug(f'Filename built by _build_filename is {self.filename}') + + def __str__(self) -> str: + return f'{self.software} - {self.image} - {self.version}' + + # def __repr__(self): + # return str(self.__dict__) + + @property + def version(self) -> str: + """Get version.""" + return self._version + + @version.setter + def version(self, value: str) -> None: + """Set version.""" + self._version = value + self.filename = self._build_filename() + + # ------------------------------------------------------------------------ # + # Internal METHODS + # ------------------------------------------------------------------------ # + + def _build_filename(self) -> str: + """ + _build_filename Helper to build filename to search on arista.com + + Returns + ------- + str: + Filename to search for on Arista.com + """ + logger.info('start build') + if self.software in DATA_MAPPING: + logger.info(f'software in data mapping: {self.software}') + if self.image in DATA_MAPPING[self.software]: + logger.info(f'image in data mapping: {self.image}') + return f"{DATA_MAPPING[self.software][self.image]['prepend']}-{self.version}{DATA_MAPPING[self.software][self.image]['extension']}" + return f"{DATA_MAPPING[self.software]['default']['prepend']}-{self.version}{DATA_MAPPING[self.software]['default']['extension']}" + raise ValueError(f'Incorrect value for software {self.software}') + + def _parse_xml_for_path(self, root_xml: ET.ElementTree, xpath: str, search_file: str) -> str: + # sourcery skip: remove-unnecessary-cast + """ + _parse_xml Read and extract data from XML using XPATH + + Get all interested nodes using XPATH and then get node that match search_file + + Parameters + ---------- + root_xml : ET.ElementTree + XML document + xpath : str + XPATH expression to filter XML + search_file : str + Filename to search for + + Returns + ------- + str + File Path on Arista server side + """ + logger.debug(f'Using xpath {xpath}') + logger.debug(f'Search for file {search_file}') + console.print(f'🔎 Searching file {search_file}') + for node in root_xml.findall(xpath): + # logger.debug('Found {}', node.text) + if str(node.text).lower() == search_file.lower(): + path = node.get('path') + console.print(f' -> Found file at {path}') + logger.info(f'Found {node.text} at {node.get("path")}') + return str(node.get('path')) if node.get('path') is not None else '' + logger.error(f'Requested file ({self.filename}) not found !') + return '' + + def _get_hash(self, file_path: str) -> str: + """ + _get_hash Download HASH file from Arista server + + Parameters + ---------- + file_path : str + Path of the HASH file + + Returns + ------- + str + Hash string read from HASH file downloaded from Arista.com + """ + remote_hash_file = self._get_remote_hashpath(hash_method=self.hash_method) + hash_url = self._get_url(remote_file_path=remote_hash_file) + # hash_downloaded = self._download_file_raw(url=hash_url, file_path=file_path + "/" + os.path.basename(remote_hash_file)) + dl_rich_progress_bar = DownloadProgressBar() + dl_rich_progress_bar.download(urls=[hash_url], dest_dir=file_path) + hash_downloaded = f"{file_path}/{os.path.basename(remote_hash_file)}" + hash_content = 'unset' + with open(hash_downloaded, 'r', encoding='utf-8') as f: + hash_content = f.read() + return hash_content.split(' ')[0] + + @staticmethod + def _compute_hash_md5sum(file: str, hash_expected: str) -> bool: + """ + _compute_hash_md5sum Compare MD5 sum + + Do comparison between local md5 of the file and value provided by arista.com + + Parameters + ---------- + file : str + Local file to use for MD5 sum + hash_expected : str + MD5 from arista.com + + Returns + ------- + bool + True if both are equal, False if not + """ + hash_md5 = hashlib.md5() + with open(file, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + if hash_md5.hexdigest() == hash_expected: + return True + logger.warning(f'Downloaded file is corrupt: local md5 ({hash_md5.hexdigest()}) is different to md5 from arista ({hash_expected})') + return False + + @staticmethod + def _compute_hash_sh512sum(file: str, hash_expected: str) -> bool: + """ + _compute_hash_sh512sum Compare SHA512 sum + + Do comparison between local sha512 of the file and value provided by arista.com + + Parameters + ---------- + file : str + Local file to use for MD5 sum + hash_expected : str + SHA512 from arista.com + + Returns + ------- + bool + True if both are equal, False if not + """ + hash_sha512 = hashlib.sha512() + with open(file, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_sha512.update(chunk) + if hash_sha512.hexdigest() == hash_expected: + return True + logger.warning(f'Downloaded file is corrupt: local sha512 ({hash_sha512.hexdigest()}) is different to sha512 from arista ({hash_expected})') + return False + + def _get_folder_tree(self) -> ET.ElementTree: + """ + _get_folder_tree Download XML tree from Arista server + + Returns + ------- + ET.ElementTree + XML document + """ + if self.session_id is None: + self.authenticate() + jsonpost = {'sessionCode': self.session_id} + result = requests.post(ARISTA_SOFTWARE_FOLDER_TREE, data=json.dumps(jsonpost), timeout=self.timeout) + try: + folder_tree = result.json()["data"]["xml"] + return ET.ElementTree(ET.fromstring(folder_tree)) + except KeyError as error: + logger.error(MSG_INVALID_DATA) + logger.error(f'Server returned: {error}') + console.print(f'❌ {MSG_INVALID_DATA}', style="bold red") + sys.exit(1) + + def _get_remote_filepath(self) -> str: + """ + _get_remote_filepath Helper to get path of the file to download + + Set XPATH and return result of _parse_xml for the file to download + + Returns + ------- + str + Remote path of the file to download + """ + root = self._get_folder_tree() + logger.debug("GET XML content from ARISTA.com") + xpath = f'.//dir[@label="{self.software}"]//file' + return self._parse_xml_for_path(root_xml=root, xpath=xpath, search_file=self.filename) + + def _get_remote_hashpath(self, hash_method: str = 'md5sum') -> str: + """ + _get_remote_hashpath Helper to get path of the hash's file to download + + Set XPATH and return result of _parse_xml for the file to download + + Returns + ------- + str + Remote path of the hash's file to download + """ + root = self._get_folder_tree() + logger.debug("GET XML content from ARISTA.com") + xpath = f'.//dir[@label="{self.software}"]//file' + return self._parse_xml_for_path( + root_xml=root, + xpath=xpath, + search_file=f'{self.filename}.{hash_method}', + ) + + def _get_url(self, remote_file_path: str) -> str: + """ + _get_url Get URL to use for downloading file from Arista server + + Send remote_file_path to get correct URL to use for download + + Parameters + ---------- + remote_file_path : str + Filepath from XML to use to get correct download link + + Returns + ------- + str + URL link to use for download + """ + if self.session_id is None: + self.authenticate() + jsonpost = {'sessionCode': self.session_id, 'filePath': remote_file_path} + result = requests.post(ARISTA_DOWNLOAD_URL, data=json.dumps(jsonpost), timeout=self.timeout) + if 'data' in result.json() and 'url' in result.json()['data']: + # logger.debug('URL to download file is: {}', result.json()) + return result.json()["data"]["url"] + logger.critical(f'Server returns following message: {result.json()}') + return '' + + @staticmethod + def _download_file_raw(url: str, file_path: str) -> str: + """ + _download_file Helper to download file from Arista.com + + [extended_summary] + + Parameters + ---------- + url : str + URL provided by server for remote_file_path + file_path : str + Location where to save local file + + Returns + ------- + str + File path + """ + chunkSize = 1024 + r = requests.get(url, stream=True, timeout=5) + with open(file_path, 'wb') as f: + pbar = tqdm(unit="B", total=int(r.headers['Content-Length']), unit_scale=True, unit_divisor=1024) + for chunk in r.iter_content(chunk_size=chunkSize): + if chunk: + pbar.update(len(chunk)) + f.write(chunk) + return file_path + + def _download_file(self, file_path: str, filename: str, rich_interface: bool = True) -> Union[None, str]: + remote_file_path = self._get_remote_filepath() + logger.info(f'File found on arista server: {remote_file_path}') + file_url = self._get_url(remote_file_path=remote_file_path) + if file_url is not False: + if not rich_interface: + return self._download_file_raw(url=file_url, file_path=os.path.join(file_path, filename)) + rich_downloader = DownloadProgressBar() + rich_downloader.download(urls=[file_url], dest_dir=file_path) + return os.path.join(file_path, filename) + logger.error(f'Cannot download file {file_path}') + return None + + @staticmethod + def _create_destination_folder(path: str) -> None: + # os.makedirs(path, mode, exist_ok=True) + os.system(f'mkdir -p {path}') + + @staticmethod + def _disable_ztp(file_path: str) -> None: + pass + + # ------------------------------------------------------------------------ # + # Public METHODS + # ------------------------------------------------------------------------ # + + def authenticate(self) -> bool: + """ + authenticate Authenticate user on Arista.com server + + Send API token and get a session-id from remote server. + Session-id will be used by all other functions. + + Returns + ------- + bool + True if authentication succeeds=, False in all other situations. + """ + credentials = (base64.b64encode(self.token.encode())).decode("utf-8") + session_code_url = ARISTA_GET_SESSION + jsonpost = {'accessToken': credentials} + + result = requests.post(session_code_url, data=json.dumps(jsonpost), timeout=self.timeout) + + if result.json()["status"]["message"] in[ 'Access token expired', 'Invalid access token']: + console.print(f'❌ {MSG_TOKEN_EXPIRED}', style="bold red") + logger.error(MSG_TOKEN_EXPIRED) + return False + + try: + if 'data' in result.json(): + self.session_id = result.json()["data"]["session_code"] + logger.info('Authenticated on arista.com') + return True + logger.debug(f'{result.json()}') + return False + except KeyError as error_arista: + logger.error(f'Error: {error_arista}') + sys.exit(1) + + def download_local(self, file_path: str, checksum: bool = False) -> bool: + # sourcery skip: move-assign + """ + download_local Entrypoint for local download feature + + Do local downnload feature: + - Get remote file path + - Get URL from Arista.com + - Download file + - Do HASH comparison (optional) + + Parameters + ---------- + file_path : str + Local path to save downloaded file + checksum : bool, optional + Execute checksum or not, by default False + + Returns + ------- + bool + True if everything went well, False if any problem appears + """ + file_downloaded = str(self._download_file(file_path=file_path, filename=self.filename)) + + # Check file HASH + hash_result = False + if checksum: + logger.info('🚀 Running checksum validation') + console.print('🚀 Running checksum validation') + if self.hash_method == 'md5sum': + hash_expected = self._get_hash(file_path=file_path) + hash_result = self._compute_hash_md5sum(file=file_downloaded, hash_expected=hash_expected) + elif self.hash_method == 'sha512sum': + hash_expected = self._get_hash(file_path=file_path) + hash_result = self._compute_hash_sh512sum(file=file_downloaded, hash_expected=hash_expected) + if not hash_result: + logger.error('Downloaded file is corrupted, please check your connection') + console.print('❌ Downloaded file is corrupted, please check your connection') + return False + logger.info('Downloaded file is correct.') + console.print('✅ Downloaded file is correct.') + return True + + def provision_eve(self, noztp: bool = False, checksum: bool = True) -> None: + # pylint: disable=unused-argument + """ + provision_eve Entrypoint for EVE-NG download and provisioning + + Do following actions: + - Get remote file path + - Get URL from file path + - Download file + - Convert file to qcow2 format + - Create new version to EVE-NG + - Disable ZTP (optional) + + Parameters + ---------- + noztp : bool, optional + Flag to deactivate ZTP in EOS image, by default False + checksum : bool, optional + Flag to ask for hash validation, by default True + """ + # Build image name to use in folder path + eos_image_name = self.filename.rstrip(".vmdk").lower() + if noztp: + eos_image_name = f'{eos_image_name}-noztp' + # Create full path for EVE-NG + file_path = os.path.join(EVE_QEMU_FOLDER_PATH, eos_image_name.rstrip()) + # Create folders in filesystem + self._create_destination_folder(path=file_path) + + # Download file to local destination + file_downloaded = self._download_file( + file_path=file_path, filename=self.filename) + + # Convert to QCOW2 format + file_qcow2 = os.path.join(file_path, "hda.qcow2") + logger.info('Converting VMDK to QCOW2 format') + console.print('🚀 Converting VMDK to QCOW2 format...') + + os.system(f'$(which qemu-img) convert -f vmdk -O qcow2 {file_downloaded} {file_qcow2}') + + logger.info('Applying unl_wrapper to fix permissions') + console.print('Applying unl_wrapper to fix permissions') + + os.system('/opt/unetlab/wrappers/unl_wrapper -a fixpermissions') + os.system(f'rm -f {file_downloaded}') + + if noztp: + self._disable_ztp(file_path=file_path) + + def docker_import(self, image_name: str = "arista/ceos") -> None: + """ + Import docker container to your docker server. + + Import downloaded container to your local docker engine. + + Args: + version (str): + image_name (str, optional): Image name to use. Defaults to "arista/ceos". + """ + docker_image = f'{image_name}:{self.version}' + logger.info(f'Importing image {self.filename} to {docker_image}') + console.print(f'🚀 Importing image {self.filename} to {docker_image}') + os.system(f'$(which docker) import {self.filename} {docker_image}') + for filename in glob.glob(f'{self.filename}*'): + try: + os.remove(filename) + except FileNotFoundError: + console.print(f'File not found: {filename}') diff --git a/eos_downloader/tools.py b/eos_downloader/tools.py new file mode 100644 index 0000000..a0f971a --- /dev/null +++ b/eos_downloader/tools.py @@ -0,0 +1,13 @@ +#!/usr/bin/python +# coding: utf-8 -*- + +"""Module for tools related to ardl""" + + +def exc_to_str(exception: Exception) -> str: + """ + Helper function to parse Exceptions + """ + return ( + f"{type(exception).__name__}{f' ({str(exception)})' if str(exception) else ''}" + ) |