#!/usr/bin/python3 -u
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

import argparse
import bz2
import concurrent.futures
import contextlib
import datetime
import gzip
import hashlib
import json
import lzma
import multiprocessing
import os
import pathlib
import random
import re
import stat
import subprocess
import sys
import tarfile
import tempfile
import time
import urllib.parse
import urllib.request
import zipfile

try:
    import zstandard
except ImportError:
    zstandard = None

try:
    import certifi
except ImportError:
    certifi = None


CONCURRENCY = multiprocessing.cpu_count()


def log(msg):
    print(msg, file=sys.stderr)
    sys.stderr.flush()


class IntegrityError(Exception):
    """Represents an integrity error when downloading a URL."""


def ZstdCompressor(*args, **kwargs):
    if not zstandard:
        raise ValueError("zstandard Python package not available")
    return zstandard.ZstdCompressor(*args, **kwargs)


def ZstdDecompressor(*args, **kwargs):
    if not zstandard:
        raise ValueError("zstandard Python package not available")
    return zstandard.ZstdDecompressor(*args, **kwargs)


@contextlib.contextmanager
def rename_after_close(fname, *args, **kwargs):
    """
    Context manager that opens a temporary file to use as a writer,
    and closes the file on context exit, renaming it to the expected
    file name in case of success, or removing it in case of failure.

    Takes the same options as open(), but must be used as a context
    manager.
    """
    path = pathlib.Path(fname)
    tmp = path.with_name("%s.tmp" % path.name)
    try:
        with tmp.open(*args, **kwargs) as fh:
            yield fh
    except Exception:
        tmp.unlink()
        raise
    else:
        tmp.rename(fname)


# The following is copied from
# https://github.com/mozilla-releng/redo/blob/6d07678a014e0c525e54a860381a165d34db10ff/redo/__init__.py#L15-L85
def retrier(attempts=5, sleeptime=10, max_sleeptime=300, sleepscale=1.5, jitter=1):
    """
    A generator function that sleeps between retries, handles exponential
    backoff and jitter. The action you are retrying is meant to run after
    retrier yields.

    At each iteration, we sleep for sleeptime + random.randint(-jitter, jitter).
    Afterwards sleeptime is multiplied by sleepscale for the next iteration.

    Args:
        attempts (int): maximum number of times to try; defaults to 5
        sleeptime (float): how many seconds to sleep between tries; defaults to
                           60s (one minute)
        max_sleeptime (float): the longest we'll sleep, in seconds; defaults to
                               300s (five minutes)
        sleepscale (float): how much to multiply the sleep time by each
                            iteration; defaults to 1.5
        jitter (int): random jitter to introduce to sleep time each iteration.
                      the amount is chosen at random between [-jitter, +jitter]
                      defaults to 1

    Yields:
        None, a maximum of `attempts` number of times

    Example:
        >>> n = 0
        >>> for _ in retrier(sleeptime=0, jitter=0):
        ...     if n == 3:
        ...         # We did the thing!
        ...         break
        ...     n += 1
        >>> n
        3

        >>> n = 0
        >>> for _ in retrier(sleeptime=0, jitter=0):
        ...     if n == 6:
        ...         # We did the thing!
        ...         break
        ...     n += 1
        ... else:
        ...     print("max tries hit")
        max tries hit
    """
    jitter = jitter or 0  # py35 barfs on the next line if jitter is None
    if jitter > sleeptime:
        # To prevent negative sleep times
        raise Exception(
            "jitter ({}) must be less than sleep time ({})".format(jitter, sleeptime)
        )

    sleeptime_real = sleeptime
    for _ in range(attempts):
        log("attempt %i/%i" % (_ + 1, attempts))

        yield sleeptime_real

        if jitter:
            sleeptime_real = sleeptime + random.randint(-jitter, jitter)
            # our jitter should scale along with the sleeptime
            jitter = int(jitter * sleepscale)
        else:
            sleeptime_real = sleeptime

        sleeptime *= sleepscale

        if sleeptime_real > max_sleeptime:
            sleeptime_real = max_sleeptime

        # Don't need to sleep the last time
        if _ < attempts - 1:
            log(
                "sleeping for %.2fs (attempt %i/%i)" % (sleeptime_real, _ + 1, attempts)
            )
            time.sleep(sleeptime_real)


