# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """An example Flight CLI client.""" import argparse import sys import pyarrow import pyarrow.flight import pyarrow.csv as csv def list_flights(args, client, connection_args={}): print('Flights\n=======') for flight in client.list_flights(): descriptor = flight.descriptor if descriptor.descriptor_type == pyarrow.flight.DescriptorType.PATH: print("Path:", descriptor.path) elif descriptor.descriptor_type == pyarrow.flight.DescriptorType.CMD: print("Command:", descriptor.command) else: print("Unknown descriptor type") print("Total records:", end=" ") if flight.total_records >= 0: print(flight.total_records) else: print("Unknown") print("Total bytes:", end=" ") if flight.total_bytes >= 0: print(flight.total_bytes) else: print("Unknown") print("Number of endpoints:", len(flight.endpoints)) print("Schema:") print(flight.schema) print('---') print('\nActions\n=======') for action in client.list_actions(): print("Type:", action.type) print("Description:", action.description) print('---') def do_action(args, client, connection_args={}): try: buf = pyarrow.allocate_buffer(0) action = pyarrow.flight.Action(args.action_type, buf) print('Running action', args.action_type) for result in client.do_action(action): print("Got result", result.body.to_pybytes()) except pyarrow.lib.ArrowIOError as e: print("Error calling action:", e) def push_data(args, client, connection_args={}): print('File Name:', args.file) my_table = csv.read_csv(args.file) print('Table rows=', str(len(my_table))) df = my_table.to_pandas() print(df.head()) writer, _ = client.do_put( pyarrow.flight.FlightDescriptor.for_path(args.file), my_table.schema) writer.write_table(my_table) writer.close() def get_flight(args, client, connection_args={}): if args.path: descriptor = pyarrow.flight.FlightDescriptor.for_path(*args.path) else: descriptor = pyarrow.flight.FlightDescriptor.for_command(args.command) info = client.get_flight_info(descriptor) for endpoint in info.endpoints: print('Ticket:', endpoint.ticket) for location in endpoint.locations: print(location) get_client = pyarrow.flight.FlightClient(location, **connection_args) reader = get_client.do_get(endpoint.ticket) df = reader.read_pandas() print(df) def _add_common_arguments(parser): parser.add_argument('--tls', action='store_true', help='Enable transport-level security') parser.add_argument('--tls-roots', default=None, help='Path to trusted TLS certificate(s)') parser.add_argument("--mtls", nargs=2, default=None, metavar=('CERTFILE', 'KEYFILE'), help="Enable transport-level security") parser.add_argument('host', type=str, help="Address or hostname to connect to") def main(): parser = argparse.ArgumentParser() subcommands = parser.add_subparsers() cmd_list = subcommands.add_parser('list') cmd_list.set_defaults(action='list') _add_common_arguments(cmd_list) cmd_list.add_argument('-l', '--list', action='store_true', help="Print more details.") cmd_do = subcommands.add_parser('do') cmd_do.set_defaults(action='do') _add_common_arguments(cmd_do) cmd_do.add_argument('action_type', type=str, help="The action type to run.") cmd_put = subcommands.add_parser('put') cmd_put.set_defaults(action='put') _add_common_arguments(cmd_put) cmd_put.add_argument('file', type=str, help="CSV file to upload.") cmd_get = subcommands.add_parser('get') cmd_get.set_defaults(action='get') _add_common_arguments(cmd_get) cmd_get_descriptor = cmd_get.add_mutually_exclusive_group(required=True) cmd_get_descriptor.add_argument('-p', '--path', type=str, action='append', help="The path for the descriptor.") cmd_get_descriptor.add_argument('-c', '--command', type=str, help="The command for the descriptor.") args = parser.parse_args() if not hasattr(args, 'action'): parser.print_help() sys.exit(1) commands = { 'list': list_flights, 'do': do_action, 'get': get_flight, 'put': push_data, } host, port = args.host.split(':') port = int(port) scheme = "grpc+tcp" connection_args = {} if args.tls: scheme = "grpc+tls" if args.tls_roots: with open(args.tls_roots, "rb") as root_certs: connection_args["tls_root_certs"] = root_certs.read() if args.mtls: with open(args.mtls[0], "rb") as cert_file: tls_cert_chain = cert_file.read() with open(args.mtls[1], "rb") as key_file: tls_private_key = key_file.read() connection_args["cert_chain"] = tls_cert_chain connection_args["private_key"] = tls_private_key client = pyarrow.flight.FlightClient(f"{scheme}://{host}:{port}", **connection_args) while True: try: action = pyarrow.flight.Action("healthcheck", b"") options = pyarrow.flight.FlightCallOptions(timeout=1) list(client.do_action(action, options=options)) break except pyarrow.ArrowIOError as e: if "Deadline" in str(e): print("Server is not ready, waiting...") commands[args.action](args, client, connection_args) if __name__ == '__main__': main()