diff options
Diffstat (limited to 'src/arrow/python/examples/flight')
-rw-r--r-- | src/arrow/python/examples/flight/client.py | 189 | ||||
-rw-r--r-- | src/arrow/python/examples/flight/middleware.py | 167 | ||||
-rw-r--r-- | src/arrow/python/examples/flight/server.py | 154 |
3 files changed, 510 insertions, 0 deletions
diff --git a/src/arrow/python/examples/flight/client.py b/src/arrow/python/examples/flight/client.py new file mode 100644 index 000000000..ed6ce54ce --- /dev/null +++ b/src/arrow/python/examples/flight/client.py @@ -0,0 +1,189 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""An example Flight CLI client.""" + +import argparse +import sys + +import pyarrow +import pyarrow.flight +import pyarrow.csv as csv + + +def list_flights(args, client, connection_args={}): + print('Flights\n=======') + for flight in client.list_flights(): + descriptor = flight.descriptor + if descriptor.descriptor_type == pyarrow.flight.DescriptorType.PATH: + print("Path:", descriptor.path) + elif descriptor.descriptor_type == pyarrow.flight.DescriptorType.CMD: + print("Command:", descriptor.command) + else: + print("Unknown descriptor type") + + print("Total records:", end=" ") + if flight.total_records >= 0: + print(flight.total_records) + else: + print("Unknown") + + print("Total bytes:", end=" ") + if flight.total_bytes >= 0: + print(flight.total_bytes) + else: + print("Unknown") + + print("Number of endpoints:", len(flight.endpoints)) + print("Schema:") + print(flight.schema) + print('---') + + print('\nActions\n=======') + for action in client.list_actions(): + print("Type:", action.type) + print("Description:", action.description) + print('---') + + +def do_action(args, client, connection_args={}): + try: + buf = pyarrow.allocate_buffer(0) + action = pyarrow.flight.Action(args.action_type, buf) + print('Running action', args.action_type) + for result in client.do_action(action): + print("Got result", result.body.to_pybytes()) + except pyarrow.lib.ArrowIOError as e: + print("Error calling action:", e) + + +def push_data(args, client, connection_args={}): + print('File Name:', args.file) + my_table = csv.read_csv(args.file) + print('Table rows=', str(len(my_table))) + df = my_table.to_pandas() + print(df.head()) + writer, _ = client.do_put( + pyarrow.flight.FlightDescriptor.for_path(args.file), my_table.schema) + writer.write_table(my_table) + writer.close() + + +def get_flight(args, client, connection_args={}): + if args.path: + descriptor = pyarrow.flight.FlightDescriptor.for_path(*args.path) + else: + descriptor = pyarrow.flight.FlightDescriptor.for_command(args.command) + + info = client.get_flight_info(descriptor) + for endpoint in info.endpoints: + print('Ticket:', endpoint.ticket) + for location in endpoint.locations: + print(location) + get_client = pyarrow.flight.FlightClient(location, + **connection_args) + reader = get_client.do_get(endpoint.ticket) + df = reader.read_pandas() + print(df) + + +def _add_common_arguments(parser): + parser.add_argument('--tls', action='store_true', + help='Enable transport-level security') + parser.add_argument('--tls-roots', default=None, + help='Path to trusted TLS certificate(s)') + parser.add_argument("--mtls", nargs=2, default=None, + metavar=('CERTFILE', 'KEYFILE'), + help="Enable transport-level security") + parser.add_argument('host', type=str, + help="Address or hostname to connect to") + + +def main(): + parser = argparse.ArgumentParser() + subcommands = parser.add_subparsers() + + cmd_list = subcommands.add_parser('list') + cmd_list.set_defaults(action='list') + _add_common_arguments(cmd_list) + cmd_list.add_argument('-l', '--list', action='store_true', + help="Print more details.") + + cmd_do = subcommands.add_parser('do') + cmd_do.set_defaults(action='do') + _add_common_arguments(cmd_do) + cmd_do.add_argument('action_type', type=str, + help="The action type to run.") + + cmd_put = subcommands.add_parser('put') + cmd_put.set_defaults(action='put') + _add_common_arguments(cmd_put) + cmd_put.add_argument('file', type=str, + help="CSV file to upload.") + + cmd_get = subcommands.add_parser('get') + cmd_get.set_defaults(action='get') + _add_common_arguments(cmd_get) + cmd_get_descriptor = cmd_get.add_mutually_exclusive_group(required=True) + cmd_get_descriptor.add_argument('-p', '--path', type=str, action='append', + help="The path for the descriptor.") + cmd_get_descriptor.add_argument('-c', '--command', type=str, + help="The command for the descriptor.") + + args = parser.parse_args() + if not hasattr(args, 'action'): + parser.print_help() + sys.exit(1) + + commands = { + 'list': list_flights, + 'do': do_action, + 'get': get_flight, + 'put': push_data, + } + host, port = args.host.split(':') + port = int(port) + scheme = "grpc+tcp" + connection_args = {} + if args.tls: + scheme = "grpc+tls" + if args.tls_roots: + with open(args.tls_roots, "rb") as root_certs: + connection_args["tls_root_certs"] = root_certs.read() + if args.mtls: + with open(args.mtls[0], "rb") as cert_file: + tls_cert_chain = cert_file.read() + with open(args.mtls[1], "rb") as key_file: + tls_private_key = key_file.read() + connection_args["cert_chain"] = tls_cert_chain + connection_args["private_key"] = tls_private_key + client = pyarrow.flight.FlightClient(f"{scheme}://{host}:{port}", + **connection_args) + while True: + try: + action = pyarrow.flight.Action("healthcheck", b"") + options = pyarrow.flight.FlightCallOptions(timeout=1) + list(client.do_action(action, options=options)) + break + except pyarrow.ArrowIOError as e: + if "Deadline" in str(e): + print("Server is not ready, waiting...") + commands[args.action](args, client, connection_args) + + +if __name__ == '__main__': + main() diff --git a/src/arrow/python/examples/flight/middleware.py b/src/arrow/python/examples/flight/middleware.py new file mode 100644 index 000000000..2056bae1f --- /dev/null +++ b/src/arrow/python/examples/flight/middleware.py @@ -0,0 +1,167 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Example of invisibly propagating a request ID with middleware.""" + +import argparse +import sys +import threading +import uuid + +import pyarrow as pa +import pyarrow.flight as flight + + +class TraceContext: + _locals = threading.local() + _locals.trace_id = None + + @classmethod + def current_trace_id(cls): + if not getattr(cls._locals, "trace_id", None): + cls.set_trace_id(uuid.uuid4().hex) + return cls._locals.trace_id + + @classmethod + def set_trace_id(cls, trace_id): + cls._locals.trace_id = trace_id + + +TRACE_HEADER = "x-tracing-id" + + +class TracingServerMiddleware(flight.ServerMiddleware): + def __init__(self, trace_id): + self.trace_id = trace_id + + def sending_headers(self): + return { + TRACE_HEADER: self.trace_id, + } + + +class TracingServerMiddlewareFactory(flight.ServerMiddlewareFactory): + def start_call(self, info, headers): + print("Starting new call:", info) + if TRACE_HEADER in headers: + trace_id = headers[TRACE_HEADER][0] + print("Found trace header with value:", trace_id) + TraceContext.set_trace_id(trace_id) + return TracingServerMiddleware(TraceContext.current_trace_id()) + + +class TracingClientMiddleware(flight.ClientMiddleware): + def sending_headers(self): + print("Sending trace ID:", TraceContext.current_trace_id()) + return { + "x-tracing-id": TraceContext.current_trace_id(), + } + + def received_headers(self, headers): + if TRACE_HEADER in headers: + trace_id = headers[TRACE_HEADER][0] + print("Found trace header with value:", trace_id) + # Don't overwrite our trace ID + + +class TracingClientMiddlewareFactory(flight.ClientMiddlewareFactory): + def start_call(self, info): + print("Starting new call:", info) + return TracingClientMiddleware() + + +class FlightServer(flight.FlightServerBase): + def __init__(self, delegate, **kwargs): + super().__init__(**kwargs) + if delegate: + self.delegate = flight.connect( + delegate, + middleware=(TracingClientMiddlewareFactory(),)) + else: + self.delegate = None + + def list_actions(self, context): + return [ + ("get-trace-id", "Get the trace context ID."), + ] + + def do_action(self, context, action): + trace_middleware = context.get_middleware("trace") + if trace_middleware: + TraceContext.set_trace_id(trace_middleware.trace_id) + if action.type == "get-trace-id": + if self.delegate: + for result in self.delegate.do_action(action): + yield result + else: + trace_id = TraceContext.current_trace_id().encode("utf-8") + print("Returning trace ID:", trace_id) + buf = pa.py_buffer(trace_id) + yield pa.flight.Result(buf) + else: + raise KeyError(f"Unknown action {action.type!r}") + + +def main(): + parser = argparse.ArgumentParser() + + subparsers = parser.add_subparsers(dest="command") + client = subparsers.add_parser("client", help="Run the client.") + client.add_argument("server") + client.add_argument("--request-id", default=None) + + server = subparsers.add_parser("server", help="Run the server.") + server.add_argument( + "--listen", + required=True, + help="The location to listen on (example: grpc://localhost:5050)", + ) + server.add_argument( + "--delegate", + required=False, + default=None, + help=("A location to delegate to. That is, this server will " + "simply call the given server for the response. Demonstrates " + "propagation of the trace ID between servers."), + ) + + args = parser.parse_args() + if not getattr(args, "command"): + parser.print_help() + return 1 + + if args.command == "server": + server = FlightServer( + args.delegate, + location=args.listen, + middleware={"trace": TracingServerMiddlewareFactory()}) + server.serve() + elif args.command == "client": + client = flight.connect( + args.server, + middleware=(TracingClientMiddlewareFactory(),)) + if args.request_id: + TraceContext.set_trace_id(args.request_id) + else: + TraceContext.set_trace_id("client-chosen-id") + + for result in client.do_action(flight.Action("get-trace-id", b"")): + print(result.body.to_pybytes()) + + +if __name__ == "__main__": + sys.exit(main() or 0) diff --git a/src/arrow/python/examples/flight/server.py b/src/arrow/python/examples/flight/server.py new file mode 100644 index 000000000..7a6b6697e --- /dev/null +++ b/src/arrow/python/examples/flight/server.py @@ -0,0 +1,154 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""An example Flight Python server.""" + +import argparse +import ast +import threading +import time + +import pyarrow +import pyarrow.flight + + +class FlightServer(pyarrow.flight.FlightServerBase): + def __init__(self, host="localhost", location=None, + tls_certificates=None, verify_client=False, + root_certificates=None, auth_handler=None): + super(FlightServer, self).__init__( + location, auth_handler, tls_certificates, verify_client, + root_certificates) + self.flights = {} + self.host = host + self.tls_certificates = tls_certificates + + @classmethod + def descriptor_to_key(self, descriptor): + return (descriptor.descriptor_type.value, descriptor.command, + tuple(descriptor.path or tuple())) + + def _make_flight_info(self, key, descriptor, table): + if self.tls_certificates: + location = pyarrow.flight.Location.for_grpc_tls( + self.host, self.port) + else: + location = pyarrow.flight.Location.for_grpc_tcp( + self.host, self.port) + endpoints = [pyarrow.flight.FlightEndpoint(repr(key), [location]), ] + + mock_sink = pyarrow.MockOutputStream() + stream_writer = pyarrow.RecordBatchStreamWriter( + mock_sink, table.schema) + stream_writer.write_table(table) + stream_writer.close() + data_size = mock_sink.size() + + return pyarrow.flight.FlightInfo(table.schema, + descriptor, endpoints, + table.num_rows, data_size) + + def list_flights(self, context, criteria): + for key, table in self.flights.items(): + if key[1] is not None: + descriptor = \ + pyarrow.flight.FlightDescriptor.for_command(key[1]) + else: + descriptor = pyarrow.flight.FlightDescriptor.for_path(*key[2]) + + yield self._make_flight_info(key, descriptor, table) + + def get_flight_info(self, context, descriptor): + key = FlightServer.descriptor_to_key(descriptor) + if key in self.flights: + table = self.flights[key] + return self._make_flight_info(key, descriptor, table) + raise KeyError('Flight not found.') + + def do_put(self, context, descriptor, reader, writer): + key = FlightServer.descriptor_to_key(descriptor) + print(key) + self.flights[key] = reader.read_all() + print(self.flights[key]) + + def do_get(self, context, ticket): + key = ast.literal_eval(ticket.ticket.decode()) + if key not in self.flights: + return None + return pyarrow.flight.RecordBatchStream(self.flights[key]) + + def list_actions(self, context): + return [ + ("clear", "Clear the stored flights."), + ("shutdown", "Shut down this server."), + ] + + def do_action(self, context, action): + if action.type == "clear": + raise NotImplementedError( + "{} is not implemented.".format(action.type)) + elif action.type == "healthcheck": + pass + elif action.type == "shutdown": + yield pyarrow.flight.Result(pyarrow.py_buffer(b'Shutdown!')) + # Shut down on background thread to avoid blocking current + # request + threading.Thread(target=self._shutdown).start() + else: + raise KeyError("Unknown action {!r}".format(action.type)) + + def _shutdown(self): + """Shut down after a delay.""" + print("Server is shutting down...") + time.sleep(2) + self.shutdown() + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost", + help="Address or hostname to listen on") + parser.add_argument("--port", type=int, default=5005, + help="Port number to listen on") + parser.add_argument("--tls", nargs=2, default=None, + metavar=('CERTFILE', 'KEYFILE'), + help="Enable transport-level security") + parser.add_argument("--verify_client", type=bool, default=False, + help="enable mutual TLS and verify the client if True") + + args = parser.parse_args() + tls_certificates = [] + scheme = "grpc+tcp" + if args.tls: + scheme = "grpc+tls" + with open(args.tls[0], "rb") as cert_file: + tls_cert_chain = cert_file.read() + with open(args.tls[1], "rb") as key_file: + tls_private_key = key_file.read() + tls_certificates.append((tls_cert_chain, tls_private_key)) + + location = "{}://{}:{}".format(scheme, args.host, args.port) + + server = FlightServer(args.host, location, + tls_certificates=tls_certificates, + verify_client=args.verify_client) + print("Serving on", location) + server.serve() + + +if __name__ == '__main__': + main() |