def stream_download(url, sha256=None, size=None):
    """Download a URL to a generator, optionally with content verification.

    If ``sha256`` or ``size`` are defined, the downloaded URL will be
    validated against those requirements and ``IntegrityError`` will be
    raised if expectations do not match.

    Because verification cannot occur until the file is completely downloaded
    it is recommended for consumers to not do anything meaningful with the
    data if content verification is being used. To securely handle retrieved
    content, it should be streamed to a file or memory and only operated
    on after the generator is exhausted without raising.
    """
    log("Downloading %s" % url)

    h = hashlib.sha256()
    length = 0

    t0 = time.time()
    with urllib.request.urlopen(
        url, timeout=60, cafile=certifi.where()
    ) if certifi else urllib.request.urlopen(url, timeout=60) as fh:
        if not url.endswith(".gz") and fh.info().get("Content-Encoding") == "gzip":
            fh = gzip.GzipFile(fileobj=fh)

        while True:
            chunk = fh.read(65536)
            if not chunk:
                break

            h.update(chunk)
            length += len(chunk)

            yield chunk

    duration = time.time() - t0
    digest = h.hexdigest()

    log(
        "%s resolved to %d bytes with sha256 %s in %.3fs"
        % (url, length, digest, duration)
    )

    if size:
        if size == length:
            log("Verified size of %s" % url)
        else:
            raise IntegrityError(
                "size mismatch on %s: wanted %d; got %d" % (url, size, length)
            )

    if sha256:
        if digest == sha256:
            log("Verified sha256 integrity of %s" % url)
        else:
            raise IntegrityError(
                "sha256 mismatch on %s: wanted %s; got %s" % (url, sha256, digest)
            )


def download_to_path(url, path, sha256=None, size=None):
    """Download a URL to a filesystem path, possibly with verification."""

    # We download to a temporary file and rename at the end so there's
    # no chance of the final file being partially written or containing
    # bad data.
    try:
        path.unlink()
    except FileNotFoundError:
        pass

    for _ in retrier(attempts=5, sleeptime=60):
        try:
            log("Downloading %s to %s" % (url, path))

            with rename_after_close(path, "wb") as fh:
                for chunk in stream_download(url, sha256=sha256, size=size):
                    fh.write(chunk)

            return
        except IntegrityError:
            raise
        except Exception as e:
            log("Download failed: {}".format(e))
            continue

    raise Exception("Download failed, no more retries!")


def download_to_memory(url, sha256=None, size=None):
    """Download a URL to memory, possibly with verification."""

    data = b""
    for _ in retrier(attempts=5, sleeptime=60):
        try:
            log("Downloading %s" % (url))

            for chunk in stream_download(url, sha256=sha256, size=size):
                data += chunk

            return data
        except IntegrityError:
            raise
        except Exception as e:
            log("Download failed: {}".format(e))
            continue

    raise Exception("Download failed, no more retries!")


