diff options
Diffstat (limited to 'src/arrow/python/pyarrow/tests/test_flight.py')
-rw-r--r-- | src/arrow/python/pyarrow/tests/test_flight.py | 2047 |
1 files changed, 2047 insertions, 0 deletions
diff --git a/src/arrow/python/pyarrow/tests/test_flight.py b/src/arrow/python/pyarrow/tests/test_flight.py new file mode 100644 index 000000000..5c40467a5 --- /dev/null +++ b/src/arrow/python/pyarrow/tests/test_flight.py @@ -0,0 +1,2047 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import ast +import base64 +import itertools +import os +import signal +import struct +import tempfile +import threading +import time +import traceback +import json + +import numpy as np +import pytest +import pyarrow as pa + +from pyarrow.lib import tobytes +from pyarrow.util import pathlib, find_free_port +from pyarrow.tests import util + +try: + from pyarrow import flight + from pyarrow.flight import ( + FlightClient, FlightServerBase, + ServerAuthHandler, ClientAuthHandler, + ServerMiddleware, ServerMiddlewareFactory, + ClientMiddleware, ClientMiddlewareFactory, + ) +except ImportError: + flight = None + FlightClient, FlightServerBase = object, object + ServerAuthHandler, ClientAuthHandler = object, object + ServerMiddleware, ServerMiddlewareFactory = object, object + ClientMiddleware, ClientMiddlewareFactory = object, object + +# Marks all of the tests in this module +# Ignore these with pytest ... -m 'not flight' +pytestmark = pytest.mark.flight + + +def test_import(): + # So we see the ImportError somewhere + import pyarrow.flight # noqa + + +def resource_root(): + """Get the path to the test resources directory.""" + if not os.environ.get("ARROW_TEST_DATA"): + raise RuntimeError("Test resources not found; set " + "ARROW_TEST_DATA to <repo root>/testing/data") + return pathlib.Path(os.environ["ARROW_TEST_DATA"]) / "flight" + + +def read_flight_resource(path): + """Get the contents of a test resource file.""" + root = resource_root() + if not root: + return None + try: + with (root / path).open("rb") as f: + return f.read() + except FileNotFoundError: + raise RuntimeError( + "Test resource {} not found; did you initialize the " + "test resource submodule?\n{}".format(root / path, + traceback.format_exc())) + + +def example_tls_certs(): + """Get the paths to test TLS certificates.""" + return { + "root_cert": read_flight_resource("root-ca.pem"), + "certificates": [ + flight.CertKeyPair( + cert=read_flight_resource("cert0.pem"), + key=read_flight_resource("cert0.key"), + ), + flight.CertKeyPair( + cert=read_flight_resource("cert1.pem"), + key=read_flight_resource("cert1.key"), + ), + ] + } + + +def simple_ints_table(): + data = [ + pa.array([-10, -5, 0, 5, 10]) + ] + return pa.Table.from_arrays(data, names=['some_ints']) + + +def simple_dicts_table(): + dict_values = pa.array(["foo", "baz", "quux"], type=pa.utf8()) + data = [ + pa.chunked_array([ + pa.DictionaryArray.from_arrays([1, 0, None], dict_values), + pa.DictionaryArray.from_arrays([2, 1], dict_values) + ]) + ] + return pa.Table.from_arrays(data, names=['some_dicts']) + + +class ConstantFlightServer(FlightServerBase): + """A Flight server that always returns the same data. + + See ARROW-4796: this server implementation will segfault if Flight + does not properly hold a reference to the Table object. + """ + + CRITERIA = b"the expected criteria" + + def __init__(self, location=None, options=None, **kwargs): + super().__init__(location, **kwargs) + # Ticket -> Table + self.table_factories = { + b'ints': simple_ints_table, + b'dicts': simple_dicts_table, + } + self.options = options + + def list_flights(self, context, criteria): + if criteria == self.CRITERIA: + yield flight.FlightInfo( + pa.schema([]), + flight.FlightDescriptor.for_path('/foo'), + [], + -1, -1 + ) + + def do_get(self, context, ticket): + # Return a fresh table, so that Flight is the only one keeping a + # reference. + table = self.table_factories[ticket.ticket]() + return flight.RecordBatchStream(table, options=self.options) + + +class MetadataFlightServer(FlightServerBase): + """A Flight server that numbers incoming/outgoing data.""" + + def __init__(self, options=None, **kwargs): + super().__init__(**kwargs) + self.options = options + + def do_get(self, context, ticket): + data = [ + pa.array([-10, -5, 0, 5, 10]) + ] + table = pa.Table.from_arrays(data, names=['a']) + return flight.GeneratorStream( + table.schema, + self.number_batches(table), + options=self.options) + + def do_put(self, context, descriptor, reader, writer): + counter = 0 + expected_data = [-10, -5, 0, 5, 10] + while True: + try: + batch, buf = reader.read_chunk() + assert batch.equals(pa.RecordBatch.from_arrays( + [pa.array([expected_data[counter]])], + ['a'] + )) + assert buf is not None + client_counter, = struct.unpack('<i', buf.to_pybytes()) + assert counter == client_counter + writer.write(struct.pack('<i', counter)) + counter += 1 + except StopIteration: + return + + @staticmethod + def number_batches(table): + for idx, batch in enumerate(table.to_batches()): + buf = struct.pack('<i', idx) + yield batch, buf + + +class EchoFlightServer(FlightServerBase): + """A Flight server that returns the last data uploaded.""" + + def __init__(self, location=None, expected_schema=None, **kwargs): + super().__init__(location, **kwargs) + self.last_message = None + self.expected_schema = expected_schema + + def do_get(self, context, ticket): + return flight.RecordBatchStream(self.last_message) + + def do_put(self, context, descriptor, reader, writer): + if self.expected_schema: + assert self.expected_schema == reader.schema + self.last_message = reader.read_all() + + def do_exchange(self, context, descriptor, reader, writer): + for chunk in reader: + pass + + +class EchoStreamFlightServer(EchoFlightServer): + """An echo server that streams individual record batches.""" + + def do_get(self, context, ticket): + return flight.GeneratorStream( + self.last_message.schema, + self.last_message.to_batches(max_chunksize=1024)) + + def list_actions(self, context): + return [] + + def do_action(self, context, action): + if action.type == "who-am-i": + return [context.peer_identity(), context.peer().encode("utf-8")] + raise NotImplementedError + + +class GetInfoFlightServer(FlightServerBase): + """A Flight server that tests GetFlightInfo.""" + + def get_flight_info(self, context, descriptor): + return flight.FlightInfo( + pa.schema([('a', pa.int32())]), + descriptor, + [ + flight.FlightEndpoint(b'', ['grpc://test']), + flight.FlightEndpoint( + b'', + [flight.Location.for_grpc_tcp('localhost', 5005)], + ), + ], + -1, + -1, + ) + + def get_schema(self, context, descriptor): + info = self.get_flight_info(context, descriptor) + return flight.SchemaResult(info.schema) + + +class ListActionsFlightServer(FlightServerBase): + """A Flight server that tests ListActions.""" + + @classmethod + def expected_actions(cls): + return [ + ("action-1", "description"), + ("action-2", ""), + flight.ActionType("action-3", "more detail"), + ] + + def list_actions(self, context): + yield from self.expected_actions() + + +class ListActionsErrorFlightServer(FlightServerBase): + """A Flight server that tests ListActions.""" + + def list_actions(self, context): + yield ("action-1", "") + yield "foo" + + +class CheckTicketFlightServer(FlightServerBase): + """A Flight server that compares the given ticket to an expected value.""" + + def __init__(self, expected_ticket, location=None, **kwargs): + super().__init__(location, **kwargs) + self.expected_ticket = expected_ticket + + def do_get(self, context, ticket): + assert self.expected_ticket == ticket.ticket + data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] + table = pa.Table.from_arrays(data1, names=['a']) + return flight.RecordBatchStream(table) + + def do_put(self, context, descriptor, reader): + self.last_message = reader.read_all() + + +class InvalidStreamFlightServer(FlightServerBase): + """A Flight server that tries to return messages with differing schemas.""" + + schema = pa.schema([('a', pa.int32())]) + + def do_get(self, context, ticket): + data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] + data2 = [pa.array([-10.0, -5.0, 0.0, 5.0, 10.0], type=pa.float64())] + assert data1.type != data2.type + table1 = pa.Table.from_arrays(data1, names=['a']) + table2 = pa.Table.from_arrays(data2, names=['a']) + assert table1.schema == self.schema + + return flight.GeneratorStream(self.schema, [table1, table2]) + + +class NeverSendsDataFlightServer(FlightServerBase): + """A Flight server that never actually yields data.""" + + schema = pa.schema([('a', pa.int32())]) + + def do_get(self, context, ticket): + if ticket.ticket == b'yield_data': + # Check that the server handler will ignore empty tables + # up to a certain extent + data = [ + self.schema.empty_table(), + self.schema.empty_table(), + pa.RecordBatch.from_arrays([range(5)], schema=self.schema), + ] + return flight.GeneratorStream(self.schema, data) + return flight.GeneratorStream( + self.schema, itertools.repeat(self.schema.empty_table())) + + +class SlowFlightServer(FlightServerBase): + """A Flight server that delays its responses to test timeouts.""" + + def do_get(self, context, ticket): + return flight.GeneratorStream(pa.schema([('a', pa.int32())]), + self.slow_stream()) + + def do_action(self, context, action): + time.sleep(0.5) + return [] + + @staticmethod + def slow_stream(): + data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] + yield pa.Table.from_arrays(data1, names=['a']) + # The second message should never get sent; the client should + # cancel before we send this + time.sleep(10) + yield pa.Table.from_arrays(data1, names=['a']) + + +class ErrorFlightServer(FlightServerBase): + """A Flight server that uses all the Flight-specific errors.""" + + def do_action(self, context, action): + if action.type == "internal": + raise flight.FlightInternalError("foo") + elif action.type == "timedout": + raise flight.FlightTimedOutError("foo") + elif action.type == "cancel": + raise flight.FlightCancelledError("foo") + elif action.type == "unauthenticated": + raise flight.FlightUnauthenticatedError("foo") + elif action.type == "unauthorized": + raise flight.FlightUnauthorizedError("foo") + elif action.type == "protobuf": + err_msg = b'this is an error message' + raise flight.FlightUnauthorizedError("foo", err_msg) + raise NotImplementedError + + def list_flights(self, context, criteria): + yield flight.FlightInfo( + pa.schema([]), + flight.FlightDescriptor.for_path('/foo'), + [], + -1, -1 + ) + raise flight.FlightInternalError("foo") + + def do_put(self, context, descriptor, reader, writer): + if descriptor.command == b"internal": + raise flight.FlightInternalError("foo") + elif descriptor.command == b"timedout": + raise flight.FlightTimedOutError("foo") + elif descriptor.command == b"cancel": + raise flight.FlightCancelledError("foo") + elif descriptor.command == b"unauthenticated": + raise flight.FlightUnauthenticatedError("foo") + elif descriptor.command == b"unauthorized": + raise flight.FlightUnauthorizedError("foo") + elif descriptor.command == b"protobuf": + err_msg = b'this is an error message' + raise flight.FlightUnauthorizedError("foo", err_msg) + + +class ExchangeFlightServer(FlightServerBase): + """A server for testing DoExchange.""" + + def __init__(self, options=None, **kwargs): + super().__init__(**kwargs) + self.options = options + + def do_exchange(self, context, descriptor, reader, writer): + if descriptor.descriptor_type != flight.DescriptorType.CMD: + raise pa.ArrowInvalid("Must provide a command descriptor") + elif descriptor.command == b"echo": + return self.exchange_echo(context, reader, writer) + elif descriptor.command == b"get": + return self.exchange_do_get(context, reader, writer) + elif descriptor.command == b"put": + return self.exchange_do_put(context, reader, writer) + elif descriptor.command == b"transform": + return self.exchange_transform(context, reader, writer) + else: + raise pa.ArrowInvalid( + "Unknown command: {}".format(descriptor.command)) + + def exchange_do_get(self, context, reader, writer): + """Emulate DoGet with DoExchange.""" + data = pa.Table.from_arrays([ + pa.array(range(0, 10 * 1024)) + ], names=["a"]) + writer.begin(data.schema) + writer.write_table(data) + + def exchange_do_put(self, context, reader, writer): + """Emulate DoPut with DoExchange.""" + num_batches = 0 + for chunk in reader: + if not chunk.data: + raise pa.ArrowInvalid("All chunks must have data.") + num_batches += 1 + writer.write_metadata(str(num_batches).encode("utf-8")) + + def exchange_echo(self, context, reader, writer): + """Run a simple echo server.""" + started = False + for chunk in reader: + if not started and chunk.data: + writer.begin(chunk.data.schema, options=self.options) + started = True + if chunk.app_metadata and chunk.data: + writer.write_with_metadata(chunk.data, chunk.app_metadata) + elif chunk.app_metadata: + writer.write_metadata(chunk.app_metadata) + elif chunk.data: + writer.write_batch(chunk.data) + else: + assert False, "Should not happen" + + def exchange_transform(self, context, reader, writer): + """Sum rows in an uploaded table.""" + for field in reader.schema: + if not pa.types.is_integer(field.type): + raise pa.ArrowInvalid("Invalid field: " + repr(field)) + table = reader.read_all() + sums = [0] * table.num_rows + for column in table: + for row, value in enumerate(column): + sums[row] += value.as_py() + result = pa.Table.from_arrays([pa.array(sums)], names=["sum"]) + writer.begin(result.schema) + writer.write_table(result) + + +class HttpBasicServerAuthHandler(ServerAuthHandler): + """An example implementation of HTTP basic authentication.""" + + def __init__(self, creds): + super().__init__() + self.creds = creds + + def authenticate(self, outgoing, incoming): + buf = incoming.read() + auth = flight.BasicAuth.deserialize(buf) + if auth.username not in self.creds: + raise flight.FlightUnauthenticatedError("unknown user") + if self.creds[auth.username] != auth.password: + raise flight.FlightUnauthenticatedError("wrong password") + outgoing.write(tobytes(auth.username)) + + def is_valid(self, token): + if not token: + raise flight.FlightUnauthenticatedError("token not provided") + if token not in self.creds: + raise flight.FlightUnauthenticatedError("unknown user") + return token + + +class HttpBasicClientAuthHandler(ClientAuthHandler): + """An example implementation of HTTP basic authentication.""" + + def __init__(self, username, password): + super().__init__() + self.basic_auth = flight.BasicAuth(username, password) + self.token = None + + def authenticate(self, outgoing, incoming): + auth = self.basic_auth.serialize() + outgoing.write(auth) + self.token = incoming.read() + + def get_token(self): + return self.token + + +class TokenServerAuthHandler(ServerAuthHandler): + """An example implementation of authentication via handshake.""" + + def __init__(self, creds): + super().__init__() + self.creds = creds + + def authenticate(self, outgoing, incoming): + username = incoming.read() + password = incoming.read() + if username in self.creds and self.creds[username] == password: + outgoing.write(base64.b64encode(b'secret:' + username)) + else: + raise flight.FlightUnauthenticatedError( + "invalid username/password") + + def is_valid(self, token): + token = base64.b64decode(token) + if not token.startswith(b'secret:'): + raise flight.FlightUnauthenticatedError("invalid token") + return token[7:] + + +class TokenClientAuthHandler(ClientAuthHandler): + """An example implementation of authentication via handshake.""" + + def __init__(self, username, password): + super().__init__() + self.username = username + self.password = password + self.token = b'' + + def authenticate(self, outgoing, incoming): + outgoing.write(self.username) + outgoing.write(self.password) + self.token = incoming.read() + + def get_token(self): + return self.token + + +class NoopAuthHandler(ServerAuthHandler): + """A no-op auth handler.""" + + def authenticate(self, outgoing, incoming): + """Do nothing.""" + + def is_valid(self, token): + """ + Returning an empty string. + Returning None causes Type error. + """ + return "" + + +def case_insensitive_header_lookup(headers, lookup_key): + """Lookup the value of given key in the given headers. + The key lookup is case insensitive. + """ + for key in headers: + if key.lower() == lookup_key.lower(): + return headers.get(key) + + +class ClientHeaderAuthMiddlewareFactory(ClientMiddlewareFactory): + """ClientMiddlewareFactory that creates ClientAuthHeaderMiddleware.""" + + def __init__(self): + self.call_credential = [] + + def start_call(self, info): + return ClientHeaderAuthMiddleware(self) + + def set_call_credential(self, call_credential): + self.call_credential = call_credential + + +class ClientHeaderAuthMiddleware(ClientMiddleware): + """ + ClientMiddleware that extracts the authorization header + from the server. + + This is an example of a ClientMiddleware that can extract + the bearer token authorization header from a HTTP header + authentication enabled server. + + Parameters + ---------- + factory : ClientHeaderAuthMiddlewareFactory + This factory is used to set call credentials if an + authorization header is found in the headers from the server. + """ + + def __init__(self, factory): + self.factory = factory + + def received_headers(self, headers): + auth_header = case_insensitive_header_lookup(headers, 'Authorization') + self.factory.set_call_credential([ + b'authorization', + auth_header[0].encode("utf-8")]) + + +class HeaderAuthServerMiddlewareFactory(ServerMiddlewareFactory): + """Validates incoming username and password.""" + + def start_call(self, info, headers): + auth_header = case_insensitive_header_lookup( + headers, + 'Authorization' + ) + values = auth_header[0].split(' ') + token = '' + error_message = 'Invalid credentials' + + if values[0] == 'Basic': + decoded = base64.b64decode(values[1]) + pair = decoded.decode("utf-8").split(':') + if not (pair[0] == 'test' and pair[1] == 'password'): + raise flight.FlightUnauthenticatedError(error_message) + token = 'token1234' + elif values[0] == 'Bearer': + token = values[1] + if not token == 'token1234': + raise flight.FlightUnauthenticatedError(error_message) + else: + raise flight.FlightUnauthenticatedError(error_message) + + return HeaderAuthServerMiddleware(token) + + +class HeaderAuthServerMiddleware(ServerMiddleware): + """A ServerMiddleware that transports incoming username and passowrd.""" + + def __init__(self, token): + self.token = token + + def sending_headers(self): + return {'authorization': 'Bearer ' + self.token} + + +class HeaderAuthFlightServer(FlightServerBase): + """A Flight server that tests with basic token authentication. """ + + def do_action(self, context, action): + middleware = context.get_middleware("auth") + if middleware: + auth_header = case_insensitive_header_lookup( + middleware.sending_headers(), 'Authorization') + values = auth_header.split(' ') + return [values[1].encode("utf-8")] + raise flight.FlightUnauthenticatedError( + 'No token auth middleware found.') + + +class ArbitraryHeadersServerMiddlewareFactory(ServerMiddlewareFactory): + """A ServerMiddlewareFactory that transports arbitrary headers.""" + + def start_call(self, info, headers): + return ArbitraryHeadersServerMiddleware(headers) + + +class ArbitraryHeadersServerMiddleware(ServerMiddleware): + """A ServerMiddleware that transports arbitrary headers.""" + + def __init__(self, incoming): + self.incoming = incoming + + def sending_headers(self): + return self.incoming + + +class ArbitraryHeadersFlightServer(FlightServerBase): + """A Flight server that tests multiple arbitrary headers.""" + + def do_action(self, context, action): + middleware = context.get_middleware("arbitrary-headers") + if middleware: + headers = middleware.sending_headers() + header_1 = case_insensitive_header_lookup( + headers, + 'test-header-1' + ) + header_2 = case_insensitive_header_lookup( + headers, + 'test-header-2' + ) + value1 = header_1[0].encode("utf-8") + value2 = header_2[0].encode("utf-8") + return [value1, value2] + raise flight.FlightServerError("No headers middleware found") + + +class HeaderServerMiddleware(ServerMiddleware): + """Expose a per-call value to the RPC method body.""" + + def __init__(self, special_value): + self.special_value = special_value + + +class HeaderServerMiddlewareFactory(ServerMiddlewareFactory): + """Expose a per-call hard-coded value to the RPC method body.""" + + def start_call(self, info, headers): + return HeaderServerMiddleware("right value") + + +class HeaderFlightServer(FlightServerBase): + """Echo back the per-call hard-coded value.""" + + def do_action(self, context, action): + middleware = context.get_middleware("test") + if middleware: + return [middleware.special_value.encode()] + return [b""] + + +class MultiHeaderFlightServer(FlightServerBase): + """Test sending/receiving multiple (binary-valued) headers.""" + + def do_action(self, context, action): + middleware = context.get_middleware("test") + headers = repr(middleware.client_headers).encode("utf-8") + return [headers] + + +class SelectiveAuthServerMiddlewareFactory(ServerMiddlewareFactory): + """Deny access to certain methods based on a header.""" + + def start_call(self, info, headers): + if info.method == flight.FlightMethod.LIST_ACTIONS: + # No auth needed + return + + token = headers.get("x-auth-token") + if not token: + raise flight.FlightUnauthenticatedError("No token") + + token = token[0] + if token != "password": + raise flight.FlightUnauthenticatedError("Invalid token") + + return HeaderServerMiddleware(token) + + +class SelectiveAuthClientMiddlewareFactory(ClientMiddlewareFactory): + def start_call(self, info): + return SelectiveAuthClientMiddleware() + + +class SelectiveAuthClientMiddleware(ClientMiddleware): + def sending_headers(self): + return { + "x-auth-token": "password", + } + + +class RecordingServerMiddlewareFactory(ServerMiddlewareFactory): + """Record what methods were called.""" + + def __init__(self): + super().__init__() + self.methods = [] + + def start_call(self, info, headers): + self.methods.append(info.method) + return None + + +class RecordingClientMiddlewareFactory(ClientMiddlewareFactory): + """Record what methods were called.""" + + def __init__(self): + super().__init__() + self.methods = [] + + def start_call(self, info): + self.methods.append(info.method) + return None + + +class MultiHeaderClientMiddlewareFactory(ClientMiddlewareFactory): + """Test sending/receiving multiple (binary-valued) headers.""" + + def __init__(self): + # Read in test_middleware_multi_header below. + # The middleware instance will update this value. + self.last_headers = {} + + def start_call(self, info): + return MultiHeaderClientMiddleware(self) + + +class MultiHeaderClientMiddleware(ClientMiddleware): + """Test sending/receiving multiple (binary-valued) headers.""" + + EXPECTED = { + "x-text": ["foo", "bar"], + "x-binary-bin": [b"\x00", b"\x01"], + } + + def __init__(self, factory): + self.factory = factory + + def sending_headers(self): + return self.EXPECTED + + def received_headers(self, headers): + # Let the test code know what the last set of headers we + # received were. + self.factory.last_headers = headers + + +class MultiHeaderServerMiddlewareFactory(ServerMiddlewareFactory): + """Test sending/receiving multiple (binary-valued) headers.""" + + def start_call(self, info, headers): + return MultiHeaderServerMiddleware(headers) + + +class MultiHeaderServerMiddleware(ServerMiddleware): + """Test sending/receiving multiple (binary-valued) headers.""" + + def __init__(self, client_headers): + self.client_headers = client_headers + + def sending_headers(self): + return MultiHeaderClientMiddleware.EXPECTED + + +class LargeMetadataFlightServer(FlightServerBase): + """Regression test for ARROW-13253.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._metadata = b' ' * (2 ** 31 + 1) + + def do_get(self, context, ticket): + schema = pa.schema([('a', pa.int64())]) + return flight.GeneratorStream(schema, [ + (pa.record_batch([[1]], schema=schema), self._metadata), + ]) + + def do_exchange(self, context, descriptor, reader, writer): + writer.write_metadata(self._metadata) + + +def test_flight_server_location_argument(): + locations = [ + None, + 'grpc://localhost:0', + ('localhost', find_free_port()), + ] + for location in locations: + with FlightServerBase(location) as server: + assert isinstance(server, FlightServerBase) + + +def test_server_exit_reraises_exception(): + with pytest.raises(ValueError): + with FlightServerBase(): + raise ValueError() + + +@pytest.mark.slow +def test_client_wait_for_available(): + location = ('localhost', find_free_port()) + server = None + + def serve(): + global server + time.sleep(0.5) + server = FlightServerBase(location) + server.serve() + + client = FlightClient(location) + thread = threading.Thread(target=serve, daemon=True) + thread.start() + + started = time.time() + client.wait_for_available(timeout=5) + elapsed = time.time() - started + assert elapsed >= 0.5 + + +def test_flight_list_flights(): + """Try a simple list_flights call.""" + with ConstantFlightServer() as server: + client = flight.connect(('localhost', server.port)) + assert list(client.list_flights()) == [] + flights = client.list_flights(ConstantFlightServer.CRITERIA) + assert len(list(flights)) == 1 + + +def test_flight_do_get_ints(): + """Try a simple do_get call.""" + table = simple_ints_table() + + with ConstantFlightServer() as server: + client = flight.connect(('localhost', server.port)) + data = client.do_get(flight.Ticket(b'ints')).read_all() + assert data.equals(table) + + options = pa.ipc.IpcWriteOptions( + metadata_version=pa.ipc.MetadataVersion.V4) + with ConstantFlightServer(options=options) as server: + client = flight.connect(('localhost', server.port)) + data = client.do_get(flight.Ticket(b'ints')).read_all() + assert data.equals(table) + + # Also test via RecordBatchReader interface + data = client.do_get(flight.Ticket(b'ints')).to_reader().read_all() + assert data.equals(table) + + with pytest.raises(flight.FlightServerError, + match="expected IpcWriteOptions, got <class 'int'>"): + with ConstantFlightServer(options=42) as server: + client = flight.connect(('localhost', server.port)) + data = client.do_get(flight.Ticket(b'ints')).read_all() + + +@pytest.mark.pandas +def test_do_get_ints_pandas(): + """Try a simple do_get call.""" + table = simple_ints_table() + + with ConstantFlightServer() as server: + client = flight.connect(('localhost', server.port)) + data = client.do_get(flight.Ticket(b'ints')).read_pandas() + assert list(data['some_ints']) == table.column(0).to_pylist() + + +def test_flight_do_get_dicts(): + table = simple_dicts_table() + + with ConstantFlightServer() as server: + client = flight.connect(('localhost', server.port)) + data = client.do_get(flight.Ticket(b'dicts')).read_all() + assert data.equals(table) + + +def test_flight_do_get_ticket(): + """Make sure Tickets get passed to the server.""" + data1 = [pa.array([-10, -5, 0, 5, 10], type=pa.int32())] + table = pa.Table.from_arrays(data1, names=['a']) + with CheckTicketFlightServer(expected_ticket=b'the-ticket') as server: + client = flight.connect(('localhost', server.port)) + data = client.do_get(flight.Ticket(b'the-ticket')).read_all() + assert data.equals(table) + + +def test_flight_get_info(): + """Make sure FlightEndpoint accepts string and object URIs.""" + with GetInfoFlightServer() as server: + client = FlightClient(('localhost', server.port)) + info = client.get_flight_info(flight.FlightDescriptor.for_command(b'')) + assert info.total_records == -1 + assert info.total_bytes == -1 + assert info.schema == pa.schema([('a', pa.int32())]) + assert len(info.endpoints) == 2 + assert len(info.endpoints[0].locations) == 1 + assert info.endpoints[0].locations[0] == flight.Location('grpc://test') + assert info.endpoints[1].locations[0] == \ + flight.Location.for_grpc_tcp('localhost', 5005) + + +def test_flight_get_schema(): + """Make sure GetSchema returns correct schema.""" + with GetInfoFlightServer() as server: + client = FlightClient(('localhost', server.port)) + info = client.get_schema(flight.FlightDescriptor.for_command(b'')) + assert info.schema == pa.schema([('a', pa.int32())]) + + +def test_list_actions(): + """Make sure the return type of ListActions is validated.""" + # ARROW-6392 + with ListActionsErrorFlightServer() as server: + client = FlightClient(('localhost', server.port)) + with pytest.raises( + flight.FlightServerError, + match=("Results of list_actions must be " + "ActionType or tuple") + ): + list(client.list_actions()) + + with ListActionsFlightServer() as server: + client = FlightClient(('localhost', server.port)) + assert list(client.list_actions()) == \ + ListActionsFlightServer.expected_actions() + + +class ConvenienceServer(FlightServerBase): + """ + Server for testing various implementation conveniences (auto-boxing, etc.) + """ + + @property + def simple_action_results(self): + return [b'foo', b'bar', b'baz'] + + def do_action(self, context, action): + if action.type == 'simple-action': + return self.simple_action_results + elif action.type == 'echo': + return [action.body] + elif action.type == 'bad-action': + return ['foo'] + elif action.type == 'arrow-exception': + raise pa.ArrowMemoryError() + + +def test_do_action_result_convenience(): + with ConvenienceServer() as server: + client = FlightClient(('localhost', server.port)) + + # do_action as action type without body + results = [x.body for x in client.do_action('simple-action')] + assert results == server.simple_action_results + + # do_action with tuple of type and body + body = b'the-body' + results = [x.body for x in client.do_action(('echo', body))] + assert results == [body] + + +def test_nicer_server_exceptions(): + with ConvenienceServer() as server: + client = FlightClient(('localhost', server.port)) + with pytest.raises(flight.FlightServerError, + match="a bytes-like object is required"): + list(client.do_action('bad-action')) + # While Flight/C++ sends across the original status code, it + # doesn't get mapped to the equivalent code here, since we + # want to be able to distinguish between client- and server- + # side errors. + with pytest.raises(flight.FlightServerError, + match="ArrowMemoryError"): + list(client.do_action('arrow-exception')) + + +def test_get_port(): + """Make sure port() works.""" + server = GetInfoFlightServer("grpc://localhost:0") + try: + assert server.port > 0 + finally: + server.shutdown() + + +@pytest.mark.skipif(os.name == 'nt', + reason="Unix sockets can't be tested on Windows") +def test_flight_domain_socket(): + """Try a simple do_get call over a Unix domain socket.""" + with tempfile.NamedTemporaryFile() as sock: + sock.close() + location = flight.Location.for_grpc_unix(sock.name) + with ConstantFlightServer(location=location): + client = FlightClient(location) + + reader = client.do_get(flight.Ticket(b'ints')) + table = simple_ints_table() + assert reader.schema.equals(table.schema) + data = reader.read_all() + assert data.equals(table) + + reader = client.do_get(flight.Ticket(b'dicts')) + table = simple_dicts_table() + assert reader.schema.equals(table.schema) + data = reader.read_all() + assert data.equals(table) + + +@pytest.mark.slow +def test_flight_large_message(): + """Try sending/receiving a large message via Flight. + + See ARROW-4421: by default, gRPC won't allow us to send messages > + 4MiB in size. + """ + data = pa.Table.from_arrays([ + pa.array(range(0, 10 * 1024 * 1024)) + ], names=['a']) + + with EchoFlightServer(expected_schema=data.schema) as server: + client = FlightClient(('localhost', server.port)) + writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'), + data.schema) + # Write a single giant chunk + writer.write_table(data, 10 * 1024 * 1024) + writer.close() + result = client.do_get(flight.Ticket(b'')).read_all() + assert result.equals(data) + + +def test_flight_generator_stream(): + """Try downloading a flight of RecordBatches in a GeneratorStream.""" + data = pa.Table.from_arrays([ + pa.array(range(0, 10 * 1024)) + ], names=['a']) + + with EchoStreamFlightServer() as server: + client = FlightClient(('localhost', server.port)) + writer, _ = client.do_put(flight.FlightDescriptor.for_path('test'), + data.schema) + writer.write_table(data) + writer.close() + result = client.do_get(flight.Ticket(b'')).read_all() + assert result.equals(data) + + +def test_flight_invalid_generator_stream(): + """Try streaming data with mismatched schemas.""" + with InvalidStreamFlightServer() as server: + client = FlightClient(('localhost', server.port)) + with pytest.raises(pa.ArrowException): + client.do_get(flight.Ticket(b'')).read_all() + + +def test_timeout_fires(): + """Make sure timeouts fire on slow requests.""" + # Do this in a separate thread so that if it fails, we don't hang + # the entire test process + with SlowFlightServer() as server: + client = FlightClient(('localhost', server.port)) + action = flight.Action("", b"") + options = flight.FlightCallOptions(timeout=0.2) + # gRPC error messages change based on version, so don't look + # for a particular error + with pytest.raises(flight.FlightTimedOutError): + list(client.do_action(action, options=options)) + + +def test_timeout_passes(): + """Make sure timeouts do not fire on fast requests.""" + with ConstantFlightServer() as server: + client = FlightClient(('localhost', server.port)) + options = flight.FlightCallOptions(timeout=5.0) + client.do_get(flight.Ticket(b'ints'), options=options).read_all() + + +basic_auth_handler = HttpBasicServerAuthHandler(creds={ + b"test": b"p4ssw0rd", +}) + +token_auth_handler = TokenServerAuthHandler(creds={ + b"test": b"p4ssw0rd", +}) + + +@pytest.mark.slow +def test_http_basic_unauth(): + """Test that auth fails when not authenticated.""" + with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server: + client = FlightClient(('localhost', server.port)) + action = flight.Action("who-am-i", b"") + with pytest.raises(flight.FlightUnauthenticatedError, + match=".*unauthenticated.*"): + list(client.do_action(action)) + + +@pytest.mark.skipif(os.name == 'nt', + reason="ARROW-10013: gRPC on Windows corrupts peer()") +def test_http_basic_auth(): + """Test a Python implementation of HTTP basic authentication.""" + with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server: + client = FlightClient(('localhost', server.port)) + action = flight.Action("who-am-i", b"") + client.authenticate(HttpBasicClientAuthHandler('test', 'p4ssw0rd')) + results = client.do_action(action) + identity = next(results) + assert identity.body.to_pybytes() == b'test' + peer_address = next(results) + assert peer_address.body.to_pybytes() != b'' + + +def test_http_basic_auth_invalid_password(): + """Test that auth fails with the wrong password.""" + with EchoStreamFlightServer(auth_handler=basic_auth_handler) as server: + client = FlightClient(('localhost', server.port)) + action = flight.Action("who-am-i", b"") + with pytest.raises(flight.FlightUnauthenticatedError, + match=".*wrong password.*"): + client.authenticate(HttpBasicClientAuthHandler('test', 'wrong')) + next(client.do_action(action)) + + +def test_token_auth(): + """Test an auth mechanism that uses a handshake.""" + with EchoStreamFlightServer(auth_handler=token_auth_handler) as server: + client = FlightClient(('localhost', server.port)) + action = flight.Action("who-am-i", b"") + client.authenticate(TokenClientAuthHandler('test', 'p4ssw0rd')) + identity = next(client.do_action(action)) + assert identity.body.to_pybytes() == b'test' + + +def test_token_auth_invalid(): + """Test an auth mechanism that uses a handshake.""" + with EchoStreamFlightServer(auth_handler=token_auth_handler) as server: + client = FlightClient(('localhost', server.port)) + with pytest.raises(flight.FlightUnauthenticatedError): + client.authenticate(TokenClientAuthHandler('test', 'wrong')) + + +header_auth_server_middleware_factory = HeaderAuthServerMiddlewareFactory() +no_op_auth_handler = NoopAuthHandler() + + +def test_authenticate_basic_token(): + """Test authenticate_basic_token with bearer token and auth headers.""" + with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ + "auth": HeaderAuthServerMiddlewareFactory() + }) as server: + client = FlightClient(('localhost', server.port)) + token_pair = client.authenticate_basic_token(b'test', b'password') + assert token_pair[0] == b'authorization' + assert token_pair[1] == b'Bearer token1234' + + +def test_authenticate_basic_token_invalid_password(): + """Test authenticate_basic_token with an invalid password.""" + with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ + "auth": HeaderAuthServerMiddlewareFactory() + }) as server: + client = FlightClient(('localhost', server.port)) + with pytest.raises(flight.FlightUnauthenticatedError): + client.authenticate_basic_token(b'test', b'badpassword') + + +def test_authenticate_basic_token_and_action(): + """Test authenticate_basic_token and doAction after authentication.""" + with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ + "auth": HeaderAuthServerMiddlewareFactory() + }) as server: + client = FlightClient(('localhost', server.port)) + token_pair = client.authenticate_basic_token(b'test', b'password') + assert token_pair[0] == b'authorization' + assert token_pair[1] == b'Bearer token1234' + options = flight.FlightCallOptions(headers=[token_pair]) + result = list(client.do_action( + action=flight.Action('test-action', b''), options=options)) + assert result[0].body.to_pybytes() == b'token1234' + + +def test_authenticate_basic_token_with_client_middleware(): + """Test authenticate_basic_token with client middleware + to intercept authorization header returned by the + HTTP header auth enabled server. + """ + with HeaderAuthFlightServer(auth_handler=no_op_auth_handler, middleware={ + "auth": HeaderAuthServerMiddlewareFactory() + }) as server: + client_auth_middleware = ClientHeaderAuthMiddlewareFactory() + client = FlightClient( + ('localhost', server.port), + middleware=[client_auth_middleware] + ) + encoded_credentials = base64.b64encode(b'test:password') + options = flight.FlightCallOptions(headers=[ + (b'authorization', b'Basic ' + encoded_credentials) + ]) + result = list(client.do_action( + action=flight.Action('test-action', b''), options=options)) + assert result[0].body.to_pybytes() == b'token1234' + assert client_auth_middleware.call_credential[0] == b'authorization' + assert client_auth_middleware.call_credential[1] == \ + b'Bearer ' + b'token1234' + result2 = list(client.do_action( + action=flight.Action('test-action', b''), options=options)) + assert result2[0].body.to_pybytes() == b'token1234' + assert client_auth_middleware.call_credential[0] == b'authorization' + assert client_auth_middleware.call_credential[1] == \ + b'Bearer ' + b'token1234' + + +def test_arbitrary_headers_in_flight_call_options(): + """Test passing multiple arbitrary headers to the middleware.""" + with ArbitraryHeadersFlightServer( + auth_handler=no_op_auth_handler, + middleware={ + "auth": HeaderAuthServerMiddlewareFactory(), + "arbitrary-headers": ArbitraryHeadersServerMiddlewareFactory() + }) as server: + client = FlightClient(('localhost', server.port)) + token_pair = client.authenticate_basic_token(b'test', b'password') + assert token_pair[0] == b'authorization' + assert token_pair[1] == b'Bearer token1234' + options = flight.FlightCallOptions(headers=[ + token_pair, + (b'test-header-1', b'value1'), + (b'test-header-2', b'value2') + ]) + result = list(client.do_action(flight.Action( + "test-action", b""), options=options)) + assert result[0].body.to_pybytes() == b'value1' + assert result[1].body.to_pybytes() == b'value2' + + +def test_location_invalid(): + """Test constructing invalid URIs.""" + with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"): + flight.connect("%") + + with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"): + ConstantFlightServer("%") + + +def test_location_unknown_scheme(): + """Test creating locations for unknown schemes.""" + assert flight.Location("s3://foo").uri == b"s3://foo" + assert flight.Location("https://example.com/bar.parquet").uri == \ + b"https://example.com/bar.parquet" + + +@pytest.mark.slow +@pytest.mark.requires_testing_data +def test_tls_fails(): + """Make sure clients cannot connect when cert verification fails.""" + certs = example_tls_certs() + + with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: + # Ensure client doesn't connect when certificate verification + # fails (this is a slow test since gRPC does retry a few times) + client = FlightClient("grpc+tls://localhost:" + str(s.port)) + + # gRPC error messages change based on version, so don't look + # for a particular error + with pytest.raises(flight.FlightUnavailableError): + client.do_get(flight.Ticket(b'ints')).read_all() + + +@pytest.mark.requires_testing_data +def test_tls_do_get(): + """Try a simple do_get call over TLS.""" + table = simple_ints_table() + certs = example_tls_certs() + + with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: + client = FlightClient(('localhost', s.port), + tls_root_certs=certs["root_cert"]) + data = client.do_get(flight.Ticket(b'ints')).read_all() + assert data.equals(table) + + +@pytest.mark.requires_testing_data +def test_tls_disable_server_verification(): + """Try a simple do_get call over TLS with server verification disabled.""" + table = simple_ints_table() + certs = example_tls_certs() + + with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: + try: + client = FlightClient(('localhost', s.port), + disable_server_verification=True) + except NotImplementedError: + pytest.skip('disable_server_verification feature is not available') + data = client.do_get(flight.Ticket(b'ints')).read_all() + assert data.equals(table) + + +@pytest.mark.requires_testing_data +def test_tls_override_hostname(): + """Check that incorrectly overriding the hostname fails.""" + certs = example_tls_certs() + + with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: + client = flight.connect(('localhost', s.port), + tls_root_certs=certs["root_cert"], + override_hostname="fakehostname") + with pytest.raises(flight.FlightUnavailableError): + client.do_get(flight.Ticket(b'ints')) + + +def test_flight_do_get_metadata(): + """Try a simple do_get call with metadata.""" + data = [ + pa.array([-10, -5, 0, 5, 10]) + ] + table = pa.Table.from_arrays(data, names=['a']) + + batches = [] + with MetadataFlightServer() as server: + client = FlightClient(('localhost', server.port)) + reader = client.do_get(flight.Ticket(b'')) + idx = 0 + while True: + try: + batch, metadata = reader.read_chunk() + batches.append(batch) + server_idx, = struct.unpack('<i', metadata.to_pybytes()) + assert idx == server_idx + idx += 1 + except StopIteration: + break + data = pa.Table.from_batches(batches) + assert data.equals(table) + + +def test_flight_do_get_metadata_v4(): + """Try a simple do_get call with V4 metadata version.""" + table = pa.Table.from_arrays( + [pa.array([-10, -5, 0, 5, 10])], names=['a']) + options = pa.ipc.IpcWriteOptions( + metadata_version=pa.ipc.MetadataVersion.V4) + with MetadataFlightServer(options=options) as server: + client = FlightClient(('localhost', server.port)) + reader = client.do_get(flight.Ticket(b'')) + data = reader.read_all() + assert data.equals(table) + + +def test_flight_do_put_metadata(): + """Try a simple do_put call with metadata.""" + data = [ + pa.array([-10, -5, 0, 5, 10]) + ] + table = pa.Table.from_arrays(data, names=['a']) + + with MetadataFlightServer() as server: + client = FlightClient(('localhost', server.port)) + writer, metadata_reader = client.do_put( + flight.FlightDescriptor.for_path(''), + table.schema) + with writer: + for idx, batch in enumerate(table.to_batches(max_chunksize=1)): + metadata = struct.pack('<i', idx) + writer.write_with_metadata(batch, metadata) + buf = metadata_reader.read() + assert buf is not None + server_idx, = struct.unpack('<i', buf.to_pybytes()) + assert idx == server_idx + + +def test_flight_do_put_limit(): + """Try a simple do_put call with a size limit.""" + large_batch = pa.RecordBatch.from_arrays([ + pa.array(np.ones(768, dtype=np.int64())), + ], names=['a']) + + with EchoFlightServer() as server: + client = FlightClient(('localhost', server.port), + write_size_limit_bytes=4096) + writer, metadata_reader = client.do_put( + flight.FlightDescriptor.for_path(''), + large_batch.schema) + with writer: + with pytest.raises(flight.FlightWriteSizeExceededError, + match="exceeded soft limit") as excinfo: + writer.write_batch(large_batch) + assert excinfo.value.limit == 4096 + smaller_batches = [ + large_batch.slice(0, 384), + large_batch.slice(384), + ] + for batch in smaller_batches: + writer.write_batch(batch) + expected = pa.Table.from_batches([large_batch]) + actual = client.do_get(flight.Ticket(b'')).read_all() + assert expected == actual + + +@pytest.mark.slow +def test_cancel_do_get(): + """Test canceling a DoGet operation on the client side.""" + with ConstantFlightServer() as server: + client = FlightClient(('localhost', server.port)) + reader = client.do_get(flight.Ticket(b'ints')) + reader.cancel() + with pytest.raises(flight.FlightCancelledError, match=".*Cancel.*"): + reader.read_chunk() + + +@pytest.mark.slow +def test_cancel_do_get_threaded(): + """Test canceling a DoGet operation from another thread.""" + with SlowFlightServer() as server: + client = FlightClient(('localhost', server.port)) + reader = client.do_get(flight.Ticket(b'ints')) + + read_first_message = threading.Event() + stream_canceled = threading.Event() + result_lock = threading.Lock() + raised_proper_exception = threading.Event() + + def block_read(): + reader.read_chunk() + read_first_message.set() + stream_canceled.wait(timeout=5) + try: + reader.read_chunk() + except flight.FlightCancelledError: + with result_lock: + raised_proper_exception.set() + + thread = threading.Thread(target=block_read, daemon=True) + thread.start() + read_first_message.wait(timeout=5) + reader.cancel() + stream_canceled.set() + thread.join(timeout=1) + + with result_lock: + assert raised_proper_exception.is_set() + + +def test_roundtrip_types(): + """Make sure serializable types round-trip.""" + ticket = flight.Ticket("foo") + assert ticket == flight.Ticket.deserialize(ticket.serialize()) + + desc = flight.FlightDescriptor.for_command("test") + assert desc == flight.FlightDescriptor.deserialize(desc.serialize()) + + desc = flight.FlightDescriptor.for_path("a", "b", "test.arrow") + assert desc == flight.FlightDescriptor.deserialize(desc.serialize()) + + info = flight.FlightInfo( + pa.schema([('a', pa.int32())]), + desc, + [ + flight.FlightEndpoint(b'', ['grpc://test']), + flight.FlightEndpoint( + b'', + [flight.Location.for_grpc_tcp('localhost', 5005)], + ), + ], + -1, + -1, + ) + info2 = flight.FlightInfo.deserialize(info.serialize()) + assert info.schema == info2.schema + assert info.descriptor == info2.descriptor + assert info.total_bytes == info2.total_bytes + assert info.total_records == info2.total_records + assert info.endpoints == info2.endpoints + + +def test_roundtrip_errors(): + """Ensure that Flight errors propagate from server to client.""" + with ErrorFlightServer() as server: + client = FlightClient(('localhost', server.port)) + + with pytest.raises(flight.FlightInternalError, match=".*foo.*"): + list(client.do_action(flight.Action("internal", b""))) + with pytest.raises(flight.FlightTimedOutError, match=".*foo.*"): + list(client.do_action(flight.Action("timedout", b""))) + with pytest.raises(flight.FlightCancelledError, match=".*foo.*"): + list(client.do_action(flight.Action("cancel", b""))) + with pytest.raises(flight.FlightUnauthenticatedError, match=".*foo.*"): + list(client.do_action(flight.Action("unauthenticated", b""))) + with pytest.raises(flight.FlightUnauthorizedError, match=".*foo.*"): + list(client.do_action(flight.Action("unauthorized", b""))) + with pytest.raises(flight.FlightInternalError, match=".*foo.*"): + list(client.list_flights()) + + data = [pa.array([-10, -5, 0, 5, 10])] + table = pa.Table.from_arrays(data, names=['a']) + + exceptions = { + 'internal': flight.FlightInternalError, + 'timedout': flight.FlightTimedOutError, + 'cancel': flight.FlightCancelledError, + 'unauthenticated': flight.FlightUnauthenticatedError, + 'unauthorized': flight.FlightUnauthorizedError, + } + + for command, exception in exceptions.items(): + + with pytest.raises(exception, match=".*foo.*"): + writer, reader = client.do_put( + flight.FlightDescriptor.for_command(command), + table.schema) + writer.write_table(table) + writer.close() + + with pytest.raises(exception, match=".*foo.*"): + writer, reader = client.do_put( + flight.FlightDescriptor.for_command(command), + table.schema) + writer.close() + + +def test_do_put_independent_read_write(): + """Ensure that separate threads can read/write on a DoPut.""" + # ARROW-6063: previously this would cause gRPC to abort when the + # writer was closed (due to simultaneous reads), or would hang + # forever. + data = [ + pa.array([-10, -5, 0, 5, 10]) + ] + table = pa.Table.from_arrays(data, names=['a']) + + with MetadataFlightServer() as server: + client = FlightClient(('localhost', server.port)) + writer, metadata_reader = client.do_put( + flight.FlightDescriptor.for_path(''), + table.schema) + + count = [0] + + def _reader_thread(): + while metadata_reader.read() is not None: + count[0] += 1 + + thread = threading.Thread(target=_reader_thread) + thread.start() + + batches = table.to_batches(max_chunksize=1) + with writer: + for idx, batch in enumerate(batches): + metadata = struct.pack('<i', idx) + writer.write_with_metadata(batch, metadata) + # Causes the server to stop writing and end the call + writer.done_writing() + # Thus reader thread will break out of loop + thread.join() + # writer.close() won't segfault since reader thread has + # stopped + assert count[0] == len(batches) + + +def test_server_middleware_same_thread(): + """Ensure that server middleware run on the same thread as the RPC.""" + with HeaderFlightServer(middleware={ + "test": HeaderServerMiddlewareFactory(), + }) as server: + client = FlightClient(('localhost', server.port)) + results = list(client.do_action(flight.Action(b"test", b""))) + assert len(results) == 1 + value = results[0].body.to_pybytes() + assert b"right value" == value + + +def test_middleware_reject(): + """Test rejecting an RPC with server middleware.""" + with HeaderFlightServer(middleware={ + "test": SelectiveAuthServerMiddlewareFactory(), + }) as server: + client = FlightClient(('localhost', server.port)) + # The middleware allows this through without auth. + with pytest.raises(pa.ArrowNotImplementedError): + list(client.list_actions()) + + # But not anything else. + with pytest.raises(flight.FlightUnauthenticatedError): + list(client.do_action(flight.Action(b"", b""))) + + client = FlightClient( + ('localhost', server.port), + middleware=[SelectiveAuthClientMiddlewareFactory()] + ) + response = next(client.do_action(flight.Action(b"", b""))) + assert b"password" == response.body.to_pybytes() + + +def test_middleware_mapping(): + """Test that middleware records methods correctly.""" + server_middleware = RecordingServerMiddlewareFactory() + client_middleware = RecordingClientMiddlewareFactory() + with FlightServerBase(middleware={"test": server_middleware}) as server: + client = FlightClient( + ('localhost', server.port), + middleware=[client_middleware] + ) + + descriptor = flight.FlightDescriptor.for_command(b"") + with pytest.raises(NotImplementedError): + list(client.list_flights()) + with pytest.raises(NotImplementedError): + client.get_flight_info(descriptor) + with pytest.raises(NotImplementedError): + client.get_schema(descriptor) + with pytest.raises(NotImplementedError): + client.do_get(flight.Ticket(b"")) + with pytest.raises(NotImplementedError): + writer, _ = client.do_put(descriptor, pa.schema([])) + writer.close() + with pytest.raises(NotImplementedError): + list(client.do_action(flight.Action(b"", b""))) + with pytest.raises(NotImplementedError): + list(client.list_actions()) + with pytest.raises(NotImplementedError): + writer, _ = client.do_exchange(descriptor) + writer.close() + + expected = [ + flight.FlightMethod.LIST_FLIGHTS, + flight.FlightMethod.GET_FLIGHT_INFO, + flight.FlightMethod.GET_SCHEMA, + flight.FlightMethod.DO_GET, + flight.FlightMethod.DO_PUT, + flight.FlightMethod.DO_ACTION, + flight.FlightMethod.LIST_ACTIONS, + flight.FlightMethod.DO_EXCHANGE, + ] + assert server_middleware.methods == expected + assert client_middleware.methods == expected + + +def test_extra_info(): + with ErrorFlightServer() as server: + client = FlightClient(('localhost', server.port)) + try: + list(client.do_action(flight.Action("protobuf", b""))) + assert False + except flight.FlightUnauthorizedError as e: + assert e.extra_info is not None + ei = e.extra_info + assert ei == b'this is an error message' + + +@pytest.mark.requires_testing_data +def test_mtls(): + """Test mutual TLS (mTLS) with gRPC.""" + certs = example_tls_certs() + table = simple_ints_table() + + with ConstantFlightServer( + tls_certificates=[certs["certificates"][0]], + verify_client=True, + root_certificates=certs["root_cert"]) as s: + client = FlightClient( + ('localhost', s.port), + tls_root_certs=certs["root_cert"], + cert_chain=certs["certificates"][0].cert, + private_key=certs["certificates"][0].key) + data = client.do_get(flight.Ticket(b'ints')).read_all() + assert data.equals(table) + + +def test_doexchange_get(): + """Emulate DoGet with DoExchange.""" + expected = pa.Table.from_arrays([ + pa.array(range(0, 10 * 1024)) + ], names=["a"]) + + with ExchangeFlightServer() as server: + client = FlightClient(("localhost", server.port)) + descriptor = flight.FlightDescriptor.for_command(b"get") + writer, reader = client.do_exchange(descriptor) + with writer: + table = reader.read_all() + assert expected == table + + +def test_doexchange_put(): + """Emulate DoPut with DoExchange.""" + data = pa.Table.from_arrays([ + pa.array(range(0, 10 * 1024)) + ], names=["a"]) + batches = data.to_batches(max_chunksize=512) + + with ExchangeFlightServer() as server: + client = FlightClient(("localhost", server.port)) + descriptor = flight.FlightDescriptor.for_command(b"put") + writer, reader = client.do_exchange(descriptor) + with writer: + writer.begin(data.schema) + for batch in batches: + writer.write_batch(batch) + writer.done_writing() + chunk = reader.read_chunk() + assert chunk.data is None + expected_buf = str(len(batches)).encode("utf-8") + assert chunk.app_metadata == expected_buf + + +def test_doexchange_echo(): + """Try a DoExchange echo server.""" + data = pa.Table.from_arrays([ + pa.array(range(0, 10 * 1024)) + ], names=["a"]) + batches = data.to_batches(max_chunksize=512) + + with ExchangeFlightServer() as server: + client = FlightClient(("localhost", server.port)) + descriptor = flight.FlightDescriptor.for_command(b"echo") + writer, reader = client.do_exchange(descriptor) + with writer: + # Read/write metadata before starting data. + for i in range(10): + buf = str(i).encode("utf-8") + writer.write_metadata(buf) + chunk = reader.read_chunk() + assert chunk.data is None + assert chunk.app_metadata == buf + + # Now write data without metadata. + writer.begin(data.schema) + for batch in batches: + writer.write_batch(batch) + assert reader.schema == data.schema + chunk = reader.read_chunk() + assert chunk.data == batch + assert chunk.app_metadata is None + + # And write data with metadata. + for i, batch in enumerate(batches): + buf = str(i).encode("utf-8") + writer.write_with_metadata(batch, buf) + chunk = reader.read_chunk() + assert chunk.data == batch + assert chunk.app_metadata == buf + + +def test_doexchange_echo_v4(): + """Try a DoExchange echo server using the V4 metadata version.""" + data = pa.Table.from_arrays([ + pa.array(range(0, 10 * 1024)) + ], names=["a"]) + batches = data.to_batches(max_chunksize=512) + + options = pa.ipc.IpcWriteOptions( + metadata_version=pa.ipc.MetadataVersion.V4) + with ExchangeFlightServer(options=options) as server: + client = FlightClient(("localhost", server.port)) + descriptor = flight.FlightDescriptor.for_command(b"echo") + writer, reader = client.do_exchange(descriptor) + with writer: + # Now write data without metadata. + writer.begin(data.schema, options=options) + for batch in batches: + writer.write_batch(batch) + assert reader.schema == data.schema + chunk = reader.read_chunk() + assert chunk.data == batch + assert chunk.app_metadata is None + + +def test_doexchange_transform(): + """Transform a table with a service.""" + data = pa.Table.from_arrays([ + pa.array(range(0, 1024)), + pa.array(range(1, 1025)), + pa.array(range(2, 1026)), + ], names=["a", "b", "c"]) + expected = pa.Table.from_arrays([ + pa.array(range(3, 1024 * 3 + 3, 3)), + ], names=["sum"]) + + with ExchangeFlightServer() as server: + client = FlightClient(("localhost", server.port)) + descriptor = flight.FlightDescriptor.for_command(b"transform") + writer, reader = client.do_exchange(descriptor) + with writer: + writer.begin(data.schema) + writer.write_table(data) + writer.done_writing() + table = reader.read_all() + assert expected == table + + +def test_middleware_multi_header(): + """Test sending/receiving multiple (binary-valued) headers.""" + with MultiHeaderFlightServer(middleware={ + "test": MultiHeaderServerMiddlewareFactory(), + }) as server: + headers = MultiHeaderClientMiddlewareFactory() + client = FlightClient(('localhost', server.port), middleware=[headers]) + response = next(client.do_action(flight.Action(b"", b""))) + # The server echoes the headers it got back to us. + raw_headers = response.body.to_pybytes().decode("utf-8") + client_headers = ast.literal_eval(raw_headers) + # Don't directly compare; gRPC may add headers like User-Agent. + for header, values in MultiHeaderClientMiddleware.EXPECTED.items(): + assert client_headers.get(header) == values + assert headers.last_headers.get(header) == values + + +@pytest.mark.requires_testing_data +def test_generic_options(): + """Test setting generic client options.""" + certs = example_tls_certs() + + with ConstantFlightServer(tls_certificates=certs["certificates"]) as s: + # Try setting a string argument that will make requests fail + options = [("grpc.ssl_target_name_override", "fakehostname")] + client = flight.connect(('localhost', s.port), + tls_root_certs=certs["root_cert"], + generic_options=options) + with pytest.raises(flight.FlightUnavailableError): + client.do_get(flight.Ticket(b'ints')) + # Try setting an int argument that will make requests fail + options = [("grpc.max_receive_message_length", 32)] + client = flight.connect(('localhost', s.port), + tls_root_certs=certs["root_cert"], + generic_options=options) + with pytest.raises(pa.ArrowInvalid): + client.do_get(flight.Ticket(b'ints')) + + +class CancelFlightServer(FlightServerBase): + """A server for testing StopToken.""" + + def do_get(self, context, ticket): + schema = pa.schema([]) + rb = pa.RecordBatch.from_arrays([], schema=schema) + return flight.GeneratorStream(schema, itertools.repeat(rb)) + + def do_exchange(self, context, descriptor, reader, writer): + schema = pa.schema([]) + rb = pa.RecordBatch.from_arrays([], schema=schema) + writer.begin(schema) + while not context.is_cancelled(): + writer.write_batch(rb) + time.sleep(0.5) + + +def test_interrupt(): + if threading.current_thread().ident != threading.main_thread().ident: + pytest.skip("test only works from main Python thread") + # Skips test if not available + raise_signal = util.get_raise_signal() + + def signal_from_thread(): + time.sleep(0.5) + raise_signal(signal.SIGINT) + + exc_types = (KeyboardInterrupt, pa.ArrowCancelled) + + def test(read_all): + try: + try: + t = threading.Thread(target=signal_from_thread) + with pytest.raises(exc_types) as exc_info: + t.start() + read_all() + finally: + t.join() + except KeyboardInterrupt: + # In case KeyboardInterrupt didn't interrupt read_all + # above, at least prevent it from stopping the test suite + pytest.fail("KeyboardInterrupt didn't interrupt Flight read_all") + e = exc_info.value.__context__ + assert isinstance(e, pa.ArrowCancelled) or \ + isinstance(e, KeyboardInterrupt) + + with CancelFlightServer() as server: + client = FlightClient(("localhost", server.port)) + + reader = client.do_get(flight.Ticket(b"")) + test(reader.read_all) + + descriptor = flight.FlightDescriptor.for_command(b"echo") + writer, reader = client.do_exchange(descriptor) + test(reader.read_all) + + +def test_never_sends_data(): + # Regression test for ARROW-12779 + match = "application server implementation error" + with NeverSendsDataFlightServer() as server: + client = flight.connect(('localhost', server.port)) + with pytest.raises(flight.FlightServerError, match=match): + client.do_get(flight.Ticket(b'')).read_all() + + # Check that the server handler will ignore empty tables + # up to a certain extent + table = client.do_get(flight.Ticket(b'yield_data')).read_all() + assert table.num_rows == 5 + + +@pytest.mark.large_memory +@pytest.mark.slow +def test_large_descriptor(): + # Regression test for ARROW-13253. Placed here with appropriate marks + # since some CI pipelines can't run the C++ equivalent + large_descriptor = flight.FlightDescriptor.for_command( + b' ' * (2 ** 31 + 1)) + with FlightServerBase() as server: + client = flight.connect(('localhost', server.port)) + with pytest.raises(OSError, + match="Failed to serialize Flight descriptor"): + writer, _ = client.do_put(large_descriptor, pa.schema([])) + writer.close() + with pytest.raises(pa.ArrowException, + match="Failed to serialize Flight descriptor"): + client.do_exchange(large_descriptor) + + +@pytest.mark.large_memory +@pytest.mark.slow +def test_large_metadata_client(): + # Regression test for ARROW-13253 + descriptor = flight.FlightDescriptor.for_command(b'') + metadata = b' ' * (2 ** 31 + 1) + with EchoFlightServer() as server: + client = flight.connect(('localhost', server.port)) + with pytest.raises(pa.ArrowCapacityError, + match="app_metadata size overflow"): + writer, _ = client.do_put(descriptor, pa.schema([])) + with writer: + writer.write_metadata(metadata) + writer.close() + with pytest.raises(pa.ArrowCapacityError, + match="app_metadata size overflow"): + writer, reader = client.do_exchange(descriptor) + with writer: + writer.write_metadata(metadata) + + del metadata + with LargeMetadataFlightServer() as server: + client = flight.connect(('localhost', server.port)) + with pytest.raises(flight.FlightServerError, + match="app_metadata size overflow"): + reader = client.do_get(flight.Ticket(b'')) + reader.read_all() + with pytest.raises(pa.ArrowException, + match="app_metadata size overflow"): + writer, reader = client.do_exchange(descriptor) + with writer: + reader.read_all() + + +class ActionNoneFlightServer(EchoFlightServer): + """A server that implements a side effect to a non iterable action.""" + VALUES = [] + + def do_action(self, context, action): + if action.type == "get_value": + return [json.dumps(self.VALUES).encode('utf-8')] + elif action.type == "append": + self.VALUES.append(True) + return None + raise NotImplementedError + + +def test_none_action_side_effect(): + """Ensure that actions are executed even when we don't consume iterator. + + See https://issues.apache.org/jira/browse/ARROW-14255 + """ + + with ActionNoneFlightServer() as server: + client = FlightClient(('localhost', server.port)) + client.do_action(flight.Action("append", b"")) + r = client.do_action(flight.Action("get_value", b"")) + assert json.loads(next(r).body.to_pybytes()) == [True] |