diff options
Diffstat (limited to 'src/arrow/python/examples/flight/middleware.py')
-rw-r--r-- | src/arrow/python/examples/flight/middleware.py | 167 |
1 files changed, 167 insertions, 0 deletions
diff --git a/src/arrow/python/examples/flight/middleware.py b/src/arrow/python/examples/flight/middleware.py new file mode 100644 index 000000000..2056bae1f --- /dev/null +++ b/src/arrow/python/examples/flight/middleware.py @@ -0,0 +1,167 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Example of invisibly propagating a request ID with middleware.""" + +import argparse +import sys +import threading +import uuid + +import pyarrow as pa +import pyarrow.flight as flight + + +class TraceContext: + _locals = threading.local() + _locals.trace_id = None + + @classmethod + def current_trace_id(cls): + if not getattr(cls._locals, "trace_id", None): + cls.set_trace_id(uuid.uuid4().hex) + return cls._locals.trace_id + + @classmethod + def set_trace_id(cls, trace_id): + cls._locals.trace_id = trace_id + + +TRACE_HEADER = "x-tracing-id" + + +class TracingServerMiddleware(flight.ServerMiddleware): + def __init__(self, trace_id): + self.trace_id = trace_id + + def sending_headers(self): + return { + TRACE_HEADER: self.trace_id, + } + + +class TracingServerMiddlewareFactory(flight.ServerMiddlewareFactory): + def start_call(self, info, headers): + print("Starting new call:", info) + if TRACE_HEADER in headers: + trace_id = headers[TRACE_HEADER][0] + print("Found trace header with value:", trace_id) + TraceContext.set_trace_id(trace_id) + return TracingServerMiddleware(TraceContext.current_trace_id()) + + +class TracingClientMiddleware(flight.ClientMiddleware): + def sending_headers(self): + print("Sending trace ID:", TraceContext.current_trace_id()) + return { + "x-tracing-id": TraceContext.current_trace_id(), + } + + def received_headers(self, headers): + if TRACE_HEADER in headers: + trace_id = headers[TRACE_HEADER][0] + print("Found trace header with value:", trace_id) + # Don't overwrite our trace ID + + +class TracingClientMiddlewareFactory(flight.ClientMiddlewareFactory): + def start_call(self, info): + print("Starting new call:", info) + return TracingClientMiddleware() + + +class FlightServer(flight.FlightServerBase): + def __init__(self, delegate, **kwargs): + super().__init__(**kwargs) + if delegate: + self.delegate = flight.connect( + delegate, + middleware=(TracingClientMiddlewareFactory(),)) + else: + self.delegate = None + + def list_actions(self, context): + return [ + ("get-trace-id", "Get the trace context ID."), + ] + + def do_action(self, context, action): + trace_middleware = context.get_middleware("trace") + if trace_middleware: + TraceContext.set_trace_id(trace_middleware.trace_id) + if action.type == "get-trace-id": + if self.delegate: + for result in self.delegate.do_action(action): + yield result + else: + trace_id = TraceContext.current_trace_id().encode("utf-8") + print("Returning trace ID:", trace_id) + buf = pa.py_buffer(trace_id) + yield pa.flight.Result(buf) + else: + raise KeyError(f"Unknown action {action.type!r}") + + +def main(): + parser = argparse.ArgumentParser() + + subparsers = parser.add_subparsers(dest="command") + client = subparsers.add_parser("client", help="Run the client.") + client.add_argument("server") + client.add_argument("--request-id", default=None) + + server = subparsers.add_parser("server", help="Run the server.") + server.add_argument( + "--listen", + required=True, + help="The location to listen on (example: grpc://localhost:5050)", + ) + server.add_argument( + "--delegate", + required=False, + default=None, + help=("A location to delegate to. That is, this server will " + "simply call the given server for the response. Demonstrates " + "propagation of the trace ID between servers."), + ) + + args = parser.parse_args() + if not getattr(args, "command"): + parser.print_help() + return 1 + + if args.command == "server": + server = FlightServer( + args.delegate, + location=args.listen, + middleware={"trace": TracingServerMiddlewareFactory()}) + server.serve() + elif args.command == "client": + client = flight.connect( + args.server, + middleware=(TracingClientMiddlewareFactory(),)) + if args.request_id: + TraceContext.set_trace_id(args.request_id) + else: + TraceContext.set_trace_id("client-chosen-id") + + for result in client.do_action(flight.Action("get-trace-id", b"")): + print(result.body.to_pybytes()) + + +if __name__ == "__main__": + sys.exit(main() or 0) |