summaryrefslogtreecommitdiffstats
path: root/eos_downloader/download.py
blob: df3c3815d9a44d5cb86f16549daec08fefd5c4e1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# flake8: noqa: F811
# pylint: disable=unused-argument
# pylint: disable=too-few-public-methods

"""download module"""

import os.path
import signal
from concurrent.futures import ThreadPoolExecutor
from threading import Event
from typing import Any, Iterable

import requests
import rich
from rich import console
from rich.progress import (
    BarColumn,
    DownloadColumn,
    Progress,
    TaskID,
    TextColumn,
    TimeElapsedColumn,
    TransferSpeedColumn,
)

console = rich.get_console()
done_event = Event()

REQUEST_HEADERS = {
    "Content-Type": "application/json",
    "User-Agent": "Chrome/123.0.0.0",
}


def handle_sigint(signum: Any, frame: Any) -> None:
    """Progress bar handler"""
    done_event.set()


signal.signal(signal.SIGINT, handle_sigint)


class DownloadProgressBar:
    """
    Object to manage Download process with Progress Bar from Rich
    """

    def __init__(self) -> None:
        """
        Class Constructor
        """
        self.progress = Progress(
            TextColumn(
                "💾  Downloading [bold blue]{task.fields[filename]}", justify="right"
            ),
            BarColumn(bar_width=None),
            "[progress.percentage]{task.percentage:>3.1f}%",
            "•",
            TransferSpeedColumn(),
            "•",
            DownloadColumn(),
            "•",
            TimeElapsedColumn(),
            "•",
            console=console,
        )

    def _copy_url(
        self, task_id: TaskID, url: str, path: str, block_size: int = 1024
    ) -> bool:
        """Copy data from a url to a local file."""
        response = requests.get(url, stream=True, timeout=5, headers=REQUEST_HEADERS)
        # This will break if the response doesn't contain content length
        self.progress.update(task_id, total=int(response.headers["Content-Length"]))
        with open(path, "wb") as dest_file:
            self.progress.start_task(task_id)
            for data in response.iter_content(chunk_size=block_size):
                dest_file.write(data)
                self.progress.update(task_id, advance=len(data))
                if done_event.is_set():
                    return True
        # console.print(f"Downloaded {path}")
        return False

    def download(self, urls: Iterable[str], dest_dir: str) -> None:
        """Download multuple files to the given directory."""
        with self.progress:
            with ThreadPoolExecutor(max_workers=4) as pool:
                for url in urls:
                    filename = url.split("/")[-1].split("?")[0]
                    dest_path = os.path.join(dest_dir, filename)
                    task_id = self.progress.add_task(
                        "download", filename=filename, start=False
                    )
                    pool.submit(self._copy_url, task_id, url, dest_path)