summaryrefslogtreecommitdiffstats
path: root/src/arrow/python/examples/flight/middleware.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/arrow/python/examples/flight/middleware.py')
-rw-r--r--src/arrow/python/examples/flight/middleware.py167
1 files changed, 167 insertions, 0 deletions
diff --git a/src/arrow/python/examples/flight/middleware.py b/src/arrow/python/examples/flight/middleware.py
new file mode 100644
index 000000000..2056bae1f
--- /dev/null
+++ b/src/arrow/python/examples/flight/middleware.py
@@ -0,0 +1,167 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Example of invisibly propagating a request ID with middleware."""
+
+import argparse
+import sys
+import threading
+import uuid
+
+import pyarrow as pa
+import pyarrow.flight as flight
+
+
+class TraceContext:
+ _locals = threading.local()
+ _locals.trace_id = None
+
+ @classmethod
+ def current_trace_id(cls):
+ if not getattr(cls._locals, "trace_id", None):
+ cls.set_trace_id(uuid.uuid4().hex)
+ return cls._locals.trace_id
+
+ @classmethod
+ def set_trace_id(cls, trace_id):
+ cls._locals.trace_id = trace_id
+
+
+TRACE_HEADER = "x-tracing-id"
+
+
+class TracingServerMiddleware(flight.ServerMiddleware):
+ def __init__(self, trace_id):
+ self.trace_id = trace_id
+
+ def sending_headers(self):
+ return {
+ TRACE_HEADER: self.trace_id,
+ }
+
+
+class TracingServerMiddlewareFactory(flight.ServerMiddlewareFactory):
+ def start_call(self, info, headers):
+ print("Starting new call:", info)
+ if TRACE_HEADER in headers:
+ trace_id = headers[TRACE_HEADER][0]
+ print("Found trace header with value:", trace_id)
+ TraceContext.set_trace_id(trace_id)
+ return TracingServerMiddleware(TraceContext.current_trace_id())
+
+
+class TracingClientMiddleware(flight.ClientMiddleware):
+ def sending_headers(self):
+ print("Sending trace ID:", TraceContext.current_trace_id())
+ return {
+ "x-tracing-id": TraceContext.current_trace_id(),
+ }
+
+ def received_headers(self, headers):
+ if TRACE_HEADER in headers:
+ trace_id = headers[TRACE_HEADER][0]
+ print("Found trace header with value:", trace_id)
+ # Don't overwrite our trace ID
+
+
+class TracingClientMiddlewareFactory(flight.ClientMiddlewareFactory):
+ def start_call(self, info):
+ print("Starting new call:", info)
+ return TracingClientMiddleware()
+
+
+class FlightServer(flight.FlightServerBase):
+ def __init__(self, delegate, **kwargs):
+ super().__init__(**kwargs)
+ if delegate:
+ self.delegate = flight.connect(
+ delegate,
+ middleware=(TracingClientMiddlewareFactory(),))
+ else:
+ self.delegate = None
+
+ def list_actions(self, context):
+ return [
+ ("get-trace-id", "Get the trace context ID."),
+ ]
+
+ def do_action(self, context, action):
+ trace_middleware = context.get_middleware("trace")
+ if trace_middleware:
+ TraceContext.set_trace_id(trace_middleware.trace_id)
+ if action.type == "get-trace-id":
+ if self.delegate:
+ for result in self.delegate.do_action(action):
+ yield result
+ else:
+ trace_id = TraceContext.current_trace_id().encode("utf-8")
+ print("Returning trace ID:", trace_id)
+ buf = pa.py_buffer(trace_id)
+ yield pa.flight.Result(buf)
+ else:
+ raise KeyError(f"Unknown action {action.type!r}")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+
+ subparsers = parser.add_subparsers(dest="command")
+ client = subparsers.add_parser("client", help="Run the client.")
+ client.add_argument("server")
+ client.add_argument("--request-id", default=None)
+
+ server = subparsers.add_parser("server", help="Run the server.")
+ server.add_argument(
+ "--listen",
+ required=True,
+ help="The location to listen on (example: grpc://localhost:5050)",
+ )
+ server.add_argument(
+ "--delegate",
+ required=False,
+ default=None,
+ help=("A location to delegate to. That is, this server will "
+ "simply call the given server for the response. Demonstrates "
+ "propagation of the trace ID between servers."),
+ )
+
+ args = parser.parse_args()
+ if not getattr(args, "command"):
+ parser.print_help()
+ return 1
+
+ if args.command == "server":
+ server = FlightServer(
+ args.delegate,
+ location=args.listen,
+ middleware={"trace": TracingServerMiddlewareFactory()})
+ server.serve()
+ elif args.command == "client":
+ client = flight.connect(
+ args.server,
+ middleware=(TracingClientMiddlewareFactory(),))
+ if args.request_id:
+ TraceContext.set_trace_id(args.request_id)
+ else:
+ TraceContext.set_trace_id("client-chosen-id")
+
+ for result in client.do_action(flight.Action("get-trace-id", b"")):
+ print(result.body.to_pybytes())
+
+
+if __name__ == "__main__":
+ sys.exit(main() or 0)