summaryrefslogtreecommitdiffstats
path: root/tools/tryselect/lando.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tools/tryselect/lando.py452
1 files changed, 452 insertions, 0 deletions
diff --git a/tools/tryselect/lando.py b/tools/tryselect/lando.py
new file mode 100644
index 0000000000..7abd2ddfae
--- /dev/null
+++ b/tools/tryselect/lando.py
@@ -0,0 +1,452 @@
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+"""Implements Auth0 Device Code flow and Lando try submission.
+
+See https://auth0.com/blog/securing-a-python-cli-application-with-auth0/ for more.
+"""
+
+from __future__ import annotations
+
+import base64
+import configparser
+import json
+import os
+import time
+import webbrowser
+from dataclasses import (
+ dataclass,
+ field,
+)
+from pathlib import Path
+from typing import (
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
+
+import requests
+from mach.util import get_state_dir
+from mozbuild.base import MozbuildObject
+from mozversioncontrol import (
+ GitRepository,
+ HgRepository,
+)
+
+from .task_config import (
+ try_config_commit,
+)
+
+TOKEN_FILE = (
+ Path(get_state_dir(specific_to_topsrcdir=False)) / "lando_auth0_user_token.json"
+)
+
+# The supported variants of `Repository` for this workflow.
+SupportedVcsRepository = Union[GitRepository, HgRepository]
+
+here = os.path.abspath(os.path.dirname(__file__))
+build = MozbuildObject.from_environment(cwd=here)
+
+
+def convert_bytes_patch_to_base64(patch_bytes: bytes) -> str:
+ """Return a base64 encoded `str` representing the passed `bytes` patch."""
+ return base64.b64encode(patch_bytes).decode("ascii")
+
+
+def load_token_from_disk() -> Optional[dict]:
+ """Load and validate an existing Auth0 token from disk.
+
+ Return the token as a `dict` if it can be validated, or return `None`
+ if any error was encountered.
+ """
+ if not TOKEN_FILE.exists():
+ print("No existing Auth0 token found.")
+ return None
+
+ try:
+ user_token = json.loads(TOKEN_FILE.read_bytes())
+ except json.JSONDecodeError:
+ print("Existing Auth0 token could not be decoded as JSON.")
+ return None
+
+ return user_token
+
+
+def get_stack_info(vcs: SupportedVcsRepository) -> Tuple[str, List[str]]:
+ """Retrieve information about the current stack for submission via Lando.
+
+ Returns a tuple of the current public base commit as a Mercurial SHA,
+ and a list of ordered base64 encoded patches.
+ """
+ base_commit = vcs.base_ref_as_hg()
+ if not base_commit:
+ raise ValueError(
+ "Could not determine base Mercurial commit hash for submission."
+ )
+ print("Using", base_commit, "as the hg base commit.")
+
+ # Reuse the base revision when on Mercurial to avoid multiple calls to `hg log`.
+ branch_nodes_kwargs = {}
+ if isinstance(vcs, HgRepository):
+ branch_nodes_kwargs["base_ref"] = base_commit
+
+ nodes = vcs.get_branch_nodes(**branch_nodes_kwargs)
+ if not nodes:
+ raise ValueError("Could not find any commit hashes for submission.")
+ elif len(nodes) == 1:
+ print("Submitting a single try config commit.")
+ elif len(nodes) == 2:
+ print("Submitting 1 node and the try commit.")
+ else:
+ print("Submitting stack of", len(nodes) - 1, "nodes and the try commit.")
+
+ patches = vcs.get_commit_patches(nodes)
+ base64_patches = [
+ convert_bytes_patch_to_base64(patch_bytes) for patch_bytes in patches
+ ]
+ print("Patches gathered for submission.")
+
+ return base_commit, base64_patches
+
+
+@dataclass
+class Auth0Config:
+ """Helper class to interact with Auth0."""
+
+ domain: str
+ client_id: str
+ audience: str
+ scope: str
+ algorithms: list[str] = field(default_factory=lambda: ["RS256"])
+
+ @property
+ def base_url(self) -> str:
+ """Auth0 base URL."""
+ return f"https://{self.domain}"
+
+ @property
+ def device_code_url(self) -> str:
+ """URL of the Device Code API endpoint."""
+ return f"{self.base_url}/oauth/device/code"
+
+ @property
+ def issuer(self) -> str:
+ """Token issuer URL."""
+ return f"{self.base_url}/"
+
+ @property
+ def jwks_url(self) -> str:
+ """URL of the JWKS file."""
+ return f"{self.base_url}/.well-known/jwks.json"
+
+ @property
+ def oauth_token_url(self) -> str:
+ """URL of the OAuth Token endpoint."""
+ return f"{self.base_url}/oauth/token"
+
+ def request_device_code(self) -> dict:
+ """Request authorization from Auth0 using the Device Code Flow.
+
+ See https://auth0.com/docs/api/authentication#get-device-code for more.
+ """
+ response = requests.post(
+ self.device_code_url,
+ headers={"Content-Type": "application/x-www-form-urlencoded"},
+ data={
+ "audience": self.audience,
+ "client_id": self.client_id,
+ "scope": self.scope,
+ },
+ )
+
+ response.raise_for_status()
+
+ return response.json()
+
+ def validate_token(self, user_token: dict) -> Optional[dict]:
+ """Verify the given user token is valid.
+
+ Validate the ID token, and validate the access token's expiration claim.
+ """
+ # Import `auth0-python` here to avoid `ImportError` in tests, since
+ # the `python-test` site won't have `auth0-python` installed.
+ import jwt
+ from auth0.authentication.token_verifier import (
+ AsymmetricSignatureVerifier,
+ TokenVerifier,
+ )
+ from auth0.exceptions import (
+ TokenValidationError,
+ )
+
+ signature_verifier = AsymmetricSignatureVerifier(self.jwks_url)
+ token_verifier = TokenVerifier(
+ audience=self.client_id,
+ issuer=self.issuer,
+ signature_verifier=signature_verifier,
+ )
+
+ try:
+ token_verifier.verify(user_token["id_token"])
+ except TokenValidationError as e:
+ print("Could not validate existing Auth0 ID token:", str(e))
+ return None
+
+ decoded_access_token = jwt.decode(
+ user_token["access_token"],
+ algorithms=self.algorithms,
+ options={"verify_signature": False},
+ )
+
+ access_token_expiration = decoded_access_token["exp"]
+
+ # Assert that the access token isn't expired or expiring within a minute.
+ if time.time() > access_token_expiration + 60:
+ print("Access token is expired.")
+ return None
+
+ user_token.update(
+ jwt.decode(
+ user_token["id_token"],
+ algorithms=self.algorithms,
+ options={"verify_signature": False},
+ )
+ )
+ print("Auth0 token validated.")
+ return user_token
+
+ def device_authorization_flow(self) -> dict:
+ """Perform the Device Authorization Flow.
+
+ See https://auth0.com/docs/get-started/authentication-and-authorization-flow/device-authorization-flow
+ for more.
+ """
+ start = time.perf_counter()
+
+ device_code_data = self.request_device_code()
+ print(
+ "1. On your computer or mobile device navigate to:",
+ device_code_data["verification_uri_complete"],
+ )
+ print("2. Enter the following code:", device_code_data["user_code"])
+
+ auth_msg = f"Auth0 token validation required at: {device_code_data['verification_uri_complete']}"
+ build.notify(auth_msg)
+
+ try:
+ webbrowser.open(device_code_data["verification_uri_complete"])
+ except webbrowser.Error:
+ print("Could not automatically open the web browser.")
+
+ device_code_lifetime_s = device_code_data["expires_in"]
+
+ # Print successive periods on the same line to avoid moving the link
+ # while the user is trying to click it.
+ print("Waiting...", end="", flush=True)
+ while time.perf_counter() - start < device_code_lifetime_s:
+ response = requests.post(
+ self.oauth_token_url,
+ data={
+ "client_id": self.client_id,
+ "device_code": device_code_data["device_code"],
+ "grant_type": "urn:ietf:params:oauth:grant-type:device_code",
+ "scope": self.scope,
+ },
+ )
+ response_data = response.json()
+
+ if response.status_code == 200:
+ print("\nLogin successful.")
+ return response_data
+
+ if response_data["error"] not in ("authorization_pending", "slow_down"):
+ raise RuntimeError(response_data["error_description"])
+
+ time.sleep(device_code_data["interval"])
+ print(".", end="", flush=True)
+
+ raise ValueError("Timed out waiting for Auth0 device code authentication!")
+
+ def get_token(self) -> dict:
+ """Retrieve an access token for authentication.
+
+ If a cached token is found and can be confirmed to be valid, return it.
+ Otherwise, perform the Device Code Flow authorization to request a new
+ token, validate it and save it to disk.
+ """
+ # Load a cached token and validate it if one is available.
+ cached_token = load_token_from_disk()
+ user_token = self.validate_token(cached_token) if cached_token else None
+
+ # Login with the Device Authorization Flow if an existing token isn't found.
+ if not user_token:
+ new_token = self.device_authorization_flow()
+ user_token = self.validate_token(new_token)
+
+ if not user_token:
+ raise ValueError("Could not get an Auth0 token.")
+
+ # Save token to disk.
+ with TOKEN_FILE.open("w") as f:
+ json.dump(user_token, f, indent=2, sort_keys=True)
+
+ return user_token
+
+
+class LandoAPIException(Exception):
+ """Raised when Lando throws an exception."""
+
+ def __init__(self, detail: Optional[str] = None):
+ super().__init__(detail or "")
+
+
+@dataclass
+class LandoAPI:
+ """Helper class to interact with Lando-API."""
+
+ access_token: str
+ api_url: str
+
+ @property
+ def lando_try_api_url(self) -> str:
+ """URL of the Lando Try endpoint."""
+ return f"https://{self.api_url}/try/patches"
+
+ @property
+ def api_headers(self) -> dict[str, str]:
+ """Headers for use accessing and authenticating against the API."""
+ return {
+ "Authorization": f"Bearer {self.access_token}",
+ "Content-Type": "application/json",
+ }
+
+ @classmethod
+ def from_lando_config_file(cls, config_path: Path, section: str) -> LandoAPI:
+ """Build a `LandoConfig` from `section` in the file at `config_path`."""
+ if not config_path.exists():
+ raise ValueError(f"Could not find a Lando config file at `{config_path}`.")
+
+ lando_ini_contents = config_path.read_text()
+
+ parser = configparser.ConfigParser(delimiters="=")
+ parser.read_string(lando_ini_contents)
+
+ if not parser.has_section(section):
+ raise ValueError(f"Lando config file does not have a {section} section.")
+
+ auth0 = Auth0Config(
+ domain=parser.get(section, "auth0_domain"),
+ client_id=parser.get(section, "auth0_client_id"),
+ audience=parser.get(section, "auth0_audience"),
+ scope=parser.get(section, "auth0_scope"),
+ )
+
+ token = auth0.get_token()
+
+ return LandoAPI(
+ api_url=parser.get(section, "api_domain"),
+ access_token=token["access_token"],
+ )
+
+ def post(self, url: str, body: dict) -> dict:
+ """Make a POST request to Lando."""
+ response = requests.post(url, headers=self.api_headers, json=body)
+
+ try:
+ response_json = response.json()
+ except json.JSONDecodeError:
+ # If the server didn't send back a valid JSON object, raise a stack
+ # trace to the terminal which includes error details.
+ response.raise_for_status()
+
+ # Raise `ValueError` if the response wasn't JSON and we didn't raise
+ # from an invalid status.
+ raise LandoAPIException(
+ detail="Response was not valid JSON yet status was valid."
+ )
+
+ if response.status_code >= 400:
+ raise LandoAPIException(detail=response_json["detail"])
+
+ return response_json
+
+ def post_try_push_patches(
+ self,
+ patches: List[str],
+ patch_format: str,
+ base_commit: str,
+ ) -> dict:
+ """Send try push contents to Lando.
+
+ Send the list of base64-encoded `patches` in `patch_format` to Lando, to be applied to
+ the Mercurial `base_commit`, using the Auth0 `access_token` for authorization.
+ """
+ request_json_body = {
+ "base_commit": base_commit,
+ "patch_format": patch_format,
+ "patches": patches,
+ }
+
+ print("Submitting patches to Lando.")
+ response_json = self.post(self.lando_try_api_url, request_json_body)
+
+ return response_json
+
+
+def push_to_lando_try(vcs: SupportedVcsRepository, commit_message: str):
+ """Push a set of patches to Lando's try endpoint."""
+ # Map `Repository` subclasses to the `patch_format` value Lando expects.
+ PATCH_FORMAT_STRING_MAPPING = {
+ GitRepository: "git-format-patch",
+ HgRepository: "hgexport",
+ }
+ patch_format = PATCH_FORMAT_STRING_MAPPING.get(type(vcs))
+ if not patch_format:
+ # Other VCS types (namely `src`) are unsupported.
+ raise ValueError(f"Try push via Lando is not supported for `{vcs.name}`.")
+
+ # Use Lando Prod unless the `LANDO_TRY_USE_DEV` environment variable is defined.
+ lando_config_section = (
+ "lando-prod" if not os.getenv("LANDO_TRY_USE_DEV") else "lando-dev"
+ )
+
+ # Load Auth0 config from `.lando.ini`.
+ lando_ini_path = Path(vcs.path) / ".lando.ini"
+ lando_api = LandoAPI.from_lando_config_file(lando_ini_path, lando_config_section)
+
+ # Get the time when the push was initiated, not including Auth0 login time.
+ push_start_time = time.perf_counter()
+
+ with try_config_commit(vcs, commit_message):
+ try:
+ base_commit, patches = get_stack_info(vcs)
+ except ValueError as exc:
+ error_msg = "abort: error gathering patches for submission."
+ print(error_msg)
+ print(str(exc))
+ build.notify(error_msg)
+ return
+
+ try:
+ # Make the try request to Lando.
+ response_json = lando_api.post_try_push_patches(
+ patches, patch_format, base_commit
+ )
+ except LandoAPIException as exc:
+ error_msg = "abort: error submitting patches to Lando."
+ print(error_msg)
+ print(str(exc))
+ build.notify(error_msg)
+ return
+
+ duration = round(time.perf_counter() - push_start_time, ndigits=2)
+
+ job_id = response_json["id"]
+ success_msg = (
+ f"Lando try submission success, took {duration} seconds. "
+ f"Landing job id: {job_id}."
+ )
+ print(success_msg)
+ build.notify(success_msg)