def gpg_verify_path(path: pathlib.Path, public_key_data: bytes, signature_data: bytes):
    """Verify that a filesystem path verifies using GPG.

    Takes a Path defining a file to verify. ``public_key_data`` contains
    bytes with GPG public key data. ``signature_data`` contains a signed
    GPG document to use with ``gpg --verify``.
    """
    log("Validating GPG signature of %s" % path)
    log("GPG key data:\n%s" % public_key_data.decode("ascii"))

    with tempfile.TemporaryDirectory() as td:
        try:
            # --batch since we're running unattended.
            gpg_args = ["gpg", "--homedir", td, "--batch"]

            log("Importing GPG key...")
            subprocess.run(gpg_args + ["--import"], input=public_key_data, check=True)

            log("Verifying GPG signature...")
            subprocess.run(
                gpg_args + ["--verify", "-", "%s" % path],
                input=signature_data,
                check=True,
            )

            log("GPG signature verified!")
        finally:
            # There is a race between the agent self-terminating and
            # shutil.rmtree() from the temporary directory cleanup that can
            # lead to exceptions. Kill the agent before cleanup to prevent this.
            env = dict(os.environ)
            env["GNUPGHOME"] = td
            subprocess.run(["gpgconf", "--kill", "gpg-agent"], env=env)


def open_tar_stream(path: pathlib.Path):
    """"""
    if path.suffix == ".bz2":
        return bz2.open(str(path), "rb")
    elif path.suffix in (".gz", ".tgz") :
        return gzip.open(str(path), "rb")
    elif path.suffix == ".xz":
        return lzma.open(str(path), "rb")
    elif path.suffix == ".zst":
        dctx = ZstdDecompressor()
        return dctx.stream_reader(path.open("rb"))
    elif path.suffix == ".tar":
        return path.open("rb")
    else:
        raise ValueError("unknown archive format for tar file: %s" % path)


def archive_type(path: pathlib.Path):
    """Attempt to identify a path as an extractable archive."""
    if path.suffixes[-2:-1] == [".tar"] or path.suffixes[-1:] == [".tgz"]:
        return "tar"
    elif path.suffix == ".zip":
        return "zip"
    else:
        return None


def extract_archive(path, dest_dir, typ):
    """Extract an archive to a destination directory."""

    # Resolve paths to absolute variants.
    path = path.resolve()
    dest_dir = dest_dir.resolve()

    log("Extracting %s to %s" % (path, dest_dir))
    t0 = time.time()

    # We pipe input to the decompressor program so that we can apply
    # custom decompressors that the program may not know about.
    if typ == "tar":
        ifh = open_tar_stream(path)
        # On Windows, the tar program doesn't support things like symbolic
        # links, while Windows actually support them. The tarfile module in
        # python does. So use that. But since it's significantly slower than
        # the tar program on Linux, only use tarfile on Windows (tarfile is
        # also not much slower on Windows, presumably because of the
        # notoriously bad I/O).
        if sys.platform == "win32":
            tar = tarfile.open(fileobj=ifh, mode="r|")
            tar.extractall(str(dest_dir))
            args = []
        else:
            args = ["tar", "xf", "-"]
            pipe_stdin = True
    elif typ == "zip":
        # unzip from stdin has wonky behavior. We don't use a pipe for it.
        ifh = open(os.devnull, "rb")
        args = ["unzip", "-o", str(path)]
        pipe_stdin = False
    else:
        raise ValueError("unknown archive format: %s" % path)

    if args:
        with ifh, subprocess.Popen(
            args, cwd=str(dest_dir), bufsize=0, stdin=subprocess.PIPE
        ) as p:
            while True:
                if not pipe_stdin:
                    break

                chunk = ifh.read(131072)
                if not chunk:
                    break

                p.stdin.write(chunk)

        if p.returncode:
            raise Exception("%r exited %d" % (args, p.returncode))

    log("%s extracted in %.3fs" % (path, time.time() - t0))


