# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import ast import base64 import itertools import os import signal import struct import tempfile import threading import time import traceback import json import numpy as np import pytest import pyarrow as pa from pyarrow.lib import tobytes from pyarrow.util import pathlib, find_free_port from pyarrow.tests import util try: from pyarrow import flight from pyarrow.flight import ( FlightClient, FlightServerBase, ServerAuthHandler, ClientAuthHandler, ServerMiddleware, ServerMiddlewareFactory, ClientMiddleware, ClientMiddlewareFactory, ) except ImportError: flight = None FlightClient, FlightServerBase = object, object ServerAuthHandler, ClientAuthHandler = object, object ServerMiddleware, ServerMiddlewareFactory = object, object ClientMiddleware, ClientMiddlewareFactory = object, object # Marks all of the tests in this module # Ignore these with pytest ... -m 'not flight' pytestmark = pytest.mark.flight def test_import(): # So we see the ImportError somewhere import pyarrow.flight # noqa def resource_root(): """Get the path to the test resources directory.""" if not os.environ.get("ARROW_TEST_DATA"): raise RuntimeError("Test resources not found; set " "ARROW_TEST_DATA to /testing/data") return pathlib.Path(os.environ["ARROW_TEST_DATA"]) / "flight" def read_flight_resource(path): """Get the contents of a test resource file.""" root = resource_root() if not root: return None try: with (root / path).open("rb") as f: return f.read() except FileNotFoundError: raise RuntimeError( "Test resource {} not found; did you initialize the " "test resource submodule?\n{}".format(root / path, traceback.format_exc())) def example_tls_certs(): """Get the paths to test TLS certificates.""" return { "root_cert": read_flight_resource("root-ca.pem"), "certificates": [ flight.CertKeyPair( cert=read_flight_resource("cert0.pem"), key=read_flight_resource("cert0.key"), ), flight.CertKeyPair( cert=read_flight_resource("cert1.pem"), key=read_flight_resource("cert1.key"), ), ] } def simple_ints_table(): data = [ pa.array([-10, -5, 0, 5, 10]) ] return pa.Table.from_arrays(data, names=['some_ints']) def simple_dicts_table(): dict_values = pa.array(["foo", "baz", "quux"], type=pa.utf8()) data = [ pa.chunked_array([ pa.DictionaryArray.from_arrays([1, 0, None], dict_values), pa.DictionaryArray.from_arrays([2, 1], dict_values) ]) ] return pa.Table.from_arrays(data, names=['some_dicts']) class ConstantFlightServer(FlightServerBase): """A Flight server that always returns the same data. See ARROW-4796: this server implementation will segfault if Flight does not properly hold a reference to the Table object. """ CRITERIA = b"the expected criteria" def __init__(self, location=None, options=None, **kwargs): super().__init__(location, **kwargs) # Ticket -> Table self.table_factories = { b'ints': simple_ints_table, b'dicts': simple_dicts_table, } self.options = options def list_flights(self, context, criteria): if criteria == self.CRITERIA: yield flight.FlightInfo( pa.schema([]), flight.FlightDescriptor.for_path('/foo'), [], -1, -1 ) def do_get(self, context, ticket): # Return a fresh table, so that Flight is the only one keeping a # reference. table = self.table_factories[ticket.ticket]() return flight.RecordBatchStream(table, options=self.options) class MetadataFlightServer(FlightServerBase): """A Flight server that numbers incoming/outgoing data.""" def __init__(self, options=None, **kwargs): super().__init__(**kwargs) self.options = options def do_get(self, context, ticket): data = [ pa.array([-10, -5, 0, 5, 10]) ] table = pa.Table.from_arrays(data, names=['a']) return flight.GeneratorStream( table.schema, self.number_batches(table), options=self.options) def do_put(self, context, descriptor, reader, writer): counter = 0 expected_data = [-10, -5, 0, 5, 10] while True: try: batch, buf = reader.read_chunk() assert batch.equals(pa.RecordBatch.from_arrays( [pa.array([expected_data[counter]])], ['a'] )) assert buf is not None client_counter, = struct.unpack('= 0.5 def test_flight_list_flights(): """Try a simple list_flights call.""" with ConstantFlightServer() as server: client = flight.connect(('localhost', server.port)) assert list(client.list_flights()) == [] flights = client.list_flights(ConstantFlightServer.CRITERIA) assert len(list(flights)) == 1 def test_flight_do_get_ints(): """Try a simple do_get call.""" table = simple_ints_table() with ConstantFlightServer() as server: client = flight.connect(('localhost', server.port)) data = client.do_get(flight.Ticket(b'ints')).read_all() assert data.equals(table) options = pa.ipc.IpcWriteOptions( metadata_version=pa.ipc.MetadataVersion.V4) with ConstantFlightServer(options=options) as server: client = flight.connect(('localhost', server.port)) data = client.do_get(flight.Ticket(b'ints')).read_all() assert data.equals(table) # Also test via RecordBatchReader interface data = client.do_get(flight.Ticket(b'ints')).to_reader().read_all() assert data.equals(table) with pytest.raises(flight.FlightServerError, match="expected IpcWriteOptions, got "): with ConstantFlightServer(options=42) as server: client = flight.connect(('localhost', server.port)) data = client.do_get(flight.Ticket(b'ints')).read_all() @pytest.mark.pandas def test_do_get_ints_pandas(): """Try a simple do_get call.""" table = simple_ints_table() with ConstantFlightServer() as server: client = flight.connect(('localhost', server.port)) data = client.do_get(flight.Ticket(b'ints')).read_pandas() assert list(data['some_ints']) == table.column(0).to_pylist() def test_flight_do_get_dicts(): table = simple_dicts_table() with ConstantFlightServer() as server: client = flight.connect(('localhost', server.port)) data = client.do_get(flight.Ticket(b'dicts')).read_all() assert data.equals(table) def test_flight_do_get_ticket(): """Make sure Tickets get passed to the server.""" data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] table = pa.Table.from_arrays(data1, names=['a']) with CheckTicketFlightServer(expected_ticket=b'the-ticket') as server: client = flight.connect(('localhost', server.port)) data = client.do_get(flight.Ticket(b'the-ticket')).read_all() assert data.equals(table) def test_flight_get_info(): """Make sure FlightEndpoint accepts string and object URIs.""" with GetInfoFlightServer() as server: client = FlightClient(('localhost', server.port)) info = client.get_flight_info(flight.FlightDescriptor.for_command(b'')) assert info.total_records == -1 assert info.total_bytes == -1 assert info.schema == pa.schema([('a', pa.int32())]) assert len(info.endpoints) == 2 assert len(info.endpoints[0].locations) == 1 assert info.endpoints[0].locations[0] == flight.Location('grpc://test') assert info.endpoints[1].locations[0] == \ flight.Location.for_grpc_tcp('localhost', 5005) def test_flight_get_schema(): """Make sure GetSchema returns correct schema.""" with GetInfoFlightServer() as server: client = FlightClient(('localhost', server.port)) info = client.get_schema(flight.FlightDescriptor.for_command(b'')) assert info.schema == pa.schema([('a', pa.int32())]) def test_list_actions(): """Make sure the return type of ListActions is validated.""" # ARROW-6392 with ListActionsErrorFlightServer() as server: client = FlightClient(('localhost', server.port)) with pytest.raises( flight.FlightServerError, match=("Results of list_actions must be " "ActionType or tuple") ): list(client.list_actions()) with ListActionsFlightServer() as server: client = FlightClient(('localhost', server.port)) assert list(client.list_actions()) == \ ListActionsFlightServer.expected_actions() class ConvenienceServer(FlightServerBase): """ Server for testing various implementation conveniences (auto-boxing, etc.) """ @property def simple_action_results(self): return [b'foo', b'bar', b'baz'] def do_action(self, context, action): if action.type == 'simple-action': return self.simple_action_results elif action.type == 'echo': return [action.body] elif action.type == 'bad-action': return ['foo'] elif action.type == 'arrow-exception': raise pa.ArrowMemoryError() def test_do_action_result_convenience(): with ConvenienceServer() as server: client = FlightClient(('localhost', server.port)) # do_action as action type without body results = [x.body for x in client.do_action('simple-action')] assert results == server.simple_action_results # do_action with tuple of type and body body = b'the-body' results = [x.body for x in client.do_action(('echo', body))] assert results == [body] def test_nicer_server_exceptions(): with ConvenienceServer() as server: client = FlightClient(('localhost', server.port)) with pytest.raises(flight.FlightServerError, match="a bytes-like object is required"): list(client.do_action('bad-action')) # While Flight/C++ sends across the original status code, it # doesn't get mapped to the equivalent code here, since we # want to be able to distinguish between client- and server- # side errors. with pytest.raises(flight.FlightServerError, match="ArrowMemoryError"): list(client.do_action('arrow-exception')) def test_get_port(): """Make sure port() works.""" server = GetInfoFlightServer("grpc://localhost:0") try: assert server.port > 0 finally: server.shutdown() @pytest.mark.skipif(os.name == 'nt', reason="Unix sockets can't be tested on Windows") def test_flight_domain_socket(): """Try a simple do_get call over a Unix domain socket.""" with tempfile.NamedTemporaryFile() as sock: sock.close() location = flight.Location.for_grpc_unix(sock.name) with ConstantFlightServer(location=location): client = FlightClient(location) reader = client.do_get(flight.Ticket(b'ints')) table = simple_ints_table() assert reader.schema.equals(table.schema) data = reader.read_all() assert data.equals(table) reader = client.do_get(flight.Ticket(b'dicts')) table = simple_dicts_table() assert reader.schema.equals(table.schema) data = reader.read_all() assert data.equals(table) @pytest.mark.slow def test_flight_large_message(): """Try sending/receiving a large message via Flight. See ARROW-4421: by default, gRPC won't allow us to send messages > 4MiB in size. """ data = pa.Table.from_arrays([ pa.array(range(0, 10 * 1024 * 1024)) ], names=['a']) with EchoFlightServer(expected_schema=data.schema) as server: client = FlightClient(('localhost', server.port)) writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'), data.schema) # Write a single giant chunk writer.write_table(data, 10 * 1024 * 1024) writer.close() result = client.do_get(flight.Ticket(b'')).read_all() assert result.equals(data) def test_flight_generator_stream(): """Try downloading a flight of RecordBatches in a GeneratorStream.""" data = pa.Table.from_arrays([ pa.array(range(0, 10 * 1024)) ], names=['a']) with EchoStreamFlightServer() as server: client = FlightClient(('localhost', server.port)) writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'), data.schema) writer.write_table(data) writer.close() result = client.do_get(flight.Ticket(b'')).read_all() assert result.equals(data) def test_flight_invalid_generator_stream(): """Try streaming data with mismatched schemas.""" with InvalidStreamFlightServer() as server: client = FlightClient(('localhost', server.port)) with pytest.raises(pa.ArrowException): client.do_get(flight.Ticket(b'')).read_all() def test_timeout_fires(): """Make sure timeouts fire on slow requests.""" # Do this in a separate thread so that if it fails, we don't hang # the entire test process with SlowFlightServer() as server: client = FlightClient(('localhost', server.port)) action = flight.Action("", b"") options = flight.FlightCallOptions(timeout=0.2) # gRPC error messages change based on version, so don't look # for a particular error with pytest.raises(flight.FlightTimedOutError): list(client.do_action(action, options=options)) def test_timeout_passes(): """Make sure timeouts do not fire on fast requests.""" with ConstantFlightServer() as server: client = FlightClient(('localhost', server.port)) options = flight.FlightCallOptions(timeout=5.0) client.do_get(flight.Ticket(b'ints'), options=options).read_all() basic_auth_handler = HttpBasicServerAuthHandler(creds={ b"test": b"p4ssw0rd", }) token_auth_handler = TokenServerAuthHandler(creds={ b"test": b"p4ssw0rd", }) @pytest.mark.slow def test_http_basic_unauth(): """Test that auth fails when not authenticated.""" with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server: client = FlightClient(('localhost', server.port)) action = flight.Action("who-am-i", b"") with pytest.raises(flight.FlightUnauthenticatedError, match=".*unauthenticated.*"): list(client.do_action(action)) @pytest.mark.skipif(os.name == 'nt', reason="ARROW-10013: gRPC on Windows corrupts peer()") def test_http_basic_auth(): """Test a Python implementation of HTTP basic authentication.""" with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server: client = FlightClient(('localhost', server.port)) action = flight.Action("who-am-i", b"") client.authenticate(HttpBasicClientAuthHandler('test', 'p4ssw0rd')) results = client.do_action(action) identity = next(results) assert identity.body.to_pybytes() == b'test' peer_address = next(results) assert peer_address.body.to_pybytes() != b'' def test_http_basic_auth_invalid_password(): """Test that auth fails with the wrong password.""" with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server: client = FlightClient(('localhost', server.port)) action = flight.Action("who-am-i", b"") with pytest.raises(flight.FlightUnauthenticatedError, match=".*wrong password.*"): client.authenticate(HttpBasicClientAuthHandler('test', 'wrong')) next(client.do_action(action)) def test_token_auth(): """Test an auth mechanism that uses a handshake.""" with EchoStreamFlightServer(auth_handler=token_auth_handler) as server: client = FlightClient(('localhost', server.port)) action = flight.Action("who-am-i", b"") client.authenticate(TokenClientAuthHandler('test', 'p4ssw0rd')) identity = next(client.do_action(action)) assert identity.body.to_pybytes() == b'test' def test_token_auth_invalid(): """Test an auth mechanism that uses a handshake.""" with EchoStreamFlightServer(auth_handler=token_auth_handler) as server: client = FlightClient(('localhost', server.port)) with pytest.raises(flight.FlightUnauthenticatedError): client.authenticate(TokenClientAuthHandler('test', 'wrong')) header_auth_server_middleware_factory = HeaderAuthServerMiddlewareFactory() no_op_auth_handler = NoopAuthHandler() def test_authenticate_basic_token(): """Test authenticate_basic_token with bearer token and auth headers.""" with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ "auth": HeaderAuthServerMiddlewareFactory() }) as server: client = FlightClient(('localhost', server.port)) token_pair = client.authenticate_basic_token(b'test', b'password') assert token_pair[0] == b'authorization' assert token_pair[1] == b'Bearer token1234' def test_authenticate_basic_token_invalid_password(): """Test authenticate_basic_token with an invalid password.""" with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ "auth": HeaderAuthServerMiddlewareFactory() }) as server: client = FlightClient(('localhost', server.port)) with pytest.raises(flight.FlightUnauthenticatedError): client.authenticate_basic_token(b'test', b'badpassword') def test_authenticate_basic_token_and_action(): """Test authenticate_basic_token and doAction after authentication.""" with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ "auth": HeaderAuthServerMiddlewareFactory() }) as server: client = FlightClient(('localhost', server.port)) token_pair = client.authenticate_basic_token(b'test', b'password') assert token_pair[0] == b'authorization' assert token_pair[1] == b'Bearer token1234' options = flight.FlightCallOptions(headers=[token_pair]) result = list(client.do_action( action=flight.Action('test-action', b''), options=options)) assert result[0].body.to_pybytes() == b'token1234' def test_authenticate_basic_token_with_client_middleware(): """Test authenticate_basic_token with client middleware to intercept authorization header returned by the HTTP header auth enabled server. """ with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ "auth": HeaderAuthServerMiddlewareFactory() }) as server: client_auth_middleware = ClientHeaderAuthMiddlewareFactory() client = FlightClient( ('localhost', server.port), middleware=[client_auth_middleware] ) encoded_credentials = base64.b64encode(b'test:password') options = flight.FlightCallOptions(headers=[ (b'authorization', b'Basic ' + encoded_credentials) ]) result = list(client.do_action( action=flight.Action('test-action', b''), options=options)) assert result[0].body.to_pybytes() == b'token1234' assert client_auth_middleware.call_credential[0] == b'authorization' assert client_auth_middleware.call_credential[1] == \ b'Bearer ' + b'token1234' result2 = list(client.do_action( action=flight.Action('test-action', b''), options=options)) assert result2[0].body.to_pybytes() == b'token1234' assert client_auth_middleware.call_credential[0] == b'authorization' assert client_auth_middleware.call_credential[1] == \ b'Bearer ' + b'token1234' def test_arbitrary_headers_in_flight_call_options(): """Test passing multiple arbitrary headers to the middleware.""" with ArbitraryHeadersFlightServer( auth_handler=no_op_auth_handler, middleware={ "auth": HeaderAuthServerMiddlewareFactory(), "arbitrary-headers": ArbitraryHeadersServerMiddlewareFactory() }) as server: client = FlightClient(('localhost', server.port)) token_pair = client.authenticate_basic_token(b'test', b'password') assert token_pair[0] == b'authorization' assert token_pair[1] == b'Bearer token1234' options = flight.FlightCallOptions(headers=[ token_pair, (b'test-header-1', b'value1'), (b'test-header-2', b'value2') ]) result = list(client.do_action(flight.Action( "test-action", b""), options=options)) assert result[0].body.to_pybytes() == b'value1' assert result[1].body.to_pybytes() == b'value2' def test_location_invalid(): """Test constructing invalid URIs.""" with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"): flight.connect("%") with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"): ConstantFlightServer("%") def test_location_unknown_scheme(): """Test creating locations for unknown schemes.""" assert flight.Location("s3://foo").uri == b"s3://foo" assert flight.Location("https://example.com/bar.parquet").uri == \ b"https://example.com/bar.parquet" @pytest.mark.slow @pytest.mark.requires_testing_data def test_tls_fails(): """Make sure clients cannot connect when cert verification fails.""" certs = example_tls_certs() with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: # Ensure client doesn't connect when certificate verification # fails (this is a slow test since gRPC does retry a few times) client = FlightClient("grpc+tls://localhost:" + str(s.port)) # gRPC error messages change based on version, so don't look # for a particular error with pytest.raises(flight.FlightUnavailableError): client.do_get(flight.Ticket(b'ints')).read_all() @pytest.mark.requires_testing_data def test_tls_do_get(): """Try a simple do_get call over TLS.""" table = simple_ints_table() certs = example_tls_certs() with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: client = FlightClient(('localhost', s.port), tls_root_certs=certs["root_cert"]) data = client.do_get(flight.Ticket(b'ints')).read_all() assert data.equals(table) @pytest.mark.requires_testing_data def test_tls_disable_server_verification(): """Try a simple do_get call over TLS with server verification disabled.""" table = simple_ints_table() certs = example_tls_certs() with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: try: client = FlightClient(('localhost', s.port), disable_server_verification=True) except NotImplementedError: pytest.skip('disable_server_verification feature is not available') data = client.do_get(flight.Ticket(b'ints')).read_all() assert data.equals(table) @pytest.mark.requires_testing_data def test_tls_override_hostname(): """Check that incorrectly overriding the hostname fails.""" certs = example_tls_certs() with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: client = flight.connect(('localhost', s.port), tls_root_certs=certs["root_cert"], override_hostname="fakehostname") with pytest.raises(flight.FlightUnavailableError): client.do_get(flight.Ticket(b'ints')) def test_flight_do_get_metadata(): """Try a simple do_get call with metadata.""" data = [ pa.array([-10, -5, 0, 5, 10]) ] table = pa.Table.from_arrays(data, names=['a']) batches = [] with MetadataFlightServer() as server: client = FlightClient(('localhost', server.port)) reader = client.do_get(flight.Ticket(b'')) idx = 0 while True: try: batch, metadata = reader.read_chunk() batches.append(batch) server_idx, = struct.unpack('