95 lines
2.9 KiB
Python
95 lines
2.9 KiB
Python
# This Source Code Form is subject to the terms of the Mozilla Public
|
|
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
import http.server
|
|
import os
|
|
import socket
|
|
import socketserver
|
|
import threading
|
|
from pathlib import Path
|
|
|
|
THREADS = []
|
|
|
|
|
|
class CustomHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|
hub_root = ""
|
|
|
|
def translate_path(self, path):
|
|
# Remove front slash and query args to match the files
|
|
return str(self.hub_root / Path(path.lstrip("/").split("?")[0]))
|
|
|
|
def send_head(self):
|
|
path = Path(self.translate_path(self.path))
|
|
if path.is_dir():
|
|
return super().send_head()
|
|
|
|
# when dealing with a file, we set the ETag header using the file size.
|
|
if path.is_file():
|
|
file_size = path.stat().st_size
|
|
etag = f'"{file_size}"'
|
|
|
|
# Handle conditional GET requests
|
|
if_match = self.headers.get("If-None-Match")
|
|
if if_match == etag:
|
|
self.send_response(304)
|
|
self.end_headers()
|
|
return None
|
|
|
|
self.send_response(200)
|
|
self.send_header("Content-type", self.guess_type(str(path)))
|
|
self.send_header("Content-Length", str(file_size))
|
|
self.send_header("ETag", etag)
|
|
self.end_headers()
|
|
return path.open("rb")
|
|
|
|
self.send_error(404, "File not found")
|
|
|
|
|
|
def serve_directory(directory, port):
|
|
"""Serves the directory at the given port."""
|
|
CustomHTTPRequestHandler.hub_root = directory
|
|
|
|
with socketserver.TCPServer(("", port), CustomHTTPRequestHandler) as httpd:
|
|
print(f"Serving {directory} at http://localhost:{port}")
|
|
httpd.serve_forever()
|
|
|
|
|
|
def start_hub(root_directory):
|
|
"""Starts a local hub server and returns the port and thread."""
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.bind(("", 0))
|
|
port = s.getsockname()[1]
|
|
|
|
server_thread = threading.Thread(
|
|
target=serve_directory, args=(root_directory, port), daemon=True
|
|
)
|
|
server_thread.start()
|
|
return port, server_thread
|
|
|
|
|
|
def before_runs(env):
|
|
"""Runs before all performance tests.
|
|
|
|
We grab MOZ_ML_LOCAL_DIR. If set we serve MOZ_ML_LOCAL_DIR/onnx-models as our local hub.
|
|
|
|
MOZ_FETCHES_DIR is used in the CI as an alternate localtion.
|
|
"""
|
|
fetches_dir = os.environ.get("MOZ_ML_LOCAL_DIR")
|
|
if fetches_dir is None:
|
|
fetches_dir = os.environ.get("MOZ_FETCHES_DIR")
|
|
if fetches_dir is None:
|
|
return
|
|
|
|
hub_dir = Path(fetches_dir) / "onnx-models"
|
|
if not hub_dir.is_dir():
|
|
return
|
|
port, server_thread = start_hub(hub_dir)
|
|
os.environ["MOZ_MODELS_HUB"] = f"http://localhost:{port}"
|
|
THREADS.append(server_thread)
|
|
|
|
|
|
def after_runs(env):
|
|
if len(THREADS) > 0:
|
|
print("Shutting down")
|
|
THREADS[0].join(timeout=0)
|
|
THREADS.clear()
|