# This Source Code Form is subject to the terms of the Mozilla Public # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. import http.server import os import socket import socketserver import threading from pathlib import Path THREADS = [] class CustomHTTPRequestHandler(http.server.SimpleHTTPRequestHandler): hub_root = "" def translate_path(self, path): # Remove front slash and query args to match the files return str(self.hub_root / Path(path.lstrip("/").split("?")[0])) def send_head(self): path = Path(self.translate_path(self.path)) if path.is_dir(): return super().send_head() # when dealing with a file, we set the ETag header using the file size. if path.is_file(): file_size = path.stat().st_size etag = f'"{file_size}"' # Handle conditional GET requests if_match = self.headers.get("If-None-Match") if if_match == etag: self.send_response(304) self.end_headers() return None self.send_response(200) self.send_header("Content-type", self.guess_type(str(path))) self.send_header("Content-Length", str(file_size)) self.send_header("ETag", etag) self.end_headers() return path.open("rb") self.send_error(404, "File not found") def serve_directory(directory, port): """Serves the directory at the given port.""" CustomHTTPRequestHandler.hub_root = directory with socketserver.TCPServer(("", port), CustomHTTPRequestHandler) as httpd: print(f"Serving {directory} at http://localhost:{port}") httpd.serve_forever() def start_hub(root_directory): """Starts a local hub server and returns the port and thread.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) port = s.getsockname()[1] server_thread = threading.Thread( target=serve_directory, args=(root_directory, port), daemon=True ) server_thread.start() return port, server_thread def before_runs(env): """Runs before all performance tests. We grab MOZ_ML_LOCAL_DIR. If set we serve MOZ_ML_LOCAL_DIR/onnx-models as our local hub. MOZ_FETCHES_DIR is used in the CI as an alternate localtion. """ fetches_dir = os.environ.get("MOZ_ML_LOCAL_DIR") if fetches_dir is None: fetches_dir = os.environ.get("MOZ_FETCHES_DIR") if fetches_dir is None: return hub_dir = Path(fetches_dir) / "onnx-models" if not hub_dir.is_dir(): return port, server_thread = start_hub(hub_dir) os.environ["MOZ_MODELS_HUB"] = f"http://localhost:{port}" THREADS.append(server_thread) def after_runs(env): if len(THREADS) > 0: print("Shutting down") THREADS[0].join(timeout=0) THREADS.clear()