diff options
Diffstat (limited to 'eos_downloader/object_downloader.py')
-rw-r--r-- | eos_downloader/object_downloader.py | 223 |
1 files changed, 139 insertions, 84 deletions
diff --git a/eos_downloader/object_downloader.py b/eos_downloader/object_downloader.py index 0420acb..d7b1418 100644 --- a/eos_downloader/object_downloader.py +++ b/eos_downloader/object_downloader.py @@ -8,8 +8,13 @@ eos_downloader class definition """ -from __future__ import (absolute_import, division, print_function, - unicode_literals, annotations) +from __future__ import ( + absolute_import, + annotations, + division, + print_function, + unicode_literals, +) import base64 import glob @@ -26,9 +31,14 @@ from loguru import logger from rich import console from tqdm import tqdm -from eos_downloader import (ARISTA_DOWNLOAD_URL, ARISTA_GET_SESSION, - ARISTA_SOFTWARE_FOLDER_TREE, EVE_QEMU_FOLDER_PATH, - MSG_INVALID_DATA, MSG_TOKEN_EXPIRED) +from eos_downloader import ( + ARISTA_DOWNLOAD_URL, + ARISTA_GET_SESSION, + ARISTA_SOFTWARE_FOLDER_TREE, + EVE_QEMU_FOLDER_PATH, + MSG_INVALID_DATA, + MSG_TOKEN_EXPIRED, +) from eos_downloader.data import DATA_MAPPING from eos_downloader.download import DownloadProgressBar @@ -37,11 +47,19 @@ from eos_downloader.download import DownloadProgressBar console = rich.get_console() -class ObjectDownloader(): +class ObjectDownloader: """ ObjectDownloader Generic Object to download from Arista.com """ - def __init__(self, image: str, version: str, token: str, software: str = 'EOS', hash_method: str = 'md5sum'): + + def __init__( + self, + image: str, + version: str, + token: str, + software: str = "EOS", + hash_method: str = "md5sum", + ): """ __init__ Class constructor @@ -70,10 +88,10 @@ class ObjectDownloader(): self.hash_method = hash_method self.timeout = 5 # Logging - logger.debug(f'Filename built by _build_filename is {self.filename}') + logger.debug(f"Filename built by _build_filename is {self.filename}") def __str__(self) -> str: - return f'{self.software} - {self.image} - {self.version}' + return f"{self.software} - {self.image} - {self.version}" # def __repr__(self): # return str(self.__dict__) @@ -102,16 +120,18 @@ class ObjectDownloader(): str: Filename to search for on Arista.com """ - logger.info('start build') + logger.info("start build") if self.software in DATA_MAPPING: - logger.info(f'software in data mapping: {self.software}') + logger.info(f"software in data mapping: {self.software}") if self.image in DATA_MAPPING[self.software]: - logger.info(f'image in data mapping: {self.image}') + logger.info(f"image in data mapping: {self.image}") return f"{DATA_MAPPING[self.software][self.image]['prepend']}-{self.version}{DATA_MAPPING[self.software][self.image]['extension']}" return f"{DATA_MAPPING[self.software]['default']['prepend']}-{self.version}{DATA_MAPPING[self.software]['default']['extension']}" - raise ValueError(f'Incorrect value for software {self.software}') + raise ValueError(f"Incorrect value for software {self.software}") - def _parse_xml_for_path(self, root_xml: ET.ElementTree, xpath: str, search_file: str) -> str: + def _parse_xml_for_path( + self, root_xml: ET.ElementTree, xpath: str, search_file: str + ) -> str: # sourcery skip: remove-unnecessary-cast """ _parse_xml Read and extract data from XML using XPATH @@ -132,18 +152,18 @@ class ObjectDownloader(): str File Path on Arista server side """ - logger.debug(f'Using xpath {xpath}') - logger.debug(f'Search for file {search_file}') - console.print(f'🔎 Searching file {search_file}') + logger.debug(f"Using xpath {xpath}") + logger.debug(f"Search for file {search_file}") + console.print(f"🔎 Searching file {search_file}") for node in root_xml.findall(xpath): # logger.debug('Found {}', node.text) if str(node.text).lower() == search_file.lower(): - path = node.get('path') - console.print(f' -> Found file at {path}') + path = node.get("path") + console.print(f" -> Found file at {path}") logger.info(f'Found {node.text} at {node.get("path")}') - return str(node.get('path')) if node.get('path') is not None else '' - logger.error(f'Requested file ({self.filename}) not found !') - return '' + return str(node.get("path")) if node.get("path") is not None else "" + logger.error(f"Requested file ({self.filename}) not found !") + return "" def _get_hash(self, file_path: str) -> str: """ @@ -165,10 +185,10 @@ class ObjectDownloader(): dl_rich_progress_bar = DownloadProgressBar() dl_rich_progress_bar.download(urls=[hash_url], dest_dir=file_path) hash_downloaded = f"{file_path}/{os.path.basename(remote_hash_file)}" - hash_content = 'unset' - with open(hash_downloaded, 'r', encoding='utf-8') as f: + hash_content = "unset" + with open(hash_downloaded, "r", encoding="utf-8") as f: hash_content = f.read() - return hash_content.split(' ')[0] + return hash_content.split(" ")[0] @staticmethod def _compute_hash_md5sum(file: str, hash_expected: str) -> bool: @@ -195,7 +215,9 @@ class ObjectDownloader(): hash_md5.update(chunk) if hash_md5.hexdigest() == hash_expected: return True - logger.warning(f'Downloaded file is corrupt: local md5 ({hash_md5.hexdigest()}) is different to md5 from arista ({hash_expected})') + logger.warning( + f"Downloaded file is corrupt: local md5 ({hash_md5.hexdigest()}) is different to md5 from arista ({hash_expected})" + ) return False @staticmethod @@ -223,10 +245,12 @@ class ObjectDownloader(): hash_sha512.update(chunk) if hash_sha512.hexdigest() == hash_expected: return True - logger.warning(f'Downloaded file is corrupt: local sha512 ({hash_sha512.hexdigest()}) is different to sha512 from arista ({hash_expected})') + logger.warning( + f"Downloaded file is corrupt: local sha512 ({hash_sha512.hexdigest()}) is different to sha512 from arista ({hash_expected})" + ) return False - def _get_folder_tree(self) -> ET.ElementTree: + def get_folder_tree(self) -> ET.ElementTree: """ _get_folder_tree Download XML tree from Arista server @@ -237,15 +261,17 @@ class ObjectDownloader(): """ if self.session_id is None: self.authenticate() - jsonpost = {'sessionCode': self.session_id} - result = requests.post(ARISTA_SOFTWARE_FOLDER_TREE, data=json.dumps(jsonpost), timeout=self.timeout) + jsonpost = {"sessionCode": self.session_id} + result = requests.post( + ARISTA_SOFTWARE_FOLDER_TREE, data=json.dumps(jsonpost), timeout=self.timeout + ) try: folder_tree = result.json()["data"]["xml"] return ET.ElementTree(ET.fromstring(folder_tree)) except KeyError as error: logger.error(MSG_INVALID_DATA) - logger.error(f'Server returned: {error}') - console.print(f'❌ {MSG_INVALID_DATA}', style="bold red") + logger.error(f"Server returned: {error}") + console.print(f"❌ {MSG_INVALID_DATA}", style="bold red") sys.exit(1) def _get_remote_filepath(self) -> str: @@ -259,12 +285,14 @@ class ObjectDownloader(): str Remote path of the file to download """ - root = self._get_folder_tree() + root = self.get_folder_tree() logger.debug("GET XML content from ARISTA.com") xpath = f'.//dir[@label="{self.software}"]//file' - return self._parse_xml_for_path(root_xml=root, xpath=xpath, search_file=self.filename) + return self._parse_xml_for_path( + root_xml=root, xpath=xpath, search_file=self.filename + ) - def _get_remote_hashpath(self, hash_method: str = 'md5sum') -> str: + def _get_remote_hashpath(self, hash_method: str = "md5sum") -> str: """ _get_remote_hashpath Helper to get path of the hash's file to download @@ -275,16 +303,16 @@ class ObjectDownloader(): str Remote path of the hash's file to download """ - root = self._get_folder_tree() + root = self.get_folder_tree() logger.debug("GET XML content from ARISTA.com") xpath = f'.//dir[@label="{self.software}"]//file' return self._parse_xml_for_path( root_xml=root, xpath=xpath, - search_file=f'{self.filename}.{hash_method}', + search_file=f"{self.filename}.{hash_method}", ) - def _get_url(self, remote_file_path: str) -> str: + def _get_url(self, remote_file_path: str) -> str: """ _get_url Get URL to use for downloading file from Arista server @@ -302,13 +330,15 @@ class ObjectDownloader(): """ if self.session_id is None: self.authenticate() - jsonpost = {'sessionCode': self.session_id, 'filePath': remote_file_path} - result = requests.post(ARISTA_DOWNLOAD_URL, data=json.dumps(jsonpost), timeout=self.timeout) - if 'data' in result.json() and 'url' in result.json()['data']: + jsonpost = {"sessionCode": self.session_id, "filePath": remote_file_path} + result = requests.post( + ARISTA_DOWNLOAD_URL, data=json.dumps(jsonpost), timeout=self.timeout + ) + if "data" in result.json() and "url" in result.json()["data"]: # logger.debug('URL to download file is: {}', result.json()) return result.json()["data"]["url"] - logger.critical(f'Server returns following message: {result.json()}') - return '' + logger.critical(f"Server returns following message: {result.json()}") + return "" @staticmethod def _download_file_raw(url: str, file_path: str) -> str: @@ -331,31 +361,40 @@ class ObjectDownloader(): """ chunkSize = 1024 r = requests.get(url, stream=True, timeout=5) - with open(file_path, 'wb') as f: - pbar = tqdm(unit="B", total=int(r.headers['Content-Length']), unit_scale=True, unit_divisor=1024) + with open(file_path, "wb") as f: + pbar = tqdm( + unit="B", + total=int(r.headers["Content-Length"]), + unit_scale=True, + unit_divisor=1024, + ) for chunk in r.iter_content(chunk_size=chunkSize): if chunk: pbar.update(len(chunk)) f.write(chunk) return file_path - def _download_file(self, file_path: str, filename: str, rich_interface: bool = True) -> Union[None, str]: + def _download_file( + self, file_path: str, filename: str, rich_interface: bool = True + ) -> Union[None, str]: remote_file_path = self._get_remote_filepath() - logger.info(f'File found on arista server: {remote_file_path}') + logger.info(f"File found on arista server: {remote_file_path}") file_url = self._get_url(remote_file_path=remote_file_path) if file_url is not False: if not rich_interface: - return self._download_file_raw(url=file_url, file_path=os.path.join(file_path, filename)) + return self._download_file_raw( + url=file_url, file_path=os.path.join(file_path, filename) + ) rich_downloader = DownloadProgressBar() rich_downloader.download(urls=[file_url], dest_dir=file_path) return os.path.join(file_path, filename) - logger.error(f'Cannot download file {file_path}') + logger.error(f"Cannot download file {file_path}") return None @staticmethod def _create_destination_folder(path: str) -> None: # os.makedirs(path, mode, exist_ok=True) - os.system(f'mkdir -p {path}') + os.system(f"mkdir -p {path}") @staticmethod def _disable_ztp(file_path: str) -> None: @@ -379,24 +418,29 @@ class ObjectDownloader(): """ credentials = (base64.b64encode(self.token.encode())).decode("utf-8") session_code_url = ARISTA_GET_SESSION - jsonpost = {'accessToken': credentials} + jsonpost = {"accessToken": credentials} - result = requests.post(session_code_url, data=json.dumps(jsonpost), timeout=self.timeout) + result = requests.post( + session_code_url, data=json.dumps(jsonpost), timeout=self.timeout + ) - if result.json()["status"]["message"] in[ 'Access token expired', 'Invalid access token']: - console.print(f'❌ {MSG_TOKEN_EXPIRED}', style="bold red") + if result.json()["status"]["message"] in [ + "Access token expired", + "Invalid access token", + ]: + console.print(f"❌ {MSG_TOKEN_EXPIRED}", style="bold red") logger.error(MSG_TOKEN_EXPIRED) return False try: - if 'data' in result.json(): + if "data" in result.json(): self.session_id = result.json()["data"]["session_code"] - logger.info('Authenticated on arista.com') + logger.info("Authenticated on arista.com") return True - logger.debug(f'{result.json()}') + logger.debug(f"{result.json()}") return False except KeyError as error_arista: - logger.error(f'Error: {error_arista}') + logger.error(f"Error: {error_arista}") sys.exit(1) def download_local(self, file_path: str, checksum: bool = False) -> bool: @@ -422,25 +466,33 @@ class ObjectDownloader(): bool True if everything went well, False if any problem appears """ - file_downloaded = str(self._download_file(file_path=file_path, filename=self.filename)) + file_downloaded = str( + self._download_file(file_path=file_path, filename=self.filename) + ) # Check file HASH hash_result = False if checksum: - logger.info('🚀 Running checksum validation') - console.print('🚀 Running checksum validation') - if self.hash_method == 'md5sum': + logger.info("🚀 Running checksum validation") + console.print("🚀 Running checksum validation") + if self.hash_method == "md5sum": hash_expected = self._get_hash(file_path=file_path) - hash_result = self._compute_hash_md5sum(file=file_downloaded, hash_expected=hash_expected) - elif self.hash_method == 'sha512sum': + hash_result = self._compute_hash_md5sum( + file=file_downloaded, hash_expected=hash_expected + ) + elif self.hash_method == "sha512sum": hash_expected = self._get_hash(file_path=file_path) - hash_result = self._compute_hash_sh512sum(file=file_downloaded, hash_expected=hash_expected) + hash_result = self._compute_hash_sh512sum( + file=file_downloaded, hash_expected=hash_expected + ) if not hash_result: - logger.error('Downloaded file is corrupted, please check your connection') - console.print('❌ Downloaded file is corrupted, please check your connection') + logger.error("Downloaded file is corrupted, please check your connection") + console.print( + "❌ Downloaded file is corrupted, please check your connection" + ) return False - logger.info('Downloaded file is correct.') - console.print('✅ Downloaded file is correct.') + logger.info("Downloaded file is correct.") + console.print("✅ Downloaded file is correct.") return True def provision_eve(self, noztp: bool = False, checksum: bool = True) -> None: @@ -466,7 +518,7 @@ class ObjectDownloader(): # Build image name to use in folder path eos_image_name = self.filename.rstrip(".vmdk").lower() if noztp: - eos_image_name = f'{eos_image_name}-noztp' + eos_image_name = f"{eos_image_name}-noztp" # Create full path for EVE-NG file_path = os.path.join(EVE_QEMU_FOLDER_PATH, eos_image_name.rstrip()) # Create folders in filesystem @@ -474,20 +526,23 @@ class ObjectDownloader(): # Download file to local destination file_downloaded = self._download_file( - file_path=file_path, filename=self.filename) + file_path=file_path, filename=self.filename + ) # Convert to QCOW2 format file_qcow2 = os.path.join(file_path, "hda.qcow2") - logger.info('Converting VMDK to QCOW2 format') - console.print('🚀 Converting VMDK to QCOW2 format...') + logger.info("Converting VMDK to QCOW2 format") + console.print("🚀 Converting VMDK to QCOW2 format...") - os.system(f'$(which qemu-img) convert -f vmdk -O qcow2 {file_downloaded} {file_qcow2}') + os.system( + f"$(which qemu-img) convert -f vmdk -O qcow2 {file_downloaded} {file_qcow2}" + ) - logger.info('Applying unl_wrapper to fix permissions') - console.print('Applying unl_wrapper to fix permissions') + logger.info("Applying unl_wrapper to fix permissions") + console.print("Applying unl_wrapper to fix permissions") - os.system('/opt/unetlab/wrappers/unl_wrapper -a fixpermissions') - os.system(f'rm -f {file_downloaded}') + os.system("/opt/unetlab/wrappers/unl_wrapper -a fixpermissions") + os.system(f"rm -f {file_downloaded}") if noztp: self._disable_ztp(file_path=file_path) @@ -502,12 +557,12 @@ class ObjectDownloader(): version (str): image_name (str, optional): Image name to use. Defaults to "arista/ceos". """ - docker_image = f'{image_name}:{self.version}' - logger.info(f'Importing image {self.filename} to {docker_image}') - console.print(f'🚀 Importing image {self.filename} to {docker_image}') - os.system(f'$(which docker) import {self.filename} {docker_image}') - for filename in glob.glob(f'{self.filename}*'): + docker_image = f"{image_name}:{self.version}" + logger.info(f"Importing image {self.filename} to {docker_image}") + console.print(f"🚀 Importing image {self.filename} to {docker_image}") + os.system(f"$(which docker) import {self.filename} {docker_image}") + for filename in glob.glob(f"{self.filename}*"): try: os.remove(filename) except FileNotFoundError: - console.print(f'File not found: {filename}') + console.print(f"File not found: {filename}") |