def repack_archive(
    orig: pathlib.Path, dest: pathlib.Path, strip_components=0, prefix=""
):
    assert orig != dest
    log("Repacking as %s" % dest)
    orig_typ = archive_type(orig)
    typ = archive_type(dest)
    if not orig_typ:
        raise Exception("Archive type not supported for %s" % orig.name)
    if not typ:
        raise Exception("Archive type not supported for %s" % dest.name)

    if dest.suffixes[-2:] != [".tar", ".zst"]:
        raise Exception("Only producing .tar.zst archives is supported.")

    if strip_components or prefix:

        def filter(name):
            if strip_components:
                stripped = "/".join(name.split("/")[strip_components:])
                if not stripped:
                    raise Exception(
                        "Stripping %d components would remove files" % strip_components
                    )
                name = stripped
            return prefix + name

    else:
        filter = None

    with rename_after_close(dest, "wb") as fh:
        ctx = ZstdCompressor()
        if orig_typ == "zip":
            assert typ == "tar"
            zip = zipfile.ZipFile(orig)
            # Convert the zip stream to a tar on the fly.
            with ctx.stream_writer(fh) as compressor, tarfile.open(
                fileobj=compressor, mode="w:"
            ) as tar:
                for zipinfo in zip.infolist():
                    if zipinfo.is_dir():
                        continue
                    tarinfo = tarfile.TarInfo()
                    filename = zipinfo.filename
                    tarinfo.name = filter(filename) if filter else filename
                    tarinfo.size = zipinfo.file_size
                    # Zip files don't have any knowledge of the timezone
                    # they were created in. Which is not really convenient to
                    # reliably convert to a timestamp. But we don't really
                    # care about accuracy, but rather about reproducibility,
                    # so we pick UTC.
                    time = datetime.datetime(
                        *zipinfo.date_time, tzinfo=datetime.timezone.utc
                    )
                    tarinfo.mtime = time.timestamp()
                    # 0 is MS-DOS, 3 is UNIX. Only in the latter case do we
                    # get anything useful for the tar file mode.
                    if zipinfo.create_system == 3:
                        mode = zipinfo.external_attr >> 16
                    else:
                        mode = 0o0644
                    tarinfo.mode = stat.S_IMODE(mode)
                    if stat.S_ISLNK(mode):
                        tarinfo.type = tarfile.SYMTYPE
                        tarinfo.linkname = zip.read(filename).decode()
                        tar.addfile(tarinfo, zip.open(filename))
                    elif stat.S_ISREG(mode) or stat.S_IFMT(mode) == 0:
                        tar.addfile(tarinfo, zip.open(filename))
                    else:
                        raise Exception("Unsupported file mode %o" % stat.S_IFMT(mode))

        elif orig_typ == "tar":
            if typ == "zip":
                raise Exception("Repacking a tar to zip is not supported")
            assert typ == "tar"

            ifh = open_tar_stream(orig)
            if filter:
                # To apply the filter, we need to open the tar stream and
                # tweak it.
                origtar = tarfile.open(fileobj=ifh, mode="r|")
                with ctx.stream_writer(fh) as compressor, tarfile.open(
                    fileobj=compressor,
                    mode="w:",
                    format=origtar.format,
                ) as tar:
                    for tarinfo in origtar:
                        if tarinfo.isdir():
                            continue
                        tarinfo.name = filter(tarinfo.name)
                        if "path" in tarinfo.pax_headers:
                            tarinfo.pax_headers["path"] = filter(
                                tarinfo.pax_headers["path"]
                            )
                        if tarinfo.isfile():
                            tar.addfile(tarinfo, origtar.extractfile(tarinfo))
                        else:
                            tar.addfile(tarinfo)
            else:
                # We only change compression here. The tar stream is unchanged.
                ctx.copy_stream(ifh, fh)


def fetch_and_extract(url, dest_dir, extract=True, sha256=None, size=None):
    """Fetch a URL and extract it to a destination path.

    If the downloaded URL is an archive, it is extracted automatically
    and the archive is deleted. Otherwise the file remains in place in
    the destination directory.
    """

    basename = urllib.parse.urlparse(url).path.split("/")[-1]
    dest_path = dest_dir / basename

    download_to_path(url, dest_path, sha256=sha256, size=size)

    if not extract:
        return

    typ = archive_type(dest_path)
    if typ:
        extract_archive(dest_path, dest_dir, typ)
        log("Removing %s" % dest_path)
        dest_path.unlink()


