#!/usr/bin/env python3 # # This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. import argparse import hashlib import os import shutil import subprocess import sys import urllib.request from pathlib import Path import yaml HERE = Path(__file__).resolve().parent FETCH_FILE = ( HERE / "../../../../../taskcluster/kinds/fetch/onnxruntime-web-fetch.yml" ).resolve() def is_git_lfs_installed(): try: output = subprocess.check_output( ["git", "lfs", "version"], stderr=subprocess.DEVNULL, text=True ) return "git-lfs" in output.lower() except (subprocess.CalledProcessError, FileNotFoundError): return False def compute_sha256(file_path): """Compute SHA-256 of a file (binary read).""" hasher = hashlib.sha256() with file_path.open("rb") as f: for chunk in iter(lambda: f.read(4096), b""): hasher.update(chunk) return hasher.hexdigest() def download_wasm(fetches, fetches_dir): """ Download and verify ort.jsep.wasm if needed, using the 'ort.jsep.wasm' entry in the YAML file. """ wasm_fetch = fetches["ort.jsep.wasm"]["fetch"] url = wasm_fetch["url"] expected_sha256 = wasm_fetch["sha256"] filename = url.split("/")[-1] output_file = fetches_dir / filename # If the file exists and its checksum matches, skip re-download if output_file.exists(): print(f"Found existing file {output_file}, verifying checksum...") if compute_sha256(output_file) == expected_sha256: print("Existing file's checksum matches. Skipping download.") return else: print("Checksum mismatch on existing file. Removing and re-downloading...") output_file.unlink() # Download the file print(f"Downloading {url} to {output_file}...") with urllib.request.urlopen(url) as response, open(output_file, "wb") as out_file: shutil.copyfileobj(response, out_file) # Verify SHA-256 print(f"Verifying SHA-256 of {output_file}...") downloaded_sha256 = compute_sha256(output_file) if downloaded_sha256 != expected_sha256: output_file.unlink(missing_ok=True) raise ValueError( f"Checksum mismatch for {filename}! " f"Expected: {expected_sha256}, got: {downloaded_sha256}" ) print(f"File {filename} downloaded and verified successfully!") def list_models(fetches): """ List all YAML keys where fetch.type == 'git', along with the path-prefix specified in the YAML. """ print("Available git-based models from the YAML:\n") for key, data in fetches.items(): fetch = data.get("fetch") if fetch and fetch.get("type") == "git": path_prefix = fetch.get("path-prefix", "[no path-prefix specified]") print(f"- {key} -> path-prefix: {path_prefix}") print("\n(Use `--model ` to clone one of these repositories.)") def clone_model(key, data, fetches_dir): """ Clone (or re-clone) a model if needed. The directory is determined by 'path-prefix' from the YAML, relative to --fetches-dir. Example: path-prefix: "onnx-models/Xenova/all-MiniLM-L6-v2/main/" We'll end up cloning to /onnx-models/Xenova/all-MiniLM-L6-v2/main """ fetch_data = data["fetch"] repo_url = fetch_data["repo"] path_prefix = fetch_data["path-prefix"] revision = fetch_data.get("revision", "main") # Compute the final directory from --fetches-dir + path-prefix repo_dir = fetches_dir / path_prefix # Ensure parent directories exist repo_dir.parent.mkdir(parents=True, exist_ok=True) # If the target directory exists, verify that it matches the correct repo & revision if repo_dir.exists(): # 1. Check if .git exists if not (repo_dir / ".git").is_dir(): print(f"Directory '{repo_dir}' exists but is not a git repo. Removing it.") shutil.rmtree(repo_dir, ignore_errors=True) else: # 2. Check if remote origin URL matches try: existing_url = subprocess.check_output( ["git", "remote", "get-url", "origin"], cwd=repo_dir, text=True ).strip() except subprocess.CalledProcessError: existing_url = None if existing_url != repo_url: print( f"Repository at '{repo_dir}' has remote '{existing_url}' " f"instead of '{repo_url}'. Removing it." ) shutil.rmtree(repo_dir, ignore_errors=True) else: # 3. Check if HEAD commit matches 'revision' try: current_revision = subprocess.check_output( ["git", "rev-parse", "HEAD"], cwd=repo_dir, text=True, ).strip() except subprocess.CalledProcessError: current_revision = None # If the revision is a branch name or tag, matching HEAD exactly # might not always be correct. We're keeping it simple: # if HEAD != revision, remove & reclone. if current_revision != revision: print( f"Repo at '{repo_dir}' has HEAD {current_revision}, " f"but we need '{revision}'. Removing it." ) shutil.rmtree(repo_dir, ignore_errors=True) # If we removed the directory or it never existed, clone it if not repo_dir.exists(): print(f"Cloning {repo_url} into '{repo_dir}'...") # Normal clone first subprocess.run(["git", "clone", repo_url, str(repo_dir)], check=True) # Then checkout the desired revision (branch, commit, or tag) subprocess.run(["git", "checkout", revision], cwd=repo_dir, check=True) print(f"Checked out revision '{revision}' in '{repo_dir}'.") else: print(f"{repo_dir} already exists and is up to date. Skipping clone.") def clone_models(keys, fetches, fetches_dir): """ Clone each model specified by YAML key, if fetch.type == 'git'. Uses the path-prefix from the YAML to determine the final directory. """ if not keys: return # Initialize git lfs once (if we have at least one model) subprocess.run(["git", "lfs", "install"], check=True) for key in keys: if key not in fetches: raise ValueError(f"Model '{key}' not found in YAML.") data = fetches[key] if data.get("fetch", {}).get("type") != "git": raise ValueError(f"Model '{key}' is not a git fetch type.") clone_model(key, data, fetches_dir) def main(): if not is_git_lfs_installed(): print("git lfs is required for this program to run:") print("\t$ sudo apt install git-lfs") print("\t$ sudo yum install git-lfs") print("\t$ brew install git-lfs") print() print("\tor see https://github.com/git-lfs/git-lfs/blob/main/README.md") sys.exit(1) parser = argparse.ArgumentParser( description="Download ort.jsep.wasm and optionally clone specified models." ) default_dir = os.getenv("MOZ_ML_LOCAL_DIR", None) parser.add_argument( "--fetches-dir", help="Directory to store the downloaded files (and cloned repos). Uses MOZ_FETCH_DIR if present.", default=default_dir, ) parser.add_argument( "--list-models", action="store_true", help="List all available git-based models (keys in the YAML) and exit.", ) parser.add_argument( "--model", action="append", help="YAML key of a model to clone (can be specified multiple times).", ) args = parser.parse_args() # Load YAML with FETCH_FILE.open("r", encoding="utf-8") as f: fetches = yaml.safe_load(f) # If listing models, do so and exit if args.list_models: list_models(fetches) return if args.fetches_dir is None: raise ValueError( "Missing --fetches-dir argument or MOZ_ML_LOCAL_DIR env var. Please specify a directory to store the downloaded files" ) fetches_dir = Path(args.fetches_dir).resolve() fetches_dir.mkdir(parents=True, exist_ok=True) # Always download/verify ort.jsep.wasm download_wasm(fetches, fetches_dir) # Clone requested models if args.model: clone_models(args.model, fetches, fetches_dir) if __name__ == "__main__": main()