253 lines
8.6 KiB
Python
253 lines
8.6 KiB
Python
#!/usr/bin/env python3
|
|
#
|
|
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
|
|
import argparse
|
|
import hashlib
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import urllib.request
|
|
from pathlib import Path
|
|
|
|
import yaml
|
|
|
|
HERE = Path(__file__).resolve().parent
|
|
FETCH_FILE = (
|
|
HERE / "../../../../../taskcluster/kinds/fetch/onnxruntime-web-fetch.yml"
|
|
).resolve()
|
|
|
|
|
|
def is_git_lfs_installed():
|
|
try:
|
|
output = subprocess.check_output(
|
|
["git", "lfs", "version"], stderr=subprocess.DEVNULL, text=True
|
|
)
|
|
return "git-lfs" in output.lower()
|
|
except (subprocess.CalledProcessError, FileNotFoundError):
|
|
return False
|
|
|
|
|
|
def compute_sha256(file_path):
|
|
"""Compute SHA-256 of a file (binary read)."""
|
|
hasher = hashlib.sha256()
|
|
with file_path.open("rb") as f:
|
|
for chunk in iter(lambda: f.read(4096), b""):
|
|
hasher.update(chunk)
|
|
return hasher.hexdigest()
|
|
|
|
|
|
def download_wasm(fetches, fetches_dir):
|
|
"""
|
|
Download and verify ort.jsep.wasm if needed,
|
|
using the 'ort.jsep.wasm' entry in the YAML file.
|
|
"""
|
|
wasm_fetch = fetches["ort.jsep.wasm"]["fetch"]
|
|
url = wasm_fetch["url"]
|
|
expected_sha256 = wasm_fetch["sha256"]
|
|
|
|
filename = url.split("/")[-1]
|
|
output_file = fetches_dir / filename
|
|
|
|
# If the file exists and its checksum matches, skip re-download
|
|
if output_file.exists():
|
|
print(f"Found existing file {output_file}, verifying checksum...")
|
|
if compute_sha256(output_file) == expected_sha256:
|
|
print("Existing file's checksum matches. Skipping download.")
|
|
return
|
|
else:
|
|
print("Checksum mismatch on existing file. Removing and re-downloading...")
|
|
output_file.unlink()
|
|
|
|
# Download the file
|
|
print(f"Downloading {url} to {output_file}...")
|
|
with urllib.request.urlopen(url) as response, open(output_file, "wb") as out_file:
|
|
shutil.copyfileobj(response, out_file)
|
|
|
|
# Verify SHA-256
|
|
print(f"Verifying SHA-256 of {output_file}...")
|
|
downloaded_sha256 = compute_sha256(output_file)
|
|
if downloaded_sha256 != expected_sha256:
|
|
output_file.unlink(missing_ok=True)
|
|
raise ValueError(
|
|
f"Checksum mismatch for {filename}! "
|
|
f"Expected: {expected_sha256}, got: {downloaded_sha256}"
|
|
)
|
|
|
|
print(f"File {filename} downloaded and verified successfully!")
|
|
|
|
|
|
def list_models(fetches):
|
|
"""
|
|
List all YAML keys where fetch.type == 'git',
|
|
along with the path-prefix specified in the YAML.
|
|
"""
|
|
print("Available git-based models from the YAML:\n")
|
|
for key, data in fetches.items():
|
|
fetch = data.get("fetch")
|
|
if fetch and fetch.get("type") == "git":
|
|
path_prefix = fetch.get("path-prefix", "[no path-prefix specified]")
|
|
print(f"- {key} -> path-prefix: {path_prefix}")
|
|
print("\n(Use `--model <key>` to clone one of these repositories.)")
|
|
|
|
|
|
def clone_model(key, data, fetches_dir):
|
|
"""
|
|
Clone (or re-clone) a model if needed.
|
|
|
|
The directory is determined by 'path-prefix' from the YAML,
|
|
relative to --fetches-dir. Example:
|
|
|
|
path-prefix: "onnx-models/Xenova/all-MiniLM-L6-v2/main/"
|
|
|
|
We'll end up cloning to <fetches-dir>/onnx-models/Xenova/all-MiniLM-L6-v2/main
|
|
"""
|
|
fetch_data = data["fetch"]
|
|
repo_url = fetch_data["repo"]
|
|
path_prefix = fetch_data["path-prefix"]
|
|
revision = fetch_data.get("revision", "main")
|
|
|
|
# Compute the final directory from --fetches-dir + path-prefix
|
|
repo_dir = fetches_dir / path_prefix
|
|
|
|
# Ensure parent directories exist
|
|
repo_dir.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# If the target directory exists, verify that it matches the correct repo & revision
|
|
if repo_dir.exists():
|
|
# 1. Check if .git exists
|
|
if not (repo_dir / ".git").is_dir():
|
|
print(f"Directory '{repo_dir}' exists but is not a git repo. Removing it.")
|
|
shutil.rmtree(repo_dir, ignore_errors=True)
|
|
else:
|
|
# 2. Check if remote origin URL matches
|
|
try:
|
|
existing_url = subprocess.check_output(
|
|
["git", "remote", "get-url", "origin"], cwd=repo_dir, text=True
|
|
).strip()
|
|
except subprocess.CalledProcessError:
|
|
existing_url = None
|
|
|
|
if existing_url != repo_url:
|
|
print(
|
|
f"Repository at '{repo_dir}' has remote '{existing_url}' "
|
|
f"instead of '{repo_url}'. Removing it."
|
|
)
|
|
shutil.rmtree(repo_dir, ignore_errors=True)
|
|
else:
|
|
# 3. Check if HEAD commit matches 'revision'
|
|
try:
|
|
current_revision = subprocess.check_output(
|
|
["git", "rev-parse", "HEAD"],
|
|
cwd=repo_dir,
|
|
text=True,
|
|
).strip()
|
|
except subprocess.CalledProcessError:
|
|
current_revision = None
|
|
|
|
# If the revision is a branch name or tag, matching HEAD exactly
|
|
# might not always be correct. We're keeping it simple:
|
|
# if HEAD != revision, remove & reclone.
|
|
if current_revision != revision:
|
|
print(
|
|
f"Repo at '{repo_dir}' has HEAD {current_revision}, "
|
|
f"but we need '{revision}'. Removing it."
|
|
)
|
|
shutil.rmtree(repo_dir, ignore_errors=True)
|
|
|
|
# If we removed the directory or it never existed, clone it
|
|
if not repo_dir.exists():
|
|
print(f"Cloning {repo_url} into '{repo_dir}'...")
|
|
# Normal clone first
|
|
subprocess.run(["git", "clone", repo_url, str(repo_dir)], check=True)
|
|
# Then checkout the desired revision (branch, commit, or tag)
|
|
subprocess.run(["git", "checkout", revision], cwd=repo_dir, check=True)
|
|
print(f"Checked out revision '{revision}' in '{repo_dir}'.")
|
|
else:
|
|
print(f"{repo_dir} already exists and is up to date. Skipping clone.")
|
|
|
|
|
|
def clone_models(keys, fetches, fetches_dir):
|
|
"""
|
|
Clone each model specified by YAML key, if fetch.type == 'git'.
|
|
Uses the path-prefix from the YAML to determine the final directory.
|
|
"""
|
|
if not keys:
|
|
return
|
|
|
|
# Initialize git lfs once (if we have at least one model)
|
|
subprocess.run(["git", "lfs", "install"], check=True)
|
|
|
|
for key in keys:
|
|
if key not in fetches:
|
|
raise ValueError(f"Model '{key}' not found in YAML.")
|
|
data = fetches[key]
|
|
if data.get("fetch", {}).get("type") != "git":
|
|
raise ValueError(f"Model '{key}' is not a git fetch type.")
|
|
clone_model(key, data, fetches_dir)
|
|
|
|
|
|
def main():
|
|
if not is_git_lfs_installed():
|
|
print("git lfs is required for this program to run:")
|
|
print("\t$ sudo apt install git-lfs")
|
|
print("\t$ sudo yum install git-lfs")
|
|
print("\t$ brew install git-lfs")
|
|
print()
|
|
print("\tor see https://github.com/git-lfs/git-lfs/blob/main/README.md")
|
|
sys.exit(1)
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="Download ort.jsep.wasm and optionally clone specified models."
|
|
)
|
|
|
|
default_dir = os.getenv("MOZ_ML_LOCAL_DIR", None)
|
|
|
|
parser.add_argument(
|
|
"--fetches-dir",
|
|
help="Directory to store the downloaded files (and cloned repos). Uses MOZ_FETCH_DIR if present.",
|
|
default=default_dir,
|
|
)
|
|
parser.add_argument(
|
|
"--list-models",
|
|
action="store_true",
|
|
help="List all available git-based models (keys in the YAML) and exit.",
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
action="append",
|
|
help="YAML key of a model to clone (can be specified multiple times).",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# Load YAML
|
|
with FETCH_FILE.open("r", encoding="utf-8") as f:
|
|
fetches = yaml.safe_load(f)
|
|
|
|
# If listing models, do so and exit
|
|
if args.list_models:
|
|
list_models(fetches)
|
|
return
|
|
|
|
if args.fetches_dir is None:
|
|
raise ValueError(
|
|
"Missing --fetches-dir argument or MOZ_ML_LOCAL_DIR env var. Please specify a directory to store the downloaded files"
|
|
)
|
|
|
|
fetches_dir = Path(args.fetches_dir).resolve()
|
|
fetches_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Always download/verify ort.jsep.wasm
|
|
download_wasm(fetches, fetches_dir)
|
|
|
|
# Clone requested models
|
|
if args.model:
|
|
clone_models(args.model, fetches, fetches_dir)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|