def fetch_urls(downloads):
    """Fetch URLs pairs to a pathlib.Path."""
    with concurrent.futures.ThreadPoolExecutor(CONCURRENCY) as e:
        fs = []

        for download in downloads:
            fs.append(e.submit(fetch_and_extract, *download))

        for f in fs:
            f.result()


def _git_checkout_github_archive(dest_path: pathlib.Path, repo: str,
                                 commit: str, prefix: str):
    'Use github archive generator to speed up github git repo cloning'
    repo = repo.rstrip('/')
    github_url = '{repo}/archive/{commit}.tar.gz'.format(**locals())

    with tempfile.TemporaryDirectory() as td:
        temp_dir = pathlib.Path(td)
        dl_dest = temp_dir / 'archive.tar.gz'
        download_to_path(github_url, dl_dest)
        repack_archive(dl_dest, dest_path,
                       strip_components=1,
                       prefix=prefix + '/')


def _github_submodule_required(repo: str, commit: str):
    'Use github API to check if submodules are used'
    url = '{repo}/blob/{commit}/.gitmodules'.format(**locals())
    try:
        status_code = urllib.request.urlopen(url).getcode()
        return status_code == 200
    except:
        return False


def git_checkout_archive(
    dest_path: pathlib.Path,
    repo: str,
    commit: str,
    prefix=None,
    ssh_key=None,
    include_dot_git=False,
):
    """Produce an archive of the files comprising a Git checkout."""
    dest_path.parent.mkdir(parents=True, exist_ok=True)

    if not prefix:
        prefix = repo.rstrip("/").rsplit("/", 1)[-1]

    if dest_path.suffixes[-2:] != [".tar", ".zst"]:
        raise Exception("Only producing .tar.zst archives is supported.")

    if repo.startswith('https://github.com/'):
        if not include_dot_git and not _github_submodule_required(repo, commit):
            log("Using github archive service to speedup archive creation")
            # Always log sha1 info, either from commit or resolved from repo.
            if re.match(r"^[a-fA-F0-9]{40}$", commit):
                revision = commit
            else:
                ref_output = subprocess.check_output(["git", "ls-remote", repo,
                                                      'refs/heads/' + commit])
                revision, _ = ref_output.decode().split(maxsplit=1)
            log("Fetching revision {}".format(revision))
            return _git_checkout_github_archive(dest_path, repo, commit, prefix)

    with tempfile.TemporaryDirectory() as td:
        temp_dir = pathlib.Path(td)

        git_dir = temp_dir / prefix

        # This could be faster with a shallow clone. However, Git requires a ref
        # to initiate a clone. Since the commit-ish may not refer to a ref, we
        # simply perform a full clone followed by a checkout.
        print("cloning %s to %s" % (repo, git_dir))

        env = os.environ.copy()
        keypath = ""
        if ssh_key:
            taskcluster_secret_url = api(
                os.environ.get("TASKCLUSTER_PROXY_URL"),
                "secrets",
                "v1",
                "secret/{keypath}".format(keypath=ssh_key),
            )
            taskcluster_secret = b"".join(stream_download(taskcluster_secret_url))
            taskcluster_secret = json.loads(taskcluster_secret)
            sshkey = taskcluster_secret["secret"]["ssh_privkey"]

            keypath = temp_dir.joinpath("ssh-key")
            keypath.write_text(sshkey)
            keypath.chmod(0o600)

            env = {
                "GIT_SSH_COMMAND": "ssh -o 'StrictHostKeyChecking no' -i {keypath}".format(
                    keypath=keypath
                )
            }

        subprocess.run(["git", "clone", "-n", repo, str(git_dir)], check=True, env=env)

        # Always use a detached head so that git prints out what it checked out.
        subprocess.run(
            ["git", "checkout", "--detach", commit], cwd=str(git_dir), check=True
        )

        # When including the .git, we want --depth 1, but a direct clone would not
        # necessarily be able to give us the right commit.
        if include_dot_git:
            initial_clone = git_dir.with_name(git_dir.name + ".orig")
            git_dir.rename(initial_clone)
            subprocess.run(
                [
                    "git",
                    "clone",
                    "file://" + str(initial_clone),
                    str(git_dir),
                    "--depth",
                    "1",
                ],
                check=True,
            )
            subprocess.run(
                ["git", "remote", "set-url", "origin", repo],
                cwd=str(git_dir),
                check=True,
            )

        # --depth 1 can induce more work on the server side, so only use it for
        # submodule initialization when we want to keep the .git directory.
        depth = ["--depth", "1"] if include_dot_git else []
        subprocess.run(
            ["git", "submodule", "update", "--init"] + depth,
            cwd=str(git_dir),
            check=True,
        )

        if keypath:
            os.remove(keypath)

        print("creating archive %s of commit %s" % (dest_path, commit))
        exclude_dot_git = [] if include_dot_git else ["--exclude=.git"]
        proc = subprocess.Popen(
            [
                "tar",
                "cf",
                "-",
            ]
            + exclude_dot_git
            + [
                "-C",
                str(temp_dir),
                prefix,
            ],
            stdout=subprocess.PIPE,
        )

        with rename_after_close(dest_path, "wb") as out:
            ctx = ZstdCompressor()
            ctx.copy_stream(proc.stdout, out)

        proc.wait()


