summaryrefslogtreecommitdiffstats
path: root/eos_downloader/object_downloader.py
diff options
context:
space:
mode:
Diffstat (limited to 'eos_downloader/object_downloader.py')
-rw-r--r--eos_downloader/object_downloader.py223
1 files changed, 139 insertions, 84 deletions
diff --git a/eos_downloader/object_downloader.py b/eos_downloader/object_downloader.py
index 0420acb..d7b1418 100644
--- a/eos_downloader/object_downloader.py
+++ b/eos_downloader/object_downloader.py
@@ -8,8 +8,13 @@
eos_downloader class definition
"""
-from __future__ import (absolute_import, division, print_function,
- unicode_literals, annotations)
+from __future__ import (
+ absolute_import,
+ annotations,
+ division,
+ print_function,
+ unicode_literals,
+)
import base64
import glob
@@ -26,9 +31,14 @@ from loguru import logger
from rich import console
from tqdm import tqdm
-from eos_downloader import (ARISTA_DOWNLOAD_URL, ARISTA_GET_SESSION,
- ARISTA_SOFTWARE_FOLDER_TREE, EVE_QEMU_FOLDER_PATH,
- MSG_INVALID_DATA, MSG_TOKEN_EXPIRED)
+from eos_downloader import (
+ ARISTA_DOWNLOAD_URL,
+ ARISTA_GET_SESSION,
+ ARISTA_SOFTWARE_FOLDER_TREE,
+ EVE_QEMU_FOLDER_PATH,
+ MSG_INVALID_DATA,
+ MSG_TOKEN_EXPIRED,
+)
from eos_downloader.data import DATA_MAPPING
from eos_downloader.download import DownloadProgressBar
@@ -37,11 +47,19 @@ from eos_downloader.download import DownloadProgressBar
console = rich.get_console()
-class ObjectDownloader():
+class ObjectDownloader:
"""
ObjectDownloader Generic Object to download from Arista.com
"""
- def __init__(self, image: str, version: str, token: str, software: str = 'EOS', hash_method: str = 'md5sum'):
+
+ def __init__(
+ self,
+ image: str,
+ version: str,
+ token: str,
+ software: str = "EOS",
+ hash_method: str = "md5sum",
+ ):
"""
__init__ Class constructor
@@ -70,10 +88,10 @@ class ObjectDownloader():
self.hash_method = hash_method
self.timeout = 5
# Logging
- logger.debug(f'Filename built by _build_filename is {self.filename}')
+ logger.debug(f"Filename built by _build_filename is {self.filename}")
def __str__(self) -> str:
- return f'{self.software} - {self.image} - {self.version}'
+ return f"{self.software} - {self.image} - {self.version}"
# def __repr__(self):
# return str(self.__dict__)
@@ -102,16 +120,18 @@ class ObjectDownloader():
str:
Filename to search for on Arista.com
"""
- logger.info('start build')
+ logger.info("start build")
if self.software in DATA_MAPPING:
- logger.info(f'software in data mapping: {self.software}')
+ logger.info(f"software in data mapping: {self.software}")
if self.image in DATA_MAPPING[self.software]:
- logger.info(f'image in data mapping: {self.image}')
+ logger.info(f"image in data mapping: {self.image}")
return f"{DATA_MAPPING[self.software][self.image]['prepend']}-{self.version}{DATA_MAPPING[self.software][self.image]['extension']}"
return f"{DATA_MAPPING[self.software]['default']['prepend']}-{self.version}{DATA_MAPPING[self.software]['default']['extension']}"
- raise ValueError(f'Incorrect value for software {self.software}')
+ raise ValueError(f"Incorrect value for software {self.software}")
- def _parse_xml_for_path(self, root_xml: ET.ElementTree, xpath: str, search_file: str) -> str:
+ def _parse_xml_for_path(
+ self, root_xml: ET.ElementTree, xpath: str, search_file: str
+ ) -> str:
# sourcery skip: remove-unnecessary-cast
"""
_parse_xml Read and extract data from XML using XPATH
@@ -132,18 +152,18 @@ class ObjectDownloader():
str
File Path on Arista server side
"""
- logger.debug(f'Using xpath {xpath}')
- logger.debug(f'Search for file {search_file}')
- console.print(f'🔎 Searching file {search_file}')
+ logger.debug(f"Using xpath {xpath}")
+ logger.debug(f"Search for file {search_file}")
+ console.print(f"🔎 Searching file {search_file}")
for node in root_xml.findall(xpath):
# logger.debug('Found {}', node.text)
if str(node.text).lower() == search_file.lower():
- path = node.get('path')
- console.print(f' -> Found file at {path}')
+ path = node.get("path")
+ console.print(f" -> Found file at {path}")
logger.info(f'Found {node.text} at {node.get("path")}')
- return str(node.get('path')) if node.get('path') is not None else ''
- logger.error(f'Requested file ({self.filename}) not found !')
- return ''
+ return str(node.get("path")) if node.get("path") is not None else ""
+ logger.error(f"Requested file ({self.filename}) not found !")
+ return ""
def _get_hash(self, file_path: str) -> str:
"""
@@ -165,10 +185,10 @@ class ObjectDownloader():
dl_rich_progress_bar = DownloadProgressBar()
dl_rich_progress_bar.download(urls=[hash_url], dest_dir=file_path)
hash_downloaded = f"{file_path}/{os.path.basename(remote_hash_file)}"
- hash_content = 'unset'
- with open(hash_downloaded, 'r', encoding='utf-8') as f:
+ hash_content = "unset"
+ with open(hash_downloaded, "r", encoding="utf-8") as f:
hash_content = f.read()
- return hash_content.split(' ')[0]
+ return hash_content.split(" ")[0]
@staticmethod
def _compute_hash_md5sum(file: str, hash_expected: str) -> bool:
@@ -195,7 +215,9 @@ class ObjectDownloader():
hash_md5.update(chunk)
if hash_md5.hexdigest() == hash_expected:
return True
- logger.warning(f'Downloaded file is corrupt: local md5 ({hash_md5.hexdigest()}) is different to md5 from arista ({hash_expected})')
+ logger.warning(
+ f"Downloaded file is corrupt: local md5 ({hash_md5.hexdigest()}) is different to md5 from arista ({hash_expected})"
+ )
return False
@staticmethod
@@ -223,10 +245,12 @@ class ObjectDownloader():
hash_sha512.update(chunk)
if hash_sha512.hexdigest() == hash_expected:
return True
- logger.warning(f'Downloaded file is corrupt: local sha512 ({hash_sha512.hexdigest()}) is different to sha512 from arista ({hash_expected})')
+ logger.warning(
+ f"Downloaded file is corrupt: local sha512 ({hash_sha512.hexdigest()}) is different to sha512 from arista ({hash_expected})"
+ )
return False
- def _get_folder_tree(self) -> ET.ElementTree:
+ def get_folder_tree(self) -> ET.ElementTree:
"""
_get_folder_tree Download XML tree from Arista server
@@ -237,15 +261,17 @@ class ObjectDownloader():
"""
if self.session_id is None:
self.authenticate()
- jsonpost = {'sessionCode': self.session_id}
- result = requests.post(ARISTA_SOFTWARE_FOLDER_TREE, data=json.dumps(jsonpost), timeout=self.timeout)
+ jsonpost = {"sessionCode": self.session_id}
+ result = requests.post(
+ ARISTA_SOFTWARE_FOLDER_TREE, data=json.dumps(jsonpost), timeout=self.timeout
+ )
try:
folder_tree = result.json()["data"]["xml"]
return ET.ElementTree(ET.fromstring(folder_tree))
except KeyError as error:
logger.error(MSG_INVALID_DATA)
- logger.error(f'Server returned: {error}')
- console.print(f'❌ {MSG_INVALID_DATA}', style="bold red")
+ logger.error(f"Server returned: {error}")
+ console.print(f"❌ {MSG_INVALID_DATA}", style="bold red")
sys.exit(1)
def _get_remote_filepath(self) -> str:
@@ -259,12 +285,14 @@ class ObjectDownloader():
str
Remote path of the file to download
"""
- root = self._get_folder_tree()
+ root = self.get_folder_tree()
logger.debug("GET XML content from ARISTA.com")
xpath = f'.//dir[@label="{self.software}"]//file'
- return self._parse_xml_for_path(root_xml=root, xpath=xpath, search_file=self.filename)
+ return self._parse_xml_for_path(
+ root_xml=root, xpath=xpath, search_file=self.filename
+ )
- def _get_remote_hashpath(self, hash_method: str = 'md5sum') -> str:
+ def _get_remote_hashpath(self, hash_method: str = "md5sum") -> str:
"""
_get_remote_hashpath Helper to get path of the hash's file to download
@@ -275,16 +303,16 @@ class ObjectDownloader():
str
Remote path of the hash's file to download
"""
- root = self._get_folder_tree()
+ root = self.get_folder_tree()
logger.debug("GET XML content from ARISTA.com")
xpath = f'.//dir[@label="{self.software}"]//file'
return self._parse_xml_for_path(
root_xml=root,
xpath=xpath,
- search_file=f'{self.filename}.{hash_method}',
+ search_file=f"{self.filename}.{hash_method}",
)
- def _get_url(self, remote_file_path: str) -> str:
+ def _get_url(self, remote_file_path: str) -> str:
"""
_get_url Get URL to use for downloading file from Arista server
@@ -302,13 +330,15 @@ class ObjectDownloader():
"""
if self.session_id is None:
self.authenticate()
- jsonpost = {'sessionCode': self.session_id, 'filePath': remote_file_path}
- result = requests.post(ARISTA_DOWNLOAD_URL, data=json.dumps(jsonpost), timeout=self.timeout)
- if 'data' in result.json() and 'url' in result.json()['data']:
+ jsonpost = {"sessionCode": self.session_id, "filePath": remote_file_path}
+ result = requests.post(
+ ARISTA_DOWNLOAD_URL, data=json.dumps(jsonpost), timeout=self.timeout
+ )
+ if "data" in result.json() and "url" in result.json()["data"]:
# logger.debug('URL to download file is: {}', result.json())
return result.json()["data"]["url"]
- logger.critical(f'Server returns following message: {result.json()}')
- return ''
+ logger.critical(f"Server returns following message: {result.json()}")
+ return ""
@staticmethod
def _download_file_raw(url: str, file_path: str) -> str:
@@ -331,31 +361,40 @@ class ObjectDownloader():
"""
chunkSize = 1024
r = requests.get(url, stream=True, timeout=5)
- with open(file_path, 'wb') as f:
- pbar = tqdm(unit="B", total=int(r.headers['Content-Length']), unit_scale=True, unit_divisor=1024)
+ with open(file_path, "wb") as f:
+ pbar = tqdm(
+ unit="B",
+ total=int(r.headers["Content-Length"]),
+ unit_scale=True,
+ unit_divisor=1024,
+ )
for chunk in r.iter_content(chunk_size=chunkSize):
if chunk:
pbar.update(len(chunk))
f.write(chunk)
return file_path
- def _download_file(self, file_path: str, filename: str, rich_interface: bool = True) -> Union[None, str]:
+ def _download_file(
+ self, file_path: str, filename: str, rich_interface: bool = True
+ ) -> Union[None, str]:
remote_file_path = self._get_remote_filepath()
- logger.info(f'File found on arista server: {remote_file_path}')
+ logger.info(f"File found on arista server: {remote_file_path}")
file_url = self._get_url(remote_file_path=remote_file_path)
if file_url is not False:
if not rich_interface:
- return self._download_file_raw(url=file_url, file_path=os.path.join(file_path, filename))
+ return self._download_file_raw(
+ url=file_url, file_path=os.path.join(file_path, filename)
+ )
rich_downloader = DownloadProgressBar()
rich_downloader.download(urls=[file_url], dest_dir=file_path)
return os.path.join(file_path, filename)
- logger.error(f'Cannot download file {file_path}')
+ logger.error(f"Cannot download file {file_path}")
return None
@staticmethod
def _create_destination_folder(path: str) -> None:
# os.makedirs(path, mode, exist_ok=True)
- os.system(f'mkdir -p {path}')
+ os.system(f"mkdir -p {path}")
@staticmethod
def _disable_ztp(file_path: str) -> None:
@@ -379,24 +418,29 @@ class ObjectDownloader():
"""
credentials = (base64.b64encode(self.token.encode())).decode("utf-8")
session_code_url = ARISTA_GET_SESSION
- jsonpost = {'accessToken': credentials}
+ jsonpost = {"accessToken": credentials}
- result = requests.post(session_code_url, data=json.dumps(jsonpost), timeout=self.timeout)
+ result = requests.post(
+ session_code_url, data=json.dumps(jsonpost), timeout=self.timeout
+ )
- if result.json()["status"]["message"] in[ 'Access token expired', 'Invalid access token']:
- console.print(f'❌ {MSG_TOKEN_EXPIRED}', style="bold red")
+ if result.json()["status"]["message"] in [
+ "Access token expired",
+ "Invalid access token",
+ ]:
+ console.print(f"❌ {MSG_TOKEN_EXPIRED}", style="bold red")
logger.error(MSG_TOKEN_EXPIRED)
return False
try:
- if 'data' in result.json():
+ if "data" in result.json():
self.session_id = result.json()["data"]["session_code"]
- logger.info('Authenticated on arista.com')
+ logger.info("Authenticated on arista.com")
return True
- logger.debug(f'{result.json()}')
+ logger.debug(f"{result.json()}")
return False
except KeyError as error_arista:
- logger.error(f'Error: {error_arista}')
+ logger.error(f"Error: {error_arista}")
sys.exit(1)
def download_local(self, file_path: str, checksum: bool = False) -> bool:
@@ -422,25 +466,33 @@ class ObjectDownloader():
bool
True if everything went well, False if any problem appears
"""
- file_downloaded = str(self._download_file(file_path=file_path, filename=self.filename))
+ file_downloaded = str(
+ self._download_file(file_path=file_path, filename=self.filename)
+ )
# Check file HASH
hash_result = False
if checksum:
- logger.info('🚀 Running checksum validation')
- console.print('🚀 Running checksum validation')
- if self.hash_method == 'md5sum':
+ logger.info("🚀 Running checksum validation")
+ console.print("🚀 Running checksum validation")
+ if self.hash_method == "md5sum":
hash_expected = self._get_hash(file_path=file_path)
- hash_result = self._compute_hash_md5sum(file=file_downloaded, hash_expected=hash_expected)
- elif self.hash_method == 'sha512sum':
+ hash_result = self._compute_hash_md5sum(
+ file=file_downloaded, hash_expected=hash_expected
+ )
+ elif self.hash_method == "sha512sum":
hash_expected = self._get_hash(file_path=file_path)
- hash_result = self._compute_hash_sh512sum(file=file_downloaded, hash_expected=hash_expected)
+ hash_result = self._compute_hash_sh512sum(
+ file=file_downloaded, hash_expected=hash_expected
+ )
if not hash_result:
- logger.error('Downloaded file is corrupted, please check your connection')
- console.print('❌ Downloaded file is corrupted, please check your connection')
+ logger.error("Downloaded file is corrupted, please check your connection")
+ console.print(
+ "❌ Downloaded file is corrupted, please check your connection"
+ )
return False
- logger.info('Downloaded file is correct.')
- console.print('✅ Downloaded file is correct.')
+ logger.info("Downloaded file is correct.")
+ console.print("✅ Downloaded file is correct.")
return True
def provision_eve(self, noztp: bool = False, checksum: bool = True) -> None:
@@ -466,7 +518,7 @@ class ObjectDownloader():
# Build image name to use in folder path
eos_image_name = self.filename.rstrip(".vmdk").lower()
if noztp:
- eos_image_name = f'{eos_image_name}-noztp'
+ eos_image_name = f"{eos_image_name}-noztp"
# Create full path for EVE-NG
file_path = os.path.join(EVE_QEMU_FOLDER_PATH, eos_image_name.rstrip())
# Create folders in filesystem
@@ -474,20 +526,23 @@ class ObjectDownloader():
# Download file to local destination
file_downloaded = self._download_file(
- file_path=file_path, filename=self.filename)
+ file_path=file_path, filename=self.filename
+ )
# Convert to QCOW2 format
file_qcow2 = os.path.join(file_path, "hda.qcow2")
- logger.info('Converting VMDK to QCOW2 format')
- console.print('🚀 Converting VMDK to QCOW2 format...')
+ logger.info("Converting VMDK to QCOW2 format")
+ console.print("🚀 Converting VMDK to QCOW2 format...")
- os.system(f'$(which qemu-img) convert -f vmdk -O qcow2 {file_downloaded} {file_qcow2}')
+ os.system(
+ f"$(which qemu-img) convert -f vmdk -O qcow2 {file_downloaded} {file_qcow2}"
+ )
- logger.info('Applying unl_wrapper to fix permissions')
- console.print('Applying unl_wrapper to fix permissions')
+ logger.info("Applying unl_wrapper to fix permissions")
+ console.print("Applying unl_wrapper to fix permissions")
- os.system('/opt/unetlab/wrappers/unl_wrapper -a fixpermissions')
- os.system(f'rm -f {file_downloaded}')
+ os.system("/opt/unetlab/wrappers/unl_wrapper -a fixpermissions")
+ os.system(f"rm -f {file_downloaded}")
if noztp:
self._disable_ztp(file_path=file_path)
@@ -502,12 +557,12 @@ class ObjectDownloader():
version (str):
image_name (str, optional): Image name to use. Defaults to "arista/ceos".
"""
- docker_image = f'{image_name}:{self.version}'
- logger.info(f'Importing image {self.filename} to {docker_image}')
- console.print(f'🚀 Importing image {self.filename} to {docker_image}')
- os.system(f'$(which docker) import {self.filename} {docker_image}')
- for filename in glob.glob(f'{self.filename}*'):
+ docker_image = f"{image_name}:{self.version}"
+ logger.info(f"Importing image {self.filename} to {docker_image}")
+ console.print(f"🚀 Importing image {self.filename} to {docker_image}")
+ os.system(f"$(which docker) import {self.filename} {docker_image}")
+ for filename in glob.glob(f"{self.filename}*"):
try:
os.remove(filename)
except FileNotFoundError:
- console.print(f'File not found: {filename}')
+ console.print(f"File not found: {filename}")