summaryrefslogtreecommitdiffstats
path: root/src/arrow/python/examples/flight/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/arrow/python/examples/flight/client.py')
-rw-r--r--src/arrow/python/examples/flight/client.py189
1 files changed, 189 insertions, 0 deletions
diff --git a/src/arrow/python/examples/flight/client.py b/src/arrow/python/examples/flight/client.py
new file mode 100644
index 000000000..ed6ce54ce
--- /dev/null
+++ b/src/arrow/python/examples/flight/client.py
@@ -0,0 +1,189 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""An example Flight CLI client."""
+
+import argparse
+import sys
+
+import pyarrow
+import pyarrow.flight
+import pyarrow.csv as csv
+
+
+def list_flights(args, client, connection_args={}):
+ print('Flights\n=======')
+ for flight in client.list_flights():
+ descriptor = flight.descriptor
+ if descriptor.descriptor_type == pyarrow.flight.DescriptorType.PATH:
+ print("Path:", descriptor.path)
+ elif descriptor.descriptor_type == pyarrow.flight.DescriptorType.CMD:
+ print("Command:", descriptor.command)
+ else:
+ print("Unknown descriptor type")
+
+ print("Total records:", end=" ")
+ if flight.total_records >= 0:
+ print(flight.total_records)
+ else:
+ print("Unknown")
+
+ print("Total bytes:", end=" ")
+ if flight.total_bytes >= 0:
+ print(flight.total_bytes)
+ else:
+ print("Unknown")
+
+ print("Number of endpoints:", len(flight.endpoints))
+ print("Schema:")
+ print(flight.schema)
+ print('---')
+
+ print('\nActions\n=======')
+ for action in client.list_actions():
+ print("Type:", action.type)
+ print("Description:", action.description)
+ print('---')
+
+
+def do_action(args, client, connection_args={}):
+ try:
+ buf = pyarrow.allocate_buffer(0)
+ action = pyarrow.flight.Action(args.action_type, buf)
+ print('Running action', args.action_type)
+ for result in client.do_action(action):
+ print("Got result", result.body.to_pybytes())
+ except pyarrow.lib.ArrowIOError as e:
+ print("Error calling action:", e)
+
+
+def push_data(args, client, connection_args={}):
+ print('File Name:', args.file)
+ my_table = csv.read_csv(args.file)
+ print('Table rows=', str(len(my_table)))
+ df = my_table.to_pandas()
+ print(df.head())
+ writer, _ = client.do_put(
+ pyarrow.flight.FlightDescriptor.for_path(args.file), my_table.schema)
+ writer.write_table(my_table)
+ writer.close()
+
+
+def get_flight(args, client, connection_args={}):
+ if args.path:
+ descriptor = pyarrow.flight.FlightDescriptor.for_path(*args.path)
+ else:
+ descriptor = pyarrow.flight.FlightDescriptor.for_command(args.command)
+
+ info = client.get_flight_info(descriptor)
+ for endpoint in info.endpoints:
+ print('Ticket:', endpoint.ticket)
+ for location in endpoint.locations:
+ print(location)
+ get_client = pyarrow.flight.FlightClient(location,
+ **connection_args)
+ reader = get_client.do_get(endpoint.ticket)
+ df = reader.read_pandas()
+ print(df)
+
+
+def _add_common_arguments(parser):
+ parser.add_argument('--tls', action='store_true',
+ help='Enable transport-level security')
+ parser.add_argument('--tls-roots', default=None,
+ help='Path to trusted TLS certificate(s)')
+ parser.add_argument("--mtls", nargs=2, default=None,
+ metavar=('CERTFILE', 'KEYFILE'),
+ help="Enable transport-level security")
+ parser.add_argument('host', type=str,
+ help="Address or hostname to connect to")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ subcommands = parser.add_subparsers()
+
+ cmd_list = subcommands.add_parser('list')
+ cmd_list.set_defaults(action='list')
+ _add_common_arguments(cmd_list)
+ cmd_list.add_argument('-l', '--list', action='store_true',
+ help="Print more details.")
+
+ cmd_do = subcommands.add_parser('do')
+ cmd_do.set_defaults(action='do')
+ _add_common_arguments(cmd_do)
+ cmd_do.add_argument('action_type', type=str,
+ help="The action type to run.")
+
+ cmd_put = subcommands.add_parser('put')
+ cmd_put.set_defaults(action='put')
+ _add_common_arguments(cmd_put)
+ cmd_put.add_argument('file', type=str,
+ help="CSV file to upload.")
+
+ cmd_get = subcommands.add_parser('get')
+ cmd_get.set_defaults(action='get')
+ _add_common_arguments(cmd_get)
+ cmd_get_descriptor = cmd_get.add_mutually_exclusive_group(required=True)
+ cmd_get_descriptor.add_argument('-p', '--path', type=str, action='append',
+ help="The path for the descriptor.")
+ cmd_get_descriptor.add_argument('-c', '--command', type=str,
+ help="The command for the descriptor.")
+
+ args = parser.parse_args()
+ if not hasattr(args, 'action'):
+ parser.print_help()
+ sys.exit(1)
+
+ commands = {
+ 'list': list_flights,
+ 'do': do_action,
+ 'get': get_flight,
+ 'put': push_data,
+ }
+ host, port = args.host.split(':')
+ port = int(port)
+ scheme = "grpc+tcp"
+ connection_args = {}
+ if args.tls:
+ scheme = "grpc+tls"
+ if args.tls_roots:
+ with open(args.tls_roots, "rb") as root_certs:
+ connection_args["tls_root_certs"] = root_certs.read()
+ if args.mtls:
+ with open(args.mtls[0], "rb") as cert_file:
+ tls_cert_chain = cert_file.read()
+ with open(args.mtls[1], "rb") as key_file:
+ tls_private_key = key_file.read()
+ connection_args["cert_chain"] = tls_cert_chain
+ connection_args["private_key"] = tls_private_key
+ client = pyarrow.flight.FlightClient(f"{scheme}://{host}:{port}",
+ **connection_args)
+ while True:
+ try:
+ action = pyarrow.flight.Action("healthcheck", b"")
+ options = pyarrow.flight.FlightCallOptions(timeout=1)
+ list(client.do_action(action, options=options))
+ break
+ except pyarrow.ArrowIOError as e:
+ if "Deadline" in str(e):
+ print("Server is not ready, waiting...")
+ commands[args.action](args, client, connection_args)
+
+
+if __name__ == '__main__':
+ main()