def command_git_checkout_archive(args):
    dest = pathlib.Path(args.dest)

    try:
        git_checkout_archive(
            dest,
            args.repo,
            args.commit,
            prefix=args.path_prefix,
            ssh_key=args.ssh_key_secret,
            include_dot_git=args.include_dot_git,
        )
    except Exception:
        try:
            dest.unlink()
        except FileNotFoundError:
            pass

        raise


def command_static_url(args):
    gpg_sig_url = args.gpg_sig_url
    gpg_env_key = args.gpg_key_env

    if bool(gpg_sig_url) != bool(gpg_env_key):
        print("--gpg-sig-url and --gpg-key-env must both be defined")
        return 1

    if gpg_sig_url:
        gpg_signature = b"".join(stream_download(gpg_sig_url))
        gpg_key = os.environb[gpg_env_key.encode("ascii")]

    dest = pathlib.Path(args.dest)
    dest.parent.mkdir(parents=True, exist_ok=True)

    basename = urllib.parse.urlparse(args.url).path.split("/")[-1]
    if basename.endswith("".join(dest.suffixes)):
        dl_dest = dest
    else:
        dl_dest = dest.parent / basename

    try:
        download_to_path(args.url, dl_dest, sha256=args.sha256, size=args.size)

        if gpg_sig_url:
            gpg_verify_path(dl_dest, gpg_key, gpg_signature)

        if dl_dest != dest or args.strip_components or args.add_prefix:
            repack_archive(dl_dest, dest, args.strip_components, args.add_prefix)
    except Exception:
        try:
            dl_dest.unlink()
        except FileNotFoundError:
            pass

        raise

    if dl_dest != dest:
        log("Removing %s" % dl_dest)
        dl_dest.unlink()


def api(root_url, service, version, path):
    # taskcluster-lib-urls is not available when this script runs, so
    # simulate its behavior:
    return "{root_url}/api/{service}/{version}/{path}".format(
        root_url=root_url, service=service, version=version, path=path
    )


def get_hash(fetch, root_url):
    path = "task/{task}/artifacts/{artifact}".format(
        task=fetch["task"], artifact="public/chain-of-trust.json"
    )
    url = api(root_url, "queue", "v1", path)
    cot = json.loads(download_to_memory(url))
    return cot["artifacts"][fetch["artifact"]]["sha256"]


def command_task_artifacts(args):
    start = time.monotonic()
    fetches = json.loads(os.environ["MOZ_FETCHES"])
    downloads = []
    for fetch in fetches:
        extdir = pathlib.Path(args.dest)
        if "dest" in fetch:
            # Note: normpath doesn't like pathlib.Path in python 3.5
            extdir = pathlib.Path(os.path.normpath(str(extdir.joinpath(fetch["dest"]))))
        extdir.mkdir(parents=True, exist_ok=True)
        root_url = os.environ["TASKCLUSTER_ROOT_URL"]
        sha256 = None
        if fetch.get("verify-hash"):
            sha256 = get_hash(fetch, root_url)
        if fetch["artifact"].startswith("public/"):
            path = "task/{task}/artifacts/{artifact}".format(
                task=fetch["task"], artifact=fetch["artifact"]
            )
            url = api(root_url, "queue", "v1", path)
        else:
            url = ("{proxy_url}/api/queue/v1/task/{task}/artifacts/{artifact}").format(
                proxy_url=os.environ["TASKCLUSTER_PROXY_URL"],
                task=fetch["task"],
                artifact=fetch["artifact"],
            )
        downloads.append((url, extdir, fetch["extract"], sha256))

    fetch_urls(downloads)
    end = time.monotonic()

    perfherder_data = {
        "framework": {"name": "build_metrics"},
        "suites": [
            {
                "name": "fetch_content",
                "value": end - start,
                "lowerIsBetter": True,
                "shouldAlert": False,
                "subtests": [],
            }
        ],
    }
    print("PERFHERDER_DATA: {}".format(json.dumps(perfherder_data)), file=sys.stderr)


def main():
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(title="sub commands")

    git_checkout = subparsers.add_parser(
        "git-checkout-archive",
        help="Obtain an archive of files from a Git repository checkout",
    )
    git_checkout.set_defaults(func=command_git_checkout_archive)
    git_checkout.add_argument(
        "--path-prefix", help="Prefix for paths in produced archive"
    )
    git_checkout.add_argument("repo", help="URL to Git repository to be cloned")
    git_checkout.add_argument("commit", help="Git commit to check out")
    git_checkout.add_argument("dest", help="Destination path of archive")
    git_checkout.add_argument(
        "--ssh-key-secret", help="The scope path of the ssh key to used for checkout"
    )
    git_checkout.add_argument(
        "--include-dot-git", action="store_true", help="Include the .git directory"
    )

    url = subparsers.add_parser("static-url", help="Download a static URL")
    url.set_defaults(func=command_static_url)
    url.add_argument("--sha256", required=True, help="SHA-256 of downloaded content")
    url.add_argument(
        "--size", required=True, type=int, help="Size of downloaded content, in bytes"
    )
    url.add_argument(
        "--gpg-sig-url",
        help="URL containing signed GPG document validating " "URL to fetch",
    )
    url.add_argument(
        "--gpg-key-env", help="Environment variable containing GPG key to validate"
    )
    url.add_argument(
        "--strip-components",
        type=int,
        default=0,
        help="Number of leading components to strip from file "
        "names in the downloaded archive",
    )
    url.add_argument(
        "--add-prefix",
        default="",
        help="Prefix to add to file names in the downloaded " "archive",
    )
    url.add_argument("url", help="URL to fetch")
    url.add_argument("dest", help="Destination path")

    artifacts = subparsers.add_parser("task-artifacts", help="Fetch task artifacts")
    artifacts.set_defaults(func=command_task_artifacts)
    artifacts.add_argument(
        "-d",
        "--dest",
        default=os.environ.get("MOZ_FETCHES_DIR"),
        help="Destination directory which will contain all "
        "artifacts (defaults to $MOZ_FETCHES_DIR)",
    )

    args = parser.parse_args()

    if not args.dest:
        parser.error(
            "no destination directory specified, either pass in --dest "
            "or set $MOZ_FETCHES_DIR"
        )

    return args.func(args)


if __name__ == "__main__":
    sys.exit(main())