summaryrefslogtreecommitdiffstats
path: root/src/arrow/java/flight
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-21 11:54:28 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-21 11:54:28 +0000
commite6918187568dbd01842d8d1d2c808ce16a894239 (patch)
tree64f88b554b444a49f656b6c656111a145cbbaa28 /src/arrow/java/flight
parentInitial commit. (diff)
downloadceph-e6918187568dbd01842d8d1d2c808ce16a894239.tar.xz
ceph-e6918187568dbd01842d8d1d2c808ce16a894239.zip
Adding upstream version 18.2.2.upstream/18.2.2
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/arrow/java/flight')
-rw-r--r--src/arrow/java/flight/flight-core/README.md95
-rw-r--r--src/arrow/java/flight/flight-core/pom.xml392
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Action.java61
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ActionType.java70
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java560
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/AsyncPutListener.java72
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/BackpressureStrategy.java172
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallHeaders.java65
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallInfo.java33
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallOption.java24
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallOptions.java62
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallStatus.java143
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Criteria.java58
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java127
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ErrorFlightMetadata.java81
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightBindingService.java174
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightCallHeaders.java111
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java721
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClientMiddleware.java52
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightConstants.java29
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDescriptor.java180
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightEndpoint.java106
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java208
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightMethod.java64
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java164
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightRuntimeException.java46
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java399
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServerMiddleware.java100
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java427
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStatusCode.java82
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java505
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/HeaderCallOption.java52
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java158
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/LocationSchemes.java32
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java61
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/NoOpStreamListener.java49
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListener.java123
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListenerImpl.java132
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/PutResult.java96
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/RequestContext.java51
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Result.java50
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java96
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ServerHeaderMiddleware.java65
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/StreamPipe.java118
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SyncPutListener.java122
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Ticket.java102
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/AuthConstants.java51
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/BasicClientAuthHandler.java58
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/BasicServerAuthHandler.java74
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthHandler.java55
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthInterceptor.java73
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java162
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java72
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java85
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java144
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/Auth2Constants.java31
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/AuthUtilities.java47
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BasicAuthCredentialWriter.java44
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BasicCallHeaderAuthenticator.java88
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BearerCredentialWriter.java39
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BearerTokenAuthenticator.java62
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/CallHeaderAuthenticator.java86
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientBearerHeaderHandler.java36
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientHandshakeWrapper.java100
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientHeaderHandler.java43
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientIncomingAuthHeaderMiddleware.java78
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/GeneratedBearerTokenAuthenticator.java128
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ServerCallHeaderAuthMiddleware.java74
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/client/ClientCookieMiddleware.java130
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java93
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/ExampleTicket.java141
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/FlightHolder.java131
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java176
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/Stream.java177
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java97
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java74
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java197
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java97
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/MiddlewareScenario.java168
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java45
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java90
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/AddWritableBuffer.java128
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/CallCredentialAdapter.java53
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ClientInterceptorAdapter.java149
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ContextPropagatingExecutorService.java117
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/CredentialCallOption.java41
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java99
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/MetadataAdapter.java98
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/RequestContextAdapter.java57
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ServerInterceptorAdapter.java145
-rw-r--r--src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java255
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/FlightTestUtil.java150
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java329
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestAuth.java93
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBackPressure.java262
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java567
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestCallOptions.java191
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestClientMiddleware.java359
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDictionaryUtils.java91
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java536
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestErrorMetadata.java143
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightClient.java225
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java125
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLargeMessage.java165
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLeak.java182
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java319
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerMiddleware.java360
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerOptions.java176
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestTls.java145
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java158
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth2/TestBasicAuth2.java232
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/TestCookieHandling.java267
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java117
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/grpc/TestStatusUtils.java51
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java216
-rw-r--r--src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/TestPerf.java199
-rw-r--r--src/arrow/java/flight/flight-core/src/test/protobuf/perf.proto45
-rw-r--r--src/arrow/java/flight/flight-core/src/test/resources/logback.xml28
-rw-r--r--src/arrow/java/flight/flight-grpc/pom.xml132
-rw-r--r--src/arrow/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java161
-rw-r--r--src/arrow/java/flight/flight-grpc/src/test/java/org/apache/arrow/flight/TestFlightGrpcUtils.java193
-rw-r--r--src/arrow/java/flight/flight-grpc/src/test/protobuf/test.proto26
122 files changed, 17321 insertions, 0 deletions
diff --git a/src/arrow/java/flight/flight-core/README.md b/src/arrow/java/flight/flight-core/README.md
new file mode 100644
index 000000000..37b41ede2
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/README.md
@@ -0,0 +1,95 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+# Arrow Flight Java Package
+
+Exposing Apache Arrow data on the wire.
+
+[Protocol Description Slides](https://www.slideshare.net/JacquesNadeau5/apache-arrow-flight-overview)
+
+[GRPC Protocol Definition](https://github.com/apache/arrow/blob/master/format/Flight.proto)
+
+## Example usage
+
+* Compile the java tree:
+
+ ```
+ cd java
+ mvn clean install -DskipTests
+ ```
+
+* Go Into the Flight tree
+
+ ```
+ cd flight/flight-core
+ ```
+
+
+* Start the ExampleFlightServer (supports get/put of streams and listing these streams)
+
+ ```
+ mvn exec:exec
+ ```
+
+* In new terminal, run the TestExampleServer to populate the server with example data
+
+ ```
+ cd arrow/java/flight/flight-core
+ mvn surefire:test -DdisableServer=true -Dtest=TestExampleServer
+ ```
+
+## Python Example Usage
+
+* Compile example python headers
+
+ ```
+ mkdir target/generated-python
+ pip install grpcio-tools # or conda install grpcio
+ python -m grpc_tools.protoc -I./src/main/protobuf/ --python_out=./target/generated-python --grpc_python_out=./target/generated-python ../../format/Flight.proto
+ ```
+
+* Connect to the Flight Service
+
+ ```
+ cd target/generated-python
+ python
+ ```
+
+
+ ```
+ import grpc
+ import flight_pb2
+ import flight_pb2_grpc as flightrpc
+ channel = grpc.insecure_channel('localhost:12233')
+ service = flightrpc.FlightServiceStub(channel)
+ ```
+
+* List the Flight from Python
+
+ ```
+ for f in service.ListFlights(flight_pb2.Criteria()): f
+ ```
+
+* Try to Drop
+
+ ```
+ action = flight_pb2.Action()
+ action.type="drop"
+ service.DoAction(action)
+ ```
diff --git a/src/arrow/java/flight/flight-core/pom.xml b/src/arrow/java/flight/flight-core/pom.xml
new file mode 100644
index 000000000..669c6b744
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/pom.xml
@@ -0,0 +1,392 @@
+<?xml version="1.0"?>
+<!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor
+ license agreements. See the NOTICE file distributed with this work for additional
+ information regarding copyright ownership. The ASF licenses this file to
+ You under the Apache License, Version 2.0 (the "License"); you may not use
+ this file except in compliance with the License. You may obtain a copy of
+ the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required
+ by applicable law or agreed to in writing, software distributed under the
+ License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS
+ OF ANY KIND, either express or implied. See the License for the specific
+ language governing permissions and limitations under the License. -->
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <modelVersion>4.0.0</modelVersion>
+ <parent>
+ <groupId>org.apache.arrow</groupId>
+ <artifactId>arrow-java-root</artifactId>
+ <version>6.0.1</version>
+ <relativePath>../../pom.xml</relativePath>
+ </parent>
+
+ <artifactId>flight-core</artifactId>
+ <name>Arrow Flight Core</name>
+ <description>(Experimental)An RPC mechanism for transferring ValueVectors.</description>
+ <packaging>jar</packaging>
+
+ <properties>
+ <dep.grpc.version>1.41.0</dep.grpc.version>
+ <dep.protobuf.version>3.7.1</dep.protobuf.version>
+ <forkCount>1</forkCount>
+ </properties>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.arrow</groupId>
+ <artifactId>arrow-format</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.arrow</groupId>
+ <artifactId>arrow-vector</artifactId>
+ <version>${project.version}</version>
+ <classifier>${arrow.vector.classifier}</classifier>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.arrow</groupId>
+ <artifactId>arrow-memory-core</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.arrow</groupId>
+ <artifactId>arrow-memory-netty</artifactId>
+ <version>${project.version}</version>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
+ <groupId>io.grpc</groupId>
+ <artifactId>grpc-netty</artifactId>
+ <version>${dep.grpc.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>io.grpc</groupId>
+ <artifactId>grpc-core</artifactId>
+ <version>${dep.grpc.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>io.grpc</groupId>
+ <artifactId>grpc-context</artifactId>
+ <version>${dep.grpc.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>io.grpc</groupId>
+ <artifactId>grpc-protobuf</artifactId>
+ <version>${dep.grpc.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-tcnative-boringssl-static</artifactId>
+ <version>2.0.43.Final</version>
+ </dependency>
+ <dependency>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-buffer</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-handler</artifactId>
+ <version>${dep.netty.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-transport</artifactId>
+ <version>${dep.netty.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>commons-cli</groupId>
+ <artifactId>commons-cli</artifactId>
+ <version>1.4</version>
+ </dependency>
+ <dependency>
+ <groupId>io.grpc</groupId>
+ <artifactId>grpc-stub</artifactId>
+ <version>${dep.grpc.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ <version>${dep.protobuf.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>io.grpc</groupId>
+ <artifactId>grpc-api</artifactId>
+ <version>${dep.grpc.version}</version>
+ </dependency>
+
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-core</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-annotations</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>javax.annotation</groupId>
+ <artifactId>javax.annotation-api</artifactId>
+ </dependency>
+
+ <dependency>
+ <groupId>com.google.api.grpc</groupId>
+ <artifactId>proto-google-common-protos</artifactId>
+ <version>1.12.0</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.arrow</groupId>
+ <artifactId>arrow-vector</artifactId>
+ <version>${project.version}</version>
+ <classifier>tests</classifier>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
+ </dependencies>
+ <build>
+ <extensions>
+ <extension>
+ <groupId>kr.motd.maven</groupId>
+ <artifactId>os-maven-plugin</artifactId>
+ <version>1.5.0.Final</version>
+ </extension>
+ </extensions>
+ <plugins>
+ <plugin>
+ <artifactId>maven-surefire-plugin</artifactId>
+ <configuration>
+ <enableAssertions>false</enableAssertions>
+ <systemPropertyVariables>
+ <arrow.test.dataRoot>${project.basedir}/../../../testing/data</arrow.test.dataRoot>
+ </systemPropertyVariables>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-shade-plugin</artifactId>
+ <version>3.1.1</version>
+ <executions>
+ <execution>
+ <id>shade-main</id>
+ <phase>package</phase>
+ <goals>
+ <goal>shade</goal>
+ </goals>
+ <configuration>
+ <shadedArtifactAttached>true</shadedArtifactAttached>
+ <shadedClassifierName>shaded</shadedClassifierName>
+ <artifactSet>
+ <includes>
+ <include>io.grpc:*</include>
+ <include>com.google.protobuf:*</include>
+ </includes>
+ </artifactSet>
+ <relocations>
+ <relocation>
+ <pattern>com.google.protobuf</pattern>
+ <shadedPattern>arrow.flight.com.google.protobuf</shadedPattern>
+ </relocation>
+ </relocations>
+ <transformers>
+ <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer" />
+ </transformers>
+ </configuration>
+ </execution>
+ <execution>
+ <id>shade-ext</id>
+ <phase>package</phase>
+ <goals>
+ <goal>shade</goal>
+ </goals>
+ <configuration>
+ <shadedArtifactAttached>true</shadedArtifactAttached>
+ <shadedClassifierName>shaded-ext</shadedClassifierName>
+ <artifactSet>
+ <includes>
+ <include>io.grpc:*</include>
+ <include>com.google.protobuf:*</include>
+ <include>com.google.guava:*</include>
+ </includes>
+ </artifactSet>
+ <relocations>
+ <relocation>
+ <pattern>com.google.protobuf</pattern>
+ <shadedPattern>arrow.flight.com.google.protobuf</shadedPattern>
+ </relocation>
+ <relocation>
+ <pattern>com.google.common</pattern>
+ <shadedPattern>arrow.flight.com.google.common</shadedPattern>
+ </relocation>
+ </relocations>
+ <transformers>
+ <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer" />
+ </transformers>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <groupId>org.xolstice.maven.plugins</groupId>
+ <artifactId>protobuf-maven-plugin</artifactId>
+ <version>0.5.0</version>
+ <configuration>
+ <protocArtifact>com.google.protobuf:protoc:${dep.protobuf.version}:exe:${os.detected.classifier}</protocArtifact>
+ <clearOutputDirectory>false</clearOutputDirectory>
+ <pluginId>grpc-java</pluginId>
+ <pluginArtifact>io.grpc:protoc-gen-grpc-java:${dep.grpc.version}:exe:${os.detected.classifier}</pluginArtifact>
+ </configuration>
+ <executions>
+ <execution>
+ <id>src</id>
+ <configuration>
+ <protoSourceRoot>${basedir}/../../../format/</protoSourceRoot>
+ <outputDirectory>${project.build.directory}/generated-sources/protobuf</outputDirectory>
+ </configuration>
+ <goals>
+ <goal>compile</goal>
+ <goal>compile-custom</goal>
+ </goals>
+ </execution>
+ <execution>
+ <id>test</id>
+ <configuration>
+ <protoSourceRoot>${basedir}/src/test/protobuf</protoSourceRoot>
+ <outputDirectory>${project.build.directory}/generated-test-sources//protobuf</outputDirectory>
+ </configuration>
+ <goals>
+ <goal>compile</goal>
+ <goal>compile-custom</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <groupId>org.codehaus.mojo</groupId>
+ <artifactId>exec-maven-plugin</artifactId>
+ <version>1.6.0</version>
+ <configuration>
+ <executable>java</executable>
+ <classpathScope>test</classpathScope>
+ <arguments>
+ <argument>-classpath</argument>
+ <classpath />
+ <argument>-Xms64m</argument>
+ <argument>-Xmx64m</argument>
+ <argument>-XX:MaxDirectMemorySize=4g</argument>
+ <argument>org.apache.arrow.flight.example.ExampleFlightServer</argument>
+ </arguments>
+ </configuration>
+ </plugin>
+ <plugin>
+ <groupId>org.apache.maven.plugins</groupId>
+ <artifactId>maven-dependency-plugin</artifactId>
+ <executions>
+ <execution>
+ <id>analyze</id>
+ <phase>verify</phase>
+ <goals>
+ <goal>analyze-only</goal>
+ </goals>
+ <configuration>
+ <ignoredDependencies combine.children="append">
+ <ignoredDependency>io.netty:netty-tcnative-boringssl-static:*</ignoredDependency>
+ </ignoredDependencies>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin> <!-- add generated sources to classpath -->
+ <groupId>org.codehaus.mojo</groupId>
+ <artifactId>build-helper-maven-plugin</artifactId>
+ <version>1.9.1</version>
+ <executions>
+ <execution>
+ <id>add-generated-sources-to-classpath</id>
+ <phase>generate-sources</phase>
+ <goals>
+ <goal>add-source</goal>
+ </goals>
+ <configuration>
+ <sources>
+ <source>${project.build.directory}/generated-sources/protobuf</source>
+ </sources>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
+ <plugin>
+ <artifactId>maven-assembly-plugin</artifactId>
+ <version>3.0.0</version>
+ <configuration>
+ <descriptorRefs>
+ <descriptorRef>jar-with-dependencies</descriptorRef>
+ </descriptorRefs>
+ </configuration>
+ <executions>
+ <execution>
+ <id>make-assembly</id>
+ <phase>package</phase>
+ <goals>
+ <goal>single</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+ <profiles>
+ <profile>
+ <id>linux-netty-native</id>
+ <activation>
+ <os>
+ <family>linux</family>
+ </os>
+ </activation>
+ <dependencies>
+ <dependency>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-transport-native-unix-common</artifactId>
+ <version>${dep.netty.version}</version>
+ <classifier>${os.detected.name}-${os.detected.arch}</classifier>
+ </dependency>
+ <dependency>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-transport-native-epoll</artifactId>
+ <version>${dep.netty.version}</version>
+ <classifier>${os.detected.name}-${os.detected.arch}</classifier>
+ </dependency>
+ </dependencies>
+ </profile>
+ <profile>
+ <id>mac-netty-native</id>
+ <activation>
+ <os>
+ <family>mac</family>
+ </os>
+ </activation>
+ <dependencies>
+ <dependency>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-transport-native-unix-common</artifactId>
+ <version>${dep.netty.version}</version>
+ <classifier>${os.detected.name}-${os.detected.arch}</classifier>
+ </dependency>
+ <dependency>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-transport-native-kqueue</artifactId>
+ <version>${dep.netty.version}</version>
+ <classifier>${os.detected.name}-${os.detected.arch}</classifier>
+ </dependency>
+ </dependencies>
+ </profile>
+ </profiles>
+</project>
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Action.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Action.java
new file mode 100644
index 000000000..524ffcab9
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Action.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import org.apache.arrow.flight.impl.Flight;
+
+import com.google.protobuf.ByteString;
+
+/**
+ * An opaque action for the service to perform.
+ *
+ * <p>This is a POJO wrapper around the message of the same name in Flight.proto.
+ */
+public class Action {
+
+ private final String type;
+ private final byte[] body;
+
+ public Action(String type) {
+ this(type, null);
+ }
+
+ public Action(String type, byte[] body) {
+ this.type = type;
+ this.body = body == null ? new byte[0] : body;
+ }
+
+ Action(Flight.Action action) {
+ this(action.getType(), action.getBody().toByteArray());
+ }
+
+ public String getType() {
+ return type;
+ }
+
+ public byte[] getBody() {
+ return body;
+ }
+
+ Flight.Action toProtocol() {
+ return Flight.Action.newBuilder()
+ .setType(getType())
+ .setBody(ByteString.copyFrom(getBody()))
+ .build();
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ActionType.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ActionType.java
new file mode 100644
index 000000000..d89365612
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ActionType.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import org.apache.arrow.flight.impl.Flight;
+
+/**
+ * POJO wrapper around protocol specifics for Flight actions.
+ */
+public class ActionType {
+ private final String type;
+ private final String description;
+
+ /**
+ * Construct a new instance.
+ *
+ * @param type The type of action to perform
+ * @param description The description of the type.
+ */
+ public ActionType(String type, String description) {
+ super();
+ this.type = type;
+ this.description = description;
+ }
+
+ /**
+ * Constructs a new instance from the corresponding protocol buffer object.
+ */
+ ActionType(Flight.ActionType type) {
+ this.type = type.getType();
+ this.description = type.getDescription();
+ }
+
+ public String getType() {
+ return type;
+ }
+
+ /**
+ * Converts the POJO to the corresponding protocol buffer type.
+ */
+ Flight.ActionType toProtocol() {
+ return Flight.ActionType.newBuilder()
+ .setType(type)
+ .setDescription(description)
+ .build();
+ }
+
+ @Override
+ public String toString() {
+ return "ActionType{" +
+ "type='" + type + '\'' +
+ ", description='" + description + '\'' +
+ '}';
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java
new file mode 100644
index 000000000..b4ee835de
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java
@@ -0,0 +1,560 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.arrow.flight.grpc.AddWritableBuffer;
+import org.apache.arrow.flight.grpc.GetReadableBuffer;
+import org.apache.arrow.flight.impl.Flight.FlightData;
+import org.apache.arrow.flight.impl.Flight.FlightDescriptor;
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.util.Preconditions;
+import org.apache.arrow.vector.compression.NoCompressionCodec;
+import org.apache.arrow.vector.ipc.message.ArrowBodyCompression;
+import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.ipc.message.IpcOption;
+import org.apache.arrow.vector.ipc.message.MessageMetadataResult;
+import org.apache.arrow.vector.ipc.message.MessageSerializer;
+import org.apache.arrow.vector.types.MetadataVersion;
+import org.apache.arrow.vector.types.pojo.Schema;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
+import com.google.common.io.ByteStreams;
+import com.google.protobuf.ByteString;
+import com.google.protobuf.CodedInputStream;
+import com.google.protobuf.CodedOutputStream;
+import com.google.protobuf.WireFormat;
+
+import io.grpc.Drainable;
+import io.grpc.MethodDescriptor.Marshaller;
+import io.grpc.protobuf.ProtoUtils;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufInputStream;
+import io.netty.buffer.CompositeByteBuf;
+import io.netty.buffer.Unpooled;
+import io.netty.buffer.UnpooledByteBufAllocator;
+
+/**
+ * The in-memory representation of FlightData used to manage a stream of Arrow messages.
+ */
+class ArrowMessage implements AutoCloseable {
+
+ // If true, deserialize Arrow data by giving Arrow a reference to the underlying gRPC buffer
+ // instead of copying the data. Defaults to true.
+ public static final boolean ENABLE_ZERO_COPY_READ;
+ // If true, serialize Arrow data by giving gRPC a reference to the underlying Arrow buffer
+ // instead of copying the data. Defaults to false.
+ public static final boolean ENABLE_ZERO_COPY_WRITE;
+
+ static {
+ String zeroCopyReadFlag = System.getProperty("arrow.flight.enable_zero_copy_read");
+ if (zeroCopyReadFlag == null) {
+ zeroCopyReadFlag = System.getenv("ARROW_FLIGHT_ENABLE_ZERO_COPY_READ");
+ }
+ String zeroCopyWriteFlag = System.getProperty("arrow.flight.enable_zero_copy_write");
+ if (zeroCopyWriteFlag == null) {
+ zeroCopyWriteFlag = System.getenv("ARROW_FLIGHT_ENABLE_ZERO_COPY_WRITE");
+ }
+ ENABLE_ZERO_COPY_READ = !"false".equalsIgnoreCase(zeroCopyReadFlag);
+ ENABLE_ZERO_COPY_WRITE = "true".equalsIgnoreCase(zeroCopyWriteFlag);
+ }
+
+ private static final int DESCRIPTOR_TAG =
+ (FlightData.FLIGHT_DESCRIPTOR_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED;
+ private static final int BODY_TAG =
+ (FlightData.DATA_BODY_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED;
+ private static final int HEADER_TAG =
+ (FlightData.DATA_HEADER_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED;
+ private static final int APP_METADATA_TAG =
+ (FlightData.APP_METADATA_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED;
+
+ private static final Marshaller<FlightData> NO_BODY_MARSHALLER =
+ ProtoUtils.marshaller(FlightData.getDefaultInstance());
+
+ /** Get the application-specific metadata in this message. The ArrowMessage retains ownership of the buffer. */
+ public ArrowBuf getApplicationMetadata() {
+ return appMetadata;
+ }
+
+ /** Types of messages that can be sent. */
+ public enum HeaderType {
+ NONE,
+ SCHEMA,
+ DICTIONARY_BATCH,
+ RECORD_BATCH,
+ TENSOR
+ ;
+
+ public static HeaderType getHeader(byte b) {
+ switch (b) {
+ case 0: return NONE;
+ case 1: return SCHEMA;
+ case 2: return DICTIONARY_BATCH;
+ case 3: return RECORD_BATCH;
+ case 4: return TENSOR;
+ default:
+ throw new UnsupportedOperationException("unknown type: " + b);
+ }
+ }
+
+ }
+
+ // Pre-allocated buffers for padding serialized ArrowMessages.
+ private static final List<ByteBuf> PADDING_BUFFERS = Arrays.asList(
+ null,
+ Unpooled.copiedBuffer(new byte[] { 0 }),
+ Unpooled.copiedBuffer(new byte[] { 0, 0 }),
+ Unpooled.copiedBuffer(new byte[] { 0, 0, 0 }),
+ Unpooled.copiedBuffer(new byte[] { 0, 0, 0, 0 }),
+ Unpooled.copiedBuffer(new byte[] { 0, 0, 0, 0, 0 }),
+ Unpooled.copiedBuffer(new byte[] { 0, 0, 0, 0, 0, 0 }),
+ Unpooled.copiedBuffer(new byte[] { 0, 0, 0, 0, 0, 0, 0 })
+ );
+
+ private final IpcOption writeOption;
+ private final FlightDescriptor descriptor;
+ private final MessageMetadataResult message;
+ private final ArrowBuf appMetadata;
+ private final List<ArrowBuf> bufs;
+ private final ArrowBodyCompression bodyCompression;
+ private final boolean tryZeroCopyWrite;
+
+ public ArrowMessage(FlightDescriptor descriptor, Schema schema, IpcOption option) {
+ this.writeOption = option;
+ ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(schema, writeOption);
+ this.message = MessageMetadataResult.create(serializedMessage.slice(),
+ serializedMessage.remaining());
+ bufs = ImmutableList.of();
+ this.descriptor = descriptor;
+ this.appMetadata = null;
+ this.bodyCompression = NoCompressionCodec.DEFAULT_BODY_COMPRESSION;
+ this.tryZeroCopyWrite = false;
+ }
+
+ /**
+ * Create an ArrowMessage from a record batch and app metadata.
+ * @param batch The record batch.
+ * @param appMetadata The app metadata. May be null. Takes ownership of the buffer otherwise.
+ * @param tryZeroCopy Whether to enable the zero-copy optimization.
+ */
+ public ArrowMessage(ArrowRecordBatch batch, ArrowBuf appMetadata, boolean tryZeroCopy, IpcOption option) {
+ this.writeOption = option;
+ ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(batch, writeOption);
+ this.message = MessageMetadataResult.create(serializedMessage.slice(), serializedMessage.remaining());
+ this.bufs = ImmutableList.copyOf(batch.getBuffers());
+ this.descriptor = null;
+ this.appMetadata = appMetadata;
+ this.bodyCompression = batch.getBodyCompression();
+ this.tryZeroCopyWrite = tryZeroCopy;
+ }
+
+ public ArrowMessage(ArrowDictionaryBatch batch, IpcOption option) {
+ this.writeOption = option;
+ ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(batch, writeOption);
+ serializedMessage = serializedMessage.slice();
+ this.message = MessageMetadataResult.create(serializedMessage, serializedMessage.remaining());
+ // asInputStream will free the buffers implicitly, so increment the reference count
+ batch.getDictionary().getBuffers().forEach(buf -> buf.getReferenceManager().retain());
+ this.bufs = ImmutableList.copyOf(batch.getDictionary().getBuffers());
+ this.descriptor = null;
+ this.appMetadata = null;
+ this.bodyCompression = batch.getDictionary().getBodyCompression();
+ this.tryZeroCopyWrite = false;
+ }
+
+ /**
+ * Create an ArrowMessage containing only application metadata.
+ * @param appMetadata The application-provided metadata buffer.
+ */
+ public ArrowMessage(ArrowBuf appMetadata) {
+ // No need to take IpcOption as it's not used to serialize this kind of message.
+ this.writeOption = IpcOption.DEFAULT;
+ this.message = null;
+ this.bufs = ImmutableList.of();
+ this.descriptor = null;
+ this.appMetadata = appMetadata;
+ this.bodyCompression = NoCompressionCodec.DEFAULT_BODY_COMPRESSION;
+ this.tryZeroCopyWrite = false;
+ }
+
+ public ArrowMessage(FlightDescriptor descriptor) {
+ // No need to take IpcOption as it's not used to serialize this kind of message.
+ this.writeOption = IpcOption.DEFAULT;
+ this.message = null;
+ this.bufs = ImmutableList.of();
+ this.descriptor = descriptor;
+ this.appMetadata = null;
+ this.bodyCompression = NoCompressionCodec.DEFAULT_BODY_COMPRESSION;
+ this.tryZeroCopyWrite = false;
+ }
+
+ private ArrowMessage(FlightDescriptor descriptor, MessageMetadataResult message, ArrowBuf appMetadata,
+ ArrowBuf buf) {
+ // No need to take IpcOption as this is used for deserialized ArrowMessage coming from the wire.
+ this.writeOption = message != null ?
+ // avoid writing legacy ipc format by default
+ new IpcOption(false, MetadataVersion.fromFlatbufID(message.getMessage().version())) :
+ IpcOption.DEFAULT;
+ this.message = message;
+ this.descriptor = descriptor;
+ this.appMetadata = appMetadata;
+ this.bufs = buf == null ? ImmutableList.of() : ImmutableList.of(buf);
+ this.bodyCompression = NoCompressionCodec.DEFAULT_BODY_COMPRESSION;
+ this.tryZeroCopyWrite = false;
+ }
+
+ public MessageMetadataResult asSchemaMessage() {
+ return message;
+ }
+
+ public FlightDescriptor getDescriptor() {
+ return descriptor;
+ }
+
+ public HeaderType getMessageType() {
+ if (message == null) {
+ // Null message occurs for metadata-only messages (in DoExchange)
+ return HeaderType.NONE;
+ }
+ return HeaderType.getHeader(message.headerType());
+ }
+
+ public Schema asSchema() {
+ Preconditions.checkArgument(bufs.size() == 0);
+ Preconditions.checkArgument(getMessageType() == HeaderType.SCHEMA);
+ return MessageSerializer.deserializeSchema(message);
+ }
+
+ public ArrowRecordBatch asRecordBatch() throws IOException {
+ Preconditions.checkArgument(bufs.size() == 1, "A batch can only be consumed if it contains a single ArrowBuf.");
+ Preconditions.checkArgument(getMessageType() == HeaderType.RECORD_BATCH);
+
+ ArrowBuf underlying = bufs.get(0);
+
+ underlying.getReferenceManager().retain();
+ return MessageSerializer.deserializeRecordBatch(message, underlying);
+ }
+
+ public ArrowDictionaryBatch asDictionaryBatch() throws IOException {
+ Preconditions.checkArgument(bufs.size() == 1, "A batch can only be consumed if it contains a single ArrowBuf.");
+ Preconditions.checkArgument(getMessageType() == HeaderType.DICTIONARY_BATCH);
+ ArrowBuf underlying = bufs.get(0);
+ // Retain a reference to keep the batch alive when the message is closed
+ underlying.getReferenceManager().retain();
+ // Do not set drained - we still want to release our reference
+ return MessageSerializer.deserializeDictionaryBatch(message, underlying);
+ }
+
+ public Iterable<ArrowBuf> getBufs() {
+ return Iterables.unmodifiableIterable(bufs);
+ }
+
+ private static ArrowMessage frame(BufferAllocator allocator, final InputStream stream) {
+
+ try {
+ FlightDescriptor descriptor = null;
+ MessageMetadataResult header = null;
+ ArrowBuf body = null;
+ ArrowBuf appMetadata = null;
+ while (stream.available() > 0) {
+ int tag = readRawVarint32(stream);
+ switch (tag) {
+
+ case DESCRIPTOR_TAG: {
+ int size = readRawVarint32(stream);
+ byte[] bytes = new byte[size];
+ ByteStreams.readFully(stream, bytes);
+ descriptor = FlightDescriptor.parseFrom(bytes);
+ break;
+ }
+ case HEADER_TAG: {
+ int size = readRawVarint32(stream);
+ byte[] bytes = new byte[size];
+ ByteStreams.readFully(stream, bytes);
+ header = MessageMetadataResult.create(ByteBuffer.wrap(bytes), size);
+ break;
+ }
+ case APP_METADATA_TAG: {
+ int size = readRawVarint32(stream);
+ appMetadata = allocator.buffer(size);
+ GetReadableBuffer.readIntoBuffer(stream, appMetadata, size, ENABLE_ZERO_COPY_READ);
+ break;
+ }
+ case BODY_TAG:
+ if (body != null) {
+ // only read last body.
+ body.getReferenceManager().release();
+ body = null;
+ }
+ int size = readRawVarint32(stream);
+ body = allocator.buffer(size);
+ GetReadableBuffer.readIntoBuffer(stream, body, size, ENABLE_ZERO_COPY_READ);
+ break;
+
+ default:
+ // ignore unknown fields.
+ }
+ }
+ // Protobuf implementations can omit empty fields, such as body; for some message types, like RecordBatch,
+ // this will fail later as we still expect an empty buffer. In those cases only, fill in an empty buffer here -
+ // in other cases, like Schema, having an unexpected empty buffer will also cause failures.
+ // We don't fill in defaults for fields like header, for which there is no reasonable default, or for appMetadata
+ // or descriptor, which are intended to be empty in some cases.
+ if (header != null) {
+ switch (HeaderType.getHeader(header.headerType())) {
+ case SCHEMA:
+ // Ignore 0-length buffers in case a Protobuf implementation wrote it out
+ if (body != null && body.capacity() == 0) {
+ body.close();
+ body = null;
+ }
+ break;
+ case DICTIONARY_BATCH:
+ case RECORD_BATCH:
+ // A Protobuf implementation can skip 0-length bodies, so ensure we fill it in here
+ if (body == null) {
+ body = allocator.getEmpty();
+ }
+ break;
+ case NONE:
+ case TENSOR:
+ default:
+ // Do nothing
+ break;
+ }
+ }
+ return new ArrowMessage(descriptor, header, appMetadata, body);
+ } catch (Exception ioe) {
+ throw new RuntimeException(ioe);
+ }
+
+ }
+
+ private static int readRawVarint32(InputStream is) throws IOException {
+ int firstByte = is.read();
+ return CodedInputStream.readRawVarint32(firstByte, is);
+ }
+
+ /**
+ * Convert the ArrowMessage to an InputStream.
+ *
+ * <p>Implicitly, this transfers ownership of the contained buffers to the InputStream.
+ *
+ * @return InputStream
+ */
+ private InputStream asInputStream(BufferAllocator allocator) {
+ if (message == null) {
+ // If we have no IPC message, it's a pure-metadata message
+ final FlightData.Builder builder = FlightData.newBuilder();
+ if (descriptor != null) {
+ builder.setFlightDescriptor(descriptor);
+ }
+ if (appMetadata != null) {
+ builder.setAppMetadata(ByteString.copyFrom(appMetadata.nioBuffer()));
+ }
+ return NO_BODY_MARSHALLER.stream(builder.build());
+ }
+
+ try {
+ final ByteString bytes = ByteString.copyFrom(message.getMessageBuffer(),
+ message.bytesAfterMessage());
+
+ if (getMessageType() == HeaderType.SCHEMA) {
+
+ final FlightData.Builder builder = FlightData.newBuilder()
+ .setDataHeader(bytes);
+
+ if (descriptor != null) {
+ builder.setFlightDescriptor(descriptor);
+ }
+
+ Preconditions.checkArgument(bufs.isEmpty());
+ return NO_BODY_MARSHALLER.stream(builder.build());
+ }
+
+ Preconditions.checkArgument(getMessageType() == HeaderType.RECORD_BATCH ||
+ getMessageType() == HeaderType.DICTIONARY_BATCH);
+ // There may be no buffers in the case that we write only a null array
+ Preconditions.checkArgument(descriptor == null, "Descriptor should only be included in the schema message.");
+
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ CodedOutputStream cos = CodedOutputStream.newInstance(baos);
+ cos.writeBytes(FlightData.DATA_HEADER_FIELD_NUMBER, bytes);
+
+ if (appMetadata != null && appMetadata.capacity() > 0) {
+ // Must call slice() as CodedOutputStream#writeByteBuffer writes -capacity- bytes, not -limit- bytes
+ cos.writeByteBuffer(FlightData.APP_METADATA_FIELD_NUMBER, appMetadata.nioBuffer().slice());
+ }
+
+ cos.writeTag(FlightData.DATA_BODY_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED);
+ int size = 0;
+ List<ByteBuf> allBufs = new ArrayList<>();
+ for (ArrowBuf b : bufs) {
+ // [ARROW-11066] This creates a Netty buffer whose refcnt is INDEPENDENT of the backing
+ // Arrow buffer. This is susceptible to use-after-free, so we subclass CompositeByteBuf
+ // below to tie the Arrow buffer refcnt to the Netty buffer refcnt
+ allBufs.add(Unpooled.wrappedBuffer(b.nioBuffer()).retain());
+ size += b.readableBytes();
+ // [ARROW-4213] These buffers must be aligned to an 8-byte boundary in order to be readable from C++.
+ if (b.readableBytes() % 8 != 0) {
+ int paddingBytes = (int) (8 - (b.readableBytes() % 8));
+ assert paddingBytes > 0 && paddingBytes < 8;
+ size += paddingBytes;
+ allBufs.add(PADDING_BUFFERS.get(paddingBytes).retain());
+ }
+ }
+ // rawvarint is used for length definition.
+ cos.writeUInt32NoTag(size);
+ cos.flush();
+
+ ByteBuf initialBuf = Unpooled.buffer(baos.size());
+ initialBuf.writeBytes(baos.toByteArray());
+ final CompositeByteBuf bb;
+ final int maxNumComponents = Math.max(2, bufs.size() + 1);
+ final ImmutableList<ByteBuf> byteBufs = ImmutableList.<ByteBuf>builder()
+ .add(initialBuf)
+ .addAll(allBufs)
+ .build();
+ if (tryZeroCopyWrite) {
+ bb = new ArrowBufRetainingCompositeByteBuf(maxNumComponents, byteBufs, bufs);
+ } else {
+ // Don't retain the buffers in the non-zero-copy path since we're copying them
+ bb = new CompositeByteBuf(UnpooledByteBufAllocator.DEFAULT, /* direct */ true, maxNumComponents, byteBufs);
+ }
+ return new DrainableByteBufInputStream(bb, tryZeroCopyWrite);
+ } catch (Exception ex) {
+ throw new RuntimeException("Unexpected IO Exception", ex);
+ }
+
+ }
+
+ /**
+ * ARROW-11066: enable the zero-copy optimization and protect against use-after-free.
+ *
+ * When you send a message through gRPC, the following happens:
+ * 1. gRPC immediately serializes the message, eventually calling asInputStream above.
+ * 2. gRPC buffers the serialized message for sending.
+ * 3. Later, gRPC will actually write out the message.
+ *
+ * The problem with this is that when the zero-copy optimization is enabled, Flight
+ * "serializes" the message by handing gRPC references to Arrow data. That means we need
+ * a way to keep the Arrow buffers valid until gRPC actually writes them, else, we'll read
+ * invalid data or segfault. gRPC doesn't know anything about Arrow buffers, either.
+ *
+ * This class solves that issue by bridging Arrow and Netty/gRPC. We increment the refcnt
+ * on a set of Arrow backing buffers and decrement them once the Netty buffers are freed
+ * by gRPC.
+ */
+ private static final class ArrowBufRetainingCompositeByteBuf extends CompositeByteBuf {
+ // Arrow buffers that back the Netty ByteBufs here; ByteBufs held by this class are
+ // either slices of one of the ArrowBufs or independently allocated.
+ final List<ArrowBuf> backingBuffers;
+ boolean freed;
+
+ ArrowBufRetainingCompositeByteBuf(int maxNumComponents, Iterable<ByteBuf> buffers, List<ArrowBuf> backingBuffers) {
+ super(UnpooledByteBufAllocator.DEFAULT, /* direct */ true, maxNumComponents, buffers);
+ this.backingBuffers = backingBuffers;
+ this.freed = false;
+ // N.B. the Netty superclass avoids enhanced-for to reduce GC pressure, so follow that here
+ for (int i = 0; i < backingBuffers.size(); i++) {
+ backingBuffers.get(i).getReferenceManager().retain();
+ }
+ }
+
+ @Override
+ protected void deallocate() {
+ super.deallocate();
+ if (freed) {
+ return;
+ }
+ freed = true;
+ for (int i = 0; i < backingBuffers.size(); i++) {
+ backingBuffers.get(i).getReferenceManager().release();
+ }
+ }
+ }
+
+ private static class DrainableByteBufInputStream extends ByteBufInputStream implements Drainable {
+
+ private final CompositeByteBuf buf;
+ private final boolean isZeroCopy;
+
+ public DrainableByteBufInputStream(CompositeByteBuf buffer, boolean isZeroCopy) {
+ super(buffer, buffer.readableBytes(), true);
+ this.buf = buffer;
+ this.isZeroCopy = isZeroCopy;
+ }
+
+ @Override
+ public int drainTo(OutputStream target) throws IOException {
+ int size = buf.readableBytes();
+ AddWritableBuffer.add(buf, target, isZeroCopy);
+ return size;
+ }
+
+ @Override
+ public void close() {
+ buf.release();
+ }
+
+
+
+ }
+
+ public static Marshaller<ArrowMessage> createMarshaller(BufferAllocator allocator) {
+ return new ArrowMessageHolderMarshaller(allocator);
+ }
+
+ private static class ArrowMessageHolderMarshaller implements Marshaller<ArrowMessage> {
+
+ private final BufferAllocator allocator;
+
+ public ArrowMessageHolderMarshaller(BufferAllocator allocator) {
+ this.allocator = allocator;
+ }
+
+ @Override
+ public InputStream stream(ArrowMessage value) {
+ return value.asInputStream(allocator);
+ }
+
+ @Override
+ public ArrowMessage parse(InputStream stream) {
+ return ArrowMessage.frame(allocator, stream);
+ }
+
+ }
+
+ @Override
+ public void close() throws Exception {
+ AutoCloseables.close(Iterables.concat(bufs, Collections.singletonList(appMetadata)));
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/AsyncPutListener.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/AsyncPutListener.java
new file mode 100644
index 000000000..a45463225
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/AsyncPutListener.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+
+import org.apache.arrow.flight.grpc.StatusUtils;
+
+/**
+ * A handler for server-sent application metadata messages during a Flight DoPut operation.
+ *
+ * <p>To handle messages, create an instance of this class overriding {@link #onNext(PutResult)}. The other methods
+ * should not be overridden.
+ */
+public class AsyncPutListener implements FlightClient.PutListener {
+
+ private CompletableFuture<Void> completed;
+
+ public AsyncPutListener() {
+ completed = new CompletableFuture<>();
+ }
+
+ /**
+ * Wait for the stream to finish on the server side. You must call this to be notified of any errors that may have
+ * happened during the upload.
+ */
+ @Override
+ public final void getResult() {
+ try {
+ completed.get();
+ } catch (ExecutionException e) {
+ throw StatusUtils.fromThrowable(e.getCause());
+ } catch (InterruptedException e) {
+ throw StatusUtils.fromThrowable(e);
+ }
+ }
+
+ @Override
+ public void onNext(PutResult val) {
+ }
+
+ @Override
+ public final void onError(Throwable t) {
+ completed.completeExceptionally(StatusUtils.fromThrowable(t));
+ }
+
+ @Override
+ public final void onCompleted() {
+ completed.complete(null);
+ }
+
+ @Override
+ public boolean isCancelled() {
+ return completed.isDone();
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/BackpressureStrategy.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/BackpressureStrategy.java
new file mode 100644
index 000000000..de34643a7
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/BackpressureStrategy.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import org.apache.arrow.vector.VectorSchemaRoot;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Helper interface to dynamically handle backpressure when implementing FlightProducers.
+ * This must only be used in FlightProducer implementations that are non-blocking.
+ */
+public interface BackpressureStrategy {
+ /**
+ * The state of the client after a call to waitForListener.
+ */
+ enum WaitResult {
+ /**
+ * Listener is ready.
+ */
+ READY,
+
+ /**
+ * Listener was cancelled by the client.
+ */
+ CANCELLED,
+
+ /**
+ * Timed out waiting for the listener to change state.
+ */
+ TIMEOUT,
+
+ /**
+ * Indicates that the wait was interrupted for a reason
+ * unrelated to the listener itself.
+ */
+ OTHER
+ }
+
+ /**
+ * Set up operations to work against the given listener.
+ *
+ * This must be called exactly once and before any calls to {@link #waitForListener(long)} and
+ * {@link OutboundStreamListener#start(VectorSchemaRoot)}
+ * @param listener The listener this strategy applies to.
+ */
+ void register(FlightProducer.ServerStreamListener listener);
+
+ /**
+ * Waits for the listener to be ready or cancelled up to the given timeout.
+ *
+ * @param timeout The timeout in milliseconds. Infinite if timeout is <= 0.
+ * @return The result of the wait.
+ */
+ WaitResult waitForListener(long timeout);
+
+ /**
+ * A back pressure strategy that uses callbacks to notify when the client is ready or cancelled.
+ */
+ class CallbackBackpressureStrategy implements BackpressureStrategy {
+ private final Object lock = new Object();
+ private FlightProducer.ServerStreamListener listener;
+
+ @Override
+ public void register(FlightProducer.ServerStreamListener listener) {
+ this.listener = listener;
+ listener.setOnReadyHandler(this::onReady);
+ listener.setOnCancelHandler(this::onCancel);
+ }
+
+ @Override
+ public WaitResult waitForListener(long timeout) {
+ Preconditions.checkNotNull(listener);
+ long remainingTimeout = timeout;
+ final long startTime = System.currentTimeMillis();
+ synchronized (lock) {
+ while (!listener.isReady() && !listener.isCancelled()) {
+ try {
+ lock.wait(remainingTimeout);
+ if (timeout != 0) { // If timeout was zero explicitly, we should never report timeout.
+ remainingTimeout = startTime + timeout - System.currentTimeMillis();
+ if (remainingTimeout <= 0) {
+ return WaitResult.TIMEOUT;
+ }
+ }
+ if (!shouldContinueWaiting(listener, remainingTimeout)) {
+ return WaitResult.OTHER;
+ }
+ } catch (InterruptedException ex) {
+ Thread.currentThread().interrupt();
+ return WaitResult.OTHER;
+ }
+ }
+
+ if (listener.isReady()) {
+ return WaitResult.READY;
+ } else if (listener.isCancelled()) {
+ return WaitResult.CANCELLED;
+ } else if (System.currentTimeMillis() > startTime + timeout) {
+ return WaitResult.TIMEOUT;
+ }
+ throw new RuntimeException("Invalid state when waiting for listener.");
+ }
+ }
+
+ /**
+ * Interrupt waiting on the listener to change state.
+ *
+ * This method can be used in conjunction with
+ * {@link #shouldContinueWaiting(FlightProducer.ServerStreamListener, long)} to allow FlightProducers to
+ * terminate streams internally and notify clients.
+ */
+ public void interruptWait() {
+ synchronized (lock) {
+ lock.notifyAll();
+ }
+ }
+
+ /**
+ * Callback function to run to check if the listener should continue
+ * to be waited on if it leaves the waiting state without being cancelled,
+ * ready, or timed out.
+ *
+ * This method should be used to determine if the wait on the listener was interrupted explicitly using a
+ * call to {@link #interruptWait()} or if it was woken up due to a spurious wake.
+ */
+ protected boolean shouldContinueWaiting(FlightProducer.ServerStreamListener listener, long remainingTimeout) {
+ return true;
+ }
+
+ /**
+ * Callback to execute when the listener becomes ready.
+ */
+ protected void readyCallback() {
+ }
+
+ /**
+ * Callback to execute when the listener is cancelled.
+ */
+ protected void cancelCallback() {
+ }
+
+ private void onReady() {
+ synchronized (lock) {
+ readyCallback();
+ lock.notifyAll();
+ }
+ }
+
+ private void onCancel() {
+ synchronized (lock) {
+ cancelCallback();
+ lock.notifyAll();
+ }
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallHeaders.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallHeaders.java
new file mode 100644
index 000000000..32f9a8430
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallHeaders.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.Set;
+
+/**
+ * A set of metadata key value pairs for a call (request or response).
+ */
+public interface CallHeaders {
+ /**
+ * Get the value of a metadata key. If multiple values are present, then get the last one.
+ */
+ String get(String key);
+
+ /**
+ * Get the value of a metadata key. If multiple values are present, then get the last one.
+ */
+ byte[] getByte(String key);
+
+ /**
+ * Get all values present for the given metadata key.
+ */
+ Iterable<String> getAll(String key);
+
+ /**
+ * Get all values present for the given metadata key.
+ */
+ Iterable<byte[]> getAllByte(String key);
+
+ /**
+ * Insert a metadata pair with the given value.
+ *
+ * <p>Duplicate metadata are permitted.
+ */
+ void insert(String key, String value);
+
+ /**
+ * Insert a metadata pair with the given value.
+ *
+ * <p>Duplicate metadata are permitted.
+ */
+ void insert(String key, byte[] value);
+
+ /** Get a set of all the metadata keys. */
+ Set<String> keys();
+
+ /** Check whether the given metadata key is present. */
+ boolean containsKey(String key);
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallInfo.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallInfo.java
new file mode 100644
index 000000000..744584bdf
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallInfo.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+/**
+ * A description of a Flight call for middleware to inspect.
+ */
+public final class CallInfo {
+ private final FlightMethod method;
+
+ public CallInfo(FlightMethod method) {
+ this.method = method;
+ }
+
+ public FlightMethod method() {
+ return method;
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallOption.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallOption.java
new file mode 100644
index 000000000..d3ee3ab4c
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallOption.java
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+/**
+ * Per-call RPC options. These are hints to the underlying RPC layer and may not be respected.
+ */
+public interface CallOption {
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallOptions.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallOptions.java
new file mode 100644
index 000000000..bbb4edef9
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallOptions.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.concurrent.TimeUnit;
+
+import io.grpc.stub.AbstractStub;
+
+/**
+ * Common call options.
+ */
+public class CallOptions {
+ public static CallOption timeout(long duration, TimeUnit unit) {
+ return new Timeout(duration, unit);
+ }
+
+ static <T extends AbstractStub<T>> T wrapStub(T stub, CallOption[] options) {
+ for (CallOption option : options) {
+ if (option instanceof GrpcCallOption) {
+ stub = ((GrpcCallOption) option).wrapStub(stub);
+ }
+ }
+ return stub;
+ }
+
+ private static class Timeout implements GrpcCallOption {
+ long timeout;
+ TimeUnit timeoutUnit;
+
+ Timeout(long timeout, TimeUnit timeoutUnit) {
+ this.timeout = timeout;
+ this.timeoutUnit = timeoutUnit;
+ }
+
+ @Override
+ public <T extends AbstractStub<T>> T wrapStub(T stub) {
+ return stub.withDeadlineAfter(timeout, timeoutUnit);
+ }
+ }
+
+ /**
+ * CallOptions specific to GRPC stubs.
+ */
+ public interface GrpcCallOption extends CallOption {
+ <T extends AbstractStub<T>> T wrapStub(T stub);
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallStatus.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallStatus.java
new file mode 100644
index 000000000..991d0ed6a
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallStatus.java
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.Objects;
+
+import org.apache.arrow.flight.FlightProducer.ServerStreamListener;
+import org.apache.arrow.flight.FlightProducer.StreamListener;
+
+/**
+ * The result of a Flight RPC, consisting of a status code with an optional description and/or exception that led
+ * to the status.
+ *
+ * <p>If raised or sent through {@link StreamListener#onError(Throwable)} or
+ * {@link ServerStreamListener#error(Throwable)}, the client call will raise the same error (a
+ * {@link FlightRuntimeException} with the same {@link FlightStatusCode} and description). The exception within, if
+ * present, will not be sent to the client.
+ */
+public class CallStatus {
+
+ private final FlightStatusCode code;
+ private final Throwable cause;
+ private final String description;
+ private final ErrorFlightMetadata metadata;
+
+ public static final CallStatus UNKNOWN = FlightStatusCode.UNKNOWN.toStatus();
+ public static final CallStatus INTERNAL = FlightStatusCode.INTERNAL.toStatus();
+ public static final CallStatus INVALID_ARGUMENT = FlightStatusCode.INVALID_ARGUMENT.toStatus();
+ public static final CallStatus TIMED_OUT = FlightStatusCode.TIMED_OUT.toStatus();
+ public static final CallStatus NOT_FOUND = FlightStatusCode.NOT_FOUND.toStatus();
+ public static final CallStatus ALREADY_EXISTS = FlightStatusCode.ALREADY_EXISTS.toStatus();
+ public static final CallStatus CANCELLED = FlightStatusCode.CANCELLED.toStatus();
+ public static final CallStatus UNAUTHENTICATED = FlightStatusCode.UNAUTHENTICATED.toStatus();
+ public static final CallStatus UNAUTHORIZED = FlightStatusCode.UNAUTHORIZED.toStatus();
+ public static final CallStatus UNIMPLEMENTED = FlightStatusCode.UNIMPLEMENTED.toStatus();
+ public static final CallStatus UNAVAILABLE = FlightStatusCode.UNAVAILABLE.toStatus();
+
+ /**
+ * Create a new status.
+ *
+ * @param code The status code.
+ * @param cause An exception that resulted in this status (or null).
+ * @param description A description of the status (or null).
+ */
+ public CallStatus(FlightStatusCode code, Throwable cause, String description, ErrorFlightMetadata metadata) {
+ this.code = Objects.requireNonNull(code);
+ this.cause = cause;
+ this.description = description == null ? "" : description;
+ this.metadata = metadata == null ? new ErrorFlightMetadata() : metadata;
+ }
+
+ /**
+ * Create a new status with no cause or description.
+ *
+ * @param code The status code.
+ */
+ public CallStatus(FlightStatusCode code) {
+ this(code, /* no cause */ null, /* no description */ null, /* no metadata */ null);
+ }
+
+ /**
+ * The status code describing the result of the RPC.
+ */
+ public FlightStatusCode code() {
+ return code;
+ }
+
+ /**
+ * The exception that led to this result. May be null.
+ */
+ public Throwable cause() {
+ return cause;
+ }
+
+ /**
+ * A description of the result.
+ */
+ public String description() {
+ return description;
+ }
+
+ /**
+ * Metadata associated with the exception.
+ *
+ * May be null.
+ */
+ public ErrorFlightMetadata metadata() {
+ return metadata;
+ }
+
+ /**
+ * Return a copy of this status with an error message.
+ */
+ public CallStatus withDescription(String message) {
+ return new CallStatus(code, cause, message, metadata);
+ }
+
+ /**
+ * Return a copy of this status with the given exception as the cause. This will not be sent over the wire.
+ */
+ public CallStatus withCause(Throwable t) {
+ return new CallStatus(code, t, description, metadata);
+ }
+
+ /**
+ * Return a copy of this status with associated exception metadata.
+ */
+ public CallStatus withMetadata(ErrorFlightMetadata metadata) {
+ return new CallStatus(code, cause, description, metadata);
+ }
+
+ /**
+ * Convert the status to an equivalent exception.
+ */
+ public FlightRuntimeException toRuntimeException() {
+ return new FlightRuntimeException(this);
+ }
+
+ @Override
+ public String toString() {
+ return "CallStatus{" +
+ "code=" + code +
+ ", cause=" + cause +
+ ", description='" + description +
+ "', metadata='" + metadata + '\'' +
+ '}';
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Criteria.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Criteria.java
new file mode 100644
index 000000000..989cd6581
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Criteria.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import org.apache.arrow.flight.impl.Flight;
+
+import com.google.protobuf.ByteString;
+
+/**
+ * An opaque object that can be used to filter a list of streams available from a server.
+ *
+ * <p>This is a POJO wrapper around the protobuf Criteria message.
+ */
+public class Criteria {
+
+ public static Criteria ALL = new Criteria((byte[]) null);
+
+ private final byte[] bytes;
+
+ public Criteria(byte[] bytes) {
+ this.bytes = bytes;
+ }
+
+ Criteria(Flight.Criteria criteria) {
+ this.bytes = criteria.getExpression().toByteArray();
+ }
+
+ /**
+ * Get the contained filter criteria.
+ */
+ public byte[] getExpression() {
+ return bytes;
+ }
+
+ Flight.Criteria asCriteria() {
+ Flight.Criteria.Builder b = Flight.Criteria.newBuilder();
+ if (bytes != null) {
+ b.setExpression(ByteString.copyFrom(bytes));
+ }
+
+ return b.build();
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java
new file mode 100644
index 000000000..516dab01d
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.function.Consumer;
+import java.util.stream.Collectors;
+
+import org.apache.arrow.flight.impl.Flight;
+import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.VectorUnloader;
+import org.apache.arrow.vector.dictionary.Dictionary;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
+import org.apache.arrow.vector.ipc.message.IpcOption;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.arrow.vector.util.DictionaryUtility;
+import org.apache.arrow.vector.validate.MetadataV4UnionChecker;
+
+/**
+ * Utilities to work with dictionaries in Flight.
+ */
+final class DictionaryUtils {
+
+ private DictionaryUtils() {
+ throw new UnsupportedOperationException("Do not instantiate this class.");
+ }
+
+ /**
+ * Generate all the necessary Flight messages to send a schema and associated dictionaries.
+ *
+ * @throws Exception if there was an error closing {@link ArrowMessage} objects. This is not generally expected.
+ */
+ static Schema generateSchemaMessages(final Schema originalSchema, final FlightDescriptor descriptor,
+ final DictionaryProvider provider, final IpcOption option,
+ final Consumer<ArrowMessage> messageCallback) throws Exception {
+ final Set<Long> dictionaryIds = new HashSet<>();
+ final Schema schema = generateSchema(originalSchema, provider, dictionaryIds);
+ MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option.metadataVersion);
+ // Send the schema message
+ final Flight.FlightDescriptor protoDescriptor = descriptor == null ? null : descriptor.toProtocol();
+ try (final ArrowMessage message = new ArrowMessage(protoDescriptor, schema, option)) {
+ messageCallback.accept(message);
+ }
+ // Create and write dictionary batches
+ for (Long id : dictionaryIds) {
+ final Dictionary dictionary = provider.lookup(id);
+ final FieldVector vector = dictionary.getVector();
+ final int count = vector.getValueCount();
+ // Do NOT close this root, as it does not actually own the vector.
+ final VectorSchemaRoot dictRoot = new VectorSchemaRoot(
+ Collections.singletonList(vector.getField()),
+ Collections.singletonList(vector),
+ count);
+ final VectorUnloader unloader = new VectorUnloader(dictRoot);
+ try (final ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(
+ id, unloader.getRecordBatch());
+ final ArrowMessage message = new ArrowMessage(dictionaryBatch, option)) {
+ messageCallback.accept(message);
+ }
+ }
+ return schema;
+ }
+
+ static void closeDictionaries(final Schema schema, final DictionaryProvider provider) throws Exception {
+ // Close dictionaries
+ final Set<Long> dictionaryIds = new HashSet<>();
+ schema.getFields().forEach(field -> DictionaryUtility.toMessageFormat(field, provider, dictionaryIds));
+
+ final List<AutoCloseable> dictionaryVectors = dictionaryIds.stream()
+ .map(id -> (AutoCloseable) provider.lookup(id).getVector()).collect(Collectors.toList());
+ AutoCloseables.close(dictionaryVectors);
+ }
+
+ /**
+ * Generates the schema to send with flight messages.
+ * If the schema contains no field with a dictionary, it will return the schema as is.
+ * Otherwise, it will return a newly created a new schema after converting the fields.
+ * @param originalSchema the original schema.
+ * @param provider the dictionary provider.
+ * @param dictionaryIds dictionary IDs that are used.
+ * @return the schema to send with the flight messages.
+ */
+ static Schema generateSchema(
+ final Schema originalSchema, final DictionaryProvider provider, Set<Long> dictionaryIds) {
+ // first determine if a new schema needs to be created.
+ boolean createSchema = false;
+ for (Field field : originalSchema.getFields()) {
+ if (DictionaryUtility.needConvertToMessageFormat(field)) {
+ createSchema = true;
+ break;
+ }
+ }
+
+ if (!createSchema) {
+ return originalSchema;
+ } else {
+ final List<Field> fields = new ArrayList<>(originalSchema.getFields().size());
+ for (final Field field : originalSchema.getFields()) {
+ fields.add(DictionaryUtility.toMessageFormat(field, provider, dictionaryIds));
+ }
+ return new Schema(fields, originalSchema.getCustomMetadata());
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ErrorFlightMetadata.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ErrorFlightMetadata.java
new file mode 100644
index 000000000..6669ce465
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ErrorFlightMetadata.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.nio.charset.StandardCharsets;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.StreamSupport;
+
+import com.google.common.collect.Iterables;
+import com.google.common.collect.LinkedListMultimap;
+import com.google.common.collect.Multimap;
+
+/**
+ * metadata container specific to the binary metadata held in the grpc trailer.
+ */
+public class ErrorFlightMetadata implements CallHeaders {
+ private final Multimap<String, byte[]> metadata = LinkedListMultimap.create();
+
+ public ErrorFlightMetadata() {
+ }
+
+
+ @Override
+ public String get(String key) {
+ return new String(getByte(key), StandardCharsets.US_ASCII);
+ }
+
+ @Override
+ public byte[] getByte(String key) {
+ return Iterables.getLast(metadata.get(key));
+ }
+
+ @Override
+ public Iterable<String> getAll(String key) {
+ return StreamSupport.stream(
+ getAllByte(key).spliterator(), false)
+ .map(b -> new String(b, StandardCharsets.US_ASCII))
+ .collect(Collectors.toList());
+ }
+
+ @Override
+ public Iterable<byte[]> getAllByte(String key) {
+ return metadata.get(key);
+ }
+
+ @Override
+ public void insert(String key, String value) {
+ metadata.put(key, value.getBytes());
+ }
+
+ @Override
+ public void insert(String key, byte[] value) {
+ metadata.put(key, value);
+ }
+
+ @Override
+ public Set<String> keys() {
+ return metadata.keySet();
+ }
+
+ @Override
+ public boolean containsKey(String key) {
+ return metadata.containsKey(key);
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightBindingService.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightBindingService.java
new file mode 100644
index 000000000..ba5249b4a
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightBindingService.java
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.Set;
+import java.util.concurrent.ExecutorService;
+
+import org.apache.arrow.flight.auth.ServerAuthHandler;
+import org.apache.arrow.flight.impl.Flight;
+import org.apache.arrow.flight.impl.Flight.PutResult;
+import org.apache.arrow.flight.impl.FlightServiceGrpc;
+import org.apache.arrow.memory.BufferAllocator;
+
+import com.google.common.collect.ImmutableSet;
+
+import io.grpc.BindableService;
+import io.grpc.MethodDescriptor;
+import io.grpc.MethodDescriptor.MethodType;
+import io.grpc.ServerMethodDefinition;
+import io.grpc.ServerServiceDefinition;
+import io.grpc.ServiceDescriptor;
+import io.grpc.protobuf.ProtoUtils;
+import io.grpc.stub.ServerCalls;
+import io.grpc.stub.StreamObserver;
+
+/**
+ * Extends the basic flight service to override some methods for more efficient implementations.
+ */
+class FlightBindingService implements BindableService {
+
+ private static final String DO_GET = MethodDescriptor.generateFullMethodName(FlightConstants.SERVICE, "DoGet");
+ private static final String DO_PUT = MethodDescriptor.generateFullMethodName(FlightConstants.SERVICE, "DoPut");
+ private static final String DO_EXCHANGE = MethodDescriptor.generateFullMethodName(
+ FlightConstants.SERVICE, "DoExchange");
+ private static final Set<String> OVERRIDE_METHODS = ImmutableSet.of(DO_GET, DO_PUT, DO_EXCHANGE);
+
+ private final FlightService delegate;
+ private final BufferAllocator allocator;
+
+ public FlightBindingService(BufferAllocator allocator, FlightProducer producer,
+ ServerAuthHandler authHandler, ExecutorService executor) {
+ this.allocator = allocator;
+ this.delegate = new FlightService(allocator, producer, authHandler, executor);
+ }
+
+ public static MethodDescriptor<Flight.Ticket, ArrowMessage> getDoGetDescriptor(BufferAllocator allocator) {
+ return MethodDescriptor.<Flight.Ticket, ArrowMessage>newBuilder()
+ .setType(io.grpc.MethodDescriptor.MethodType.SERVER_STREAMING)
+ .setFullMethodName(DO_GET)
+ .setSampledToLocalTracing(false)
+ .setRequestMarshaller(ProtoUtils.marshaller(Flight.Ticket.getDefaultInstance()))
+ .setResponseMarshaller(ArrowMessage.createMarshaller(allocator))
+ .setSchemaDescriptor(FlightServiceGrpc.getDoGetMethod().getSchemaDescriptor())
+ .build();
+ }
+
+ public static MethodDescriptor<ArrowMessage, Flight.PutResult> getDoPutDescriptor(BufferAllocator allocator) {
+ return MethodDescriptor.<ArrowMessage, Flight.PutResult>newBuilder()
+ .setType(MethodType.BIDI_STREAMING)
+ .setFullMethodName(DO_PUT)
+ .setSampledToLocalTracing(false)
+ .setRequestMarshaller(ArrowMessage.createMarshaller(allocator))
+ .setResponseMarshaller(ProtoUtils.marshaller(Flight.PutResult.getDefaultInstance()))
+ .setSchemaDescriptor(FlightServiceGrpc.getDoPutMethod().getSchemaDescriptor())
+ .build();
+ }
+
+ public static MethodDescriptor<ArrowMessage, ArrowMessage> getDoExchangeDescriptor(BufferAllocator allocator) {
+ return MethodDescriptor.<ArrowMessage, ArrowMessage>newBuilder()
+ .setType(MethodType.BIDI_STREAMING)
+ .setFullMethodName(DO_EXCHANGE)
+ .setSampledToLocalTracing(false)
+ .setRequestMarshaller(ArrowMessage.createMarshaller(allocator))
+ .setResponseMarshaller(ArrowMessage.createMarshaller(allocator))
+ .setSchemaDescriptor(FlightServiceGrpc.getDoExchangeMethod().getSchemaDescriptor())
+ .build();
+ }
+
+ @Override
+ public ServerServiceDefinition bindService() {
+ final ServerServiceDefinition baseDefinition = delegate.bindService();
+
+ final MethodDescriptor<Flight.Ticket, ArrowMessage> doGetDescriptor = getDoGetDescriptor(allocator);
+ final MethodDescriptor<ArrowMessage, Flight.PutResult> doPutDescriptor = getDoPutDescriptor(allocator);
+ final MethodDescriptor<ArrowMessage, ArrowMessage> doExchangeDescriptor = getDoExchangeDescriptor(allocator);
+
+ // Make sure we preserve SchemaDescriptor fields on methods so that gRPC reflection still works.
+ final ServiceDescriptor.Builder serviceDescriptorBuilder = ServiceDescriptor.newBuilder(FlightConstants.SERVICE)
+ .setSchemaDescriptor(baseDefinition.getServiceDescriptor().getSchemaDescriptor());
+ serviceDescriptorBuilder.addMethod(doGetDescriptor);
+ serviceDescriptorBuilder.addMethod(doPutDescriptor);
+ serviceDescriptorBuilder.addMethod(doExchangeDescriptor);
+ for (MethodDescriptor<?, ?> definition : baseDefinition.getServiceDescriptor().getMethods()) {
+ if (OVERRIDE_METHODS.contains(definition.getFullMethodName())) {
+ continue;
+ }
+
+ serviceDescriptorBuilder.addMethod(definition);
+ }
+
+ final ServiceDescriptor serviceDescriptor = serviceDescriptorBuilder.build();
+ ServerServiceDefinition.Builder serviceBuilder = ServerServiceDefinition.builder(serviceDescriptor);
+ serviceBuilder.addMethod(doGetDescriptor, ServerCalls.asyncServerStreamingCall(new DoGetMethod(delegate)));
+ serviceBuilder.addMethod(doPutDescriptor, ServerCalls.asyncBidiStreamingCall(new DoPutMethod(delegate)));
+ serviceBuilder.addMethod(doExchangeDescriptor, ServerCalls.asyncBidiStreamingCall(new DoExchangeMethod(delegate)));
+
+ // copy over not-overridden methods.
+ for (ServerMethodDefinition<?, ?> definition : baseDefinition.getMethods()) {
+ if (OVERRIDE_METHODS.contains(definition.getMethodDescriptor().getFullMethodName())) {
+ continue;
+ }
+
+ serviceBuilder.addMethod(definition);
+ }
+
+ return serviceBuilder.build();
+ }
+
+ private static class DoGetMethod implements ServerCalls.ServerStreamingMethod<Flight.Ticket, ArrowMessage> {
+
+ private final FlightService delegate;
+
+ public DoGetMethod(FlightService delegate) {
+ this.delegate = delegate;
+ }
+
+ @Override
+ public void invoke(Flight.Ticket request, StreamObserver<ArrowMessage> responseObserver) {
+ delegate.doGetCustom(request, responseObserver);
+ }
+ }
+
+ private static class DoPutMethod implements ServerCalls.BidiStreamingMethod<ArrowMessage, PutResult> {
+ private final FlightService delegate;
+
+ public DoPutMethod(FlightService delegate) {
+ this.delegate = delegate;
+ }
+
+ @Override
+ public StreamObserver<ArrowMessage> invoke(StreamObserver<PutResult> responseObserver) {
+ return delegate.doPutCustom(responseObserver);
+ }
+ }
+
+ private static class DoExchangeMethod implements ServerCalls.BidiStreamingMethod<ArrowMessage, ArrowMessage> {
+ private final FlightService delegate;
+
+ public DoExchangeMethod(FlightService delegate) {
+ this.delegate = delegate;
+ }
+
+ @Override
+ public StreamObserver<ArrowMessage> invoke(StreamObserver<ArrowMessage> responseObserver) {
+ return delegate.doExchangeCustom(responseObserver);
+ }
+ }
+
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightCallHeaders.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightCallHeaders.java
new file mode 100644
index 000000000..dd26d1908
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightCallHeaders.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.Collection;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ArrayListMultimap;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Multimap;
+
+import io.grpc.Metadata;
+
+/**
+ * An implementation of the Flight headers interface for headers.
+ */
+public class FlightCallHeaders implements CallHeaders {
+ private final Multimap<String, Object> keysAndValues;
+
+ public FlightCallHeaders() {
+ this.keysAndValues = ArrayListMultimap.create();
+ }
+
+ @Override
+ public String get(String key) {
+ final Collection<Object> values = this.keysAndValues.get(key);
+ if (values.isEmpty()) {
+ return null;
+ }
+
+ if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
+ return new String((byte[]) Iterables.get(values, 0));
+ }
+
+ return (String) Iterables.get(values, 0);
+ }
+
+ @Override
+ public byte[] getByte(String key) {
+ final Collection<Object> values = this.keysAndValues.get(key);
+ if (values.isEmpty()) {
+ return null;
+ }
+
+ if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
+ return (byte[]) Iterables.get(values, 0);
+ }
+
+ return ((String) Iterables.get(values, 0)).getBytes();
+ }
+
+ @Override
+ public Iterable<String> getAll(String key) {
+ if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
+ return this.keysAndValues.get(key).stream().map(o -> new String((byte[]) o)).collect(Collectors.toList());
+ }
+ return (Collection<String>) (Collection<?>) this.keysAndValues.get(key);
+ }
+
+ @Override
+ public Iterable<byte[]> getAllByte(String key) {
+ if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
+ return (Collection<byte[]>) (Collection<?>) this.keysAndValues.get(key);
+ }
+ return this.keysAndValues.get(key).stream().map(o -> ((String) o).getBytes()).collect(Collectors.toList());
+ }
+
+ @Override
+ public void insert(String key, String value) {
+ this.keysAndValues.put(key, value);
+ }
+
+ @Override
+ public void insert(String key, byte[] value) {
+ Preconditions.checkArgument(key.endsWith("-bin"), "Binary header is named %s. It must end with %s", key, "-bin");
+ Preconditions.checkArgument(key.length() > "-bin".length(), "empty key name");
+
+ this.keysAndValues.put(key, value);
+ }
+
+ @Override
+ public Set<String> keys() {
+ return this.keysAndValues.keySet();
+ }
+
+ @Override
+ public boolean containsKey(String key) {
+ return this.keysAndValues.containsKey(key);
+ }
+
+ public String toString() {
+ return this.keysAndValues.toString();
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java
new file mode 100644
index 000000000..762b37859
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java
@@ -0,0 +1,721 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.io.InputStream;
+import java.net.URISyntaxException;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.function.BooleanSupplier;
+
+import javax.net.ssl.SSLException;
+
+import org.apache.arrow.flight.FlightProducer.StreamListener;
+import org.apache.arrow.flight.auth.BasicClientAuthHandler;
+import org.apache.arrow.flight.auth.ClientAuthHandler;
+import org.apache.arrow.flight.auth.ClientAuthInterceptor;
+import org.apache.arrow.flight.auth.ClientAuthWrapper;
+import org.apache.arrow.flight.auth2.BasicAuthCredentialWriter;
+import org.apache.arrow.flight.auth2.ClientBearerHeaderHandler;
+import org.apache.arrow.flight.auth2.ClientHandshakeWrapper;
+import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware;
+import org.apache.arrow.flight.grpc.ClientInterceptorAdapter;
+import org.apache.arrow.flight.grpc.CredentialCallOption;
+import org.apache.arrow.flight.grpc.StatusUtils;
+import org.apache.arrow.flight.impl.Flight;
+import org.apache.arrow.flight.impl.Flight.Empty;
+import org.apache.arrow.flight.impl.FlightServiceGrpc;
+import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceBlockingStub;
+import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceStub;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.util.Preconditions;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider;
+
+import io.grpc.Channel;
+import io.grpc.ClientCall;
+import io.grpc.ClientInterceptor;
+import io.grpc.ClientInterceptors;
+import io.grpc.ManagedChannel;
+import io.grpc.MethodDescriptor;
+import io.grpc.StatusRuntimeException;
+import io.grpc.netty.GrpcSslContexts;
+import io.grpc.netty.NettyChannelBuilder;
+import io.grpc.stub.ClientCallStreamObserver;
+import io.grpc.stub.ClientCalls;
+import io.grpc.stub.ClientResponseObserver;
+import io.grpc.stub.StreamObserver;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.ServerChannel;
+import io.netty.handler.ssl.SslContextBuilder;
+import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
+
+/**
+ * Client for Flight services.
+ */
+public class FlightClient implements AutoCloseable {
+ private static final int PENDING_REQUESTS = 5;
+ /** The maximum number of trace events to keep on the gRPC Channel. This value disables channel tracing. */
+ private static final int MAX_CHANNEL_TRACE_EVENTS = 0;
+ private final BufferAllocator allocator;
+ private final ManagedChannel channel;
+ private final Channel interceptedChannel;
+ private final FlightServiceBlockingStub blockingStub;
+ private final FlightServiceStub asyncStub;
+ private final ClientAuthInterceptor authInterceptor = new ClientAuthInterceptor();
+ private final MethodDescriptor<Flight.Ticket, ArrowMessage> doGetDescriptor;
+ private final MethodDescriptor<ArrowMessage, Flight.PutResult> doPutDescriptor;
+ private final MethodDescriptor<ArrowMessage, ArrowMessage> doExchangeDescriptor;
+ private final List<FlightClientMiddleware.Factory> middleware;
+
+ /**
+ * Create a Flight client from an allocator and a gRPC channel.
+ */
+ FlightClient(BufferAllocator incomingAllocator, ManagedChannel channel,
+ List<FlightClientMiddleware.Factory> middleware) {
+ this.allocator = incomingAllocator.newChildAllocator("flight-client", 0, Long.MAX_VALUE);
+ this.channel = channel;
+ this.middleware = middleware;
+
+ final ClientInterceptor[] interceptors;
+ interceptors = new ClientInterceptor[]{authInterceptor, new ClientInterceptorAdapter(middleware)};
+
+ // Create a channel with interceptors pre-applied for DoGet and DoPut
+ this.interceptedChannel = ClientInterceptors.intercept(channel, interceptors);
+
+ blockingStub = FlightServiceGrpc.newBlockingStub(interceptedChannel);
+ asyncStub = FlightServiceGrpc.newStub(interceptedChannel);
+ doGetDescriptor = FlightBindingService.getDoGetDescriptor(allocator);
+ doPutDescriptor = FlightBindingService.getDoPutDescriptor(allocator);
+ doExchangeDescriptor = FlightBindingService.getDoExchangeDescriptor(allocator);
+ }
+
+ /**
+ * Get a list of available flights.
+ *
+ * @param criteria Criteria for selecting flights
+ * @param options RPC-layer hints for the call.
+ * @return FlightInfo Iterable
+ */
+ public Iterable<FlightInfo> listFlights(Criteria criteria, CallOption... options) {
+ final Iterator<Flight.FlightInfo> flights;
+ try {
+ flights = CallOptions.wrapStub(blockingStub, options)
+ .listFlights(criteria.asCriteria());
+ } catch (StatusRuntimeException sre) {
+ throw StatusUtils.fromGrpcRuntimeException(sre);
+ }
+ return () -> StatusUtils.wrapIterator(flights, t -> {
+ try {
+ return new FlightInfo(t);
+ } catch (URISyntaxException e) {
+ // We don't expect this will happen for conforming Flight implementations. For instance, a Java server
+ // itself wouldn't be able to construct an invalid Location.
+ throw new RuntimeException(e);
+ }
+ });
+ }
+
+ /**
+ * Lists actions available on the Flight service.
+ *
+ * @param options RPC-layer hints for the call.
+ */
+ public Iterable<ActionType> listActions(CallOption... options) {
+ final Iterator<Flight.ActionType> actions;
+ try {
+ actions = CallOptions.wrapStub(blockingStub, options)
+ .listActions(Empty.getDefaultInstance());
+ } catch (StatusRuntimeException sre) {
+ throw StatusUtils.fromGrpcRuntimeException(sre);
+ }
+ return () -> StatusUtils.wrapIterator(actions, ActionType::new);
+ }
+
+ /**
+ * Performs an action on the Flight service.
+ *
+ * @param action The action to perform.
+ * @param options RPC-layer hints for this call.
+ * @return An iterator of results.
+ */
+ public Iterator<Result> doAction(Action action, CallOption... options) {
+ return StatusUtils
+ .wrapIterator(CallOptions.wrapStub(blockingStub, options).doAction(action.toProtocol()), Result::new);
+ }
+
+ /**
+ * Authenticates with a username and password.
+ */
+ public void authenticateBasic(String username, String password) {
+ BasicClientAuthHandler basicClient = new BasicClientAuthHandler(username, password);
+ authenticate(basicClient);
+ }
+
+ /**
+ * Authenticates against the Flight service.
+ *
+ * @param options RPC-layer hints for this call.
+ * @param handler The auth mechanism to use.
+ */
+ public void authenticate(ClientAuthHandler handler, CallOption... options) {
+ Preconditions.checkArgument(!authInterceptor.hasAuthHandler(), "Auth already completed.");
+ ClientAuthWrapper.doClientAuth(handler, CallOptions.wrapStub(asyncStub, options));
+ authInterceptor.setAuthHandler(handler);
+ }
+
+ /**
+ * Authenticates with a username and password.
+ *
+ * @param username the username.
+ * @param password the password.
+ * @return a CredentialCallOption containing a bearer token if the server emitted one, or
+ * empty if no bearer token was returned. This can be used in subsequent API calls.
+ */
+ public Optional<CredentialCallOption> authenticateBasicToken(String username, String password) {
+ final ClientIncomingAuthHeaderMiddleware.Factory clientAuthMiddleware =
+ new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler());
+ middleware.add(clientAuthMiddleware);
+ handshake(new CredentialCallOption(new BasicAuthCredentialWriter(username, password)));
+
+ return Optional.ofNullable(clientAuthMiddleware.getCredentialCallOption());
+ }
+
+ /**
+ * Executes the handshake against the Flight service.
+ *
+ * @param options RPC-layer hints for this call.
+ */
+ public void handshake(CallOption... options) {
+ ClientHandshakeWrapper.doClientHandshake(CallOptions.wrapStub(asyncStub, options));
+ }
+
+ /**
+ * Create or append a descriptor with another stream.
+ *
+ * @param descriptor FlightDescriptor the descriptor for the data
+ * @param root VectorSchemaRoot the root containing data
+ * @param metadataListener A handler for metadata messages from the server. This will be passed buffers that will be
+ * freed after {@link StreamListener#onNext(Object)} is called!
+ * @param options RPC-layer hints for this call.
+ * @return ClientStreamListener an interface to control uploading data
+ */
+ public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRoot root,
+ PutListener metadataListener, CallOption... options) {
+ return startPut(descriptor, root, new MapDictionaryProvider(), metadataListener, options);
+ }
+
+ /**
+ * Create or append a descriptor with another stream.
+ * @param descriptor FlightDescriptor the descriptor for the data
+ * @param root VectorSchemaRoot the root containing data
+ * @param metadataListener A handler for metadata messages from the server.
+ * @param options RPC-layer hints for this call.
+ * @return ClientStreamListener an interface to control uploading data.
+ * {@link ClientStreamListener#start(VectorSchemaRoot, DictionaryProvider)} will already have been called.
+ */
+ public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRoot root, DictionaryProvider provider,
+ PutListener metadataListener, CallOption... options) {
+ Preconditions.checkNotNull(root, "root must not be null");
+ Preconditions.checkNotNull(provider, "provider must not be null");
+ final ClientStreamListener writer = startPut(descriptor, metadataListener, options);
+ writer.start(root, provider);
+ return writer;
+ }
+
+ /**
+ * Create or append a descriptor with another stream.
+ * @param descriptor FlightDescriptor the descriptor for the data
+ * @param metadataListener A handler for metadata messages from the server.
+ * @param options RPC-layer hints for this call.
+ * @return ClientStreamListener an interface to control uploading data.
+ * {@link ClientStreamListener#start(VectorSchemaRoot, DictionaryProvider)} will NOT already have been called.
+ */
+ public ClientStreamListener startPut(FlightDescriptor descriptor, PutListener metadataListener,
+ CallOption... options) {
+ Preconditions.checkNotNull(descriptor, "descriptor must not be null");
+ Preconditions.checkNotNull(metadataListener, "metadataListener must not be null");
+ final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions();
+
+ try {
+ final SetStreamObserver resultObserver = new SetStreamObserver(allocator, metadataListener);
+ ClientCallStreamObserver<ArrowMessage> observer = (ClientCallStreamObserver<ArrowMessage>)
+ ClientCalls.asyncBidiStreamingCall(
+ interceptedChannel.newCall(doPutDescriptor, callOptions), resultObserver);
+ return new PutObserver(
+ descriptor, observer, metadataListener::isCancelled, metadataListener::getResult);
+ } catch (StatusRuntimeException sre) {
+ throw StatusUtils.fromGrpcRuntimeException(sre);
+ }
+ }
+
+ /**
+ * Get info on a stream.
+ * @param descriptor The descriptor for the stream.
+ * @param options RPC-layer hints for this call.
+ */
+ public FlightInfo getInfo(FlightDescriptor descriptor, CallOption... options) {
+ try {
+ return new FlightInfo(CallOptions.wrapStub(blockingStub, options).getFlightInfo(descriptor.toProtocol()));
+ } catch (URISyntaxException e) {
+ // We don't expect this will happen for conforming Flight implementations. For instance, a Java server
+ // itself wouldn't be able to construct an invalid Location.
+ throw new RuntimeException(e);
+ } catch (StatusRuntimeException sre) {
+ throw StatusUtils.fromGrpcRuntimeException(sre);
+ }
+ }
+
+ /**
+ * Get schema for a stream.
+ * @param descriptor The descriptor for the stream.
+ * @param options RPC-layer hints for this call.
+ */
+ public SchemaResult getSchema(FlightDescriptor descriptor, CallOption... options) {
+ return SchemaResult.fromProtocol(CallOptions.wrapStub(blockingStub, options).getSchema(descriptor.toProtocol()));
+ }
+
+ /**
+ * Retrieve a stream from the server.
+ * @param ticket The ticket granting access to the data stream.
+ * @param options RPC-layer hints for this call.
+ */
+ public FlightStream getStream(Ticket ticket, CallOption... options) {
+ final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions();
+ ClientCall<Flight.Ticket, ArrowMessage> call = interceptedChannel.newCall(doGetDescriptor, callOptions);
+ FlightStream stream = new FlightStream(
+ allocator,
+ PENDING_REQUESTS,
+ (String message, Throwable cause) -> call.cancel(message, cause),
+ (count) -> call.request(count));
+
+ final StreamObserver<ArrowMessage> delegate = stream.asObserver();
+ ClientResponseObserver<Flight.Ticket, ArrowMessage> clientResponseObserver =
+ new ClientResponseObserver<Flight.Ticket, ArrowMessage>() {
+
+ @Override
+ public void beforeStart(ClientCallStreamObserver<org.apache.arrow.flight.impl.Flight.Ticket> requestStream) {
+ requestStream.disableAutoInboundFlowControl();
+ }
+
+ @Override
+ public void onNext(ArrowMessage value) {
+ delegate.onNext(value);
+ }
+
+ @Override
+ public void onError(Throwable t) {
+ delegate.onError(StatusUtils.toGrpcException(t));
+ }
+
+ @Override
+ public void onCompleted() {
+ delegate.onCompleted();
+ }
+
+ };
+
+ ClientCalls.asyncServerStreamingCall(call, ticket.toProtocol(), clientResponseObserver);
+ return stream;
+ }
+
+ /**
+ * Initiate a bidirectional data exchange with the server.
+ *
+ * @param descriptor A descriptor for the data stream.
+ * @param options RPC call options.
+ * @return A pair of a readable stream and a writable stream.
+ */
+ public ExchangeReaderWriter doExchange(FlightDescriptor descriptor, CallOption... options) {
+ Preconditions.checkNotNull(descriptor, "descriptor must not be null");
+ final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions();
+
+ try {
+ final ClientCall<ArrowMessage, ArrowMessage> call = interceptedChannel.newCall(doExchangeDescriptor, callOptions);
+ final FlightStream stream = new FlightStream(allocator, PENDING_REQUESTS, call::cancel, call::request);
+ final ClientCallStreamObserver<ArrowMessage> observer = (ClientCallStreamObserver<ArrowMessage>)
+ ClientCalls.asyncBidiStreamingCall(call, stream.asObserver());
+ final ClientStreamListener writer = new PutObserver(
+ descriptor, observer, stream.cancelled::isDone,
+ () -> {
+ try {
+ stream.completed.get();
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw CallStatus.INTERNAL
+ .withDescription("Client error: interrupted while completing call")
+ .withCause(e)
+ .toRuntimeException();
+ } catch (ExecutionException e) {
+ throw CallStatus.INTERNAL
+ .withDescription("Client error: internal while completing call")
+ .withCause(e)
+ .toRuntimeException();
+ }
+ });
+ // Send the descriptor to start.
+ try (final ArrowMessage message = new ArrowMessage(descriptor.toProtocol())) {
+ observer.onNext(message);
+ } catch (Exception e) {
+ throw CallStatus.INTERNAL
+ .withCause(e)
+ .withDescription("Could not write descriptor " + descriptor)
+ .toRuntimeException();
+ }
+ return new ExchangeReaderWriter(stream, writer);
+ } catch (StatusRuntimeException sre) {
+ throw StatusUtils.fromGrpcRuntimeException(sre);
+ }
+ }
+
+ /** A pair of a reader and a writer for a DoExchange call. */
+ public static class ExchangeReaderWriter implements AutoCloseable {
+ private final FlightStream reader;
+ private final ClientStreamListener writer;
+
+ ExchangeReaderWriter(FlightStream reader, ClientStreamListener writer) {
+ this.reader = reader;
+ this.writer = writer;
+ }
+
+ /** Get the reader for the call. */
+ public FlightStream getReader() {
+ return reader;
+ }
+
+ /** Get the writer for the call. */
+ public ClientStreamListener getWriter() {
+ return writer;
+ }
+
+ /** Shut down the streams in this call. */
+ @Override
+ public void close() throws Exception {
+ reader.close();
+ }
+ }
+
+ /**
+ * A stream observer for Flight.PutResult
+ */
+ private static class SetStreamObserver implements StreamObserver<Flight.PutResult> {
+ private final BufferAllocator allocator;
+ private final StreamListener<PutResult> listener;
+
+ SetStreamObserver(BufferAllocator allocator, StreamListener<PutResult> listener) {
+ super();
+ this.allocator = allocator;
+ this.listener = listener == null ? NoOpStreamListener.getInstance() : listener;
+ }
+
+ @Override
+ public void onNext(Flight.PutResult value) {
+ try (final PutResult message = PutResult.fromProtocol(allocator, value)) {
+ listener.onNext(message);
+ }
+ }
+
+ @Override
+ public void onError(Throwable t) {
+ listener.onError(StatusUtils.fromThrowable(t));
+ }
+
+ @Override
+ public void onCompleted() {
+ listener.onCompleted();
+ }
+ }
+
+ /**
+ * The implementation of a {@link ClientStreamListener} for writing data to a Flight server.
+ */
+ static class PutObserver extends OutboundStreamListenerImpl implements ClientStreamListener {
+ private final BooleanSupplier isCancelled;
+ private final Runnable getResult;
+
+ /**
+ * Create a new client stream listener.
+ *
+ * @param descriptor The descriptor for the stream.
+ * @param observer The write-side gRPC StreamObserver.
+ * @param isCancelled A flag to check if the call has been cancelled.
+ * @param getResult A flag that blocks until the overall call completes.
+ */
+ PutObserver(FlightDescriptor descriptor, ClientCallStreamObserver<ArrowMessage> observer,
+ BooleanSupplier isCancelled, Runnable getResult) {
+ super(descriptor, observer);
+ Preconditions.checkNotNull(descriptor, "descriptor must be provided");
+ Preconditions.checkNotNull(isCancelled, "isCancelled must be provided");
+ Preconditions.checkNotNull(getResult, "getResult must be provided");
+ this.isCancelled = isCancelled;
+ this.getResult = getResult;
+ this.unloader = null;
+ }
+
+ @Override
+ protected void waitUntilStreamReady() {
+ // Check isCancelled as well to avoid inadvertently blocking forever
+ // (so long as PutListener properly implements it)
+ while (!responseObserver.isReady() && !isCancelled.getAsBoolean()) {
+ /* busy wait */
+ }
+ }
+
+ @Override
+ public void getResult() {
+ getResult.run();
+ }
+ }
+
+ /**
+ * Interface for writers to an Arrow data stream.
+ */
+ public interface ClientStreamListener extends OutboundStreamListener {
+
+ /**
+ * Wait for the stream to finish on the server side. You must call this to be notified of any errors that may have
+ * happened during the upload.
+ */
+ void getResult();
+ }
+
+ /**
+ * A handler for server-sent application metadata messages during a Flight DoPut operation.
+ *
+ * <p>Generally, instead of implementing this yourself, you should use {@link AsyncPutListener} or {@link
+ * SyncPutListener}.
+ */
+ public interface PutListener extends StreamListener<PutResult> {
+
+ /**
+ * Wait for the stream to finish on the server side. You must call this to be notified of any errors that may have
+ * happened during the upload.
+ */
+ void getResult();
+
+ /**
+ * Called when a message from the server is received.
+ *
+ * @param val The application metadata. This buffer will be reclaimed once onNext returns; you must retain a
+ * reference to use it outside this method.
+ */
+ @Override
+ void onNext(PutResult val);
+
+ /**
+ * Check if the call has been cancelled.
+ *
+ * <p>By default, this always returns false. Implementations should provide an appropriate implementation, as
+ * otherwise, a DoPut operation may inadvertently block forever.
+ */
+ default boolean isCancelled() {
+ return false;
+ }
+ }
+
+ /**
+ * Shut down this client.
+ */
+ public void close() throws InterruptedException {
+ channel.shutdown().awaitTermination(5, TimeUnit.SECONDS);
+ allocator.close();
+ }
+
+ /**
+ * Create a builder for a Flight client.
+ */
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /**
+ * Create a builder for a Flight client.
+ * @param allocator The allocator to use for the client.
+ * @param location The location to connect to.
+ */
+ public static Builder builder(BufferAllocator allocator, Location location) {
+ return new Builder(allocator, location);
+ }
+
+ /**
+ * A builder for Flight clients.
+ */
+ public static final class Builder {
+ private BufferAllocator allocator;
+ private Location location;
+ private boolean forceTls = false;
+ private int maxInboundMessageSize = FlightServer.MAX_GRPC_MESSAGE_SIZE;
+ private InputStream trustedCertificates = null;
+ private InputStream clientCertificate = null;
+ private InputStream clientKey = null;
+ private String overrideHostname = null;
+ private List<FlightClientMiddleware.Factory> middleware = new ArrayList<>();
+ private boolean verifyServer = true;
+
+ private Builder() {
+ }
+
+ private Builder(BufferAllocator allocator, Location location) {
+ this.allocator = Preconditions.checkNotNull(allocator);
+ this.location = Preconditions.checkNotNull(location);
+ }
+
+ /**
+ * Force the client to connect over TLS.
+ */
+ public Builder useTls() {
+ this.forceTls = true;
+ return this;
+ }
+
+ /** Override the hostname checked for TLS. Use with caution in production. */
+ public Builder overrideHostname(final String hostname) {
+ this.overrideHostname = hostname;
+ return this;
+ }
+
+ /** Set the maximum inbound message size. */
+ public Builder maxInboundMessageSize(int maxSize) {
+ Preconditions.checkArgument(maxSize > 0);
+ this.maxInboundMessageSize = maxSize;
+ return this;
+ }
+
+ /** Set the trusted TLS certificates. */
+ public Builder trustedCertificates(final InputStream stream) {
+ this.trustedCertificates = Preconditions.checkNotNull(stream);
+ return this;
+ }
+
+ /** Set the trusted TLS certificates. */
+ public Builder clientCertificate(final InputStream clientCertificate, final InputStream clientKey) {
+ Preconditions.checkNotNull(clientKey);
+ this.clientCertificate = Preconditions.checkNotNull(clientCertificate);
+ this.clientKey = Preconditions.checkNotNull(clientKey);
+ return this;
+ }
+
+ public Builder allocator(BufferAllocator allocator) {
+ this.allocator = Preconditions.checkNotNull(allocator);
+ return this;
+ }
+
+ public Builder location(Location location) {
+ this.location = Preconditions.checkNotNull(location);
+ return this;
+ }
+
+ public Builder intercept(FlightClientMiddleware.Factory factory) {
+ middleware.add(factory);
+ return this;
+ }
+
+ public Builder verifyServer(boolean verifyServer) {
+ this.verifyServer = verifyServer;
+ return this;
+ }
+
+ /**
+ * Create the client from this builder.
+ */
+ public FlightClient build() {
+ final NettyChannelBuilder builder;
+
+ switch (location.getUri().getScheme()) {
+ case LocationSchemes.GRPC:
+ case LocationSchemes.GRPC_INSECURE:
+ case LocationSchemes.GRPC_TLS: {
+ builder = NettyChannelBuilder.forAddress(location.toSocketAddress());
+ break;
+ }
+ case LocationSchemes.GRPC_DOMAIN_SOCKET: {
+ // The implementation is platform-specific, so we have to find the classes at runtime
+ builder = NettyChannelBuilder.forAddress(location.toSocketAddress());
+ try {
+ try {
+ // Linux
+ builder.channelType(
+ (Class<? extends ServerChannel>) Class.forName("io.netty.channel.epoll.EpollDomainSocketChannel"));
+ final EventLoopGroup elg = (EventLoopGroup) Class.forName("io.netty.channel.epoll.EpollEventLoopGroup")
+ .newInstance();
+ builder.eventLoopGroup(elg);
+ } catch (ClassNotFoundException e) {
+ // BSD
+ builder.channelType(
+ (Class<? extends ServerChannel>) Class.forName("io.netty.channel.kqueue.KQueueDomainSocketChannel"));
+ final EventLoopGroup elg = (EventLoopGroup) Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup")
+ .newInstance();
+ builder.eventLoopGroup(elg);
+ }
+ } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) {
+ throw new UnsupportedOperationException(
+ "Could not find suitable Netty native transport implementation for domain socket address.");
+ }
+ break;
+ }
+ default:
+ throw new IllegalArgumentException("Scheme is not supported: " + location.getUri().getScheme());
+ }
+
+ if (this.forceTls || LocationSchemes.GRPC_TLS.equals(location.getUri().getScheme())) {
+ builder.useTransportSecurity();
+
+ final boolean hasTrustedCerts = this.trustedCertificates != null;
+ final boolean hasKeyCertPair = this.clientCertificate != null && this.clientKey != null;
+ if (!this.verifyServer && (hasTrustedCerts || hasKeyCertPair)) {
+ throw new IllegalArgumentException("FlightClient has been configured to disable server verification, " +
+ "but certificate options have been specified.");
+ }
+
+ final SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();
+
+ if (!this.verifyServer) {
+ sslContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE);
+ } else if (this.trustedCertificates != null || this.clientCertificate != null || this.clientKey != null) {
+ if (this.trustedCertificates != null) {
+ sslContextBuilder.trustManager(this.trustedCertificates);
+ }
+ if (this.clientCertificate != null && this.clientKey != null) {
+ sslContextBuilder.keyManager(this.clientCertificate, this.clientKey);
+ }
+ }
+ try {
+ builder.sslContext(sslContextBuilder.build());
+ } catch (SSLException e) {
+ throw new RuntimeException(e);
+ }
+
+ if (this.overrideHostname != null) {
+ builder.overrideAuthority(this.overrideHostname);
+ }
+ } else {
+ builder.usePlaintext();
+ }
+
+ builder
+ .maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS)
+ .maxInboundMessageSize(maxInboundMessageSize);
+ return new FlightClient(allocator, builder.build(), middleware);
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClientMiddleware.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClientMiddleware.java
new file mode 100644
index 000000000..1528ca6c6
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClientMiddleware.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+/**
+ * Client-side middleware for Flight.
+ *
+ * <p>Middleware are instantiated per-call and should store state in the middleware instance.
+ */
+public interface FlightClientMiddleware {
+ /**
+ * A callback used before request headers are sent. The headers may be manipulated.
+ */
+ void onBeforeSendingHeaders(CallHeaders outgoingHeaders);
+
+ /**
+ * A callback called after response headers are received. The headers may be manipulated.
+ */
+ void onHeadersReceived(CallHeaders incomingHeaders);
+
+ /**
+ * A callback called after the call completes.
+ */
+ void onCallCompleted(CallStatus status);
+
+ /**
+ * A factory for client middleware instances.
+ */
+ interface Factory {
+ /**
+ * Create a new middleware instance for the given call.
+ *
+ * @throws FlightRuntimeException if the middleware wants to reject the call with the given status
+ */
+ FlightClientMiddleware onCallStarted(CallInfo info);
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightConstants.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightConstants.java
new file mode 100644
index 000000000..2d039c9d2
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightConstants.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+/**
+ * String constants relevant to flight implementations.
+ */
+public interface FlightConstants {
+
+ String SERVICE = "arrow.flight.protocol.FlightService";
+
+ FlightServerMiddleware.Key<ServerHeaderMiddleware> HEADER_KEY =
+ FlightServerMiddleware.Key.of("org.apache.arrow.flight.ServerHeaderMiddleware");
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDescriptor.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDescriptor.java
new file mode 100644
index 000000000..3eff011d9
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDescriptor.java
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.arrow.flight.impl.Flight;
+import org.apache.arrow.flight.impl.Flight.FlightDescriptor.DescriptorType;
+import org.apache.arrow.util.Preconditions;
+
+import com.google.common.base.Joiner;
+import com.google.common.collect.ImmutableList;
+import com.google.protobuf.ByteString;
+
+/**
+ * An identifier for a particular set of data. This can either be an opaque command that generates
+ * the data or a static "path" to the data. This is a POJO wrapper around the protobuf message with
+ * the same name.
+ */
+public class FlightDescriptor {
+
+ private boolean isCmd;
+ private List<String> path;
+ private byte[] cmd;
+
+ private FlightDescriptor(boolean isCmd, List<String> path, byte[] cmd) {
+ super();
+ this.isCmd = isCmd;
+ this.path = path;
+ this.cmd = cmd;
+ }
+
+ public static FlightDescriptor command(byte[] cmd) {
+ return new FlightDescriptor(true, null, cmd);
+ }
+
+ public static FlightDescriptor path(Iterable<String> path) {
+ return new FlightDescriptor(false, ImmutableList.copyOf(path), null);
+ }
+
+ public static FlightDescriptor path(String...path) {
+ return new FlightDescriptor(false, ImmutableList.copyOf(path), null);
+ }
+
+ FlightDescriptor(Flight.FlightDescriptor descriptor) {
+ if (descriptor.getType() == DescriptorType.CMD) {
+ isCmd = true;
+ cmd = descriptor.getCmd().toByteArray();
+ } else if (descriptor.getType() == DescriptorType.PATH) {
+ isCmd = false;
+ path = descriptor.getPathList();
+ } else {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ public boolean isCommand() {
+ return isCmd;
+ }
+
+ public List<String> getPath() {
+ Preconditions.checkArgument(!isCmd);
+ return path;
+ }
+
+ public byte[] getCommand() {
+ Preconditions.checkArgument(isCmd);
+ return cmd;
+ }
+
+ Flight.FlightDescriptor toProtocol() {
+ Flight.FlightDescriptor.Builder b = Flight.FlightDescriptor.newBuilder();
+
+ if (isCmd) {
+ return b.setType(DescriptorType.CMD).setCmd(ByteString.copyFrom(cmd)).build();
+ }
+ return b.setType(DescriptorType.PATH).addAllPath(path).build();
+ }
+
+ /**
+ * Get the serialized form of this protocol message.
+ *
+ * <p>Intended to help interoperability by allowing non-Flight services to still return Flight types.
+ */
+ public ByteBuffer serialize() {
+ return ByteBuffer.wrap(toProtocol().toByteArray());
+ }
+
+ /**
+ * Parse the serialized form of this protocol message.
+ *
+ * <p>Intended to help interoperability by allowing Flight clients to obtain stream info from non-Flight services.
+ *
+ * @param serialized The serialized form of the FlightDescriptor, as returned by {@link #serialize()}.
+ * @return The deserialized FlightDescriptor.
+ * @throws IOException if the serialized form is invalid.
+ */
+ public static FlightDescriptor deserialize(ByteBuffer serialized) throws IOException {
+ return new FlightDescriptor(Flight.FlightDescriptor.parseFrom(serialized));
+ }
+
+ @Override
+ public String toString() {
+ if (isCmd) {
+ return toHex(cmd);
+ } else {
+ return Joiner.on('.').join(path);
+ }
+ }
+
+ private String toHex(byte[] bytes) {
+ StringBuilder sb = new StringBuilder();
+ for (byte b : bytes) {
+ sb.append(String.format("%02X ", b));
+ }
+ return sb.toString();
+ }
+
+ @Override
+ public int hashCode() {
+ final int prime = 31;
+ int result = 1;
+ result = prime * result + ((cmd == null) ? 0 : Arrays.hashCode(cmd));
+ result = prime * result + (isCmd ? 1231 : 1237);
+ result = prime * result + ((path == null) ? 0 : path.hashCode());
+ return result;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (obj == null) {
+ return false;
+ }
+ if (getClass() != obj.getClass()) {
+ return false;
+ }
+ FlightDescriptor other = (FlightDescriptor) obj;
+ if (cmd == null) {
+ if (other.cmd != null) {
+ return false;
+ }
+ } else if (!Arrays.equals(cmd, other.cmd)) {
+ return false;
+ }
+ if (isCmd != other.isCmd) {
+ return false;
+ }
+ if (path == null) {
+ if (other.path != null) {
+ return false;
+ }
+ } else if (!path.equals(other.path)) {
+ return false;
+ }
+ return true;
+ }
+
+
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightEndpoint.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightEndpoint.java
new file mode 100644
index 000000000..2e46b694d
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightEndpoint.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.net.URISyntaxException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+
+import org.apache.arrow.flight.impl.Flight;
+
+import com.google.common.collect.ImmutableList;
+
+/**
+ * POJO to convert to/from the underlying protobuf FlightEndpoint.
+ */
+public class FlightEndpoint {
+ private List<Location> locations;
+ private Ticket ticket;
+
+ /**
+ * Constructs a new instance.
+ *
+ * @param ticket A ticket that describe the key of a data stream.
+ * @param locations The possible locations the stream can be retrieved from.
+ */
+ public FlightEndpoint(Ticket ticket, Location... locations) {
+ super();
+ Objects.requireNonNull(ticket);
+ this.locations = ImmutableList.copyOf(locations);
+ this.ticket = ticket;
+ }
+
+ /**
+ * Constructs from the protocol buffer representation.
+ */
+ FlightEndpoint(Flight.FlightEndpoint flt) throws URISyntaxException {
+ locations = new ArrayList<>();
+ for (final Flight.Location location : flt.getLocationList()) {
+ locations.add(new Location(location.getUri()));
+ }
+ ticket = new Ticket(flt.getTicket());
+ }
+
+ public List<Location> getLocations() {
+ return locations;
+ }
+
+ public Ticket getTicket() {
+ return ticket;
+ }
+
+ /**
+ * Converts to the protocol buffer representation.
+ */
+ Flight.FlightEndpoint toProtocol() {
+ Flight.FlightEndpoint.Builder b = Flight.FlightEndpoint.newBuilder()
+ .setTicket(ticket.toProtocol());
+
+ for (Location l : locations) {
+ b.addLocation(l.toProtocol());
+ }
+ return b.build();
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ FlightEndpoint that = (FlightEndpoint) o;
+ return locations.equals(that.locations) &&
+ ticket.equals(that.ticket);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(locations, ticket);
+ }
+
+ @Override
+ public String toString() {
+ return "FlightEndpoint{" +
+ "locations=" + locations +
+ ", ticket=" + ticket +
+ '}';
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java
new file mode 100644
index 000000000..e57b311c2
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.net.URISyntaxException;
+import java.nio.ByteBuffer;
+import java.nio.channels.Channels;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Objects;
+import java.util.stream.Collectors;
+
+import org.apache.arrow.flight.impl.Flight;
+import org.apache.arrow.vector.ipc.ReadChannel;
+import org.apache.arrow.vector.ipc.WriteChannel;
+import org.apache.arrow.vector.ipc.message.IpcOption;
+import org.apache.arrow.vector.ipc.message.MessageSerializer;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.arrow.vector.validate.MetadataV4UnionChecker;
+
+import com.fasterxml.jackson.databind.util.ByteBufferBackedInputStream;
+import com.google.common.collect.ImmutableList;
+import com.google.protobuf.ByteString;
+
+/**
+ * A POJO representation of a FlightInfo, metadata associated with a set of data records.
+ */
+public class FlightInfo {
+ private final Schema schema;
+ private final FlightDescriptor descriptor;
+ private final List<FlightEndpoint> endpoints;
+ private final long bytes;
+ private final long records;
+ private final IpcOption option;
+
+ /**
+ * Constructs a new instance.
+ *
+ * @param schema The schema of the Flight
+ * @param descriptor An identifier for the Flight.
+ * @param endpoints A list of endpoints that have the flight available.
+ * @param bytes The number of bytes in the flight
+ * @param records The number of records in the flight.
+ */
+ public FlightInfo(Schema schema, FlightDescriptor descriptor, List<FlightEndpoint> endpoints, long bytes,
+ long records) {
+ this(schema, descriptor, endpoints, bytes, records, IpcOption.DEFAULT);
+ }
+
+ /**
+ * Constructs a new instance.
+ *
+ * @param schema The schema of the Flight
+ * @param descriptor An identifier for the Flight.
+ * @param endpoints A list of endpoints that have the flight available.
+ * @param bytes The number of bytes in the flight
+ * @param records The number of records in the flight.
+ * @param option IPC write options.
+ */
+ public FlightInfo(Schema schema, FlightDescriptor descriptor, List<FlightEndpoint> endpoints, long bytes,
+ long records, IpcOption option) {
+ Objects.requireNonNull(schema);
+ Objects.requireNonNull(descriptor);
+ Objects.requireNonNull(endpoints);
+ MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option.metadataVersion);
+ this.schema = schema;
+ this.descriptor = descriptor;
+ this.endpoints = endpoints;
+ this.bytes = bytes;
+ this.records = records;
+ this.option = option;
+ }
+
+ /**
+ * Constructs from the protocol buffer representation.
+ */
+ FlightInfo(Flight.FlightInfo pbFlightInfo) throws URISyntaxException {
+ try {
+ final ByteBuffer schemaBuf = pbFlightInfo.getSchema().asReadOnlyByteBuffer();
+ schema = pbFlightInfo.getSchema().size() > 0 ?
+ MessageSerializer.deserializeSchema(
+ new ReadChannel(Channels.newChannel(new ByteBufferBackedInputStream(schemaBuf))))
+ : new Schema(ImmutableList.of());
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ descriptor = new FlightDescriptor(pbFlightInfo.getFlightDescriptor());
+ endpoints = new ArrayList<>();
+ for (final Flight.FlightEndpoint endpoint : pbFlightInfo.getEndpointList()) {
+ endpoints.add(new FlightEndpoint(endpoint));
+ }
+ bytes = pbFlightInfo.getTotalBytes();
+ records = pbFlightInfo.getTotalRecords();
+ option = IpcOption.DEFAULT;
+ }
+
+ public Schema getSchema() {
+ return schema;
+ }
+
+ public long getBytes() {
+ return bytes;
+ }
+
+ public long getRecords() {
+ return records;
+ }
+
+ public FlightDescriptor getDescriptor() {
+ return descriptor;
+ }
+
+ public List<FlightEndpoint> getEndpoints() {
+ return endpoints;
+ }
+
+ /**
+ * Converts to the protocol buffer representation.
+ */
+ Flight.FlightInfo toProtocol() {
+ // Encode schema in a Message payload
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ try {
+ MessageSerializer.serialize(new WriteChannel(Channels.newChannel(baos)), schema, option);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ return Flight.FlightInfo.newBuilder()
+ .addAllEndpoint(endpoints.stream().map(t -> t.toProtocol()).collect(Collectors.toList()))
+ .setSchema(ByteString.copyFrom(baos.toByteArray()))
+ .setFlightDescriptor(descriptor.toProtocol())
+ .setTotalBytes(FlightInfo.this.bytes)
+ .setTotalRecords(records)
+ .build();
+ }
+
+ /**
+ * Get the serialized form of this protocol message.
+ *
+ * <p>Intended to help interoperability by allowing non-Flight services to still return Flight types.
+ */
+ public ByteBuffer serialize() {
+ return ByteBuffer.wrap(toProtocol().toByteArray());
+ }
+
+ /**
+ * Parse the serialized form of this protocol message.
+ *
+ * <p>Intended to help interoperability by allowing Flight clients to obtain stream info from non-Flight services.
+ *
+ * @param serialized The serialized form of the FlightInfo, as returned by {@link #serialize()}.
+ * @return The deserialized FlightInfo.
+ * @throws IOException if the serialized form is invalid.
+ * @throws URISyntaxException if the serialized form contains an unsupported URI format.
+ */
+ public static FlightInfo deserialize(ByteBuffer serialized) throws IOException, URISyntaxException {
+ return new FlightInfo(Flight.FlightInfo.parseFrom(serialized));
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ FlightInfo that = (FlightInfo) o;
+ return bytes == that.bytes &&
+ records == that.records &&
+ schema.equals(that.schema) &&
+ descriptor.equals(that.descriptor) &&
+ endpoints.equals(that.endpoints);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(schema, descriptor, endpoints, bytes, records);
+ }
+
+ @Override
+ public String toString() {
+ return "FlightInfo{" +
+ "schema=" + schema +
+ ", descriptor=" + descriptor +
+ ", endpoints=" + endpoints +
+ ", bytes=" + bytes +
+ ", records=" + records +
+ '}';
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightMethod.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightMethod.java
new file mode 100644
index 000000000..5d2915bb6
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightMethod.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import org.apache.arrow.flight.impl.FlightServiceGrpc;
+
+/**
+ * All the RPC methods available in Flight.
+ */
+public enum FlightMethod {
+ HANDSHAKE,
+ LIST_FLIGHTS,
+ GET_FLIGHT_INFO,
+ GET_SCHEMA,
+ DO_GET,
+ DO_PUT,
+ DO_ACTION,
+ LIST_ACTIONS,
+ DO_EXCHANGE,
+ ;
+
+ /**
+ * Convert a method name string into a {@link FlightMethod}.
+ *
+ * @throws IllegalArgumentException if the method name is not valid.
+ */
+ public static FlightMethod fromProtocol(final String methodName) {
+ if (FlightServiceGrpc.getHandshakeMethod().getFullMethodName().equals(methodName)) {
+ return HANDSHAKE;
+ } else if (FlightServiceGrpc.getListFlightsMethod().getFullMethodName().equals(methodName)) {
+ return LIST_FLIGHTS;
+ } else if (FlightServiceGrpc.getGetFlightInfoMethod().getFullMethodName().equals(methodName)) {
+ return GET_FLIGHT_INFO;
+ } else if (FlightServiceGrpc.getGetSchemaMethod().getFullMethodName().equals(methodName)) {
+ return GET_SCHEMA;
+ } else if (FlightServiceGrpc.getDoGetMethod().getFullMethodName().equals(methodName)) {
+ return DO_GET;
+ } else if (FlightServiceGrpc.getDoPutMethod().getFullMethodName().equals(methodName)) {
+ return DO_PUT;
+ } else if (FlightServiceGrpc.getDoActionMethod().getFullMethodName().equals(methodName)) {
+ return DO_ACTION;
+ } else if (FlightServiceGrpc.getListActionsMethod().getFullMethodName().equals(methodName)) {
+ return LIST_ACTIONS;
+ } else if (FlightServiceGrpc.getDoExchangeMethod().getFullMethodName().equals(methodName)) {
+ return DO_EXCHANGE;
+ }
+ throw new IllegalArgumentException("Not a Flight method name in gRPC: " + methodName);
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java
new file mode 100644
index 000000000..5e5b26505
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.Map;
+
+/**
+ * API to Implement an Arrow Flight producer.
+ */
+public interface FlightProducer {
+
+ /**
+ * Return data for a stream.
+ *
+ * @param context Per-call context.
+ * @param ticket The application-defined ticket identifying this stream.
+ * @param listener An interface for sending data back to the client.
+ */
+ void getStream(CallContext context, Ticket ticket, ServerStreamListener listener);
+
+ /**
+ * List available data streams on this service.
+ *
+ * @param context Per-call context.
+ * @param criteria Application-defined criteria for filtering streams.
+ * @param listener An interface for sending data back to the client.
+ */
+ void listFlights(CallContext context, Criteria criteria,
+ StreamListener<FlightInfo> listener);
+
+ /**
+ * Get information about a particular data stream.
+ *
+ * @param context Per-call context.
+ * @param descriptor The descriptor identifying the data stream.
+ * @return Metadata about the stream.
+ */
+ FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor);
+
+ /**
+ * Get schema for a particular data stream.
+ *
+ * @param context Per-call context.
+ * @param descriptor The descriptor identifying the data stream.
+ * @return Schema for the stream.
+ */
+ default SchemaResult getSchema(CallContext context, FlightDescriptor descriptor) {
+ FlightInfo info = getFlightInfo(context, descriptor);
+ return new SchemaResult(info.getSchema());
+ }
+
+
+ /**
+ * Accept uploaded data for a particular stream.
+ *
+ * @param context Per-call context.
+ * @param flightStream The data stream being uploaded.
+ */
+ Runnable acceptPut(CallContext context,
+ FlightStream flightStream, StreamListener<PutResult> ackStream);
+
+ default void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) {
+ throw CallStatus.UNIMPLEMENTED.withDescription("DoExchange is unimplemented").toRuntimeException();
+ }
+
+ /**
+ * Generic handler for application-defined RPCs.
+ *
+ * @param context Per-call context.
+ * @param action Client-supplied parameters.
+ * @param listener A stream of responses.
+ */
+ void doAction(CallContext context, Action action,
+ StreamListener<Result> listener);
+
+ /**
+ * List available application-defined RPCs.
+ * @param context Per-call context.
+ * @param listener An interface for sending data back to the client.
+ */
+ void listActions(CallContext context, StreamListener<ActionType> listener);
+
+ /**
+ * An interface for sending Arrow data back to a client.
+ */
+ interface ServerStreamListener extends OutboundStreamListener {
+
+ /**
+ * Check whether the call has been cancelled. If so, stop sending data.
+ */
+ boolean isCancelled();
+
+ /**
+ * Set a callback for when the client cancels a call, i.e. {@link #isCancelled()} has become true.
+ *
+ * <p>Note that this callback may only be called some time after {@link #isCancelled()} becomes true, and may never
+ * be called if all executor threads on the server are busy, or the RPC method body is implemented in a blocking
+ * fashion.
+ */
+ void setOnCancelHandler(Runnable handler);
+ }
+
+ /**
+ * Callbacks for pushing objects to a receiver.
+ *
+ * @param <T> Type of the values in the stream.
+ */
+ interface StreamListener<T> {
+
+ /**
+ * Send the next value to the client.
+ */
+ void onNext(T val);
+
+ /**
+ * Indicate an error to the client.
+ *
+ * <p>Terminates the stream; do not call {@link #onCompleted()}.
+ */
+ void onError(Throwable t);
+
+ /**
+ * Indicate that the transmission is finished.
+ */
+ void onCompleted();
+
+ }
+
+ /**
+ * Call-specific context.
+ */
+ interface CallContext {
+ /** The identity of the authenticated peer. May be the empty string if unknown. */
+ String peerIdentity();
+
+ /** Whether the call has been cancelled by the client. */
+ boolean isCancelled();
+
+ /**
+ * Get the middleware instance of the given type for this call.
+ *
+ * <p>Returns null if not found.
+ */
+ <T extends FlightServerMiddleware> T getMiddleware(FlightServerMiddleware.Key<T> key);
+
+ /** Get an immutable map of middleware for this call. */
+ Map<FlightServerMiddleware.Key<?>, FlightServerMiddleware> getMiddleware();
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightRuntimeException.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightRuntimeException.java
new file mode 100644
index 000000000..76d3349a2
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightRuntimeException.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+/**
+ * An exception raised from a Flight RPC.
+ *
+ * <p>In service implementations, raising an instance of this exception will provide clients with a more detailed
+ * message and error code.
+ */
+public class FlightRuntimeException extends RuntimeException {
+ private final CallStatus status;
+
+ /**
+ * Create a new exception from the given status.
+ */
+ FlightRuntimeException(CallStatus status) {
+ super(status.description(), status.cause());
+ this.status = status;
+ }
+
+ public CallStatus status() {
+ return status;
+ }
+
+ @Override
+ public String toString() {
+ String s = getClass().getName();
+ return String.format("%s: %s: %s", s, status.code(), status.description());
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java
new file mode 100644
index 000000000..d59480bfb
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java
@@ -0,0 +1,399 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
+
+import org.apache.arrow.flight.auth.ServerAuthHandler;
+import org.apache.arrow.flight.auth.ServerAuthInterceptor;
+import org.apache.arrow.flight.auth2.Auth2Constants;
+import org.apache.arrow.flight.auth2.CallHeaderAuthenticator;
+import org.apache.arrow.flight.auth2.ServerCallHeaderAuthMiddleware;
+import org.apache.arrow.flight.grpc.ServerInterceptorAdapter;
+import org.apache.arrow.flight.grpc.ServerInterceptorAdapter.KeyFactory;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.util.Preconditions;
+import org.apache.arrow.util.VisibleForTesting;
+
+import com.google.common.util.concurrent.ThreadFactoryBuilder;
+
+import io.grpc.Server;
+import io.grpc.ServerInterceptors;
+import io.grpc.netty.NettyServerBuilder;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.ServerChannel;
+
+/**
+ * Generic server of flight data that is customized via construction with delegate classes for the
+ * actual logic. The server currently uses GRPC as its transport mechanism.
+ */
+public class FlightServer implements AutoCloseable {
+
+ private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(FlightServer.class);
+
+ private final Location location;
+ private final Server server;
+ // The executor used by the gRPC server. We don't use it here, but we do need to clean it up with the server.
+ // May be null, if a user-supplied executor was provided (as we do not want to clean that up)
+ @VisibleForTesting
+ final ExecutorService grpcExecutor;
+
+ /** The maximum size of an individual gRPC message. This effectively disables the limit. */
+ static final int MAX_GRPC_MESSAGE_SIZE = Integer.MAX_VALUE;
+
+ /** Create a new instance from a gRPC server. For internal use only. */
+ private FlightServer(Location location, Server server, ExecutorService grpcExecutor) {
+ this.location = location;
+ this.server = server;
+ this.grpcExecutor = grpcExecutor;
+ }
+
+ /** Start the server. */
+ public FlightServer start() throws IOException {
+ server.start();
+ return this;
+ }
+
+ /** Get the port the server is running on (if applicable). */
+ public int getPort() {
+ return server.getPort();
+ }
+
+ /** Get the location for this server. */
+ public Location getLocation() {
+ if (location.getUri().getPort() == 0) {
+ // If the server was bound to port 0, replace the port in the location with the real port.
+ final URI uri = location.getUri();
+ try {
+ return new Location(new URI(uri.getScheme(), uri.getUserInfo(), uri.getHost(), getPort(),
+ uri.getPath(), uri.getQuery(), uri.getFragment()));
+ } catch (URISyntaxException e) {
+ // We don't expect this to happen
+ throw new RuntimeException(e);
+ }
+ }
+ return location;
+ }
+
+ /** Block until the server shuts down. */
+ public void awaitTermination() throws InterruptedException {
+ server.awaitTermination();
+ }
+
+ /** Request that the server shut down. */
+ public void shutdown() {
+ server.shutdown();
+ if (grpcExecutor != null) {
+ grpcExecutor.shutdown();
+ }
+ }
+
+ /**
+ * Wait for the server to shut down with a timeout.
+ * @return true if the server shut down successfully.
+ */
+ public boolean awaitTermination(final long timeout, final TimeUnit unit) throws InterruptedException {
+ return server.awaitTermination(timeout, unit);
+ }
+
+ /** Shutdown the server, waits for up to 6 seconds for successful shutdown before returning. */
+ public void close() throws InterruptedException {
+ shutdown();
+ final boolean terminated = awaitTermination(3000, TimeUnit.MILLISECONDS);
+ if (terminated) {
+ logger.debug("Server was terminated within 3s");
+ return;
+ }
+
+ // get more aggressive in termination.
+ server.shutdownNow();
+
+ int count = 0;
+ while (!server.isTerminated() & count < 30) {
+ count++;
+ logger.debug("Waiting for termination");
+ Thread.sleep(100);
+ }
+
+ if (!server.isTerminated()) {
+ logger.warn("Couldn't shutdown server, resources likely will be leaked.");
+ }
+ }
+
+ /** Create a builder for a Flight server. */
+ public static Builder builder() {
+ return new Builder();
+ }
+
+ /** Create a builder for a Flight server. */
+ public static Builder builder(BufferAllocator allocator, Location location, FlightProducer producer) {
+ return new Builder(allocator, location, producer);
+ }
+
+ /** A builder for Flight servers. */
+ public static final class Builder {
+ private BufferAllocator allocator;
+ private Location location;
+ private FlightProducer producer;
+ private final Map<String, Object> builderOptions;
+ private ServerAuthHandler authHandler = ServerAuthHandler.NO_OP;
+ private CallHeaderAuthenticator headerAuthenticator = CallHeaderAuthenticator.NO_OP;
+ private ExecutorService executor = null;
+ private int maxInboundMessageSize = MAX_GRPC_MESSAGE_SIZE;
+ private InputStream certChain;
+ private InputStream key;
+ private final List<KeyFactory<?>> interceptors;
+ // Keep track of inserted interceptors
+ private final Set<String> interceptorKeys;
+
+ Builder() {
+ builderOptions = new HashMap<>();
+ interceptors = new ArrayList<>();
+ interceptorKeys = new HashSet<>();
+ }
+
+ Builder(BufferAllocator allocator, Location location, FlightProducer producer) {
+ this();
+ this.allocator = Preconditions.checkNotNull(allocator);
+ this.location = Preconditions.checkNotNull(location);
+ this.producer = Preconditions.checkNotNull(producer);
+ }
+
+ /** Create the server for this builder. */
+ public FlightServer build() {
+ // Add the auth middleware if applicable.
+ if (headerAuthenticator != CallHeaderAuthenticator.NO_OP) {
+ this.middleware(FlightServerMiddleware.Key.of(Auth2Constants.AUTHORIZATION_HEADER),
+ new ServerCallHeaderAuthMiddleware.Factory(headerAuthenticator));
+ }
+
+ this.middleware(FlightConstants.HEADER_KEY, new ServerHeaderMiddleware.Factory());
+
+ final NettyServerBuilder builder;
+ switch (location.getUri().getScheme()) {
+ case LocationSchemes.GRPC_DOMAIN_SOCKET: {
+ // The implementation is platform-specific, so we have to find the classes at runtime
+ builder = NettyServerBuilder.forAddress(location.toSocketAddress());
+ try {
+ try {
+ // Linux
+ builder.channelType(
+ (Class<? extends ServerChannel>) Class
+ .forName("io.netty.channel.epoll.EpollServerDomainSocketChannel"));
+ final EventLoopGroup elg = (EventLoopGroup) Class.forName("io.netty.channel.epoll.EpollEventLoopGroup")
+ .newInstance();
+ builder.bossEventLoopGroup(elg).workerEventLoopGroup(elg);
+ } catch (ClassNotFoundException e) {
+ // BSD
+ builder.channelType(
+ (Class<? extends ServerChannel>) Class
+ .forName("io.netty.channel.kqueue.KQueueServerDomainSocketChannel"));
+ final EventLoopGroup elg = (EventLoopGroup) Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup")
+ .newInstance();
+ builder.bossEventLoopGroup(elg).workerEventLoopGroup(elg);
+ }
+ } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) {
+ throw new UnsupportedOperationException(
+ "Could not find suitable Netty native transport implementation for domain socket address.");
+ }
+ break;
+ }
+ case LocationSchemes.GRPC:
+ case LocationSchemes.GRPC_INSECURE: {
+ builder = NettyServerBuilder.forAddress(location.toSocketAddress());
+ break;
+ }
+ case LocationSchemes.GRPC_TLS: {
+ if (certChain == null) {
+ throw new IllegalArgumentException("Must provide a certificate and key to serve gRPC over TLS");
+ }
+ builder = NettyServerBuilder.forAddress(location.toSocketAddress());
+ break;
+ }
+ default:
+ throw new IllegalArgumentException("Scheme is not supported: " + location.getUri().getScheme());
+ }
+
+ if (certChain != null) {
+ builder.useTransportSecurity(certChain, key);
+ }
+
+ // Share one executor between the gRPC service, DoPut, and Handshake
+ final ExecutorService exec;
+ // We only want to have FlightServer close the gRPC executor if we created it here. We should not close
+ // user-supplied executors.
+ final ExecutorService grpcExecutor;
+ if (executor != null) {
+ exec = executor;
+ grpcExecutor = null;
+ } else {
+ exec = Executors.newCachedThreadPool(
+ // Name threads for better debuggability
+ new ThreadFactoryBuilder().setNameFormat("flight-server-default-executor-%d").build());
+ grpcExecutor = exec;
+ }
+ final FlightBindingService flightService = new FlightBindingService(allocator, producer, authHandler, exec);
+ builder
+ .executor(exec)
+ .maxInboundMessageSize(maxInboundMessageSize)
+ .addService(
+ ServerInterceptors.intercept(
+ flightService,
+ new ServerAuthInterceptor(authHandler)));
+
+ // Allow hooking into the gRPC builder. This is not guaranteed to be available on all Arrow versions or
+ // Flight implementations.
+ builderOptions.computeIfPresent("grpc.builderConsumer", (key, builderConsumer) -> {
+ final Consumer<NettyServerBuilder> consumer = (Consumer<NettyServerBuilder>) builderConsumer;
+ consumer.accept(builder);
+ return null;
+ });
+
+ // Allow explicitly setting some Netty-specific options
+ builderOptions.computeIfPresent("netty.channelType", (key, channelType) -> {
+ builder.channelType((Class<? extends ServerChannel>) channelType);
+ return null;
+ });
+ builderOptions.computeIfPresent("netty.bossEventLoopGroup", (key, elg) -> {
+ builder.bossEventLoopGroup((EventLoopGroup) elg);
+ return null;
+ });
+ builderOptions.computeIfPresent("netty.workerEventLoopGroup", (key, elg) -> {
+ builder.workerEventLoopGroup((EventLoopGroup) elg);
+ return null;
+ });
+
+ builder.intercept(new ServerInterceptorAdapter(interceptors));
+ return new FlightServer(location, builder.build(), grpcExecutor);
+ }
+
+ /**
+ * Set the maximum size of a message. Defaults to "unlimited", depending on the underlying transport.
+ */
+ public Builder maxInboundMessageSize(int maxMessageSize) {
+ this.maxInboundMessageSize = maxMessageSize;
+ return this;
+ }
+
+ /**
+ * Enable TLS on the server.
+ * @param certChain The certificate chain to use.
+ * @param key The private key to use.
+ */
+ public Builder useTls(final File certChain, final File key) throws IOException {
+ this.certChain = new FileInputStream(certChain);
+ this.key = new FileInputStream(key);
+ return this;
+ }
+
+ /**
+ * Enable TLS on the server.
+ * @param certChain The certificate chain to use.
+ * @param key The private key to use.
+ */
+ public Builder useTls(final InputStream certChain, final InputStream key) {
+ this.certChain = certChain;
+ this.key = key;
+ return this;
+ }
+
+ /**
+ * Set the executor used by the server.
+ *
+ * <p>Flight will NOT take ownership of the executor. The application must clean it up if one is provided. (If not
+ * provided, Flight will use a default executor which it will clean up.)
+ */
+ public Builder executor(ExecutorService executor) {
+ this.executor = executor;
+ return this;
+ }
+
+ /**
+ * Set the authentication handler.
+ */
+ public Builder authHandler(ServerAuthHandler authHandler) {
+ this.authHandler = authHandler;
+ return this;
+ }
+
+ /**
+ * Set the header-based authentication mechanism.
+ */
+ public Builder headerAuthenticator(CallHeaderAuthenticator headerAuthenticator) {
+ this.headerAuthenticator = headerAuthenticator;
+ return this;
+ }
+
+ /**
+ * Provide a transport-specific option. Not guaranteed to have any effect.
+ */
+ public Builder transportHint(final String key, Object option) {
+ builderOptions.put(key, option);
+ return this;
+ }
+
+ /**
+ * Add a Flight middleware component to inspect and modify requests to this service.
+ *
+ * @param key An identifier for this middleware component. Service implementations can retrieve the middleware
+ * instance for the current call using {@link org.apache.arrow.flight.FlightProducer.CallContext}.
+ * @param factory A factory for the middleware.
+ * @param <T> The middleware type.
+ * @throws IllegalArgumentException if the key already exists
+ */
+ public <T extends FlightServerMiddleware> Builder middleware(final FlightServerMiddleware.Key<T> key,
+ final FlightServerMiddleware.Factory<T> factory) {
+ if (interceptorKeys.contains(key.key)) {
+ throw new IllegalArgumentException("Key already exists: " + key.key);
+ }
+ interceptors.add(new KeyFactory<>(key, factory));
+ interceptorKeys.add(key.key);
+ return this;
+ }
+
+ public Builder allocator(BufferAllocator allocator) {
+ this.allocator = Preconditions.checkNotNull(allocator);
+ return this;
+ }
+
+ public Builder location(Location location) {
+ this.location = Preconditions.checkNotNull(location);
+ return this;
+ }
+
+ public Builder producer(FlightProducer producer) {
+ this.producer = Preconditions.checkNotNull(producer);
+ return this;
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServerMiddleware.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServerMiddleware.java
new file mode 100644
index 000000000..9bc8bbfe7
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServerMiddleware.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.Objects;
+
+/**
+ * Server-side middleware for Flight calls.
+ *
+ * <p>Middleware are instantiated per-call.
+ *
+ * <p>Methods are not guaranteed to be called on any particular thread, relative to the thread that Flight requests are
+ * executed on. Do not depend on thread-local storage; instead, use state on the middleware instance. Service
+ * implementations may communicate with middleware implementations through
+ * {@link org.apache.arrow.flight.FlightProducer.CallContext#getMiddleware(Key)}. Methods on the middleware instance
+ * are non-reentrant, that is, a particular RPC will not make multiple concurrent calls to methods on a single
+ * middleware instance. However, methods on the factory instance are expected to be thread-safe, and if the factory
+ * instance returns the same middleware object more than once, then that middleware object must be thread-safe.
+ */
+public interface FlightServerMiddleware {
+
+ /**
+ * A factory for Flight server middleware.
+ * @param <T> The middleware type.
+ */
+ interface Factory<T extends FlightServerMiddleware> {
+ /**
+ * A callback for when the call starts.
+ *
+ * @param info Details about the call.
+ * @param incomingHeaders A mutable set of request headers.
+ * @param context Context about the current request.
+ *
+ * @throws FlightRuntimeException if the middleware wants to reject the call with the given status
+ */
+ T onCallStarted(CallInfo info, CallHeaders incomingHeaders, RequestContext context);
+ }
+
+ /**
+ * A key for Flight server middleware. On a server, middleware instances are identified by this key.
+ *
+ * <p>Keys use reference equality, so instances should be shared.
+ *
+ * @param <T> The middleware class stored in this key. This provides a compile-time check when retrieving instances.
+ */
+ class Key<T extends FlightServerMiddleware> {
+ final String key;
+
+ Key(String key) {
+ this.key = Objects.requireNonNull(key, "Key must not be null.");
+ }
+
+ /**
+ * Create a new key for the given type.
+ */
+ public static <T extends FlightServerMiddleware> Key<T> of(String key) {
+ return new Key<>(key);
+ }
+ }
+
+ /**
+ * Callback for when the underlying transport is about to send response headers.
+ *
+ * @param outgoingHeaders A mutable set of response headers. These can be manipulated to send different headers to the
+ * client.
+ */
+ void onBeforeSendingHeaders(CallHeaders outgoingHeaders);
+
+ /**
+ * Callback for when the underlying transport has completed a call.
+ * @param status Whether the call completed successfully or not.
+ */
+ void onCallCompleted(CallStatus status);
+
+ /**
+ * Callback for when an RPC method implementation throws an uncaught exception.
+ *
+ * <p>May be called multiple times, and may be called before or after {@link #onCallCompleted(CallStatus)}.
+ * Generally, an uncaught exception will end the call with a error {@link CallStatus}, and will be reported to {@link
+ * #onCallCompleted(CallStatus)}, but not necessarily this method.
+ *
+ * @param err The exception that was thrown.
+ */
+ void onCallErrored(Throwable err);
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java
new file mode 100644
index 000000000..4fb0dea2c
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java
@@ -0,0 +1,427 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.function.BooleanSupplier;
+import java.util.function.Consumer;
+
+import org.apache.arrow.flight.FlightProducer.ServerStreamListener;
+import org.apache.arrow.flight.FlightServerMiddleware.Key;
+import org.apache.arrow.flight.auth.AuthConstants;
+import org.apache.arrow.flight.auth.ServerAuthHandler;
+import org.apache.arrow.flight.auth.ServerAuthWrapper;
+import org.apache.arrow.flight.auth2.Auth2Constants;
+import org.apache.arrow.flight.grpc.ContextPropagatingExecutorService;
+import org.apache.arrow.flight.grpc.RequestContextAdapter;
+import org.apache.arrow.flight.grpc.ServerInterceptorAdapter;
+import org.apache.arrow.flight.grpc.StatusUtils;
+import org.apache.arrow.flight.impl.Flight;
+import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceImplBase;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Strings;
+
+import io.grpc.stub.ServerCallStreamObserver;
+import io.grpc.stub.StreamObserver;
+
+/**
+ * GRPC service implementation for a flight server.
+ */
+class FlightService extends FlightServiceImplBase {
+
+ private static final Logger logger = LoggerFactory.getLogger(FlightService.class);
+ private static final int PENDING_REQUESTS = 5;
+
+ private final BufferAllocator allocator;
+ private final FlightProducer producer;
+ private final ServerAuthHandler authHandler;
+ private final ExecutorService executors;
+
+ FlightService(BufferAllocator allocator, FlightProducer producer, ServerAuthHandler authHandler,
+ ExecutorService executors) {
+ this.allocator = allocator;
+ this.producer = producer;
+ this.authHandler = authHandler;
+ this.executors = new ContextPropagatingExecutorService(executors);
+ }
+
+ private CallContext makeContext(ServerCallStreamObserver<?> responseObserver) {
+ // Try to get the peer identity from middleware first (using the auth2 interfaces).
+ final RequestContext context = RequestContextAdapter.REQUEST_CONTEXT_KEY.get();
+ String peerIdentity = null;
+ if (context != null) {
+ peerIdentity = context.get(Auth2Constants.PEER_IDENTITY_KEY);
+ }
+
+ if (Strings.isNullOrEmpty(peerIdentity)) {
+ // Try the legacy auth interface, which defaults to empty string.
+ peerIdentity = AuthConstants.PEER_IDENTITY_KEY.get();
+ }
+
+ return new CallContext(peerIdentity, responseObserver::isCancelled);
+ }
+
+ @Override
+ public StreamObserver<Flight.HandshakeRequest> handshake(StreamObserver<Flight.HandshakeResponse> responseObserver) {
+ // This method is not meaningful with the auth2 interfaces. Authentication would already
+ // have happened by header/middleware with the auth2 classes.
+ return ServerAuthWrapper.wrapHandshake(authHandler, responseObserver, executors);
+ }
+
+ @Override
+ public void listFlights(Flight.Criteria criteria, StreamObserver<Flight.FlightInfo> responseObserver) {
+ final StreamPipe<FlightInfo, Flight.FlightInfo> listener = StreamPipe
+ .wrap(responseObserver, FlightInfo::toProtocol, this::handleExceptionWithMiddleware);
+ try {
+ final CallContext context = makeContext((ServerCallStreamObserver<?>) responseObserver);
+ producer.listFlights(context, new Criteria(criteria), listener);
+ } catch (Exception ex) {
+ listener.onError(ex);
+ }
+ // Do NOT call StreamPipe#onCompleted, as the FlightProducer implementation may be asynchronous
+ }
+
+ public void doGetCustom(Flight.Ticket ticket, StreamObserver<ArrowMessage> responseObserverSimple) {
+ final ServerCallStreamObserver<ArrowMessage> responseObserver =
+ (ServerCallStreamObserver<ArrowMessage>) responseObserverSimple;
+
+ final GetListener listener = new GetListener(responseObserver, this::handleExceptionWithMiddleware);
+ try {
+ producer.getStream(makeContext(responseObserver), new Ticket(ticket), listener);
+ } catch (Exception ex) {
+ listener.error(ex);
+ }
+ // Do NOT call GetListener#completed, as the implementation of getStream may be asynchronous
+ }
+
+ @Override
+ public void doAction(Flight.Action request, StreamObserver<Flight.Result> responseObserver) {
+ final StreamPipe<Result, Flight.Result> listener = StreamPipe
+ .wrap(responseObserver, Result::toProtocol, this::handleExceptionWithMiddleware);
+ try {
+ final CallContext context = makeContext((ServerCallStreamObserver<?>) responseObserver);
+ producer.doAction(context, new Action(request), listener);
+ } catch (Exception ex) {
+ listener.onError(ex);
+ }
+ // Do NOT call StreamPipe#onCompleted, as the FlightProducer implementation may be asynchronous
+ }
+
+ @Override
+ public void listActions(Flight.Empty request, StreamObserver<Flight.ActionType> responseObserver) {
+ final StreamPipe<org.apache.arrow.flight.ActionType, Flight.ActionType> listener = StreamPipe
+ .wrap(responseObserver, ActionType::toProtocol, this::handleExceptionWithMiddleware);
+ try {
+ final CallContext context = makeContext((ServerCallStreamObserver<?>) responseObserver);
+ producer.listActions(context, listener);
+ } catch (Exception ex) {
+ listener.onError(ex);
+ }
+ // Do NOT call StreamPipe#onCompleted, as the FlightProducer implementation may be asynchronous
+ }
+
+ private static class GetListener extends OutboundStreamListenerImpl implements ServerStreamListener {
+ private ServerCallStreamObserver<ArrowMessage> responseObserver;
+ private final Consumer<Throwable> errorHandler;
+ private Runnable onCancelHandler = null;
+ private Runnable onReadyHandler = null;
+ private boolean completed;
+
+ public GetListener(ServerCallStreamObserver<ArrowMessage> responseObserver, Consumer<Throwable> errorHandler) {
+ super(null, responseObserver);
+ this.errorHandler = errorHandler;
+ this.completed = false;
+ this.responseObserver = responseObserver;
+ this.responseObserver.setOnCancelHandler(this::onCancel);
+ this.responseObserver.setOnReadyHandler(this::onReady);
+ this.responseObserver.disableAutoInboundFlowControl();
+ }
+
+ private void onCancel() {
+ logger.debug("Stream cancelled by client.");
+ if (onCancelHandler != null) {
+ onCancelHandler.run();
+ }
+ }
+
+ private void onReady() {
+ if (onReadyHandler != null) {
+ onReadyHandler.run();
+ }
+ }
+
+ @Override
+ public void setOnCancelHandler(Runnable handler) {
+ this.onCancelHandler = handler;
+ }
+
+ @Override
+ public void setOnReadyHandler(Runnable handler) {
+ this.onReadyHandler = handler;
+ }
+
+ @Override
+ public boolean isCancelled() {
+ return responseObserver.isCancelled();
+ }
+
+ @Override
+ protected void waitUntilStreamReady() {
+ // Don't do anything - service implementations are expected to manage backpressure themselves
+ }
+
+ @Override
+ public void error(Throwable ex) {
+ if (!completed) {
+ completed = true;
+ super.error(ex);
+ } else {
+ errorHandler.accept(ex);
+ }
+ }
+
+ @Override
+ public void completed() {
+ if (!completed) {
+ completed = true;
+ super.completed();
+ } else {
+ errorHandler.accept(new IllegalStateException("Tried to complete already-completed call"));
+ }
+ }
+ }
+
+ public StreamObserver<ArrowMessage> doPutCustom(final StreamObserver<Flight.PutResult> responseObserverSimple) {
+ ServerCallStreamObserver<Flight.PutResult> responseObserver =
+ (ServerCallStreamObserver<Flight.PutResult>) responseObserverSimple;
+ responseObserver.disableAutoInboundFlowControl();
+ responseObserver.request(1);
+
+ final StreamPipe<PutResult, Flight.PutResult> ackStream = StreamPipe
+ .wrap(responseObserver, PutResult::toProtocol, this::handleExceptionWithMiddleware);
+ final FlightStream fs = new FlightStream(
+ allocator,
+ PENDING_REQUESTS,
+ /* server-upload streams are not cancellable */null,
+ responseObserver::request);
+ // When the ackStream is completed, the FlightStream will be closed with it
+ ackStream.setAutoCloseable(fs);
+ final StreamObserver<ArrowMessage> observer = fs.asObserver();
+ executors.submit(() -> {
+ try {
+ producer.acceptPut(makeContext(responseObserver), fs, ackStream).run();
+ } catch (Exception ex) {
+ ackStream.onError(ex);
+ } finally {
+ // ARROW-6136: Close the stream if and only if acceptPut hasn't closed it itself
+ // We don't do this for other streams since the implementation may be asynchronous
+ ackStream.ensureCompleted();
+ }
+ });
+
+ return observer;
+ }
+
+ @Override
+ public void getFlightInfo(Flight.FlightDescriptor request, StreamObserver<Flight.FlightInfo> responseObserver) {
+ final FlightInfo info;
+ try {
+ info = producer
+ .getFlightInfo(makeContext((ServerCallStreamObserver<?>) responseObserver), new FlightDescriptor(request));
+ } catch (Exception ex) {
+ // Don't capture exceptions from onNext or onCompleted with this block - because then we can't call onError
+ responseObserver.onError(StatusUtils.toGrpcException(ex));
+ return;
+ }
+ responseObserver.onNext(info.toProtocol());
+ responseObserver.onCompleted();
+ }
+
+ /**
+ * Broadcast the given exception to all registered middleware.
+ */
+ private void handleExceptionWithMiddleware(Throwable t) {
+ final Map<Key<?>, FlightServerMiddleware> middleware = ServerInterceptorAdapter.SERVER_MIDDLEWARE_KEY.get();
+ if (middleware == null || middleware.isEmpty()) {
+ logger.error("Uncaught exception in Flight method body", t);
+ return;
+ }
+ middleware.forEach((k, v) -> v.onCallErrored(t));
+ }
+
+ @Override
+ public void getSchema(Flight.FlightDescriptor request, StreamObserver<Flight.SchemaResult> responseObserver) {
+ try {
+ SchemaResult result = producer
+ .getSchema(makeContext((ServerCallStreamObserver<?>) responseObserver),
+ new FlightDescriptor(request));
+ responseObserver.onNext(result.toProtocol());
+ responseObserver.onCompleted();
+ } catch (Exception ex) {
+ responseObserver.onError(StatusUtils.toGrpcException(ex));
+ }
+ }
+
+ /** Ensures that other resources are cleaned up when the service finishes its call. */
+ private static class ExchangeListener extends GetListener {
+
+ private AutoCloseable resource;
+ private boolean closed = false;
+ private Runnable onCancelHandler = null;
+
+ public ExchangeListener(ServerCallStreamObserver<ArrowMessage> responseObserver, Consumer<Throwable> errorHandler) {
+ super(responseObserver, errorHandler);
+ this.resource = null;
+ super.setOnCancelHandler(() -> {
+ try {
+ if (onCancelHandler != null) {
+ onCancelHandler.run();
+ }
+ } finally {
+ cleanup();
+ }
+ });
+ }
+
+ private void cleanup() {
+ if (closed) {
+ // Prevent double-free. gRPC will call the OnCancelHandler even on a normal call end, which means that
+ // we'll double-free without this guard.
+ return;
+ }
+ closed = true;
+ try {
+ AutoCloseables.close(resource);
+ } catch (Exception e) {
+ throw CallStatus.INTERNAL
+ .withCause(e)
+ .withDescription("Server internal error cleaning up resources")
+ .toRuntimeException();
+ }
+ }
+
+ @Override
+ public void error(Throwable ex) {
+ try {
+ this.cleanup();
+ } finally {
+ super.error(ex);
+ }
+ }
+
+ @Override
+ public void completed() {
+ try {
+ this.cleanup();
+ } finally {
+ super.completed();
+ }
+ }
+
+ @Override
+ public void setOnCancelHandler(Runnable handler) {
+ onCancelHandler = handler;
+ }
+ }
+
+ public StreamObserver<ArrowMessage> doExchangeCustom(StreamObserver<ArrowMessage> responseObserverSimple) {
+ final ServerCallStreamObserver<ArrowMessage> responseObserver =
+ (ServerCallStreamObserver<ArrowMessage>) responseObserverSimple;
+ final ExchangeListener listener = new ExchangeListener(
+ responseObserver,
+ this::handleExceptionWithMiddleware);
+ final FlightStream fs = new FlightStream(
+ allocator,
+ PENDING_REQUESTS,
+ /* server-upload streams are not cancellable */null,
+ responseObserver::request);
+ // When service completes the call, this cleans up the FlightStream
+ listener.resource = fs;
+ responseObserver.disableAutoInboundFlowControl();
+ responseObserver.request(1);
+ final StreamObserver<ArrowMessage> observer = fs.asObserver();
+ try {
+ executors.submit(() -> {
+ try {
+ producer.doExchange(makeContext(responseObserver), fs, listener);
+ } catch (Exception ex) {
+ listener.error(ex);
+ }
+ // We do not clean up or close anything here, to allow long-running asynchronous implementations.
+ // It is the service's responsibility to call completed() or error(), which will then clean up the FlightStream.
+ });
+ } catch (Exception ex) {
+ listener.error(ex);
+ }
+ return observer;
+ }
+
+ /**
+ * Call context for the service.
+ */
+ static class CallContext implements FlightProducer.CallContext {
+
+ private final String peerIdentity;
+ private final BooleanSupplier isCancelled;
+
+ CallContext(final String peerIdentity, BooleanSupplier isCancelled) {
+ this.peerIdentity = peerIdentity;
+ this.isCancelled = isCancelled;
+ }
+
+ @Override
+ public String peerIdentity() {
+ return peerIdentity;
+ }
+
+ @Override
+ public boolean isCancelled() {
+ return this.isCancelled.getAsBoolean();
+ }
+
+ @Override
+ public <T extends FlightServerMiddleware> T getMiddleware(Key<T> key) {
+ final Map<Key<?>, FlightServerMiddleware> middleware = ServerInterceptorAdapter.SERVER_MIDDLEWARE_KEY.get();
+ if (middleware == null) {
+ return null;
+ }
+ final FlightServerMiddleware m = middleware.get(key);
+ if (m == null) {
+ return null;
+ }
+ @SuppressWarnings("unchecked") final T result = (T) m;
+ return result;
+ }
+
+ @Override
+ public Map<Key<?>, FlightServerMiddleware> getMiddleware() {
+ final Map<Key<?>, FlightServerMiddleware> middleware = ServerInterceptorAdapter.SERVER_MIDDLEWARE_KEY.get();
+ if (middleware == null) {
+ return Collections.emptyMap();
+ }
+ // This is an unmodifiable map
+ return middleware;
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStatusCode.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStatusCode.java
new file mode 100644
index 000000000..3d96877ba
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStatusCode.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+/**
+ * A status code describing the result of a Flight call.
+ */
+public enum FlightStatusCode {
+ /**
+ * The call completed successfully. Generally clients will not see this, but middleware may.
+ */
+ OK,
+ /**
+ * An unknown error occurred. This may also be the result of an implementation error on the server-side; by default,
+ * unhandled server exceptions result in this code.
+ */
+ UNKNOWN,
+ /**
+ * An internal/implementation error occurred.
+ */
+ INTERNAL,
+ /**
+ * One or more of the given arguments was invalid.
+ */
+ INVALID_ARGUMENT,
+ /**
+ * The operation timed out.
+ */
+ TIMED_OUT,
+ /**
+ * The operation describes a resource that does not exist.
+ */
+ NOT_FOUND,
+ /**
+ * The operation creates a resource that already exists.
+ */
+ ALREADY_EXISTS,
+ /**
+ * The operation was cancelled.
+ */
+ CANCELLED,
+ /**
+ * The client was not authenticated.
+ */
+ UNAUTHENTICATED,
+ /**
+ * The client did not have permission to make the call.
+ */
+ UNAUTHORIZED,
+ /**
+ * The requested operation is not implemented.
+ */
+ UNIMPLEMENTED,
+ /**
+ * The server cannot currently handle the request. This should be used for retriable requests, i.e. the server
+ * should send this code only if it has not done any work.
+ */
+ UNAVAILABLE,
+ ;
+
+ /**
+ * Create a blank {@link CallStatus} with this code.
+ */
+ public CallStatus toStatus() {
+ return new CallStatus(this);
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java
new file mode 100644
index 000000000..03ce13c97
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java
@@ -0,0 +1,505 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.LinkedBlockingQueue;
+
+import org.apache.arrow.flight.ArrowMessage.HeaderType;
+import org.apache.arrow.flight.grpc.StatusUtils;
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.util.VisibleForTesting;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.VectorLoader;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.dictionary.Dictionary;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.types.MetadataVersion;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.arrow.vector.util.DictionaryUtility;
+import org.apache.arrow.vector.validate.MetadataV4UnionChecker;
+
+import com.google.common.util.concurrent.SettableFuture;
+
+import io.grpc.stub.StreamObserver;
+
+/**
+ * An adaptor between protobuf streams and flight data streams.
+ */
+public class FlightStream implements AutoCloseable {
+ // Use AutoCloseable sentinel objects to simplify logic in #close
+ private final AutoCloseable DONE = () -> {
+ };
+ private final AutoCloseable DONE_EX = () -> {
+ };
+
+ private final BufferAllocator allocator;
+ private final Cancellable cancellable;
+ private final LinkedBlockingQueue<AutoCloseable> queue = new LinkedBlockingQueue<>();
+ private final SettableFuture<VectorSchemaRoot> root = SettableFuture.create();
+ private final SettableFuture<FlightDescriptor> descriptor = SettableFuture.create();
+ private final int pendingTarget;
+ private final Requestor requestor;
+ // The completion flags.
+ // This flag is only updated as the user iterates through the data, i.e. it tracks whether the user has read all the
+ // data and closed the stream
+ final CompletableFuture<Void> completed;
+ // This flag is immediately updated when gRPC signals that the server has ended the call. This is used to make sure
+ // we don't block forever trying to write to a server that has rejected a call.
+ final CompletableFuture<Void> cancelled;
+
+ private volatile int pending = 1;
+ private volatile VectorSchemaRoot fulfilledRoot;
+ private DictionaryProvider.MapDictionaryProvider dictionaries;
+ private volatile VectorLoader loader;
+ private volatile Throwable ex;
+ private volatile ArrowBuf applicationMetadata = null;
+ @VisibleForTesting
+ volatile MetadataVersion metadataVersion = null;
+
+ /**
+ * Constructs a new instance.
+ *
+ * @param allocator The allocator to use for creating/reallocating buffers for Vectors.
+ * @param pendingTarget Target number of messages to receive.
+ * @param cancellable Used to cancel mid-stream requests.
+ * @param requestor A callback to determine how many pending items there are.
+ */
+ public FlightStream(BufferAllocator allocator, int pendingTarget, Cancellable cancellable, Requestor requestor) {
+ Objects.requireNonNull(allocator);
+ Objects.requireNonNull(requestor);
+ this.allocator = allocator;
+ this.pendingTarget = pendingTarget;
+ this.cancellable = cancellable;
+ this.requestor = requestor;
+ this.dictionaries = new DictionaryProvider.MapDictionaryProvider();
+ this.completed = new CompletableFuture<>();
+ this.cancelled = new CompletableFuture<>();
+ }
+
+ /**
+ * Get the schema for this stream. Blocks until the schema is available.
+ */
+ public Schema getSchema() {
+ return getRoot().getSchema();
+ }
+
+ /**
+ * Get the provider for dictionaries in this stream.
+ *
+ * <p>Does NOT retain a reference to the underlying dictionaries. Dictionaries may be updated as the stream is read.
+ * This method is intended for stream processing, where the application code will not retain references to values
+ * after the stream is closed.
+ *
+ * @throws IllegalStateException if {@link #takeDictionaryOwnership()} was called
+ * @see #takeDictionaryOwnership()
+ */
+ public DictionaryProvider getDictionaryProvider() {
+ if (dictionaries == null) {
+ throw new IllegalStateException("Dictionary ownership was claimed by the application.");
+ }
+ return dictionaries;
+ }
+
+ /**
+ * Get an owned reference to the dictionaries in this stream. Should be called after finishing reading the stream,
+ * but before closing.
+ *
+ * <p>If called, the client is responsible for closing the dictionaries in this provider. Can only be called once.
+ *
+ * @return The dictionary provider for the stream.
+ * @throws IllegalStateException if called more than once.
+ */
+ public DictionaryProvider takeDictionaryOwnership() {
+ if (dictionaries == null) {
+ throw new IllegalStateException("Dictionary ownership was claimed by the application.");
+ }
+ // Swap out the provider so it is not closed
+ final DictionaryProvider provider = dictionaries;
+ dictionaries = null;
+ return provider;
+ }
+
+ /**
+ * Get the descriptor for this stream. Only applicable on the server side of a DoPut operation. Will block until the
+ * client sends the descriptor.
+ */
+ public FlightDescriptor getDescriptor() {
+ // This blocks until the first message from the client is received.
+ try {
+ return descriptor.get();
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw CallStatus.INTERNAL.withCause(e).withDescription("Interrupted").toRuntimeException();
+ } catch (ExecutionException e) {
+ throw CallStatus.INTERNAL.withCause(e).withDescription("Error getting descriptor").toRuntimeException();
+ }
+ }
+
+ /**
+ * Closes the stream (freeing any existing resources).
+ *
+ * <p>If the stream isn't complete and is cancellable, this method will cancel and drain the stream first.
+ */
+ public void close() throws Exception {
+ final List<AutoCloseable> closeables = new ArrayList<>();
+ Throwable suppressor = null;
+ if (cancellable != null) {
+ // Client-side stream. Cancel the call, to help ensure gRPC doesn't deliver a message after close() ends.
+ // On the server side, we can't rely on draining the stream , because this gRPC bug means the completion callback
+ // may never run https://github.com/grpc/grpc-java/issues/5882
+ try {
+ synchronized (cancellable) {
+ if (!cancelled.isDone()) {
+ // Only cancel if the call is not done on the gRPC side
+ cancellable.cancel("Stream closed before end", /* no exception to report */null);
+ }
+ }
+ // Drain the stream without the lock (as next() implicitly needs the lock)
+ while (next()) { }
+ } catch (FlightRuntimeException e) {
+ suppressor = e;
+ }
+ }
+ // Perform these operations under a lock. This way the observer can't enqueue new messages while we're in the
+ // middle of cleanup. This should only be a concern for server-side streams since client-side streams are drained
+ // by the lambda above.
+ synchronized (completed) {
+ try {
+ if (fulfilledRoot != null) {
+ closeables.add(fulfilledRoot);
+ }
+ closeables.add(applicationMetadata);
+ closeables.addAll(queue);
+ if (dictionaries != null) {
+ dictionaries.getDictionaryIds().forEach(id -> closeables.add(dictionaries.lookup(id).getVector()));
+ }
+ if (suppressor != null) {
+ AutoCloseables.close(suppressor, closeables);
+ } else {
+ AutoCloseables.close(closeables);
+ }
+ } finally {
+ // The value of this CompletableFuture is meaningless, only whether it's completed (or has an exception)
+ // No-op if already complete
+ completed.complete(null);
+ }
+ }
+ }
+
+ /**
+ * Blocking request to load next item into list.
+ * @return Whether or not more data was found.
+ */
+ public boolean next() {
+ try {
+ if (completed.isDone() && queue.isEmpty()) {
+ return false;
+ }
+
+ pending--;
+ requestOutstanding();
+
+ Object data = queue.take();
+ if (DONE == data) {
+ queue.put(DONE);
+ // Other code ignores the value of this CompletableFuture, only whether it's completed (or has an exception)
+ completed.complete(null);
+ return false;
+ } else if (DONE_EX == data) {
+ queue.put(DONE_EX);
+ if (ex instanceof Exception) {
+ throw (Exception) ex;
+ } else {
+ throw new Exception(ex);
+ }
+ } else {
+ try (ArrowMessage msg = ((ArrowMessage) data)) {
+ if (msg.getMessageType() == HeaderType.NONE) {
+ updateMetadata(msg);
+ // We received a message without data, so erase any leftover data
+ if (fulfilledRoot != null) {
+ fulfilledRoot.clear();
+ }
+ } else if (msg.getMessageType() == HeaderType.RECORD_BATCH) {
+ checkMetadataVersion(msg);
+ // Ensure we have the root
+ root.get().clear();
+ try (ArrowRecordBatch arb = msg.asRecordBatch()) {
+ loader.load(arb);
+ }
+ updateMetadata(msg);
+ } else if (msg.getMessageType() == HeaderType.DICTIONARY_BATCH) {
+ checkMetadataVersion(msg);
+ // Ensure we have the root
+ root.get().clear();
+ try (ArrowDictionaryBatch arb = msg.asDictionaryBatch()) {
+ final long id = arb.getDictionaryId();
+ if (dictionaries == null) {
+ throw new IllegalStateException("Dictionary ownership was claimed by the application.");
+ }
+ final Dictionary dictionary = dictionaries.lookup(id);
+ if (dictionary == null) {
+ throw new IllegalArgumentException("Dictionary not defined in schema: ID " + id);
+ }
+
+ final FieldVector vector = dictionary.getVector();
+ final VectorSchemaRoot dictionaryRoot = new VectorSchemaRoot(Collections.singletonList(vector.getField()),
+ Collections.singletonList(vector), 0);
+ final VectorLoader dictionaryLoader = new VectorLoader(dictionaryRoot);
+ dictionaryLoader.load(arb.getDictionary());
+ }
+ return next();
+ } else {
+ throw new UnsupportedOperationException("Message type is unsupported: " + msg.getMessageType());
+ }
+ return true;
+ }
+ }
+ } catch (RuntimeException e) {
+ throw e;
+ } catch (ExecutionException e) {
+ throw StatusUtils.fromThrowable(e.getCause());
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ /** Update our metadata reference with a new one from this message. */
+ private void updateMetadata(ArrowMessage msg) {
+ if (this.applicationMetadata != null) {
+ this.applicationMetadata.close();
+ }
+ this.applicationMetadata = msg.getApplicationMetadata();
+ if (this.applicationMetadata != null) {
+ this.applicationMetadata.getReferenceManager().retain();
+ }
+ }
+
+ /** Ensure the Arrow metadata version doesn't change mid-stream. */
+ private void checkMetadataVersion(ArrowMessage msg) {
+ if (msg.asSchemaMessage() == null) {
+ return;
+ }
+ MetadataVersion receivedVersion = MetadataVersion.fromFlatbufID(msg.asSchemaMessage().getMessage().version());
+ if (this.metadataVersion != receivedVersion) {
+ throw new IllegalStateException("Metadata version mismatch: stream started as " +
+ this.metadataVersion + " but got message with version " + receivedVersion);
+ }
+ }
+
+ /**
+ * Get the current vector data from the stream.
+ *
+ * <p>The data in the root may change at any time. Clients should NOT modify the root, but instead unload the data
+ * into their own root.
+ *
+ * @throws FlightRuntimeException if there was an error reading the schema from the stream.
+ */
+ public VectorSchemaRoot getRoot() {
+ try {
+ return root.get();
+ } catch (InterruptedException e) {
+ throw CallStatus.INTERNAL.withCause(e).toRuntimeException();
+ } catch (ExecutionException e) {
+ throw StatusUtils.fromThrowable(e.getCause());
+ }
+ }
+
+ /**
+ * Check if there is a root (i.e. whether the other end has started sending data).
+ *
+ * Updated by calls to {@link #next()}.
+ *
+ * @return true if and only if the other end has started sending data.
+ */
+ public boolean hasRoot() {
+ return root.isDone();
+ }
+
+ /**
+ * Get the most recent metadata sent from the server. This may be cleared by calls to {@link #next()} if the server
+ * sends a message without metadata. This does NOT take ownership of the buffer - call retain() to create a reference
+ * if you need the buffer after a call to {@link #next()}.
+ *
+ * @return the application metadata. May be null.
+ */
+ public ArrowBuf getLatestMetadata() {
+ return applicationMetadata;
+ }
+
+ private synchronized void requestOutstanding() {
+ if (pending < pendingTarget) {
+ requestor.request(pendingTarget - pending);
+ pending = pendingTarget;
+ }
+ }
+
+ private class Observer implements StreamObserver<ArrowMessage> {
+
+ Observer() {
+ super();
+ }
+
+ /** Helper to add an item to the queue under the appropriate lock. */
+ private void enqueue(AutoCloseable message) {
+ synchronized (completed) {
+ if (completed.isDone()) {
+ // The stream is already closed (RPC ended), discard the message
+ AutoCloseables.closeNoChecked(message);
+ } else {
+ queue.add(message);
+ }
+ }
+ }
+
+ @Override
+ public void onNext(ArrowMessage msg) {
+ // Operations here have to be under a lock so that we don't add a message to the queue while in the middle of
+ // close().
+ requestOutstanding();
+ switch (msg.getMessageType()) {
+ case NONE: {
+ // No IPC message - pure metadata or descriptor
+ if (msg.getDescriptor() != null) {
+ descriptor.set(new FlightDescriptor(msg.getDescriptor()));
+ }
+ if (msg.getApplicationMetadata() != null) {
+ enqueue(msg);
+ }
+ break;
+ }
+ case SCHEMA: {
+ Schema schema = msg.asSchema();
+
+ // if there is app metadata in the schema message, make sure
+ // that we don't leak it.
+ ArrowBuf meta = msg.getApplicationMetadata();
+ if (meta != null) {
+ meta.close();
+ }
+
+ final List<Field> fields = new ArrayList<>();
+ final Map<Long, Dictionary> dictionaryMap = new HashMap<>();
+ for (final Field originalField : schema.getFields()) {
+ final Field updatedField = DictionaryUtility.toMemoryFormat(originalField, allocator, dictionaryMap);
+ fields.add(updatedField);
+ }
+ for (final Map.Entry<Long, Dictionary> entry : dictionaryMap.entrySet()) {
+ dictionaries.put(entry.getValue());
+ }
+ schema = new Schema(fields, schema.getCustomMetadata());
+ metadataVersion = MetadataVersion.fromFlatbufID(msg.asSchemaMessage().getMessage().version());
+ try {
+ MetadataV4UnionChecker.checkRead(schema, metadataVersion);
+ } catch (IOException e) {
+ ex = e;
+ enqueue(DONE_EX);
+ break;
+ }
+
+ synchronized (completed) {
+ if (!completed.isDone()) {
+ fulfilledRoot = VectorSchemaRoot.create(schema, allocator);
+ loader = new VectorLoader(fulfilledRoot);
+ if (msg.getDescriptor() != null) {
+ descriptor.set(new FlightDescriptor(msg.getDescriptor()));
+ }
+ root.set(fulfilledRoot);
+ }
+ }
+ break;
+ }
+ case RECORD_BATCH:
+ case DICTIONARY_BATCH:
+ enqueue(msg);
+ break;
+ case TENSOR:
+ default:
+ ex = new UnsupportedOperationException("Unable to handle message of type: " + msg.getMessageType());
+ enqueue(DONE_EX);
+ }
+ }
+
+ @Override
+ public void onError(Throwable t) {
+ ex = StatusUtils.fromThrowable(t);
+ queue.add(DONE_EX);
+ cancelled.complete(null);
+ root.setException(ex);
+ }
+
+ @Override
+ public void onCompleted() {
+ // Depends on gRPC calling onNext and onCompleted non-concurrently
+ cancelled.complete(null);
+ queue.add(DONE);
+ }
+ }
+
+ /**
+ * Cancels sending the stream to a client.
+ *
+ * <p>Callers should drain the stream (with {@link #next()}) to ensure all messages sent before cancellation are
+ * received and to wait for the underlying transport to acknowledge cancellation.
+ */
+ public void cancel(String message, Throwable exception) {
+ if (cancellable == null) {
+ throw new UnsupportedOperationException("Streams cannot be cancelled that are produced by client. " +
+ "Instead, server should reject incoming messages.");
+ }
+ cancellable.cancel(message, exception);
+ // Do not mark the stream as completed, as gRPC may still be delivering messages.
+ }
+
+ StreamObserver<ArrowMessage> asObserver() {
+ return new Observer();
+ }
+
+ /**
+ * Provides a callback to cancel a process that is in progress.
+ */
+ @FunctionalInterface
+ public interface Cancellable {
+ void cancel(String message, Throwable exception);
+ }
+
+ /**
+ * Provides a interface to request more items from a stream producer.
+ */
+ @FunctionalInterface
+ public interface Requestor {
+ /**
+ * Requests <code>count</code> more messages from the instance of this object.
+ */
+ void request(int count);
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/HeaderCallOption.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/HeaderCallOption.java
new file mode 100644
index 000000000..e2fad1a40
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/HeaderCallOption.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import io.grpc.Metadata;
+import io.grpc.stub.AbstractStub;
+import io.grpc.stub.MetadataUtils;
+
+/**
+ * Method option for supplying headers to method calls.
+ */
+public class HeaderCallOption implements CallOptions.GrpcCallOption {
+ private final Metadata propertiesMetadata = new Metadata();
+
+ /**
+ * Header property constructor.
+ *
+ * @param headers the headers that should be sent across. If a header is a string, it should only be valid ASCII
+ * characters. Binary headers should end in "-bin".
+ */
+ public HeaderCallOption(CallHeaders headers) {
+ for (String key : headers.keys()) {
+ if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
+ final Metadata.Key<byte[]> metaKey = Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER);
+ headers.getAllByte(key).forEach(v -> propertiesMetadata.put(metaKey, v));
+ } else {
+ final Metadata.Key<String> metaKey = Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER);
+ headers.getAll(key).forEach(v -> propertiesMetadata.put(metaKey, v));
+ }
+ }
+ }
+
+ @Override
+ public <T extends AbstractStub<T>> T wrapStub(T stub) {
+ return MetadataUtils.attachHeaders(stub, propertiesMetadata);
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java
new file mode 100644
index 000000000..1fbec7b5a
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.lang.reflect.InvocationTargetException;
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.util.Objects;
+
+import org.apache.arrow.flight.impl.Flight;
+
+/** A URI where a Flight stream is available. */
+public class Location {
+ private final URI uri;
+
+ /**
+ * Constructs a new instance.
+ *
+ * @param uri the URI of the Flight service
+ * @throws IllegalArgumentException if the URI scheme is unsupported
+ */
+ public Location(String uri) throws URISyntaxException {
+ this(new URI(uri));
+ }
+
+ /**
+ * Construct a new instance from an existing URI.
+ *
+ * @param uri the URI of the Flight service
+ */
+ public Location(URI uri) {
+ super();
+ Objects.requireNonNull(uri);
+ this.uri = uri;
+ }
+
+ public URI getUri() {
+ return uri;
+ }
+
+ /**
+ * Helper method to turn this Location into a SocketAddress.
+ *
+ * @return null if could not be converted
+ */
+ SocketAddress toSocketAddress() {
+ switch (uri.getScheme()) {
+ case LocationSchemes.GRPC:
+ case LocationSchemes.GRPC_TLS:
+ case LocationSchemes.GRPC_INSECURE: {
+ return new InetSocketAddress(uri.getHost(), uri.getPort());
+ }
+
+ case LocationSchemes.GRPC_DOMAIN_SOCKET: {
+ try {
+ // This dependency is not available on non-Unix platforms.
+ return (SocketAddress) Class.forName("io.netty.channel.unix.DomainSocketAddress")
+ .getConstructor(String.class)
+ .newInstance(uri.getPath());
+ } catch (InstantiationException | ClassNotFoundException | InvocationTargetException |
+ NoSuchMethodException | IllegalAccessException e) {
+ return null;
+ }
+ }
+
+ default: {
+ return null;
+ }
+ }
+ }
+
+ /**
+ * Convert this Location into its protocol-level representation.
+ */
+ Flight.Location toProtocol() {
+ return Flight.Location.newBuilder().setUri(uri.toString()).build();
+ }
+
+ /**
+ * Construct a URI for a Flight+gRPC server without transport security.
+ *
+ * @throws IllegalArgumentException if the constructed URI is invalid.
+ */
+ public static Location forGrpcInsecure(String host, int port) {
+ try {
+ return new Location(new URI(LocationSchemes.GRPC_INSECURE, null, host, port, null, null, null));
+ } catch (URISyntaxException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
+ /**
+ * Construct a URI for a Flight+gRPC server with transport security.
+ *
+ * @throws IllegalArgumentException if the constructed URI is invalid.
+ */
+ public static Location forGrpcTls(String host, int port) {
+ try {
+ return new Location(new URI(LocationSchemes.GRPC_TLS, null, host, port, null, null, null));
+ } catch (URISyntaxException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
+ /**
+ * Construct a URI for a Flight+gRPC server over a Unix domain socket.
+ *
+ * @throws IllegalArgumentException if the constructed URI is invalid.
+ */
+ public static Location forGrpcDomainSocket(String path) {
+ try {
+ return new Location(new URI(LocationSchemes.GRPC_DOMAIN_SOCKET, null, path, null));
+ } catch (URISyntaxException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "Location{" +
+ "uri=" + uri +
+ '}';
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ Location location = (Location) o;
+ return uri.equals(location.uri);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(uri);
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/LocationSchemes.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/LocationSchemes.java
new file mode 100644
index 000000000..872e5b1c2
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/LocationSchemes.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+/**
+ * Constants representing well-known URI schemes for Flight services.
+ */
+public final class LocationSchemes {
+ public static final String GRPC = "grpc";
+ public static final String GRPC_INSECURE = "grpc+tcp";
+ public static final String GRPC_DOMAIN_SOCKET = "grpc+unix";
+ public static final String GRPC_TLS = "grpc+tls";
+
+ private LocationSchemes() {
+ throw new AssertionError("Do not instantiate this class.");
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java
new file mode 100644
index 000000000..d1432f514
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+/**
+ * A {@link FlightProducer} that throws on all operations.
+ */
+public class NoOpFlightProducer implements FlightProducer {
+
+ @Override
+ public void getStream(CallContext context, Ticket ticket,
+ ServerStreamListener listener) {
+ listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException());
+ }
+
+ @Override
+ public void listFlights(CallContext context, Criteria criteria,
+ StreamListener<FlightInfo> listener) {
+ listener.onError(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException());
+ }
+
+ @Override
+ public FlightInfo getFlightInfo(CallContext context,
+ FlightDescriptor descriptor) {
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException();
+ }
+
+ @Override
+ public Runnable acceptPut(CallContext context,
+ FlightStream flightStream, StreamListener<PutResult> ackStream) {
+ throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException();
+ }
+
+ @Override
+ public void doAction(CallContext context, Action action,
+ StreamListener<Result> listener) {
+ listener.onError(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException());
+ }
+
+ @Override
+ public void listActions(CallContext context,
+ StreamListener<ActionType> listener) {
+ listener.onError(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException());
+ }
+
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/NoOpStreamListener.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/NoOpStreamListener.java
new file mode 100644
index 000000000..e06af1a10
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/NoOpStreamListener.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import org.apache.arrow.flight.FlightProducer.StreamListener;
+
+/**
+ * A {@link StreamListener} that does nothing for all callbacks.
+ * @param <T> The type of the callback object.
+ */
+public class NoOpStreamListener<T> implements StreamListener<T> {
+ private static NoOpStreamListener INSTANCE = new NoOpStreamListener();
+
+ /** Ignores the value received. */
+ @Override
+ public void onNext(T val) {
+ }
+
+ /** Ignores the error received. */
+ @Override
+ public void onError(Throwable t) {
+ }
+
+ /** Ignores the stream completion event. */
+ @Override
+ public void onCompleted() {
+ }
+
+ @SuppressWarnings("unchecked")
+ public static <T> StreamListener<T> getInstance() {
+ // Safe because we never use T
+ return (StreamListener<T>) INSTANCE;
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListener.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListener.java
new file mode 100644
index 000000000..38a44d0e5
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListener.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.ipc.message.IpcOption;
+
+/**
+ * An interface for writing data to a peer, client or server.
+ */
+public interface OutboundStreamListener {
+
+ /**
+ * A hint indicating whether the client is ready to receive data without excessive buffering.
+ *
+ * <p>Writers should poll this flag before sending data to respect backpressure from the client and
+ * avoid sending data faster than the client can handle. Ignoring this flag may mean that the server
+ * will start consuming excessive amounts of memory, as it may buffer messages in memory.
+ */
+ boolean isReady();
+
+ /**
+ * Set a callback for when the listener is ready for new calls to putNext(), i.e. {@link #isReady()}
+ * has become true.
+ *
+ * <p>Note that this callback may only be called some time after {@link #isReady()} becomes true, and may never
+ * be called if all executor threads on the server are busy, or the RPC method body is implemented in a blocking
+ * fashion. Note that isReady() must still be checked after the callback is run as it may have been run
+ * spuriously.
+ */
+ default void setOnReadyHandler(Runnable handler) {
+ throw new UnsupportedOperationException("Not yet implemented.");
+ }
+
+ /**
+ * Start sending data, using the schema of the given {@link VectorSchemaRoot}.
+ *
+ * <p>This method must be called before all others, except {@link #putMetadata(ArrowBuf)}.
+ */
+ default void start(VectorSchemaRoot root) {
+ start(root, null, IpcOption.DEFAULT);
+ }
+
+ /**
+ * Start sending data, using the schema of the given {@link VectorSchemaRoot}.
+ *
+ * <p>This method must be called before all others, except {@link #putMetadata(ArrowBuf)}.
+ */
+ default void start(VectorSchemaRoot root, DictionaryProvider dictionaries) {
+ start(root, dictionaries, IpcOption.DEFAULT);
+ }
+
+ /**
+ * Start sending data, using the schema of the given {@link VectorSchemaRoot}.
+ *
+ * <p>This method must be called before all others, except {@link #putMetadata(ArrowBuf)}.
+ */
+ void start(VectorSchemaRoot root, DictionaryProvider dictionaries, IpcOption option);
+
+ /**
+ * Send the current contents of the associated {@link VectorSchemaRoot}.
+ *
+ * <p>This will not necessarily block until the message is actually sent; it may buffer messages
+ * in memory. Use {@link #isReady()} to check if there is backpressure and avoid excessive buffering.
+ */
+ void putNext();
+
+ /**
+ * Send the current contents of the associated {@link VectorSchemaRoot} alongside application-defined metadata.
+ * @param metadata The metadata to send. Ownership of the buffer is transferred to the Flight implementation.
+ */
+ void putNext(ArrowBuf metadata);
+
+ /**
+ * Send a pure metadata message without any associated data.
+ *
+ * <p>This may be called without starting the stream.
+ */
+ void putMetadata(ArrowBuf metadata);
+
+ /**
+ * Indicate an error to the client. Terminates the stream; do not call {@link #completed()} afterwards.
+ */
+ void error(Throwable ex);
+
+ /**
+ * Indicate that transmission is finished.
+ */
+ void completed();
+
+ /**
+ * Toggle whether to ues the zero-copy write optimization.
+ *
+ * <p>By default or when disabled, Arrow may copy data into a buffer for the underlying implementation to
+ * send. When enabled, Arrow will instead try to directly enqueue the Arrow buffer for sending. Not all
+ * implementations support this optimization, so even if enabled, you may not see a difference.
+ *
+ * <p>In this mode, buffers must not be reused after they are written with {@link #putNext()}. For example,
+ * you would have to call {@link VectorSchemaRoot#allocateNew()} after every call to {@link #putNext()}.
+ * Hence, this is not enabled by default.
+ *
+ * <p>The default value can be toggled globally by setting the JVM property arrow.flight.enable_zero_copy_write
+ * or the environment variable ARROW_FLIGHT_ENABLE_ZERO_COPY_WRITE.
+ */
+ default void setUseZeroCopy(boolean enabled) {}
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListenerImpl.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListenerImpl.java
new file mode 100644
index 000000000..8c1cfde3a
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListenerImpl.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import org.apache.arrow.flight.grpc.StatusUtils;
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.util.Preconditions;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.VectorUnloader;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.ipc.message.IpcOption;
+
+import io.grpc.stub.CallStreamObserver;
+
+/**
+ * A base class for writing Arrow data to a Flight stream.
+ */
+abstract class OutboundStreamListenerImpl implements OutboundStreamListener {
+ private final FlightDescriptor descriptor; // nullable
+ protected final CallStreamObserver<ArrowMessage> responseObserver;
+ protected volatile VectorUnloader unloader; // null until stream started
+ protected IpcOption option; // null until stream started
+ protected boolean tryZeroCopy = ArrowMessage.ENABLE_ZERO_COPY_WRITE;
+
+ OutboundStreamListenerImpl(FlightDescriptor descriptor, CallStreamObserver<ArrowMessage> responseObserver) {
+ Preconditions.checkNotNull(responseObserver, "responseObserver must be provided");
+ this.descriptor = descriptor;
+ this.responseObserver = responseObserver;
+ this.unloader = null;
+ }
+
+ @Override
+ public boolean isReady() {
+ return responseObserver.isReady();
+ }
+
+ @Override
+ public void setOnReadyHandler(Runnable handler) {
+ responseObserver.setOnReadyHandler(handler);
+ }
+
+ @Override
+ public void start(VectorSchemaRoot root, DictionaryProvider dictionaries, IpcOption option) {
+ this.option = option;
+ try {
+ DictionaryUtils.generateSchemaMessages(root.getSchema(), descriptor, dictionaries, option,
+ responseObserver::onNext);
+ } catch (RuntimeException e) {
+ // Propagate runtime exceptions, like those raised when trying to write unions with V4 metadata
+ throw e;
+ } catch (Exception e) {
+ // Only happens if closing buffers somehow fails - indicates application is an unknown state so propagate
+ // the exception
+ throw new RuntimeException("Could not generate and send all schema messages", e);
+ }
+ // We include the null count and align buffers to be compatible with Flight/C++
+ unloader = new VectorUnloader(root, /* includeNullCount */ true, /* alignBuffers */ true);
+ }
+
+ @Override
+ public void putNext() {
+ putNext(null);
+ }
+
+ /**
+ * Busy-wait until the stream is ready.
+ *
+ * <p>This is overridable as client/server have different behavior.
+ */
+ protected abstract void waitUntilStreamReady();
+
+ @Override
+ public void putNext(ArrowBuf metadata) {
+ if (unloader == null) {
+ throw CallStatus.INTERNAL.withDescription("Stream was not started, call start()").toRuntimeException();
+ }
+
+ waitUntilStreamReady();
+ // close is a no-op if the message has been written to gRPC, otherwise frees the associated buffers
+ // in some code paths (e.g. if the call is cancelled), gRPC does not write the message, so we need to clean up
+ // ourselves. Normally, writing the ArrowMessage will transfer ownership of the data to gRPC/Netty.
+ try (final ArrowMessage message = new ArrowMessage(unloader.getRecordBatch(), metadata, tryZeroCopy, option)) {
+ responseObserver.onNext(message);
+ } catch (Exception e) {
+ // This exception comes from ArrowMessage#close, not responseObserver#onNext.
+ // Generally this should not happen - ArrowMessage's implementation only closes non-throwing things.
+ // The user can't reasonably do anything about this, but if something does throw, we shouldn't let
+ // execution continue since other state (e.g. allocators) may be in an odd state.
+ throw new RuntimeException("Could not free ArrowMessage", e);
+ }
+ }
+
+ @Override
+ public void putMetadata(ArrowBuf metadata) {
+ waitUntilStreamReady();
+ try (final ArrowMessage message = new ArrowMessage(metadata)) {
+ responseObserver.onNext(message);
+ } catch (Exception e) {
+ throw StatusUtils.fromThrowable(e);
+ }
+ }
+
+ @Override
+ public void error(Throwable ex) {
+ responseObserver.onError(StatusUtils.toGrpcException(ex));
+ }
+
+ @Override
+ public void completed() {
+ responseObserver.onCompleted();
+ }
+
+ @Override
+ public void setUseZeroCopy(boolean enabled) {
+ tryZeroCopy = enabled;
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/PutResult.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/PutResult.java
new file mode 100644
index 000000000..862401312
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/PutResult.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import org.apache.arrow.flight.impl.Flight;
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.ReferenceManager;
+
+import com.google.protobuf.ByteString;
+
+/**
+ * A message from the server during a DoPut operation.
+ *
+ * <p>This object owns an {@link ArrowBuf} and should be closed when you are done with it.
+ */
+public class PutResult implements AutoCloseable {
+
+ private ArrowBuf applicationMetadata;
+
+ private PutResult(ArrowBuf metadata) {
+ applicationMetadata = metadata;
+ }
+
+ /**
+ * Create a PutResult with application-specific metadata.
+ *
+ * <p>This method assumes ownership of the {@link ArrowBuf}.
+ */
+ public static PutResult metadata(ArrowBuf metadata) {
+ if (metadata == null) {
+ return empty();
+ }
+ return new PutResult(metadata);
+ }
+
+ /** Create an empty PutResult. */
+ public static PutResult empty() {
+ return new PutResult(null);
+ }
+
+ /**
+ * Get the metadata in this message. May be null.
+ *
+ * <p>Ownership of the {@link ArrowBuf} is retained by this object. Call {@link ReferenceManager#retain()} to preserve
+ * a reference.
+ */
+ public ArrowBuf getApplicationMetadata() {
+ return applicationMetadata;
+ }
+
+ Flight.PutResult toProtocol() {
+ if (applicationMetadata == null) {
+ return Flight.PutResult.getDefaultInstance();
+ }
+ return Flight.PutResult.newBuilder().setAppMetadata(ByteString.copyFrom(applicationMetadata.nioBuffer())).build();
+ }
+
+ /**
+ * Construct a PutResult from a Protobuf message.
+ *
+ * @param allocator The allocator to use for allocating application metadata memory. The result object owns the
+ * allocated buffer, if any.
+ * @param message The gRPC/Protobuf message.
+ */
+ static PutResult fromProtocol(BufferAllocator allocator, Flight.PutResult message) {
+ final ArrowBuf buf = allocator.buffer(message.getAppMetadata().size());
+ message.getAppMetadata().asReadOnlyByteBufferList().forEach(bb -> {
+ buf.setBytes(buf.writerIndex(), bb);
+ buf.writerIndex(buf.writerIndex() + bb.limit());
+ });
+ return new PutResult(buf);
+ }
+
+ @Override
+ public void close() {
+ if (applicationMetadata != null) {
+ applicationMetadata.close();
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/RequestContext.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/RequestContext.java
new file mode 100644
index 000000000..5117d05c2
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/RequestContext.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.Set;
+
+/**
+ * Tracks variables about the current request.
+ */
+public interface RequestContext {
+ /**
+ * Register a variable and a value.
+ * @param key the variable name.
+ * @param value the value.
+ */
+ void put(String key, String value);
+
+ /**
+ * Retrieve a registered variable.
+ * @param key the variable name.
+ * @return the value, or null if not found.
+ */
+ String get(String key);
+
+ /**
+ * Retrieves the keys that have been registered to this context.
+ * @return the keys used in this context.
+ */
+ Set<String> keySet();
+
+ /**
+ * Deletes a registered variable.
+ * @return the value associated with the deleted variable, or null if the key doesn't exist.
+ */
+ String remove(String key);
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Result.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Result.java
new file mode 100644
index 000000000..5d6ce485d
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Result.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import org.apache.arrow.flight.impl.Flight;
+
+import com.google.protobuf.ByteString;
+
+/**
+ * Opaque result returned after executing an action.
+ *
+ * <p>POJO wrapper around the Flight protocol buffer message sharing the same name.
+ */
+public class Result {
+
+ private final byte[] body;
+
+ public Result(byte[] body) {
+ this.body = body;
+ }
+
+ Result(Flight.Result result) {
+ this.body = result.getBody().toByteArray();
+ }
+
+ public byte[] getBody() {
+ return body;
+ }
+
+ Flight.Result toProtocol() {
+ return Flight.Result.newBuilder()
+ .setBody(ByteString.copyFrom(body))
+ .build();
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java
new file mode 100644
index 000000000..8a5e7d9a4
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.Channels;
+
+import org.apache.arrow.flight.impl.Flight;
+import org.apache.arrow.vector.ipc.ReadChannel;
+import org.apache.arrow.vector.ipc.WriteChannel;
+import org.apache.arrow.vector.ipc.message.IpcOption;
+import org.apache.arrow.vector.ipc.message.MessageSerializer;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.arrow.vector.validate.MetadataV4UnionChecker;
+
+import com.fasterxml.jackson.databind.util.ByteBufferBackedInputStream;
+import com.google.common.collect.ImmutableList;
+import com.google.protobuf.ByteString;
+
+/**
+ * Opaque result returned after executing a getSchema request.
+ *
+ * <p>POJO wrapper around the Flight protocol buffer message sharing the same name.
+ */
+public class SchemaResult {
+
+ private final Schema schema;
+ private final IpcOption option;
+
+ public SchemaResult(Schema schema) {
+ this(schema, IpcOption.DEFAULT);
+ }
+
+ /**
+ * Create a schema result with specific IPC options for serialization.
+ */
+ public SchemaResult(Schema schema, IpcOption option) {
+ MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option.metadataVersion);
+ this.schema = schema;
+ this.option = option;
+ }
+
+ public Schema getSchema() {
+ return schema;
+ }
+
+ /**
+ * Converts to the protocol buffer representation.
+ */
+ Flight.SchemaResult toProtocol() {
+ // Encode schema in a Message payload
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ try {
+ MessageSerializer.serialize(new WriteChannel(Channels.newChannel(baos)), schema, option);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ return Flight.SchemaResult.newBuilder()
+ .setSchema(ByteString.copyFrom(baos.toByteArray()))
+ .build();
+
+ }
+
+ /**
+ * Converts from the protocol buffer representation.
+ */
+ static SchemaResult fromProtocol(Flight.SchemaResult pbSchemaResult) {
+ try {
+ final ByteBuffer schemaBuf = pbSchemaResult.getSchema().asReadOnlyByteBuffer();
+ Schema schema = pbSchemaResult.getSchema().size() > 0 ?
+ MessageSerializer.deserializeSchema(
+ new ReadChannel(Channels.newChannel(new ByteBufferBackedInputStream(schemaBuf))))
+ : new Schema(ImmutableList.of());
+ return new SchemaResult(schema);
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ServerHeaderMiddleware.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ServerHeaderMiddleware.java
new file mode 100644
index 000000000..527c3128c
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ServerHeaderMiddleware.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+/**
+ * Middleware that's used to extract and pass headers to the server during requests.
+ */
+public class ServerHeaderMiddleware implements FlightServerMiddleware {
+ /**
+ * Factory for accessing ServerHeaderMiddleware.
+ */
+ public static class Factory implements FlightServerMiddleware.Factory<ServerHeaderMiddleware> {
+ /**
+ * Construct a factory for receiving call headers.
+ */
+ public Factory() {
+ }
+
+ @Override
+ public ServerHeaderMiddleware onCallStarted(CallInfo callInfo, CallHeaders incomingHeaders,
+ RequestContext context) {
+ return new ServerHeaderMiddleware(incomingHeaders);
+ }
+ }
+
+ private final CallHeaders headers;
+
+ private ServerHeaderMiddleware(CallHeaders incomingHeaders) {
+ this.headers = incomingHeaders;
+ }
+
+ /**
+ * Retrieve the headers for this call.
+ */
+ public CallHeaders headers() {
+ return headers;
+ }
+
+ @Override
+ public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
+ }
+
+ @Override
+ public void onCallCompleted(CallStatus status) {
+ }
+
+ @Override
+ public void onCallErrored(Throwable err) {
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/StreamPipe.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/StreamPipe.java
new file mode 100644
index 000000000..d506914d5
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/StreamPipe.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.function.Consumer;
+import java.util.function.Function;
+
+import org.apache.arrow.flight.FlightProducer.StreamListener;
+import org.apache.arrow.flight.grpc.StatusUtils;
+import org.apache.arrow.util.AutoCloseables;
+
+import io.grpc.stub.StreamObserver;
+
+/**
+ * Shim listener to avoid exposing GRPC internals.
+
+ * @param <FROM> From Type
+ * @param <TO> To Type
+ */
+class StreamPipe<FROM, TO> implements StreamListener<FROM> {
+
+ private final StreamObserver<TO> delegate;
+ private final Function<FROM, TO> mapFunction;
+ private final Consumer<Throwable> errorHandler;
+ private AutoCloseable resource;
+ private boolean closed = false;
+
+ /**
+ * Wrap the given gRPC StreamObserver with a transformation function.
+ *
+ * @param delegate The {@link StreamObserver} to wrap.
+ * @param func The transformation function.
+ * @param errorHandler A handler for uncaught exceptions (e.g. if something tries to double-close this stream).
+ * @param <FROM> The source type.
+ * @param <TO> The output type.
+ * @return A wrapped listener.
+ */
+ public static <FROM, TO> StreamPipe<FROM, TO> wrap(StreamObserver<TO> delegate, Function<FROM, TO> func,
+ Consumer<Throwable> errorHandler) {
+ return new StreamPipe<>(delegate, func, errorHandler);
+ }
+
+ public StreamPipe(StreamObserver<TO> delegate, Function<FROM, TO> func, Consumer<Throwable> errorHandler) {
+ super();
+ this.delegate = delegate;
+ this.mapFunction = func;
+ this.errorHandler = errorHandler;
+ this.resource = null;
+ }
+
+ /** Set an AutoCloseable resource to be cleaned up when the gRPC observer is to be completed. */
+ void setAutoCloseable(AutoCloseable ac) {
+ resource = ac;
+ }
+
+ @Override
+ public void onNext(FROM val) {
+ delegate.onNext(mapFunction.apply(val));
+ }
+
+ @Override
+ public void onError(Throwable t) {
+ if (closed) {
+ errorHandler.accept(t);
+ return;
+ }
+ try {
+ AutoCloseables.close(resource);
+ } catch (Exception e) {
+ errorHandler.accept(e);
+ } finally {
+ // Set closed to true in case onError throws, so that we don't try to close again
+ closed = true;
+ delegate.onError(StatusUtils.toGrpcException(t));
+ }
+ }
+
+ @Override
+ public void onCompleted() {
+ if (closed) {
+ errorHandler.accept(new IllegalStateException("Tried to complete already-completed call"));
+ return;
+ }
+ try {
+ AutoCloseables.close(resource);
+ } catch (Exception e) {
+ errorHandler.accept(e);
+ } finally {
+ // Set closed to true in case onCompleted throws, so that we don't try to close again
+ closed = true;
+ delegate.onCompleted();
+ }
+ }
+
+ /**
+ * Ensure this stream has been completed.
+ */
+ void ensureCompleted() {
+ if (!closed) {
+ onCompleted();
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SyncPutListener.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SyncPutListener.java
new file mode 100644
index 000000000..730cf4924
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SyncPutListener.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.arrow.flight.grpc.StatusUtils;
+import org.apache.arrow.memory.ArrowBuf;
+
+/**
+ * A listener for server-sent application metadata messages during a Flight DoPut. This class wraps the messages in a
+ * synchronous interface.
+ */
+public final class SyncPutListener implements FlightClient.PutListener, AutoCloseable {
+
+ private final LinkedBlockingQueue<Object> queue;
+ private final CompletableFuture<Void> completed;
+ private static final Object DONE = new Object();
+ private static final Object DONE_WITH_EXCEPTION = new Object();
+
+ public SyncPutListener() {
+ queue = new LinkedBlockingQueue<>();
+ completed = new CompletableFuture<>();
+ }
+
+ private PutResult unwrap(Object queueItem) throws InterruptedException, ExecutionException {
+ if (queueItem == DONE) {
+ queue.put(queueItem);
+ return null;
+ } else if (queueItem == DONE_WITH_EXCEPTION) {
+ queue.put(queueItem);
+ completed.get();
+ }
+ return (PutResult) queueItem;
+ }
+
+ /**
+ * Get the next message from the server, blocking until it is available.
+ *
+ * @return The next message, or null if the server is done sending messages. The caller assumes ownership of the
+ * metadata and must remember to close it.
+ * @throws InterruptedException if interrupted while waiting.
+ * @throws ExecutionException if the server sent an error, or if there was an internal error.
+ */
+ public PutResult read() throws InterruptedException, ExecutionException {
+ return unwrap(queue.take());
+ }
+
+ /**
+ * Get the next message from the server, blocking for the specified amount of time until it is available.
+ *
+ * @return The next message, or null if the server is done sending messages or no message arrived before the timeout.
+ * The caller assumes ownership of the metadata and must remember to close it.
+ * @throws InterruptedException if interrupted while waiting.
+ * @throws ExecutionException if the server sent an error, or if there was an internal error.
+ */
+ public PutResult poll(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException {
+ return unwrap(queue.poll(timeout, unit));
+ }
+
+ @Override
+ public void getResult() {
+ try {
+ completed.get();
+ } catch (ExecutionException e) {
+ throw StatusUtils.fromThrowable(e.getCause());
+ } catch (InterruptedException e) {
+ throw StatusUtils.fromThrowable(e);
+ }
+ }
+
+ @Override
+ public void onNext(PutResult val) {
+ final ArrowBuf metadata = val.getApplicationMetadata();
+ metadata.getReferenceManager().retain();
+ queue.add(PutResult.metadata(metadata));
+ }
+
+ @Override
+ public void onError(Throwable t) {
+ completed.completeExceptionally(StatusUtils.fromThrowable(t));
+ queue.add(DONE_WITH_EXCEPTION);
+ }
+
+ @Override
+ public void onCompleted() {
+ completed.complete(null);
+ queue.add(DONE);
+ }
+
+ @Override
+ public void close() {
+ queue.forEach(o -> {
+ if (o instanceof PutResult) {
+ ((PutResult) o).close();
+ }
+ });
+ }
+
+ @Override
+ public boolean isCancelled() {
+ return completed.isDone();
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Ticket.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Ticket.java
new file mode 100644
index 000000000..a93cd0879
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Ticket.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+
+import org.apache.arrow.flight.impl.Flight;
+
+import com.google.protobuf.ByteString;
+
+/**
+ * Endpoint for a particular stream.
+ */
+public class Ticket {
+ private final byte[] bytes;
+
+ public Ticket(byte[] bytes) {
+ super();
+ this.bytes = bytes;
+ }
+
+ public byte[] getBytes() {
+ return bytes;
+ }
+
+ Ticket(org.apache.arrow.flight.impl.Flight.Ticket ticket) {
+ this.bytes = ticket.getTicket().toByteArray();
+ }
+
+ Flight.Ticket toProtocol() {
+ return Flight.Ticket.newBuilder()
+ .setTicket(ByteString.copyFrom(bytes))
+ .build();
+ }
+
+ /**
+ * Get the serialized form of this protocol message.
+ *
+ * <p>Intended to help interoperability by allowing non-Flight services to still return Flight types.
+ */
+ public ByteBuffer serialize() {
+ return ByteBuffer.wrap(toProtocol().toByteArray());
+ }
+
+ /**
+ * Parse the serialized form of this protocol message.
+ *
+ * <p>Intended to help interoperability by allowing Flight clients to obtain stream info from non-Flight services.
+ *
+ * @param serialized The serialized form of the Ticket, as returned by {@link #serialize()}.
+ * @return The deserialized Ticket.
+ * @throws IOException if the serialized form is invalid.
+ */
+ public static Ticket deserialize(ByteBuffer serialized) throws IOException {
+ return new Ticket(Flight.Ticket.parseFrom(serialized));
+ }
+
+ @Override
+ public int hashCode() {
+ final int prime = 31;
+ int result = 1;
+ result = prime * result + Arrays.hashCode(bytes);
+ return result;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (obj == null) {
+ return false;
+ }
+ if (getClass() != obj.getClass()) {
+ return false;
+ }
+ Ticket other = (Ticket) obj;
+ if (!Arrays.equals(bytes, other.bytes)) {
+ return false;
+ }
+ return true;
+ }
+
+
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/AuthConstants.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/AuthConstants.java
new file mode 100644
index 000000000..ac55872e5
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/AuthConstants.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth;
+
+import org.apache.arrow.flight.FlightConstants;
+
+import io.grpc.Context;
+import io.grpc.Metadata.BinaryMarshaller;
+import io.grpc.Metadata.Key;
+import io.grpc.MethodDescriptor;
+
+/**
+ * Constants used in authorization of flight connections.
+ */
+public final class AuthConstants {
+
+ public static final String HANDSHAKE_DESCRIPTOR_NAME = MethodDescriptor
+ .generateFullMethodName(FlightConstants.SERVICE, "Handshake");
+ public static final String TOKEN_NAME = "Auth-Token-bin";
+ public static final Key<byte[]> TOKEN_KEY = Key.of(TOKEN_NAME, new BinaryMarshaller<byte[]>() {
+
+ @Override
+ public byte[] toBytes(byte[] value) {
+ return value;
+ }
+
+ @Override
+ public byte[] parseBytes(byte[] serialized) {
+ return serialized;
+ }
+ });
+
+ public static final Context.Key<String> PEER_IDENTITY_KEY = Context.keyWithDefault("arrow-flight-peer-identity", "");
+
+ private AuthConstants() {}
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/BasicClientAuthHandler.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/BasicClientAuthHandler.java
new file mode 100644
index 000000000..c6dca97fb
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/BasicClientAuthHandler.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth;
+
+import java.util.Iterator;
+
+import org.apache.arrow.flight.impl.Flight.BasicAuth;
+
+/**
+ * A client auth handler that supports username and password.
+ */
+public class BasicClientAuthHandler implements ClientAuthHandler {
+
+ private final String name;
+ private final String password;
+ private byte[] token = null;
+
+ public BasicClientAuthHandler(String name, String password) {
+ this.name = name;
+ this.password = password;
+ }
+
+ @Override
+ public void authenticate(ClientAuthSender outgoing, Iterator<byte[]> incoming) {
+ BasicAuth.Builder builder = BasicAuth.newBuilder();
+ if (name != null) {
+ builder.setUsername(name);
+ }
+
+ if (password != null) {
+ builder.setPassword(password);
+ }
+
+ outgoing.send(builder.build().toByteArray());
+ this.token = incoming.next();
+ }
+
+ @Override
+ public byte[] getCallToken() {
+ return token;
+ }
+
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/BasicServerAuthHandler.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/BasicServerAuthHandler.java
new file mode 100644
index 000000000..34e3efc0d
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/BasicServerAuthHandler.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth;
+
+import java.util.Iterator;
+import java.util.Optional;
+
+import org.apache.arrow.flight.impl.Flight.BasicAuth;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.protobuf.InvalidProtocolBufferException;
+
+/**
+ * A ServerAuthHandler for username/password authentication.
+ */
+public class BasicServerAuthHandler implements ServerAuthHandler {
+
+ private static final Logger logger = LoggerFactory.getLogger(BasicServerAuthHandler.class);
+ private final BasicAuthValidator authValidator;
+
+ public BasicServerAuthHandler(BasicAuthValidator authValidator) {
+ super();
+ this.authValidator = authValidator;
+ }
+
+ /**
+ * Interface that this handler delegates for determining if credentials are valid.
+ */
+ public interface BasicAuthValidator {
+
+ byte[] getToken(String username, String password) throws Exception;
+
+ Optional<String> isValid(byte[] token);
+
+ }
+
+ @Override
+ public boolean authenticate(ServerAuthSender outgoing, Iterator<byte[]> incoming) {
+ byte[] bytes = incoming.next();
+ try {
+ BasicAuth auth = BasicAuth.parseFrom(bytes);
+ byte[] token = authValidator.getToken(auth.getUsername(), auth.getPassword());
+ outgoing.send(token);
+ return true;
+ } catch (InvalidProtocolBufferException e) {
+ logger.debug("Failure parsing auth message.", e);
+ } catch (Exception e) {
+ logger.debug("Unknown error during authorization.", e);
+ }
+
+ return false;
+ }
+
+ @Override
+ public Optional<String> isValid(byte[] token) {
+ return authValidator.isValid(token);
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthHandler.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthHandler.java
new file mode 100644
index 000000000..985e10aa4
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthHandler.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth;
+
+import java.util.Iterator;
+
+/**
+ * Implement authentication for Flight on the client side.
+ */
+public interface ClientAuthHandler {
+ /**
+ * Handle the initial handshake with the server.
+ * @param outgoing A channel to send data to the server.
+ * @param incoming An iterator of incoming data from the server.
+ */
+ void authenticate(ClientAuthSender outgoing, Iterator<byte[]> incoming);
+
+ /**
+ * Get the per-call authentication token.
+ */
+ byte[] getCallToken();
+
+ /**
+ * A communication channel to the server during initial connection.
+ */
+ interface ClientAuthSender {
+
+ /**
+ * Send the server a message.
+ */
+ void send(byte[] payload);
+
+ /**
+ * Signal an error to the server and abort the authentication attempt.
+ */
+ void onError(Throwable cause);
+
+ }
+
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthInterceptor.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthInterceptor.java
new file mode 100644
index 000000000..3d28b7ba7
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthInterceptor.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth;
+
+import io.grpc.CallOptions;
+import io.grpc.Channel;
+import io.grpc.ClientCall;
+import io.grpc.ClientInterceptor;
+import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
+import io.grpc.Metadata;
+import io.grpc.MethodDescriptor;
+
+/**
+ * GRPC client intercepter that handles authentication with the server.
+ */
+public class ClientAuthInterceptor implements ClientInterceptor {
+ private volatile ClientAuthHandler authHandler = null;
+
+ public void setAuthHandler(ClientAuthHandler authHandler) {
+ this.authHandler = authHandler;
+ }
+
+ public ClientAuthInterceptor() {
+ }
+
+ public boolean hasAuthHandler() {
+ return authHandler != null;
+ }
+
+ @Override
+ public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> methodDescriptor,
+ CallOptions callOptions, Channel next) {
+ ClientCall<ReqT, RespT> call = next.newCall(methodDescriptor, callOptions);
+
+ // once we have an auth header, add that to the calls.
+ if (authHandler != null) {
+ call = new HeaderAttachingClientCall<>(call);
+ }
+
+ return call;
+ }
+
+ private final class HeaderAttachingClientCall<ReqT, RespT> extends SimpleForwardingClientCall<ReqT, RespT> {
+
+ private HeaderAttachingClientCall(ClientCall<ReqT, RespT> call) {
+ super(call);
+ }
+
+ @Override
+ public void start(Listener<RespT> responseListener, Metadata headers) {
+ final Metadata authHeaders = new Metadata();
+ authHeaders.put(AuthConstants.TOKEN_KEY, authHandler.getCallToken());
+ headers.merge(authHeaders);
+ super.start(responseListener, headers);
+ }
+ }
+
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java
new file mode 100644
index 000000000..e86dc163c
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth;
+
+import java.util.Iterator;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.LinkedBlockingQueue;
+
+import org.apache.arrow.flight.auth.ClientAuthHandler.ClientAuthSender;
+import org.apache.arrow.flight.grpc.StatusUtils;
+import org.apache.arrow.flight.impl.Flight.HandshakeRequest;
+import org.apache.arrow.flight.impl.Flight.HandshakeResponse;
+import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceStub;
+
+import com.google.protobuf.ByteString;
+
+import io.grpc.StatusRuntimeException;
+import io.grpc.stub.StreamObserver;
+
+/**
+ * Utility class for performing authorization over using a GRPC stub.
+ */
+public class ClientAuthWrapper {
+
+ /**
+ * Do client auth for a client. The stub will be authenticated after this method returns.
+ *
+ * @param authHandler The handler to use.
+ * @param stub The service stub.
+ */
+ public static void doClientAuth(ClientAuthHandler authHandler, FlightServiceStub stub) {
+ AuthObserver observer = new AuthObserver();
+ try {
+ observer.responseObserver = stub.handshake(observer);
+ authHandler.authenticate(observer.sender, observer.iter);
+ if (!observer.sender.errored) {
+ observer.responseObserver.onCompleted();
+ }
+ } catch (StatusRuntimeException sre) {
+ throw StatusUtils.fromGrpcRuntimeException(sre);
+ }
+ try {
+ if (!observer.completed.get()) {
+ // TODO: ARROW-5681
+ throw new RuntimeException("Unauthenticated");
+ }
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ } catch (ExecutionException e) {
+ throw StatusUtils.fromThrowable(e.getCause());
+ }
+ }
+
+ private static class AuthObserver implements StreamObserver<HandshakeResponse> {
+
+ private volatile StreamObserver<HandshakeRequest> responseObserver;
+ private final LinkedBlockingQueue<byte[]> messages = new LinkedBlockingQueue<>();
+ private final AuthSender sender = new AuthSender();
+ private CompletableFuture<Boolean> completed;
+
+ public AuthObserver() {
+ super();
+ completed = new CompletableFuture<>();
+ }
+
+ @Override
+ public void onNext(HandshakeResponse value) {
+ ByteString payload = value.getPayload();
+ if (payload != null) {
+ messages.add(payload.toByteArray());
+ }
+ }
+
+ private Iterator<byte[]> iter = new Iterator<byte[]>() {
+
+ @Override
+ public byte[] next() {
+ while (!completed.isDone() || !messages.isEmpty()) {
+ byte[] bytes = messages.poll();
+ if (bytes == null) {
+ // busy wait.
+ continue;
+ } else {
+ return bytes;
+ }
+ }
+
+ if (completed.isCompletedExceptionally()) {
+ // Preserve prior exception behavior
+ // TODO: with ARROW-5681, throw an appropriate Flight exception if gRPC raised an exception
+ try {
+ completed.get();
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ } catch (ExecutionException e) {
+ if (e.getCause() instanceof StatusRuntimeException) {
+ throw (StatusRuntimeException) e.getCause();
+ }
+ throw new RuntimeException(e);
+ }
+ }
+
+ throw new IllegalStateException("You attempted to retrieve messages after there were none.");
+ }
+
+ @Override
+ public boolean hasNext() {
+ return !messages.isEmpty();
+ }
+ };
+
+ @Override
+ public void onError(Throwable t) {
+ completed.completeExceptionally(t);
+ }
+
+ private class AuthSender implements ClientAuthSender {
+
+ private boolean errored = false;
+
+ @Override
+ public void send(byte[] payload) {
+ try {
+ responseObserver.onNext(HandshakeRequest.newBuilder()
+ .setPayload(ByteString.copyFrom(payload))
+ .build());
+ } catch (StatusRuntimeException sre) {
+ throw StatusUtils.fromGrpcRuntimeException(sre);
+ }
+ }
+
+ @Override
+ public void onError(Throwable cause) {
+ this.errored = true;
+ responseObserver.onError(StatusUtils.toGrpcException(cause));
+ }
+
+ }
+
+ @Override
+ public void onCompleted() {
+ completed.complete(true);
+ }
+ }
+
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java
new file mode 100644
index 000000000..3a978b131
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth;
+
+import java.util.Iterator;
+import java.util.Optional;
+
+/**
+ * Interface for Server side authentication handlers.
+ */
+public interface ServerAuthHandler {
+
+ /**
+ * Validate the client token provided on each call.
+ *
+ * @return An empty optional if the client is not authenticated; the peer identity otherwise (may be the empty
+ * string).
+ */
+ Optional<String> isValid(byte[] token);
+
+ /**
+ * Handle the initial handshake with the client.
+ *
+ * @param outgoing A writer to send messages to the client.
+ * @param incoming An iterator of messages from the client.
+ * @return true if client is authenticated, false otherwise.
+ */
+ boolean authenticate(ServerAuthSender outgoing, Iterator<byte[]> incoming);
+
+ /**
+ * Interface for a server implementations to send back authentication messages
+ * back to the client.
+ */
+ interface ServerAuthSender {
+
+ void send(byte[] payload);
+
+ void onError(Throwable cause);
+
+ }
+
+ /**
+ * An auth handler that does nothing.
+ */
+ ServerAuthHandler NO_OP = new ServerAuthHandler() {
+
+ @Override
+ public Optional<String> isValid(byte[] token) {
+ return Optional.of("");
+ }
+
+ @Override
+ public boolean authenticate(ServerAuthSender outgoing, Iterator<byte[]> incoming) {
+ return true;
+ }
+ };
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java
new file mode 100644
index 000000000..5bff3784e
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth;
+
+import java.util.Optional;
+
+import org.apache.arrow.flight.FlightRuntimeException;
+import org.apache.arrow.flight.grpc.StatusUtils;
+
+import io.grpc.Context;
+import io.grpc.Contexts;
+import io.grpc.Metadata;
+import io.grpc.ServerCall;
+import io.grpc.ServerCall.Listener;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
+import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
+
+/**
+ * GRPC Interceptor for performing authentication.
+ */
+public class ServerAuthInterceptor implements ServerInterceptor {
+
+ private final ServerAuthHandler authHandler;
+
+ public ServerAuthInterceptor(ServerAuthHandler authHandler) {
+ this.authHandler = authHandler;
+ }
+
+ @Override
+ public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers,
+ ServerCallHandler<ReqT, RespT> next) {
+ if (!call.getMethodDescriptor().getFullMethodName().equals(AuthConstants.HANDSHAKE_DESCRIPTOR_NAME)) {
+ final Optional<String> peerIdentity;
+
+ // Allow customizing the response code by throwing FlightRuntimeException
+ try {
+ peerIdentity = isValid(headers);
+ } catch (FlightRuntimeException e) {
+ final Status grpcStatus = StatusUtils.toGrpcStatus(e.status());
+ call.close(grpcStatus, new Metadata());
+ return new NoopServerCallListener<>();
+ } catch (StatusRuntimeException e) {
+ Metadata trailers = e.getTrailers();
+ call.close(e.getStatus(), trailers == null ? new Metadata() : trailers);
+ return new NoopServerCallListener<>();
+ }
+
+ if (!peerIdentity.isPresent()) {
+ // Send back a description along with the status code
+ call.close(Status.UNAUTHENTICATED
+ .withDescription("Unauthenticated (invalid or missing auth token)"), new Metadata());
+ return new NoopServerCallListener<>();
+ }
+ return Contexts.interceptCall(Context.current().withValue(AuthConstants.PEER_IDENTITY_KEY, peerIdentity.get()),
+ call, headers, next);
+ }
+
+ return next.startCall(call, headers);
+ }
+
+ private Optional<String> isValid(Metadata headers) {
+ byte[] token = headers.get(AuthConstants.TOKEN_KEY);
+ return authHandler.isValid(token);
+ }
+
+ private static class NoopServerCallListener<T> extends ServerCall.Listener<T> {
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java
new file mode 100644
index 000000000..ad1a36a93
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth;
+
+import java.util.Iterator;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+import java.util.concurrent.LinkedBlockingQueue;
+
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.auth.ServerAuthHandler.ServerAuthSender;
+import org.apache.arrow.flight.grpc.StatusUtils;
+import org.apache.arrow.flight.impl.Flight.HandshakeRequest;
+import org.apache.arrow.flight.impl.Flight.HandshakeResponse;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.protobuf.ByteString;
+
+import io.grpc.stub.StreamObserver;
+
+/**
+ * Contains utility methods for integrating authorization into a GRPC stream.
+ */
+public class ServerAuthWrapper {
+ private static final Logger LOGGER = LoggerFactory.getLogger(ServerAuthWrapper.class);
+
+ /**
+ * Wrap the auth handler for handshake purposes.
+ *
+ * @param authHandler Authentication handler
+ * @param responseObserver Observer for handshake response
+ * @param executors ExecutorService
+ * @return AuthObserver
+ */
+ public static StreamObserver<HandshakeRequest> wrapHandshake(ServerAuthHandler authHandler,
+ StreamObserver<HandshakeResponse> responseObserver, ExecutorService executors) {
+
+ // stream started.
+ AuthObserver observer = new AuthObserver(responseObserver);
+ final Runnable r = () -> {
+ try {
+ if (authHandler.authenticate(observer.sender, observer.iter)) {
+ responseObserver.onCompleted();
+ return;
+ }
+
+ responseObserver.onError(StatusUtils.toGrpcException(CallStatus.UNAUTHENTICATED.toRuntimeException()));
+ } catch (Exception ex) {
+ LOGGER.error("Error during authentication", ex);
+ responseObserver.onError(StatusUtils.toGrpcException(ex));
+ }
+ };
+ observer.future = executors.submit(r);
+ return observer;
+ }
+
+ private static class AuthObserver implements StreamObserver<HandshakeRequest> {
+
+ private final StreamObserver<HandshakeResponse> responseObserver;
+ private volatile Future<?> future;
+ private volatile boolean completed = false;
+ private final LinkedBlockingQueue<byte[]> messages = new LinkedBlockingQueue<>();
+ private final AuthSender sender = new AuthSender();
+
+ public AuthObserver(StreamObserver<HandshakeResponse> responseObserver) {
+ super();
+ this.responseObserver = responseObserver;
+ }
+
+ @Override
+ public void onNext(HandshakeRequest value) {
+ ByteString payload = value.getPayload();
+ if (payload != null) {
+ messages.add(payload.toByteArray());
+ }
+ }
+
+ private Iterator<byte[]> iter = new Iterator<byte[]>() {
+
+ @Override
+ public byte[] next() {
+ while (!completed || !messages.isEmpty()) {
+ byte[] bytes = messages.poll();
+ if (bytes == null) {
+ //busy wait.
+ continue;
+ }
+ return bytes;
+ }
+ throw new IllegalStateException("Requesting more messages than client sent.");
+ }
+
+ @Override
+ public boolean hasNext() {
+ return !messages.isEmpty();
+ }
+ };
+
+ @Override
+ public void onError(Throwable t) {
+ completed = true;
+ while (future == null) {/* busy wait */}
+ future.cancel(true);
+ }
+
+ @Override
+ public void onCompleted() {
+ completed = true;
+ }
+
+ private class AuthSender implements ServerAuthSender {
+
+ @Override
+ public void send(byte[] payload) {
+ responseObserver.onNext(HandshakeResponse.newBuilder()
+ .setPayload(ByteString.copyFrom(payload))
+ .build());
+ }
+
+ @Override
+ public void onError(Throwable cause) {
+ responseObserver.onError(StatusUtils.toGrpcException(cause));
+ }
+
+ }
+ }
+
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/Auth2Constants.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/Auth2Constants.java
new file mode 100644
index 000000000..624d7d5ff
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/Auth2Constants.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth2;
+
+/**
+ * Constants used in authorization of flight connections.
+ */
+public final class Auth2Constants {
+ public static final String PEER_IDENTITY_KEY = "arrow-flight-peer-identity";
+ public static final String BEARER_PREFIX = "Bearer ";
+ public static final String BASIC_PREFIX = "Basic ";
+ public static final String AUTHORIZATION_HEADER = "Authorization";
+
+ private Auth2Constants() {
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/AuthUtilities.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/AuthUtilities.java
new file mode 100644
index 000000000..c73b7cf1a
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/AuthUtilities.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth2;
+
+import org.apache.arrow.flight.CallHeaders;
+
+/**
+ * Utility class for completing the auth process.
+ */
+public final class AuthUtilities {
+ private AuthUtilities() {
+
+ }
+
+ /**
+ * Helper method for retrieving a value from the Authorization header.
+ *
+ * @param headers The headers to inspect.
+ * @param valuePrefix The prefix within the value portion of the header to extract away.
+ * @return The header value.
+ */
+ public static String getValueFromAuthHeader(CallHeaders headers, String valuePrefix) {
+ final String authHeaderValue = headers.get(Auth2Constants.AUTHORIZATION_HEADER);
+ if (authHeaderValue != null) {
+ if (authHeaderValue.regionMatches(true, 0, valuePrefix, 0, valuePrefix.length())) {
+ return authHeaderValue.substring(valuePrefix.length());
+ }
+ }
+ return null;
+ }
+
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BasicAuthCredentialWriter.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BasicAuthCredentialWriter.java
new file mode 100644
index 000000000..698287e88
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BasicAuthCredentialWriter.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth2;
+
+import java.nio.charset.StandardCharsets;
+import java.util.Base64;
+import java.util.function.Consumer;
+
+import org.apache.arrow.flight.CallHeaders;
+
+/**
+ * Client credentials that use a username and password.
+ */
+public final class BasicAuthCredentialWriter implements Consumer<CallHeaders> {
+
+ private final String name;
+ private final String password;
+
+ public BasicAuthCredentialWriter(String name, String password) {
+ this.name = name;
+ this.password = password;
+ }
+
+ @Override
+ public void accept(CallHeaders outputHeaders) {
+ outputHeaders.insert(Auth2Constants.AUTHORIZATION_HEADER, Auth2Constants.BASIC_PREFIX +
+ Base64.getEncoder().encodeToString(String.format("%s:%s", name, password).getBytes(StandardCharsets.UTF_8)));
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BasicCallHeaderAuthenticator.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BasicCallHeaderAuthenticator.java
new file mode 100644
index 000000000..fff7b4690
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BasicCallHeaderAuthenticator.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth2;
+
+import java.io.UnsupportedEncodingException;
+import java.nio.charset.StandardCharsets;
+import java.util.Base64;
+
+import org.apache.arrow.flight.CallHeaders;
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.FlightRuntimeException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A ServerAuthHandler for username/password authentication.
+ */
+public class BasicCallHeaderAuthenticator implements CallHeaderAuthenticator {
+
+ private static final Logger logger = LoggerFactory.getLogger(BasicCallHeaderAuthenticator.class);
+
+ private final CredentialValidator authValidator;
+
+ public BasicCallHeaderAuthenticator(CredentialValidator authValidator) {
+ this.authValidator = authValidator;
+ }
+
+ @Override
+ public AuthResult authenticate(CallHeaders incomingHeaders) {
+ try {
+ final String authEncoded = AuthUtilities.getValueFromAuthHeader(
+ incomingHeaders, Auth2Constants.BASIC_PREFIX);
+ if (authEncoded == null) {
+ throw CallStatus.UNAUTHENTICATED.toRuntimeException();
+ }
+ // The value has the format Base64(<username>:<password>)
+ final String authDecoded = new String(Base64.getDecoder().decode(authEncoded), StandardCharsets.UTF_8);
+ final int colonPos = authDecoded.indexOf(':');
+ if (colonPos == -1) {
+ throw CallStatus.UNAUTHENTICATED.toRuntimeException();
+ }
+
+ final String user = authDecoded.substring(0, colonPos);
+ final String password = authDecoded.substring(colonPos + 1);
+ return authValidator.validate(user, password);
+ } catch (UnsupportedEncodingException ex) {
+ // Note: Intentionally discarding the exception cause when reporting back to the client for security purposes.
+ logger.error("Authentication failed due to missing encoding.", ex);
+ throw CallStatus.INTERNAL.toRuntimeException();
+ } catch (FlightRuntimeException ex) {
+ throw ex;
+ } catch (Exception ex) {
+ // Note: Intentionally discarding the exception cause when reporting back to the client for security purposes.
+ logger.error("Authentication failed.", ex);
+ throw CallStatus.UNAUTHENTICATED.toRuntimeException();
+ }
+ }
+
+ /**
+ * Interface that this handler delegates to for validating the incoming headers.
+ */
+ public interface CredentialValidator {
+ /**
+ * Validate the supplied credentials (username/password) and return the peer identity.
+ *
+ * @param username The username to validate.
+ * @param password The password to validate.
+ * @return The peer identity if the supplied credentials are valid.
+ * @throws Exception If the supplied credentials are not valid.
+ */
+ AuthResult validate(String username, String password) throws Exception;
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BearerCredentialWriter.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BearerCredentialWriter.java
new file mode 100644
index 000000000..715ee502b
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BearerCredentialWriter.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth2;
+
+import java.util.function.Consumer;
+
+import org.apache.arrow.flight.CallHeaders;
+
+/**
+ * Client credentials that use a bearer token.
+ */
+public final class BearerCredentialWriter implements Consumer<CallHeaders> {
+
+ private final String bearer;
+
+ public BearerCredentialWriter(String bearer) {
+ this.bearer = bearer;
+ }
+
+ @Override
+ public void accept(CallHeaders outputHeaders) {
+ outputHeaders.insert(Auth2Constants.AUTHORIZATION_HEADER, Auth2Constants.BEARER_PREFIX + bearer);
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BearerTokenAuthenticator.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BearerTokenAuthenticator.java
new file mode 100644
index 000000000..2006e0a2b
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BearerTokenAuthenticator.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth2;
+
+import org.apache.arrow.flight.CallHeaders;
+
+/**
+ * Partial implementation of {@link CallHeaderAuthenticator} for bearer-token based authentication.
+ */
+public abstract class BearerTokenAuthenticator implements CallHeaderAuthenticator {
+
+ final CallHeaderAuthenticator initialAuthenticator;
+
+ public BearerTokenAuthenticator(CallHeaderAuthenticator initialAuthenticator) {
+ this.initialAuthenticator = initialAuthenticator;
+ }
+
+ @Override
+ public AuthResult authenticate(CallHeaders incomingHeaders) {
+ // Check if headers contain a bearer token and if so, validate the token.
+ final String bearerToken =
+ AuthUtilities.getValueFromAuthHeader(incomingHeaders, Auth2Constants.BEARER_PREFIX);
+ if (bearerToken != null) {
+ return validateBearer(bearerToken);
+ }
+
+ // Delegate to the basic auth handler to do the validation.
+ final CallHeaderAuthenticator.AuthResult result = initialAuthenticator.authenticate(incomingHeaders);
+ return getAuthResultWithBearerToken(result);
+ }
+
+ /**
+ * Callback to run when the initial authenticator succeeds.
+ * @param authResult A successful initial authentication result.
+ * @return an alternate AuthResult based on the original AuthResult that will write a bearer token to output headers.
+ */
+ protected abstract AuthResult getAuthResultWithBearerToken(AuthResult authResult);
+
+ /**
+ * Validate the bearer token.
+ * @param bearerToken The bearer token to validate.
+ * @return A successful AuthResult if validation succeeded.
+ * @throws Exception If the token validation fails.
+ */
+ protected abstract AuthResult validateBearer(String bearerToken);
+
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/CallHeaderAuthenticator.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/CallHeaderAuthenticator.java
new file mode 100644
index 000000000..87e60f1fa
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/CallHeaderAuthenticator.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth2;
+
+import org.apache.arrow.flight.CallHeaders;
+import org.apache.arrow.flight.FlightRuntimeException;
+
+/**
+ * Interface for Server side authentication handlers.
+ *
+ * A CallHeaderAuthenticator is used by {@link ServerCallHeaderAuthMiddleware} to validate headers sent by a Flight
+ * client for authentication purposes. The headers validated do not necessarily have to be Authorization headers.
+ *
+ * The workflow is that the FlightServer will intercept headers on a request, validate the headers, and
+ * either send back an UNAUTHENTICATED error, or succeed and potentially send back additional headers to the client.
+ *
+ * Implementations of CallHeaderAuthenticator should take care not to provide leak confidential details (such as
+ * indicating if usernames are valid or not) for security reasons when reporting errors back to clients.
+ *
+ * Example CallHeaderAuthenticators provided include:
+ * The {@link BasicCallHeaderAuthenticator} will authenticate basic HTTP credentials.
+ *
+ * The {@link BearerTokenAuthenticator} will authenticate basic HTTP credentials initially, then also send back a
+ * bearer token that the client can use for subsequent requests. The {@link GeneratedBearerTokenAuthenticator} will
+ * provide internally generated bearer tokens and maintain a cache of them.
+ */
+public interface CallHeaderAuthenticator {
+
+ /**
+ * Encapsulates the result of the {@link CallHeaderAuthenticator} analysis of headers.
+ *
+ * This includes the identity of the incoming user and any outbound headers to send as a response to the client.
+ */
+ interface AuthResult {
+ /**
+ * The peer identity that was determined by the handshake process based on the
+ * authentication credentials supplied by the client.
+ *
+ * @return The peer identity.
+ */
+ String getPeerIdentity();
+
+ /**
+ * Appends a header to the outgoing call headers.
+ * @param outgoingHeaders The outgoing headers.
+ */
+ default void appendToOutgoingHeaders(CallHeaders outgoingHeaders) {
+
+ }
+ }
+
+ /**
+ * Validate the auth headers sent by the client.
+ *
+ * @param incomingHeaders The incoming headers to authenticate.
+ * @return an auth result containing a peer identity and optionally a bearer token.
+ * @throws FlightRuntimeException with CallStatus.UNAUTHENTICATED if credentials were not supplied
+ * or if credentials were supplied but were not valid.
+ */
+ AuthResult authenticate(CallHeaders incomingHeaders);
+
+ /**
+ * An auth handler that does nothing.
+ */
+ CallHeaderAuthenticator NO_OP = new CallHeaderAuthenticator() {
+ @Override
+ public AuthResult authenticate(CallHeaders incomingHeaders) {
+ return () -> "";
+ }
+ };
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientBearerHeaderHandler.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientBearerHeaderHandler.java
new file mode 100644
index 000000000..45bdb6d95
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientBearerHeaderHandler.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth2;
+
+import org.apache.arrow.flight.CallHeaders;
+import org.apache.arrow.flight.grpc.CredentialCallOption;
+
+/**
+ * A client header handler that parses the incoming headers for a bearer token.
+ */
+public class ClientBearerHeaderHandler implements ClientHeaderHandler {
+
+ @Override
+ public CredentialCallOption getCredentialCallOptionFromIncomingHeaders(CallHeaders incomingHeaders) {
+ final String bearerValue = AuthUtilities.getValueFromAuthHeader(incomingHeaders, Auth2Constants.BEARER_PREFIX);
+ if (bearerValue != null) {
+ return new CredentialCallOption(new BearerCredentialWriter(bearerValue));
+ }
+ return null;
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientHandshakeWrapper.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientHandshakeWrapper.java
new file mode 100644
index 000000000..16a514250
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientHandshakeWrapper.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth2;
+
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.FlightRuntimeException;
+import org.apache.arrow.flight.grpc.StatusUtils;
+import org.apache.arrow.flight.impl.Flight.HandshakeRequest;
+import org.apache.arrow.flight.impl.Flight.HandshakeResponse;
+import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceStub;
+
+import io.grpc.StatusRuntimeException;
+import io.grpc.stub.StreamObserver;
+
+/**
+ * Utility class for executing a handshake with a FlightServer.
+ */
+public class ClientHandshakeWrapper {
+ private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(ClientHandshakeWrapper.class);
+
+ /**
+ * Do handshake for a client. The stub will be authenticated after this method returns.
+ *
+ * @param stub The service stub.
+ */
+ public static void doClientHandshake(FlightServiceStub stub) {
+ final HandshakeObserver observer = new HandshakeObserver();
+ try {
+ observer.requestObserver = stub.handshake(observer);
+ observer.requestObserver.onNext(HandshakeRequest.newBuilder().build());
+ observer.requestObserver.onCompleted();
+ try {
+ if (!observer.completed.get()) {
+ // TODO: ARROW-5681
+ throw CallStatus.UNAUTHENTICATED.toRuntimeException();
+ }
+ } catch (InterruptedException ex) {
+ Thread.currentThread().interrupt();
+ throw ex;
+ } catch (ExecutionException ex) {
+ final FlightRuntimeException wrappedException = StatusUtils.fromThrowable(ex.getCause());
+ logger.error("Failed on completing future", wrappedException);
+ throw wrappedException;
+ }
+ } catch (StatusRuntimeException sre) {
+ logger.error("Failed with SREe", sre);
+ throw StatusUtils.fromGrpcRuntimeException(sre);
+ } catch (Throwable ex) {
+ logger.error("Failed with unknown", ex);
+ if (ex instanceof FlightRuntimeException) {
+ throw (FlightRuntimeException) ex;
+ }
+ throw StatusUtils.fromThrowable(ex);
+ }
+ }
+
+ private static class HandshakeObserver implements StreamObserver<HandshakeResponse> {
+
+ private volatile StreamObserver<HandshakeRequest> requestObserver;
+ private final CompletableFuture<Boolean> completed;
+
+ public HandshakeObserver() {
+ super();
+ completed = new CompletableFuture<>();
+ }
+
+ @Override
+ public void onNext(HandshakeResponse value) {
+ }
+
+ @Override
+ public void onError(Throwable t) {
+ completed.completeExceptionally(t);
+ }
+
+ @Override
+ public void onCompleted() {
+ completed.complete(true);
+ }
+ }
+
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientHeaderHandler.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientHeaderHandler.java
new file mode 100644
index 000000000..514189f9b
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientHeaderHandler.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth2;
+
+import org.apache.arrow.flight.CallHeaders;
+import org.apache.arrow.flight.grpc.CredentialCallOption;
+
+/**
+ * Interface for client side header parsing and conversion to CredentialCallOption.
+ */
+public interface ClientHeaderHandler {
+ /**
+ * Parses the incoming headers and converts them into a CredentialCallOption.
+ * @param incomingHeaders Incoming headers to parse.
+ * @return An instance of CredentialCallOption.
+ */
+ CredentialCallOption getCredentialCallOptionFromIncomingHeaders(CallHeaders incomingHeaders);
+
+ /**
+ * An client header handler that does nothing.
+ */
+ ClientHeaderHandler NO_OP = new ClientHeaderHandler() {
+ @Override
+ public CredentialCallOption getCredentialCallOptionFromIncomingHeaders(CallHeaders incomingHeaders) {
+ return null;
+ }
+ };
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientIncomingAuthHeaderMiddleware.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientIncomingAuthHeaderMiddleware.java
new file mode 100644
index 000000000..be5f3f54d
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientIncomingAuthHeaderMiddleware.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth2;
+
+import org.apache.arrow.flight.CallHeaders;
+import org.apache.arrow.flight.CallInfo;
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.FlightClientMiddleware;
+import org.apache.arrow.flight.grpc.CredentialCallOption;
+
+/**
+ * Middleware for capturing bearer tokens sent back from the Flight server.
+ */
+public class ClientIncomingAuthHeaderMiddleware implements FlightClientMiddleware {
+ private final Factory factory;
+
+ /**
+ * Factory used within FlightClient.
+ */
+ public static class Factory implements FlightClientMiddleware.Factory {
+ private final ClientHeaderHandler headerHandler;
+ private CredentialCallOption credentialCallOption;
+
+ /**
+ * Construct a factory with the given header handler.
+ * @param headerHandler The header handler that will be used for handling incoming headers from the flight server.
+ */
+ public Factory(ClientHeaderHandler headerHandler) {
+ this.headerHandler = headerHandler;
+ }
+
+ @Override
+ public FlightClientMiddleware onCallStarted(CallInfo info) {
+ return new ClientIncomingAuthHeaderMiddleware(this);
+ }
+
+ void setCredentialCallOption(CredentialCallOption callOption) {
+ this.credentialCallOption = callOption;
+ }
+
+ public CredentialCallOption getCredentialCallOption() {
+ return credentialCallOption;
+ }
+ }
+
+ private ClientIncomingAuthHeaderMiddleware(Factory factory) {
+ this.factory = factory;
+ }
+
+ @Override
+ public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
+ }
+
+ @Override
+ public void onHeadersReceived(CallHeaders incomingHeaders) {
+ factory.setCredentialCallOption(
+ factory.headerHandler.getCredentialCallOptionFromIncomingHeaders(incomingHeaders));
+ }
+
+ @Override
+ public void onCallCompleted(CallStatus status) {
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/GeneratedBearerTokenAuthenticator.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/GeneratedBearerTokenAuthenticator.java
new file mode 100644
index 000000000..8b312b6b7
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/GeneratedBearerTokenAuthenticator.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth2;
+
+import java.nio.ByteBuffer;
+import java.util.Base64;
+import java.util.UUID;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.arrow.flight.CallHeaders;
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.grpc.MetadataAdapter;
+
+import com.google.common.base.Strings;
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
+
+import io.grpc.Metadata;
+
+/**
+ * Generates and caches bearer tokens from user credentials.
+ */
+public class GeneratedBearerTokenAuthenticator extends BearerTokenAuthenticator {
+ private final Cache<String, String> bearerToIdentityCache;
+
+ /**
+ * Generate bearer tokens for the given basic call authenticator.
+ * @param authenticator The authenticator to initial validate inputs with.
+ */
+ public GeneratedBearerTokenAuthenticator(CallHeaderAuthenticator authenticator) {
+ this(authenticator, CacheBuilder.newBuilder().expireAfterAccess(2, TimeUnit.HOURS));
+ }
+
+ /**
+ * Generate bearer tokens for the given basic call authenticator.
+ * @param authenticator The authenticator to initial validate inputs with.
+ * @param timeoutMinutes The time before tokens expire after being accessed.
+ */
+ public GeneratedBearerTokenAuthenticator(CallHeaderAuthenticator authenticator, int timeoutMinutes) {
+ this(authenticator, CacheBuilder.newBuilder().expireAfterAccess(timeoutMinutes, TimeUnit.MINUTES));
+ }
+
+ /**
+ * Generate bearer tokens for the given basic call authenticator.
+ * @param authenticator The authenticator to initial validate inputs with.
+ * @param cacheBuilder The configuration of the cache of bearer tokens.
+ */
+ public GeneratedBearerTokenAuthenticator(CallHeaderAuthenticator authenticator,
+ CacheBuilder<Object, Object> cacheBuilder) {
+ super(authenticator);
+ bearerToIdentityCache = cacheBuilder.build();
+ }
+
+ @Override
+ protected AuthResult validateBearer(String bearerToken) {
+ final String peerIdentity = bearerToIdentityCache.getIfPresent(bearerToken);
+ if (peerIdentity == null) {
+ throw CallStatus.UNAUTHENTICATED.toRuntimeException();
+ }
+
+ return new AuthResult() {
+ @Override
+ public String getPeerIdentity() {
+ return peerIdentity;
+ }
+
+ @Override
+ public void appendToOutgoingHeaders(CallHeaders outgoingHeaders) {
+ if (null == AuthUtilities.getValueFromAuthHeader(outgoingHeaders, Auth2Constants.BEARER_PREFIX)) {
+ outgoingHeaders.insert(Auth2Constants.AUTHORIZATION_HEADER, Auth2Constants.BEARER_PREFIX + bearerToken);
+ }
+ }
+ };
+ }
+
+ @Override
+ protected AuthResult getAuthResultWithBearerToken(AuthResult authResult) {
+ // We generate a dummy header and call appendToOutgoingHeaders with it.
+ // We then inspect the dummy header and parse the bearer token if present in the header
+ // and generate a new bearer token if a bearer token is not present in the header.
+ final CallHeaders dummyHeaders = new MetadataAdapter(new Metadata());
+ authResult.appendToOutgoingHeaders(dummyHeaders);
+ String bearerToken =
+ AuthUtilities.getValueFromAuthHeader(dummyHeaders, Auth2Constants.BEARER_PREFIX);
+ final AuthResult authResultWithBearerToken;
+ if (Strings.isNullOrEmpty(bearerToken)) {
+ // Generate a new bearer token and return an AuthResult that can write it.
+ final UUID uuid = UUID.randomUUID();
+ final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[16]);
+ byteBuffer.putLong(uuid.getMostSignificantBits());
+ byteBuffer.putLong(uuid.getLeastSignificantBits());
+ final String newToken = Base64.getEncoder().encodeToString(byteBuffer.array());
+ bearerToken = newToken;
+ authResultWithBearerToken = new AuthResult() {
+ @Override
+ public String getPeerIdentity() {
+ return authResult.getPeerIdentity();
+ }
+
+ @Override
+ public void appendToOutgoingHeaders(CallHeaders outgoingHeaders) {
+ authResult.appendToOutgoingHeaders(outgoingHeaders);
+ outgoingHeaders.insert(Auth2Constants.AUTHORIZATION_HEADER, Auth2Constants.BEARER_PREFIX + newToken);
+ }
+ };
+ } else {
+ // Use the bearer token supplied by the original auth result.
+ authResultWithBearerToken = authResult;
+ }
+ bearerToIdentityCache.put(bearerToken, authResult.getPeerIdentity());
+ return authResultWithBearerToken;
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ServerCallHeaderAuthMiddleware.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ServerCallHeaderAuthMiddleware.java
new file mode 100644
index 000000000..9bfa73818
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ServerCallHeaderAuthMiddleware.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth2;
+
+import static org.apache.arrow.flight.auth2.CallHeaderAuthenticator.AuthResult;
+
+import org.apache.arrow.flight.CallHeaders;
+import org.apache.arrow.flight.CallInfo;
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.FlightServerMiddleware;
+import org.apache.arrow.flight.RequestContext;
+
+/**
+ * Middleware that's used to validate credentials during the handshake and verify
+ * the bearer token in subsequent requests.
+ */
+public class ServerCallHeaderAuthMiddleware implements FlightServerMiddleware {
+ /**
+ * Factory for accessing ServerAuthMiddleware.
+ */
+ public static class Factory implements FlightServerMiddleware.Factory<ServerCallHeaderAuthMiddleware> {
+ private final CallHeaderAuthenticator authHandler;
+
+ /**
+ * Construct a factory with the given auth handler.
+ * @param authHandler The auth handler what will be used for authenticating requests.
+ */
+ public Factory(CallHeaderAuthenticator authHandler) {
+ this.authHandler = authHandler;
+ }
+
+ @Override
+ public ServerCallHeaderAuthMiddleware onCallStarted(CallInfo callInfo, CallHeaders incomingHeaders,
+ RequestContext context) {
+ final AuthResult result = authHandler.authenticate(incomingHeaders);
+ context.put(Auth2Constants.PEER_IDENTITY_KEY, result.getPeerIdentity());
+ return new ServerCallHeaderAuthMiddleware(result);
+ }
+ }
+
+ private final AuthResult authResult;
+
+ public ServerCallHeaderAuthMiddleware(AuthResult authResult) {
+ this.authResult = authResult;
+ }
+
+ @Override
+ public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
+ authResult.appendToOutgoingHeaders(outgoingHeaders);
+ }
+
+ @Override
+ public void onCallCompleted(CallStatus status) {
+ }
+
+ @Override
+ public void onCallErrored(Throwable err) {
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/client/ClientCookieMiddleware.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/client/ClientCookieMiddleware.java
new file mode 100644
index 000000000..56f24e101
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/client/ClientCookieMiddleware.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.client;
+
+import java.net.HttpCookie;
+import java.util.List;
+import java.util.Locale;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.stream.Collectors;
+
+import org.apache.arrow.flight.CallHeaders;
+import org.apache.arrow.flight.CallInfo;
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.FlightClientMiddleware;
+import org.apache.arrow.util.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A client middleware for receiving and sending cookie information.
+ * Note that this class will not persist permanent cookies beyond the lifetime
+ * of this session.
+ *
+ * This middleware will automatically remove cookies that have expired.
+ * <b>Note</b>: Negative max-age values currently do not get marked as expired due to
+ * a JDK issue. Use max-age=0 to explicitly remove an existing cookie.
+ */
+public class ClientCookieMiddleware implements FlightClientMiddleware {
+ private static final Logger LOGGER = LoggerFactory.getLogger(ClientCookieMiddleware.class);
+
+ private static final String SET_COOKIE_HEADER = "Set-Cookie";
+ private static final String COOKIE_HEADER = "Cookie";
+
+ private final Factory factory;
+
+ @VisibleForTesting
+ ClientCookieMiddleware(Factory factory) {
+ this.factory = factory;
+ }
+
+ /**
+ * Factory used within FlightClient.
+ */
+ public static class Factory implements FlightClientMiddleware.Factory {
+ // Use a map to track the most recent version of a cookie from the server.
+ // Note that cookie names are case-sensitive (but header names aren't).
+ private ConcurrentMap<String, HttpCookie> cookies = new ConcurrentHashMap<>();
+
+ @Override
+ public ClientCookieMiddleware onCallStarted(CallInfo info) {
+ return new ClientCookieMiddleware(this);
+ }
+
+ private void updateCookies(Iterable<String> newCookieHeaderValues) {
+ // Note: Intentionally overwrite existing cookie values.
+ // A cookie defined once will continue to be used in all subsequent
+ // requests on the client instance. The server can send the same cookie again
+ // with a different value and the client will use the new value in future requests.
+ // The server can also update a cookie to have an Expiry in the past or negative age
+ // to signal that the client should stop using the cookie immediately.
+ newCookieHeaderValues.forEach(headerValue -> {
+ try {
+ final List<HttpCookie> parsedCookies = HttpCookie.parse(headerValue);
+ parsedCookies.forEach(parsedCookie -> {
+ final String cookieNameLc = parsedCookie.getName().toLowerCase(Locale.ENGLISH);
+ if (parsedCookie.hasExpired()) {
+ cookies.remove(cookieNameLc);
+ } else {
+ cookies.put(parsedCookie.getName().toLowerCase(Locale.ENGLISH), parsedCookie);
+ }
+ });
+ } catch (IllegalArgumentException ex) {
+ LOGGER.warn("Skipping incorrectly formatted Set-Cookie header with value '{}'.", headerValue);
+ }
+ });
+ }
+ }
+
+ @Override
+ public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
+ final String cookieValue = getValidCookiesAsString();
+ if (!cookieValue.isEmpty()) {
+ outgoingHeaders.insert(COOKIE_HEADER, cookieValue);
+ }
+ }
+
+ @Override
+ public void onHeadersReceived(CallHeaders incomingHeaders) {
+ final Iterable<String> setCookieHeaders = incomingHeaders.getAll(SET_COOKIE_HEADER);
+ if (setCookieHeaders != null) {
+ factory.updateCookies(setCookieHeaders);
+ }
+ }
+
+ @Override
+ public void onCallCompleted(CallStatus status) {
+
+ }
+
+ /**
+ * Discards expired cookies and returns the valid cookies as a String delimited by ';'.
+ */
+ @VisibleForTesting
+ String getValidCookiesAsString() {
+ // Discard expired cookies.
+ factory.cookies.entrySet().removeIf(cookieEntry -> cookieEntry.getValue().hasExpired());
+
+ // Cookie header value format:
+ // [<cookie-name1>=<cookie-value1>; <cookie-name2>=<cookie-value2; ...]
+ return factory.cookies.entrySet().stream()
+ .map(cookie -> cookie.getValue().toString())
+ .collect(Collectors.joining("; "));
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java
new file mode 100644
index 000000000..528c227df
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example;
+
+import java.io.IOException;
+
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.util.AutoCloseables;
+
+/**
+ * An Example Flight Server that provides access to the InMemoryStore. Used for integration testing.
+ */
+public class ExampleFlightServer implements AutoCloseable {
+
+ private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(ExampleFlightServer.class);
+
+ private final FlightServer flightServer;
+ private final Location location;
+ private final BufferAllocator allocator;
+ private final InMemoryStore mem;
+
+ /**
+ * Constructs a new instance using Allocator for allocating buffer storage that binds
+ * to the given location.
+ */
+ public ExampleFlightServer(BufferAllocator allocator, Location location) {
+ this.allocator = allocator.newChildAllocator("flight-server", 0, Long.MAX_VALUE);
+ this.location = location;
+ this.mem = new InMemoryStore(this.allocator, location);
+ this.flightServer = FlightServer.builder(allocator, location, mem).build();
+ }
+
+ public Location getLocation() {
+ return location;
+ }
+
+ public int getPort() {
+ return this.flightServer.getPort();
+ }
+
+ public void start() throws IOException {
+ flightServer.start();
+ }
+
+ public void awaitTermination() throws InterruptedException {
+ flightServer.awaitTermination();
+ }
+
+ public InMemoryStore getStore() {
+ return mem;
+ }
+
+ @Override
+ public void close() throws Exception {
+ AutoCloseables.close(mem, flightServer, allocator);
+ }
+
+ /**
+ * Main method starts the server listening to localhost:12233.
+ */
+ public static void main(String[] args) throws Exception {
+ final BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
+ final ExampleFlightServer efs = new ExampleFlightServer(a, Location.forGrpcInsecure("localhost", 12233));
+ efs.start();
+ Runtime.getRuntime().addShutdownHook(new Thread(() -> {
+ try {
+ System.out.println("\nExiting...");
+ AutoCloseables.close(efs, a);
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }));
+ efs.awaitTermination();
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/ExampleTicket.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/ExampleTicket.java
new file mode 100644
index 000000000..e15ecd034
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/ExampleTicket.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example;
+
+import java.io.IOException;
+import java.util.List;
+
+import org.apache.arrow.flight.Ticket;
+import org.apache.arrow.util.Preconditions;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.annotation.JsonSerialize;
+import com.google.common.base.Throwables;
+
+/**
+ * POJO object used to demonstrate how an opaque ticket can be generated.
+ */
+@JsonSerialize
+public class ExampleTicket {
+
+ private static final ObjectMapper MAPPER = new ObjectMapper();
+
+ private final List<String> path;
+ private final int ordinal;
+
+ // uuid to ensure that a stream from one node is not recreated on another node and mixed up.
+ private final String uuid;
+
+ /**
+ * Constructs a new instance.
+ *
+ * @param path Path to data
+ * @param ordinal A counter for the stream.
+ * @param uuid A unique identifier for this particular stream.
+ */
+ @JsonCreator
+ public ExampleTicket(@JsonProperty("path") List<String> path, @JsonProperty("ordinal") int ordinal,
+ @JsonProperty("uuid") String uuid) {
+ super();
+ Preconditions.checkArgument(ordinal >= 0);
+ this.path = path;
+ this.ordinal = ordinal;
+ this.uuid = uuid;
+ }
+
+ public List<String> getPath() {
+ return path;
+ }
+
+ public int getOrdinal() {
+ return ordinal;
+ }
+
+ public String getUuid() {
+ return uuid;
+ }
+
+ /**
+ * Deserializes a new instance from the protocol buffer ticket.
+ */
+ public static ExampleTicket from(Ticket ticket) {
+ try {
+ return MAPPER.readValue(ticket.getBytes(), ExampleTicket.class);
+ } catch (IOException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ /**
+ * Creates a new protocol buffer Ticket by serializing to JSON.
+ */
+ public Ticket toTicket() {
+ try {
+ return new Ticket(MAPPER.writeValueAsBytes(this));
+ } catch (JsonProcessingException e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ @Override
+ public int hashCode() {
+ final int prime = 31;
+ int result = 1;
+ result = prime * result + ordinal;
+ result = prime * result + ((path == null) ? 0 : path.hashCode());
+ result = prime * result + ((uuid == null) ? 0 : uuid.hashCode());
+ return result;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (obj == null) {
+ return false;
+ }
+ if (getClass() != obj.getClass()) {
+ return false;
+ }
+ ExampleTicket other = (ExampleTicket) obj;
+ if (ordinal != other.ordinal) {
+ return false;
+ }
+ if (path == null) {
+ if (other.path != null) {
+ return false;
+ }
+ } else if (!path.equals(other.path)) {
+ return false;
+ }
+ if (uuid == null) {
+ if (other.uuid != null) {
+ return false;
+ }
+ } else if (!uuid.equals(other.uuid)) {
+ return false;
+ }
+ return true;
+ }
+
+
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/FlightHolder.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/FlightHolder.java
new file mode 100644
index 000000000..f6295211e
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/FlightHolder.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.concurrent.CopyOnWriteArrayList;
+import java.util.stream.Collectors;
+
+import org.apache.arrow.flight.FlightDescriptor;
+import org.apache.arrow.flight.FlightEndpoint;
+import org.apache.arrow.flight.FlightInfo;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.util.Preconditions;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.arrow.vector.util.DictionaryUtility;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Iterables;
+
+/**
+ * A logical collection of streams sharing the same schema.
+ */
+public class FlightHolder implements AutoCloseable {
+
+ private final BufferAllocator allocator;
+ private final FlightDescriptor descriptor;
+ private final Schema schema;
+ private final List<Stream> streams = new CopyOnWriteArrayList<>();
+ private final DictionaryProvider dictionaryProvider;
+
+ /**
+ * Creates a new instance.
+ * @param allocator The allocator to use for allocating buffers to store data.
+ * @param descriptor The descriptor for the streams.
+ * @param schema The schema for the stream.
+ * @param dictionaryProvider The dictionary provider for the stream.
+ */
+ public FlightHolder(BufferAllocator allocator, FlightDescriptor descriptor, Schema schema,
+ DictionaryProvider dictionaryProvider) {
+ Preconditions.checkArgument(!descriptor.isCommand());
+ this.allocator = allocator.newChildAllocator(descriptor.toString(), 0, Long.MAX_VALUE);
+ this.descriptor = descriptor;
+ this.schema = schema;
+ this.dictionaryProvider = dictionaryProvider;
+ }
+
+ /**
+ * Returns the stream based on the ordinal of ExampleTicket.
+ */
+ public Stream getStream(ExampleTicket ticket) {
+ Preconditions.checkArgument(ticket.getOrdinal() < streams.size(), "Unknown stream.");
+ Stream stream = streams.get(ticket.getOrdinal());
+ stream.verify(ticket);
+ return stream;
+ }
+
+ /**
+ * Adds a new streams which clients can populate via the returned object.
+ */
+ public Stream.StreamCreator addStream(Schema schema) {
+ Preconditions.checkArgument(this.schema.equals(schema), "Stream schema inconsistent with existing schema.");
+ return new Stream.StreamCreator(schema, dictionaryProvider, allocator, t -> {
+ synchronized (streams) {
+ streams.add(t);
+ }
+ });
+ }
+
+ /**
+ * List all available streams as being available at <code>l</code>.
+ */
+ public FlightInfo getFlightInfo(final Location l) {
+ final long bytes = allocator.getAllocatedMemory();
+ final long records = streams.stream().collect(Collectors.summingLong(t -> t.getRecordCount()));
+
+ final List<FlightEndpoint> endpoints = new ArrayList<>();
+ int i = 0;
+ for (Stream s : streams) {
+ endpoints.add(
+ new FlightEndpoint(
+ new ExampleTicket(descriptor.getPath(), i, s.getUuid())
+ .toTicket(),
+ l));
+ i++;
+ }
+ return new FlightInfo(messageFormatSchema(), descriptor, endpoints, bytes, records);
+ }
+
+ private Schema messageFormatSchema() {
+ Set<Long> dictionaryIdsUsed = new HashSet<>();
+ List<Field> messageFormatFields = schema.getFields()
+ .stream()
+ .map(f -> DictionaryUtility.toMessageFormat(f, dictionaryProvider, dictionaryIdsUsed))
+ .collect(Collectors.toList());
+ return new Schema(messageFormatFields, schema.getCustomMetadata());
+ }
+
+ @Override
+ public void close() throws Exception {
+ // Close dictionaries
+ final Set<Long> dictionaryIds = new HashSet<>();
+ schema.getFields().forEach(field -> DictionaryUtility.toMessageFormat(field, dictionaryProvider, dictionaryIds));
+
+ final Iterable<AutoCloseable> dictionaries = dictionaryIds.stream()
+ .map(id -> (AutoCloseable) dictionaryProvider.lookup(id).getVector())::iterator;
+
+ AutoCloseables.close(Iterables.concat(streams, ImmutableList.of(allocator), dictionaries));
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java
new file mode 100644
index 000000000..ff796718d
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example;
+
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+
+import org.apache.arrow.flight.Action;
+import org.apache.arrow.flight.ActionType;
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.Criteria;
+import org.apache.arrow.flight.FlightDescriptor;
+import org.apache.arrow.flight.FlightInfo;
+import org.apache.arrow.flight.FlightProducer;
+import org.apache.arrow.flight.FlightStream;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.flight.PutResult;
+import org.apache.arrow.flight.Result;
+import org.apache.arrow.flight.Ticket;
+import org.apache.arrow.flight.example.Stream.StreamCreator;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.VectorUnloader;
+
+/**
+ * A FlightProducer that hosts an in memory store of Arrow buffers. Used for integration testing.
+ */
+public class InMemoryStore implements FlightProducer, AutoCloseable {
+
+ private final ConcurrentMap<FlightDescriptor, FlightHolder> holders = new ConcurrentHashMap<>();
+ private final BufferAllocator allocator;
+ private Location location;
+
+ /**
+ * Constructs a new instance.
+ *
+ * @param allocator The allocator for creating new Arrow buffers.
+ * @param location The location of the storage.
+ */
+ public InMemoryStore(BufferAllocator allocator, Location location) {
+ super();
+ this.allocator = allocator;
+ this.location = location;
+ }
+
+ /**
+ * Update the location after server start.
+ *
+ * <p>Useful for binding to port 0 to get a free port.
+ */
+ public void setLocation(Location location) {
+ this.location = location;
+ }
+
+ @Override
+ public void getStream(CallContext context, Ticket ticket,
+ ServerStreamListener listener) {
+ getStream(ticket).sendTo(allocator, listener);
+ }
+
+ /**
+ * Returns the appropriate stream given the ticket (streams are indexed by path and an ordinal).
+ */
+ public Stream getStream(Ticket t) {
+ ExampleTicket example = ExampleTicket.from(t);
+ FlightDescriptor d = FlightDescriptor.path(example.getPath());
+ FlightHolder h = holders.get(d);
+ if (h == null) {
+ throw new IllegalStateException("Unknown ticket.");
+ }
+
+ return h.getStream(example);
+ }
+
+ @Override
+ public void listFlights(CallContext context, Criteria criteria, StreamListener<FlightInfo> listener) {
+ try {
+ for (FlightHolder h : holders.values()) {
+ listener.onNext(h.getFlightInfo(location));
+ }
+ listener.onCompleted();
+ } catch (Exception ex) {
+ listener.onError(ex);
+ }
+ }
+
+ @Override
+ public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
+ FlightHolder h = holders.get(descriptor);
+ if (h == null) {
+ throw new IllegalStateException("Unknown descriptor.");
+ }
+
+ return h.getFlightInfo(location);
+ }
+
+ @Override
+ public Runnable acceptPut(CallContext context,
+ final FlightStream flightStream, final StreamListener<PutResult> ackStream) {
+ return () -> {
+ StreamCreator creator = null;
+ boolean success = false;
+ try (VectorSchemaRoot root = flightStream.getRoot()) {
+ final FlightHolder h = holders.computeIfAbsent(
+ flightStream.getDescriptor(),
+ t -> new FlightHolder(allocator, t, flightStream.getSchema(), flightStream.getDictionaryProvider()));
+
+ creator = h.addStream(flightStream.getSchema());
+
+ VectorUnloader unloader = new VectorUnloader(root);
+ while (flightStream.next()) {
+ ackStream.onNext(PutResult.metadata(flightStream.getLatestMetadata()));
+ creator.add(unloader.getRecordBatch());
+ }
+ // Closing the stream will release the dictionaries
+ flightStream.takeDictionaryOwnership();
+ creator.complete();
+ success = true;
+ } finally {
+ if (!success) {
+ creator.drop();
+ }
+ }
+
+ };
+
+ }
+
+ @Override
+ public void doAction(CallContext context, Action action,
+ StreamListener<Result> listener) {
+ switch (action.getType()) {
+ case "drop": {
+ // not implemented.
+ listener.onNext(new Result(new byte[0]));
+ listener.onCompleted();
+ break;
+ }
+ default: {
+ listener.onError(CallStatus.UNIMPLEMENTED.toRuntimeException());
+ }
+ }
+ }
+
+ @Override
+ public void listActions(CallContext context,
+ StreamListener<ActionType> listener) {
+ listener.onNext(new ActionType("get", "pull a stream. Action must be done via standard get mechanism"));
+ listener.onNext(new ActionType("put", "push a stream. Action must be done via standard put mechanism"));
+ listener.onNext(new ActionType("drop", "delete a flight. Action body is a JSON encoded path."));
+ listener.onCompleted();
+ }
+
+ @Override
+ public void close() throws Exception {
+ AutoCloseables.close(holders.values());
+ holders.clear();
+ }
+
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/Stream.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/Stream.java
new file mode 100644
index 000000000..0bc35798d
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/Stream.java
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example;
+
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.UUID;
+import java.util.function.Consumer;
+
+import org.apache.arrow.flight.FlightProducer.ServerStreamListener;
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.vector.VectorLoader;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.types.pojo.Schema;
+
+import com.google.common.base.Throwables;
+import com.google.common.collect.ImmutableList;
+
+/**
+ * A collection of Arrow record batches.
+ */
+public class Stream implements AutoCloseable, Iterable<ArrowRecordBatch> {
+
+ private final String uuid = UUID.randomUUID().toString();
+ private final DictionaryProvider dictionaryProvider;
+ private final List<ArrowRecordBatch> batches;
+ private final Schema schema;
+ private final long recordCount;
+
+ /**
+ * Create a new instance.
+ *
+ * @param schema The schema for the record batches.
+ * @param batches The data associated with the stream.
+ * @param recordCount The total record count across all batches.
+ */
+ public Stream(
+ final Schema schema,
+ final DictionaryProvider dictionaryProvider,
+ List<ArrowRecordBatch> batches,
+ long recordCount) {
+ this.schema = schema;
+ this.dictionaryProvider = dictionaryProvider;
+ this.batches = ImmutableList.copyOf(batches);
+ this.recordCount = recordCount;
+ }
+
+ public Schema getSchema() {
+ return schema;
+ }
+
+ @Override
+ public Iterator<ArrowRecordBatch> iterator() {
+ return batches.iterator();
+ }
+
+ public long getRecordCount() {
+ return recordCount;
+ }
+
+ public String getUuid() {
+ return uuid;
+ }
+
+ /**
+ * Sends that data from this object to the given listener.
+ */
+ public void sendTo(BufferAllocator allocator, ServerStreamListener listener) {
+ try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ listener.start(root, dictionaryProvider);
+ final VectorLoader loader = new VectorLoader(root);
+ int counter = 0;
+ for (ArrowRecordBatch batch : batches) {
+ final byte[] rawMetadata = Integer.toString(counter).getBytes(StandardCharsets.UTF_8);
+ final ArrowBuf metadata = allocator.buffer(rawMetadata.length);
+ metadata.writeBytes(rawMetadata);
+ loader.load(batch);
+ // Transfers ownership of the buffer - do not free buffer ourselves
+ listener.putNext(metadata);
+ counter++;
+ }
+ listener.completed();
+ } catch (Exception ex) {
+ listener.error(ex);
+ }
+ }
+
+ /**
+ * Throws an IllegalStateException if the given ticket doesn't correspond to this stream.
+ */
+ public void verify(ExampleTicket ticket) {
+ if (!uuid.equals(ticket.getUuid())) {
+ throw new IllegalStateException("Ticket doesn't match.");
+ }
+ }
+
+ @Override
+ public void close() throws Exception {
+ AutoCloseables.close(batches);
+ }
+
+ /**
+ * Provides the functionality to create a new stream by adding batches serially.
+ */
+ public static class StreamCreator {
+
+ private final Schema schema;
+ private final BufferAllocator allocator;
+ private final List<ArrowRecordBatch> batches = new ArrayList<>();
+ private final Consumer<Stream> committer;
+ private long recordCount = 0;
+ private DictionaryProvider dictionaryProvider;
+
+ /**
+ * Creates a new instance.
+ *
+ * @param schema The schema for batches in the stream.
+ * @param dictionaryProvider The dictionary provider for the stream.
+ * @param allocator The allocator used to copy data permanently into the stream.
+ * @param committer A callback for when the stream is ready to be finalized (no more batches).
+ */
+ public StreamCreator(Schema schema, DictionaryProvider dictionaryProvider,
+ BufferAllocator allocator, Consumer<Stream> committer) {
+ this.allocator = allocator;
+ this.committer = committer;
+ this.schema = schema;
+ this.dictionaryProvider = dictionaryProvider;
+ }
+
+ /**
+ * Abandon creation of the stream.
+ */
+ public void drop() {
+ try {
+ AutoCloseables.close(batches);
+ } catch (Exception ex) {
+ throw Throwables.propagate(ex);
+ }
+ }
+
+ public void add(ArrowRecordBatch batch) {
+ batches.add(batch.cloneWithTransfer(allocator));
+ recordCount += batch.getLength();
+ }
+
+ /**
+ * Complete building the stream (no more batches can be added).
+ */
+ public void complete() {
+ Stream stream = new Stream(schema, dictionaryProvider, batches, recordCount);
+ committer.accept(stream);
+ }
+
+ }
+
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java
new file mode 100644
index 000000000..3955d7d21
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example.integration;
+
+import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+import java.util.Optional;
+
+import org.apache.arrow.flight.Action;
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightProducer;
+import org.apache.arrow.flight.FlightRuntimeException;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.FlightStatusCode;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.flight.NoOpFlightProducer;
+import org.apache.arrow.flight.Result;
+import org.apache.arrow.flight.auth.BasicClientAuthHandler;
+import org.apache.arrow.flight.auth.BasicServerAuthHandler;
+import org.apache.arrow.memory.BufferAllocator;
+
+/**
+ * A scenario testing the built-in basic authentication Protobuf.
+ */
+final class AuthBasicProtoScenario implements Scenario {
+
+ static final String USERNAME = "arrow";
+ static final String PASSWORD = "flight";
+
+ @Override
+ public FlightProducer producer(BufferAllocator allocator, Location location) {
+ return new NoOpFlightProducer() {
+ @Override
+ public void doAction(CallContext context, Action action, StreamListener<Result> listener) {
+ listener.onNext(new Result(context.peerIdentity().getBytes(StandardCharsets.UTF_8)));
+ listener.onCompleted();
+ }
+ };
+ }
+
+ @Override
+ public void buildServer(FlightServer.Builder builder) {
+ builder.authHandler(new BasicServerAuthHandler(new BasicServerAuthHandler.BasicAuthValidator() {
+ @Override
+ public byte[] getToken(String username, String password) throws Exception {
+ if (!USERNAME.equals(username) || !PASSWORD.equals(password)) {
+ throw CallStatus.UNAUTHENTICATED.withDescription("Username or password is invalid.").toRuntimeException();
+ }
+ return ("valid:" + username).getBytes(StandardCharsets.UTF_8);
+ }
+
+ @Override
+ public Optional<String> isValid(byte[] token) {
+ if (token != null) {
+ final String credential = new String(token, StandardCharsets.UTF_8);
+ if (credential.startsWith("valid:")) {
+ return Optional.of(credential.substring(6));
+ }
+ }
+ return Optional.empty();
+ }
+ }));
+ }
+
+ @Override
+ public void client(BufferAllocator allocator, Location location, FlightClient client) {
+ final FlightRuntimeException e = IntegrationAssertions.assertThrows(FlightRuntimeException.class, () -> {
+ client.listActions().forEach(act -> {
+ });
+ });
+ if (!FlightStatusCode.UNAUTHENTICATED.equals(e.status().code())) {
+ throw new AssertionError("Expected UNAUTHENTICATED but found " + e.status().code(), e);
+ }
+
+ client.authenticate(new BasicClientAuthHandler(USERNAME, PASSWORD));
+ final Result result = client.doAction(new Action("")).next();
+ if (!USERNAME.equals(new String(result.getBody(), StandardCharsets.UTF_8))) {
+ throw new AssertionError("Expected " + USERNAME + " but got " + Arrays.toString(result.getBody()));
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java
new file mode 100644
index 000000000..576d1887f
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example.integration;
+
+import java.util.Objects;
+
+/**
+ * Utility methods to implement integration tests without using JUnit assertions.
+ */
+final class IntegrationAssertions {
+
+ /**
+ * Assert that the given code throws the given exception or subclass thereof.
+ *
+ * @param clazz The exception type.
+ * @param body The code to run.
+ * @param <T> The exception type.
+ * @return The thrown exception.
+ */
+ @SuppressWarnings("unchecked")
+ static <T extends Throwable> T assertThrows(Class<T> clazz, AssertThrows body) {
+ try {
+ body.run();
+ } catch (Throwable t) {
+ if (clazz.isInstance(t)) {
+ return (T) t;
+ }
+ throw new AssertionError("Expected exception of class " + clazz + " but got " + t.getClass(), t);
+ }
+ throw new AssertionError("Expected exception of class " + clazz + " but did not throw.");
+ }
+
+ /**
+ * Assert that the two (non-array) objects are equal.
+ */
+ static void assertEquals(Object expected, Object actual) {
+ if (!Objects.equals(expected, actual)) {
+ throw new AssertionError("Expected:\n" + expected + "\nbut got:\n" + actual);
+ }
+ }
+
+ /**
+ * Assert that the value is false, using the given message as an error otherwise.
+ */
+ static void assertFalse(String message, boolean value) {
+ if (value) {
+ throw new AssertionError("Expected false: " + message);
+ }
+ }
+
+ /**
+ * An interface used with {@link #assertThrows(Class, AssertThrows)}.
+ */
+ @FunctionalInterface
+ interface AssertThrows {
+
+ void run() throws Throwable;
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java
new file mode 100644
index 000000000..27a545f84
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example.integration;
+
+import static org.apache.arrow.memory.util.LargeMemoryUtil.checkedCastToInt;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.List;
+
+import org.apache.arrow.flight.AsyncPutListener;
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightDescriptor;
+import org.apache.arrow.flight.FlightEndpoint;
+import org.apache.arrow.flight.FlightInfo;
+import org.apache.arrow.flight.FlightStream;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.flight.PutResult;
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.VectorLoader;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.VectorUnloader;
+import org.apache.arrow.vector.ipc.JsonFileReader;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.apache.arrow.vector.util.Validator;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.CommandLineParser;
+import org.apache.commons.cli.DefaultParser;
+import org.apache.commons.cli.Options;
+import org.apache.commons.cli.ParseException;
+
+/**
+ * A Flight client for integration testing.
+ */
+class IntegrationTestClient {
+ private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(IntegrationTestClient.class);
+ private final Options options;
+
+ private IntegrationTestClient() {
+ options = new Options();
+ options.addOption("j", "json", true, "json file");
+ options.addOption("scenario", true, "The integration test scenario.");
+ options.addOption("host", true, "The host to connect to.");
+ options.addOption("port", true, "The port to connect to.");
+ }
+
+ public static void main(String[] args) {
+ try {
+ new IntegrationTestClient().run(args);
+ } catch (ParseException e) {
+ fatalError("Invalid parameters", e);
+ } catch (IOException e) {
+ fatalError("Error accessing files", e);
+ } catch (Exception e) {
+ fatalError("Unknown error", e);
+ }
+ }
+
+ private static void fatalError(String message, Throwable e) {
+ System.err.println(message);
+ System.err.println(e.getMessage());
+ LOGGER.error(message, e);
+ System.exit(1);
+ }
+
+ private void run(String[] args) throws Exception {
+ final CommandLineParser parser = new DefaultParser();
+ final CommandLine cmd = parser.parse(options, args, false);
+
+ final String host = cmd.getOptionValue("host", "localhost");
+ final int port = Integer.parseInt(cmd.getOptionValue("port", "31337"));
+
+ final Location defaultLocation = Location.forGrpcInsecure(host, port);
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+ final FlightClient client = FlightClient.builder(allocator, defaultLocation).build()) {
+
+ if (cmd.hasOption("scenario")) {
+ Scenarios.getScenario(cmd.getOptionValue("scenario")).client(allocator, defaultLocation, client);
+ } else {
+ final String inputPath = cmd.getOptionValue("j");
+ testStream(allocator, defaultLocation, client, inputPath);
+ }
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private static void testStream(BufferAllocator allocator, Location server, FlightClient client, String inputPath)
+ throws IOException {
+ // 1. Read data from JSON and upload to server.
+ FlightDescriptor descriptor = FlightDescriptor.path(inputPath);
+ try (JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator);
+ VectorSchemaRoot root = VectorSchemaRoot.create(reader.start(), allocator)) {
+ FlightClient.ClientStreamListener stream = client.startPut(descriptor, root, reader,
+ new AsyncPutListener() {
+ int counter = 0;
+
+ @Override
+ public void onNext(PutResult val) {
+ final byte[] metadataRaw = new byte[checkedCastToInt(val.getApplicationMetadata().readableBytes())];
+ val.getApplicationMetadata().readBytes(metadataRaw);
+ final String metadata = new String(metadataRaw, StandardCharsets.UTF_8);
+ if (!Integer.toString(counter).equals(metadata)) {
+ throw new RuntimeException(
+ String.format("Invalid ACK from server. Expected '%d' but got '%s'.", counter, metadata));
+ }
+ counter++;
+ }
+ });
+ int counter = 0;
+ while (reader.read(root)) {
+ final byte[] rawMetadata = Integer.toString(counter).getBytes(StandardCharsets.UTF_8);
+ final ArrowBuf metadata = allocator.buffer(rawMetadata.length);
+ metadata.writeBytes(rawMetadata);
+ // Transfers ownership of the buffer, so do not release it ourselves
+ stream.putNext(metadata);
+ root.clear();
+ counter++;
+ }
+ stream.completed();
+ // Need to call this, or exceptions from the server get swallowed
+ stream.getResult();
+ }
+
+ // 2. Get the ticket for the data.
+ FlightInfo info = client.getInfo(descriptor);
+ List<FlightEndpoint> endpoints = info.getEndpoints();
+ if (endpoints.isEmpty()) {
+ throw new RuntimeException("No endpoints returned from Flight server.");
+ }
+
+ for (FlightEndpoint endpoint : info.getEndpoints()) {
+ // 3. Download the data from the server.
+ List<Location> locations = endpoint.getLocations();
+ if (locations.isEmpty()) {
+ throw new RuntimeException("No locations returned from Flight server.");
+ }
+ for (Location location : locations) {
+ System.out.println("Verifying location " + location.getUri());
+ try (FlightClient readClient = FlightClient.builder(allocator, location).build();
+ FlightStream stream = readClient.getStream(endpoint.getTicket());
+ VectorSchemaRoot root = stream.getRoot();
+ VectorSchemaRoot downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator);
+ JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator)) {
+ VectorLoader loader = new VectorLoader(downloadedRoot);
+ VectorUnloader unloader = new VectorUnloader(root);
+
+ Schema jsonSchema = reader.start();
+ Validator.compareSchemas(root.getSchema(), jsonSchema);
+ try (VectorSchemaRoot jsonRoot = VectorSchemaRoot.create(jsonSchema, allocator)) {
+
+ while (stream.next()) {
+ try (final ArrowRecordBatch arb = unloader.getRecordBatch()) {
+ loader.load(arb);
+ if (reader.read(jsonRoot)) {
+
+ // 4. Validate the data.
+ Validator.compareVectorSchemaRoot(jsonRoot, downloadedRoot);
+ jsonRoot.clear();
+ } else {
+ throw new RuntimeException("Flight stream has more batches than JSON");
+ }
+ }
+ }
+
+ // Verify no more batches with data in JSON
+ // NOTE: Currently the C++ Flight server skips empty batches at end of the stream
+ if (reader.read(jsonRoot) && jsonRoot.getRowCount() > 0) {
+ throw new RuntimeException("JSON has more batches with than Flight stream");
+ }
+ }
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java
new file mode 100644
index 000000000..da336c502
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example.integration;
+
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.flight.example.InMemoryStore;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.apache.commons.cli.CommandLine;
+import org.apache.commons.cli.CommandLineParser;
+import org.apache.commons.cli.DefaultParser;
+import org.apache.commons.cli.Options;
+import org.apache.commons.cli.ParseException;
+
+/**
+ * Flight server for integration testing.
+ */
+class IntegrationTestServer {
+ private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(IntegrationTestServer.class);
+ private final Options options;
+
+ private IntegrationTestServer() {
+ options = new Options();
+ options.addOption("port", true, "The port to serve on.");
+ options.addOption("scenario", true, "The integration test scenario.");
+ }
+
+ private void run(String[] args) throws Exception {
+ CommandLineParser parser = new DefaultParser();
+ CommandLine cmd = parser.parse(options, args, false);
+ final int port = Integer.parseInt(cmd.getOptionValue("port", "31337"));
+ final Location location = Location.forGrpcInsecure("localhost", port);
+
+ final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ final FlightServer.Builder builder = FlightServer.builder().allocator(allocator).location(location);
+
+ final FlightServer server;
+ if (cmd.hasOption("scenario")) {
+ final Scenario scenario = Scenarios.getScenario(cmd.getOptionValue("scenario"));
+ scenario.buildServer(builder);
+ server = builder.producer(scenario.producer(allocator, location)).build();
+ server.start();
+ } else {
+ final InMemoryStore store = new InMemoryStore(allocator, location);
+ server = FlightServer.builder(allocator, location, store).build().start();
+ store.setLocation(Location.forGrpcInsecure("localhost", server.getPort()));
+ }
+ // Print out message for integration test script
+ System.out.println("Server listening on localhost:" + server.getPort());
+
+ Runtime.getRuntime().addShutdownHook(new Thread(() -> {
+ try {
+ System.out.println("\nExiting...");
+ AutoCloseables.close(server, allocator);
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }));
+
+ server.awaitTermination();
+ }
+
+ public static void main(String[] args) {
+ try {
+ new IntegrationTestServer().run(args);
+ } catch (ParseException e) {
+ fatalError("Error parsing arguments", e);
+ } catch (Exception e) {
+ fatalError("Runtime error", e);
+ }
+ }
+
+ private static void fatalError(String message, Throwable e) {
+ System.err.println(message);
+ System.err.println(e.getMessage());
+ LOGGER.error(message, e);
+ System.exit(1);
+ }
+
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/MiddlewareScenario.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/MiddlewareScenario.java
new file mode 100644
index 000000000..c710ce98b
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/MiddlewareScenario.java
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example.integration;
+
+import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+import java.util.Collections;
+
+import org.apache.arrow.flight.CallHeaders;
+import org.apache.arrow.flight.CallInfo;
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightClientMiddleware;
+import org.apache.arrow.flight.FlightDescriptor;
+import org.apache.arrow.flight.FlightInfo;
+import org.apache.arrow.flight.FlightProducer;
+import org.apache.arrow.flight.FlightRuntimeException;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.FlightServerMiddleware;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.flight.NoOpFlightProducer;
+import org.apache.arrow.flight.RequestContext;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.types.pojo.Schema;
+
+/**
+ * Test an edge case in middleware: gRPC-Java consolidates headers and trailers if a call fails immediately. On the
+ * gRPC implementation side, we need to watch for this, or else we'll have a call with "no headers" if we only look
+ * for headers.
+ */
+final class MiddlewareScenario implements Scenario {
+
+ private static final String HEADER = "x-middleware";
+ private static final String EXPECTED_HEADER_VALUE = "expected value";
+ private static final byte[] COMMAND_SUCCESS = "success".getBytes(StandardCharsets.UTF_8);
+
+ @Override
+ public FlightProducer producer(BufferAllocator allocator, Location location) {
+ return new NoOpFlightProducer() {
+ @Override
+ public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
+ if (descriptor.isCommand()) {
+ if (Arrays.equals(COMMAND_SUCCESS, descriptor.getCommand())) {
+ return new FlightInfo(new Schema(Collections.emptyList()), descriptor, Collections.emptyList(), -1, -1);
+ }
+ }
+ throw CallStatus.UNIMPLEMENTED.toRuntimeException();
+ }
+ };
+ }
+
+ @Override
+ public void buildServer(FlightServer.Builder builder) {
+ builder.middleware(FlightServerMiddleware.Key.of("test"), new InjectingServerMiddleware.Factory());
+ }
+
+ @Override
+ public void client(BufferAllocator allocator, Location location, FlightClient ignored) throws Exception {
+ final ExtractingClientMiddleware.Factory factory = new ExtractingClientMiddleware.Factory();
+ try (final FlightClient client = FlightClient.builder(allocator, location).intercept(factory).build()) {
+ // Should fail immediately
+ IntegrationAssertions.assertThrows(FlightRuntimeException.class,
+ () -> client.getInfo(FlightDescriptor.command(new byte[0])));
+ if (!EXPECTED_HEADER_VALUE.equals(factory.extractedHeader)) {
+ throw new AssertionError(
+ "Expected to extract the header value '" +
+ EXPECTED_HEADER_VALUE +
+ "', but found: " +
+ factory.extractedHeader);
+ }
+
+ // Should not fail
+ factory.extractedHeader = "";
+ client.getInfo(FlightDescriptor.command(COMMAND_SUCCESS));
+ if (!EXPECTED_HEADER_VALUE.equals(factory.extractedHeader)) {
+ throw new AssertionError(
+ "Expected to extract the header value '" +
+ EXPECTED_HEADER_VALUE +
+ "', but found: " +
+ factory.extractedHeader);
+ }
+ }
+ }
+
+ /** Middleware that inserts a constant value in outgoing requests. */
+ static class InjectingServerMiddleware implements FlightServerMiddleware {
+
+ private final String headerValue;
+
+ InjectingServerMiddleware(String incoming) {
+ this.headerValue = incoming;
+ }
+
+ @Override
+ public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
+ outgoingHeaders.insert("x-middleware", headerValue);
+ }
+
+ @Override
+ public void onCallCompleted(CallStatus status) {
+ }
+
+ @Override
+ public void onCallErrored(Throwable err) {
+ }
+
+ /** The factory for the server middleware. */
+ static class Factory implements FlightServerMiddleware.Factory<InjectingServerMiddleware> {
+
+ @Override
+ public InjectingServerMiddleware onCallStarted(CallInfo info, CallHeaders incomingHeaders,
+ RequestContext context) {
+ String incoming = incomingHeaders.get(HEADER);
+ return new InjectingServerMiddleware(incoming == null ? "" : incoming);
+ }
+ }
+ }
+
+ /** Middleware that pulls a value out of incoming responses. */
+ static class ExtractingClientMiddleware implements FlightClientMiddleware {
+
+ private final ExtractingClientMiddleware.Factory factory;
+
+ public ExtractingClientMiddleware(ExtractingClientMiddleware.Factory factory) {
+ this.factory = factory;
+ }
+
+ @Override
+ public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
+ outgoingHeaders.insert(HEADER, EXPECTED_HEADER_VALUE);
+ }
+
+ @Override
+ public void onHeadersReceived(CallHeaders incomingHeaders) {
+ this.factory.extractedHeader = incomingHeaders.get(HEADER);
+ }
+
+ @Override
+ public void onCallCompleted(CallStatus status) {
+ }
+
+ /** The factory for the client middleware. */
+ static class Factory implements FlightClientMiddleware.Factory {
+
+ String extractedHeader = null;
+
+ @Override
+ public FlightClientMiddleware onCallStarted(CallInfo info) {
+ return new ExtractingClientMiddleware(this);
+ }
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java
new file mode 100644
index 000000000..b3b962d2e
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example.integration;
+
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightProducer;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.memory.BufferAllocator;
+
+/**
+ * A particular scenario in integration testing.
+ */
+interface Scenario {
+
+ /**
+ * Construct the FlightProducer for a server in this scenario.
+ */
+ FlightProducer producer(BufferAllocator allocator, Location location) throws Exception;
+
+ /**
+ * Set any other server options.
+ */
+ void buildServer(FlightServer.Builder builder) throws Exception;
+
+ /**
+ * Run as the client in the scenario.
+ */
+ void client(BufferAllocator allocator, Location location, FlightClient client) throws Exception;
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java
new file mode 100644
index 000000000..cd9859b4f
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example.integration;
+
+import java.util.Map;
+import java.util.TreeMap;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Supplier;
+
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+
+/**
+ * Scenarios for integration testing.
+ */
+final class Scenarios {
+
+ private static Scenarios INSTANCE;
+
+ private final Map<String, Supplier<Scenario>> scenarios;
+
+ private Scenarios() {
+ scenarios = new TreeMap<>();
+ scenarios.put("auth:basic_proto", AuthBasicProtoScenario::new);
+ scenarios.put("middleware", MiddlewareScenario::new);
+ }
+
+ private static Scenarios getInstance() {
+ if (INSTANCE == null) {
+ INSTANCE = new Scenarios();
+ }
+ return INSTANCE;
+ }
+
+ static Scenario getScenario(String scenario) {
+ final Supplier<Scenario> ctor = getInstance().scenarios.get(scenario);
+ if (ctor == null) {
+ throw new IllegalArgumentException("Unknown integration test scenario: " + scenario);
+ }
+ return ctor.get();
+ }
+
+ // Utility methods for implementing tests.
+
+ public static void main(String[] args) {
+ // Run scenarios one after the other
+ final Location location = Location.forGrpcInsecure("localhost", 31337);
+ for (final Map.Entry<String, Supplier<Scenario>> entry : getInstance().scenarios.entrySet()) {
+ System.out.println("Running test scenario: " + entry.getKey());
+ final Scenario scenario = entry.getValue().get();
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
+ final FlightServer.Builder builder = FlightServer
+ .builder(allocator, location, scenario.producer(allocator, location));
+ scenario.buildServer(builder);
+ try (final FlightServer server = builder.build()) {
+ server.start();
+
+ try (final FlightClient client = FlightClient.builder(allocator, location).build()) {
+ scenario.client(allocator, location, client);
+ }
+
+ server.shutdown();
+ server.awaitTermination(1, TimeUnit.SECONDS);
+ System.out.println("Ran scenario " + entry.getKey());
+ }
+ } catch (Exception e) {
+ System.out.println("Exception while running scenario " + entry.getKey());
+ e.printStackTrace();
+ }
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/AddWritableBuffer.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/AddWritableBuffer.java
new file mode 100644
index 000000000..26e0274fa
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/AddWritableBuffer.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.grpc;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.Field;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Allow a user to add a ByteBuf based InputStream directly into GRPC WritableBuffer to avoid an
+ * extra copy. This could be solved in GRPC by adding a ByteBufListable interface on InputStream and
+ * letting BufferChainOutputStream take advantage of it.
+ */
+public class AddWritableBuffer {
+
+ private static final Constructor<?> bufConstruct;
+ private static final Field bufferList;
+ private static final Field current;
+ private static final Method listAdd;
+ private static final Class<?> bufChainOut;
+
+ static {
+
+ Constructor<?> tmpConstruct = null;
+ Field tmpBufferList = null;
+ Field tmpCurrent = null;
+ Class<?> tmpBufChainOut = null;
+ Method tmpListAdd = null;
+
+ try {
+ Class<?> nwb = Class.forName("io.grpc.netty.NettyWritableBuffer");
+
+ Constructor<?> tmpConstruct2 = nwb.getDeclaredConstructor(ByteBuf.class);
+ tmpConstruct2.setAccessible(true);
+
+ Class<?> tmpBufChainOut2 = Class.forName("io.grpc.internal.MessageFramer$BufferChainOutputStream");
+
+ Field tmpBufferList2 = tmpBufChainOut2.getDeclaredField("bufferList");
+ tmpBufferList2.setAccessible(true);
+
+ Field tmpCurrent2 = tmpBufChainOut2.getDeclaredField("current");
+ tmpCurrent2.setAccessible(true);
+
+ Method tmpListAdd2 = List.class.getDeclaredMethod("add", Object.class);
+
+ // output fields last.
+ tmpConstruct = tmpConstruct2;
+ tmpBufferList = tmpBufferList2;
+ tmpCurrent = tmpCurrent2;
+ tmpListAdd = tmpListAdd2;
+ tmpBufChainOut = tmpBufChainOut2;
+
+ } catch (Exception ex) {
+ ex.printStackTrace();
+ }
+
+ bufConstruct = tmpConstruct;
+ bufferList = tmpBufferList;
+ current = tmpCurrent;
+ listAdd = tmpListAdd;
+ bufChainOut = tmpBufChainOut;
+
+ }
+
+ /**
+ * Add the provided ByteBuf to the gRPC BufferChainOutputStream if possible, else copy the buffer to the stream.
+ * @param buf The buffer to add.
+ * @param stream The Candidate OutputStream to add to.
+ * @param tryZeroCopy If true, try to zero-copy append the buffer to the stream. This may not succeed.
+ * @return True if buffer was zero-copy added to the stream. False if the buffer was copied.
+ * @throws IOException if the fast path is not enabled and there was an error copying the buffer to the stream.
+ */
+ public static boolean add(ByteBuf buf, OutputStream stream, boolean tryZeroCopy) throws IOException {
+ if (!tryZeroCopy || !tryAddBuffer(buf, stream)) {
+ buf.getBytes(0, stream, buf.readableBytes());
+ return false;
+ }
+ return true;
+ }
+
+ private static boolean tryAddBuffer(ByteBuf buf, OutputStream stream) throws IOException {
+
+ if (bufChainOut == null) {
+ return false;
+ }
+
+ if (!stream.getClass().equals(bufChainOut)) {
+ return false;
+ }
+
+ try {
+ if (current.get(stream) != null) {
+ return false;
+ }
+
+ buf.retain();
+ Object obj = bufConstruct.newInstance(buf);
+ Object list = bufferList.get(stream);
+ listAdd.invoke(list, obj);
+ current.set(stream, obj);
+ return true;
+ } catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException | InstantiationException e) {
+ e.printStackTrace();
+ return false;
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/CallCredentialAdapter.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/CallCredentialAdapter.java
new file mode 100644
index 000000000..285ddb9ba
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/CallCredentialAdapter.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.grpc;
+
+import java.util.concurrent.Executor;
+import java.util.function.Consumer;
+
+import org.apache.arrow.flight.CallHeaders;
+
+import io.grpc.CallCredentials;
+import io.grpc.Metadata;
+
+/**
+ * Adapter class to utilize a CredentialWriter to implement Grpc CallCredentials.
+ */
+public class CallCredentialAdapter extends CallCredentials {
+
+ private final Consumer<CallHeaders> credentialWriter;
+
+ public CallCredentialAdapter(Consumer<CallHeaders> credentialWriter) {
+ this.credentialWriter = credentialWriter;
+ }
+
+ @Override
+ public void applyRequestMetadata(RequestInfo requestInfo, Executor executor, MetadataApplier metadataApplier) {
+ executor.execute(() ->
+ {
+ final Metadata headers = new Metadata();
+ credentialWriter.accept(new MetadataAdapter(headers));
+ metadataApplier.apply(headers);
+ });
+ }
+
+ @Override
+ public void thisUsesUnstableApi() {
+ // Mandatory to override this to acknowledge that CallCredentials is Experimental.
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ClientInterceptorAdapter.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ClientInterceptorAdapter.java
new file mode 100644
index 000000000..ae11e5260
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ClientInterceptorAdapter.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.grpc;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.arrow.flight.CallInfo;
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.FlightClientMiddleware;
+import org.apache.arrow.flight.FlightClientMiddleware.Factory;
+import org.apache.arrow.flight.FlightMethod;
+import org.apache.arrow.flight.FlightRuntimeException;
+import org.apache.arrow.flight.FlightStatusCode;
+
+import io.grpc.CallOptions;
+import io.grpc.Channel;
+import io.grpc.ClientCall;
+import io.grpc.ClientInterceptor;
+import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
+import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
+import io.grpc.Metadata;
+import io.grpc.MethodDescriptor;
+import io.grpc.Status;
+import io.grpc.StatusRuntimeException;
+
+/**
+ * An adapter between Flight client middleware and gRPC interceptors.
+ *
+ * <p>This is implemented as a single gRPC interceptor that runs all Flight client middleware sequentially.
+ */
+public class ClientInterceptorAdapter implements ClientInterceptor {
+
+ private final List<Factory> factories;
+
+ public ClientInterceptorAdapter(List<Factory> factories) {
+ this.factories = factories;
+ }
+
+ @Override
+ public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method,
+ CallOptions callOptions, Channel next) {
+ final List<FlightClientMiddleware> middleware = new ArrayList<>();
+ final CallInfo info = new CallInfo(FlightMethod.fromProtocol(method.getFullMethodName()));
+
+ try {
+ for (final Factory factory : factories) {
+ middleware.add(factory.onCallStarted(info));
+ }
+ } catch (FlightRuntimeException e) {
+ // Explicitly propagate
+ throw e;
+ } catch (StatusRuntimeException e) {
+ throw StatusUtils.fromGrpcRuntimeException(e);
+ } catch (RuntimeException e) {
+ throw StatusUtils.fromThrowable(e);
+ }
+ return new FlightClientCall<>(next.newCall(method, callOptions), middleware);
+ }
+
+ /**
+ * The ClientCallListener which hooks into the gRPC request cycle and actually runs middleware at certain points.
+ */
+ private static class FlightClientCallListener<RespT> extends SimpleForwardingClientCallListener<RespT> {
+
+ private final List<FlightClientMiddleware> middleware;
+ boolean receivedHeaders;
+
+ public FlightClientCallListener(ClientCall.Listener<RespT> responseListener,
+ List<FlightClientMiddleware> middleware) {
+ super(responseListener);
+ this.middleware = middleware;
+ receivedHeaders = false;
+ }
+
+ @Override
+ public void onHeaders(Metadata headers) {
+ receivedHeaders = true;
+ final MetadataAdapter adapter = new MetadataAdapter(headers);
+ try {
+ middleware.forEach(m -> m.onHeadersReceived(adapter));
+ } finally {
+ // Make sure to always call the gRPC callback to avoid interrupting the gRPC request cycle
+ super.onHeaders(headers);
+ }
+ }
+
+ @Override
+ public void onClose(Status status, Metadata trailers) {
+ try {
+ if (!receivedHeaders) {
+ // gRPC doesn't always send response headers if the call errors or completes immediately, but instead
+ // consolidates them with the trailers. If we never got headers, assume this happened and run the header
+ // callback with the trailers.
+ final MetadataAdapter adapter = new MetadataAdapter(trailers);
+ middleware.forEach(m -> m.onHeadersReceived(adapter));
+ }
+ final CallStatus flightStatus = StatusUtils.fromGrpcStatusAndTrailers(status, trailers);
+ middleware.forEach(m -> m.onCallCompleted(flightStatus));
+ } finally {
+ // Make sure to always call the gRPC callback to avoid interrupting the gRPC request cycle
+ super.onClose(status, trailers);
+ }
+ }
+ }
+
+ /**
+ * The gRPC ClientCall which hooks into the gRPC request cycle and injects our ClientCallListener.
+ */
+ private static class FlightClientCall<ReqT, RespT> extends SimpleForwardingClientCall<ReqT, RespT> {
+
+ private final List<FlightClientMiddleware> middleware;
+
+ public FlightClientCall(ClientCall<ReqT, RespT> clientCall, List<FlightClientMiddleware> middleware) {
+ super(clientCall);
+ this.middleware = middleware;
+ }
+
+ @Override
+ public void start(Listener<RespT> responseListener, Metadata headers) {
+ final MetadataAdapter metadataAdapter = new MetadataAdapter(headers);
+ middleware.forEach(m -> m.onBeforeSendingHeaders(metadataAdapter));
+
+ super.start(new FlightClientCallListener<>(responseListener, middleware), headers);
+ }
+
+ @Override
+ public void cancel(String message, Throwable cause) {
+ final CallStatus flightStatus = new CallStatus(FlightStatusCode.CANCELLED, cause, message, null);
+ middleware.forEach(m -> m.onCallCompleted(flightStatus));
+ super.cancel(message, cause);
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ContextPropagatingExecutorService.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ContextPropagatingExecutorService.java
new file mode 100644
index 000000000..8f6bb6db2
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ContextPropagatingExecutorService.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.grpc;
+
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.stream.Collectors;
+
+import io.grpc.Context;
+
+/**
+ * An {@link ExecutorService} that propagates the {@link Context}.
+ *
+ * <p>Context is used to propagate per-call state, like the authenticated user, between threads (as gRPC makes no
+ * guarantees about what thread things execute on). This wrapper makes it easy to preserve this when using an Executor.
+ * The Context itself is immutable, so it is thread-safe.
+ */
+public class ContextPropagatingExecutorService implements ExecutorService {
+
+ private final ExecutorService delegate;
+
+ public ContextPropagatingExecutorService(ExecutorService delegate) {
+ this.delegate = delegate;
+ }
+
+ // These are just delegate methods.
+
+ @Override
+ public void shutdown() {
+ delegate.shutdown();
+ }
+
+ @Override
+ public List<Runnable> shutdownNow() {
+ return delegate.shutdownNow();
+ }
+
+ @Override
+ public boolean isShutdown() {
+ return delegate.isShutdown();
+ }
+
+ @Override
+ public boolean isTerminated() {
+ return delegate.isTerminated();
+ }
+
+ @Override
+ public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException {
+ return delegate.awaitTermination(timeout, unit);
+ }
+
+ // These are delegate methods that wrap the submitted task in the current gRPC Context.
+
+ @Override
+ public <T> Future<T> submit(Callable<T> task) {
+ return delegate.submit(Context.current().wrap(task));
+ }
+
+ @Override
+ public <T> Future<T> submit(Runnable task, T result) {
+ return delegate.submit(Context.current().wrap(task), result);
+ }
+
+ @Override
+ public Future<?> submit(Runnable task) {
+ return delegate.submit(Context.current().wrap(task));
+ }
+
+ @Override
+ public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks) throws InterruptedException {
+ return delegate.invokeAll(tasks.stream().map(Context.current()::wrap).collect(Collectors.toList()));
+ }
+
+ @Override
+ public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks, long timeout,
+ TimeUnit unit) throws InterruptedException {
+ return delegate.invokeAll(tasks.stream().map(Context.current()::wrap).collect(Collectors.toList()), timeout, unit);
+ }
+
+ @Override
+ public <T> T invokeAny(Collection<? extends Callable<T>> tasks) throws InterruptedException, ExecutionException {
+ return delegate.invokeAny(tasks.stream().map(Context.current()::wrap).collect(Collectors.toList()));
+ }
+
+ @Override
+ public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit)
+ throws InterruptedException, ExecutionException, TimeoutException {
+ return delegate.invokeAny(tasks.stream().map(Context.current()::wrap).collect(Collectors.toList()), timeout, unit);
+ }
+
+ @Override
+ public void execute(Runnable command) {
+ delegate.execute(Context.current().wrap(command));
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/CredentialCallOption.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/CredentialCallOption.java
new file mode 100644
index 000000000..3bde7a835
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/CredentialCallOption.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.grpc;
+
+import java.util.function.Consumer;
+
+import org.apache.arrow.flight.CallHeaders;
+import org.apache.arrow.flight.CallOptions;
+
+import io.grpc.stub.AbstractStub;
+
+/**
+ * Method option for supplying credentials to method calls.
+ */
+public class CredentialCallOption implements CallOptions.GrpcCallOption {
+ private final Consumer<CallHeaders> credentialWriter;
+
+ public CredentialCallOption(Consumer<CallHeaders> credentialWriter) {
+ this.credentialWriter = credentialWriter;
+ }
+
+ @Override
+ public <T extends AbstractStub<T>> T wrapStub(T stub) {
+ return stub.withCallCredentials(new CallCredentialAdapter(credentialWriter));
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java
new file mode 100644
index 000000000..5f8a71576
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.grpc;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.lang.reflect.Field;
+
+import org.apache.arrow.memory.ArrowBuf;
+
+import com.google.common.base.Throwables;
+import com.google.common.io.ByteStreams;
+
+import io.grpc.internal.ReadableBuffer;
+
+/**
+ * Enable access to ReadableBuffer directly to copy data from a BufferInputStream into a target
+ * ByteBuffer/ByteBuf.
+ *
+ * <p>This could be solved by BufferInputStream exposing Drainable.
+ */
+public class GetReadableBuffer {
+
+ private static final Field READABLE_BUFFER;
+ private static final Class<?> BUFFER_INPUT_STREAM;
+
+ static {
+ Field tmpField = null;
+ Class<?> tmpClazz = null;
+ try {
+ Class<?> clazz = Class.forName("io.grpc.internal.ReadableBuffers$BufferInputStream");
+
+ Field f = clazz.getDeclaredField("buffer");
+ f.setAccessible(true);
+ // don't set until we've gotten past all exception cases.
+ tmpField = f;
+ tmpClazz = clazz;
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ READABLE_BUFFER = tmpField;
+ BUFFER_INPUT_STREAM = tmpClazz;
+ }
+
+ /**
+ * Extracts the ReadableBuffer for the given input stream.
+ *
+ * @param is Must be an instance of io.grpc.internal.ReadableBuffers$BufferInputStream or
+ * null will be returned.
+ */
+ public static ReadableBuffer getReadableBuffer(InputStream is) {
+
+ if (BUFFER_INPUT_STREAM == null || !is.getClass().equals(BUFFER_INPUT_STREAM)) {
+ return null;
+ }
+
+ try {
+ return (ReadableBuffer) READABLE_BUFFER.get(is);
+ } catch (Exception ex) {
+ throw Throwables.propagate(ex);
+ }
+ }
+
+ /**
+ * Helper method to read a gRPC-provided InputStream into an ArrowBuf.
+ * @param stream The stream to read from. Should be an instance of {@link #BUFFER_INPUT_STREAM}.
+ * @param buf The buffer to read into.
+ * @param size The number of bytes to read.
+ * @param fastPath Whether to enable the fast path (i.e. detect whether the stream is a {@link #BUFFER_INPUT_STREAM}).
+ * @throws IOException if there is an error reading form the stream
+ */
+ public static void readIntoBuffer(final InputStream stream, final ArrowBuf buf, final int size,
+ final boolean fastPath) throws IOException {
+ ReadableBuffer readableBuffer = fastPath ? getReadableBuffer(stream) : null;
+ if (readableBuffer != null) {
+ readableBuffer.readBytes(buf.nioBuffer(0, size));
+ } else {
+ byte[] heapBytes = new byte[size];
+ ByteStreams.readFully(stream, heapBytes);
+ buf.writeBytes(heapBytes);
+ }
+ buf.writerIndex(size);
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/MetadataAdapter.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/MetadataAdapter.java
new file mode 100644
index 000000000..4327f0ca8
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/MetadataAdapter.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.grpc;
+
+import java.util.HashSet;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.StreamSupport;
+
+import org.apache.arrow.flight.CallHeaders;
+
+import io.grpc.Metadata;
+import io.grpc.Metadata.Key;
+
+/**
+ * A mutable adapter between the gRPC Metadata object and the Flight headers interface.
+ *
+ * <p>This allows us to present the headers (metadata) from gRPC without copying to/from our own object.
+ */
+public class MetadataAdapter implements CallHeaders {
+
+ private final Metadata metadata;
+
+ public MetadataAdapter(Metadata metadata) {
+ this.metadata = metadata;
+ }
+
+ @Override
+ public String get(String key) {
+ return this.metadata.get(Key.of(key, Metadata.ASCII_STRING_MARSHALLER));
+ }
+
+ @Override
+ public byte[] getByte(String key) {
+ if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
+ return this.metadata.get(Key.of(key, Metadata.BINARY_BYTE_MARSHALLER));
+ }
+ return get(key).getBytes();
+ }
+
+ @Override
+ public Iterable<String> getAll(String key) {
+ return this.metadata.getAll(Key.of(key, Metadata.ASCII_STRING_MARSHALLER));
+ }
+
+ @Override
+ public Iterable<byte[]> getAllByte(String key) {
+ if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
+ return this.metadata.getAll(Key.of(key, Metadata.BINARY_BYTE_MARSHALLER));
+ }
+ return StreamSupport.stream(getAll(key).spliterator(), false)
+ .map(String::getBytes).collect(Collectors.toList());
+ }
+
+ @Override
+ public void insert(String key, String value) {
+ this.metadata.put(Key.of(key, Metadata.ASCII_STRING_MARSHALLER), value);
+ }
+
+ @Override
+ public void insert(String key, byte[] value) {
+ this.metadata.put(Key.of(key, Metadata.BINARY_BYTE_MARSHALLER), value);
+ }
+
+ @Override
+ public Set<String> keys() {
+ return new HashSet<>(this.metadata.keys());
+ }
+
+ @Override
+ public boolean containsKey(String key) {
+ if (key.endsWith("-bin")) {
+ final Key<?> grpcKey = Key.of(key, Metadata.BINARY_BYTE_MARSHALLER);
+ return this.metadata.containsKey(grpcKey);
+ }
+ final Key<?> grpcKey = Key.of(key, Metadata.ASCII_STRING_MARSHALLER);
+ return this.metadata.containsKey(grpcKey);
+ }
+
+ public String toString() {
+ return this.metadata.toString();
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/RequestContextAdapter.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/RequestContextAdapter.java
new file mode 100644
index 000000000..9be4d12b9
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/RequestContextAdapter.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.grpc;
+
+import java.util.HashMap;
+import java.util.Set;
+
+import org.apache.arrow.flight.RequestContext;
+
+import io.grpc.Context;
+
+
+/**
+ * Adapter for holding key value pairs.
+ */
+public class RequestContextAdapter implements RequestContext {
+ public static final Context.Key<RequestContext> REQUEST_CONTEXT_KEY =
+ Context.key("arrow-flight-request-context");
+ private final HashMap<String, String> map = new HashMap<>();
+
+ @Override
+ public void put(String key, String value) {
+ if (map.putIfAbsent(key, value) != null) {
+ throw new IllegalArgumentException("Duplicate write to a RequestContext at key " + key + " not allowed.");
+ }
+ }
+
+ @Override
+ public String get(String key) {
+ return map.get(key);
+ }
+
+ @Override
+ public Set<String> keySet() {
+ return map.keySet();
+ }
+
+ @Override
+ public String remove(String key) {
+ return map.remove(key);
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ServerInterceptorAdapter.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ServerInterceptorAdapter.java
new file mode 100644
index 000000000..ddf43ff84
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ServerInterceptorAdapter.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.grpc;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.arrow.flight.CallInfo;
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.FlightMethod;
+import org.apache.arrow.flight.FlightProducer.CallContext;
+import org.apache.arrow.flight.FlightRuntimeException;
+import org.apache.arrow.flight.FlightServerMiddleware;
+import org.apache.arrow.flight.FlightServerMiddleware.Factory;
+import org.apache.arrow.flight.FlightServerMiddleware.Key;
+
+import io.grpc.Context;
+import io.grpc.Contexts;
+import io.grpc.ForwardingServerCall.SimpleForwardingServerCall;
+import io.grpc.Metadata;
+import io.grpc.ServerCall;
+import io.grpc.ServerCall.Listener;
+import io.grpc.ServerCallHandler;
+import io.grpc.ServerInterceptor;
+import io.grpc.Status;
+
+/**
+ * An adapter between Flight middleware and a gRPC interceptor.
+ *
+ * <p>This is implemented as a single gRPC interceptor that runs all Flight server middleware sequentially. Flight
+ * middleware instances are stored in the gRPC Context so their state is accessible later.
+ */
+public class ServerInterceptorAdapter implements ServerInterceptor {
+
+ /**
+ * A combination of a middleware Key and factory.
+ *
+ * @param <T> The middleware type.
+ */
+ public static class KeyFactory<T extends FlightServerMiddleware> {
+
+ private final FlightServerMiddleware.Key<T> key;
+ private final FlightServerMiddleware.Factory<T> factory;
+
+ public KeyFactory(Key<T> key, Factory<T> factory) {
+ this.key = key;
+ this.factory = factory;
+ }
+ }
+
+ /**
+ * The {@link Context.Key} that stores the Flight middleware active for a particular call.
+ *
+ * <p>Applications should not use this directly. Instead, see {@link CallContext#getMiddleware(Key)}.
+ */
+ public static final Context.Key<Map<FlightServerMiddleware.Key<?>, FlightServerMiddleware>> SERVER_MIDDLEWARE_KEY =
+ Context.key("arrow.flight.server_middleware");
+ private final List<KeyFactory<?>> factories;
+
+ public ServerInterceptorAdapter(List<KeyFactory<?>> factories) {
+ this.factories = factories;
+ }
+
+ @Override
+ public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers,
+ ServerCallHandler<ReqT, RespT> next) {
+ final CallInfo info = new CallInfo(FlightMethod.fromProtocol(call.getMethodDescriptor().getFullMethodName()));
+ final List<FlightServerMiddleware> middleware = new ArrayList<>();
+ // Use LinkedHashMap to preserve insertion order
+ final Map<FlightServerMiddleware.Key<?>, FlightServerMiddleware> middlewareMap = new LinkedHashMap<>();
+ final MetadataAdapter headerAdapter = new MetadataAdapter(headers);
+ final RequestContextAdapter requestContextAdapter = new RequestContextAdapter();
+ for (final KeyFactory<?> factory : factories) {
+ final FlightServerMiddleware m;
+ try {
+ m = factory.factory.onCallStarted(info, headerAdapter, requestContextAdapter);
+ } catch (FlightRuntimeException e) {
+ // Cancel call
+ call.close(StatusUtils.toGrpcStatus(e.status()), new Metadata());
+ return new Listener<ReqT>() {};
+ }
+ middleware.add(m);
+ middlewareMap.put(factory.key, m);
+ }
+
+ // Inject the middleware into the context so RPC method implementations can communicate with middleware instances
+ final Context contextWithMiddlewareAndRequestsOptions = Context.current()
+ .withValue(SERVER_MIDDLEWARE_KEY, Collections.unmodifiableMap(middlewareMap))
+ .withValue(RequestContextAdapter.REQUEST_CONTEXT_KEY, requestContextAdapter);
+
+ final SimpleForwardingServerCall<ReqT, RespT> forwardingServerCall = new SimpleForwardingServerCall<ReqT, RespT>(
+ call) {
+ boolean sentHeaders = false;
+
+ @Override
+ public void sendHeaders(Metadata headers) {
+ sentHeaders = true;
+ try {
+ final MetadataAdapter headerAdapter = new MetadataAdapter(headers);
+ middleware.forEach(m -> m.onBeforeSendingHeaders(headerAdapter));
+ } finally {
+ // Make sure to always call the gRPC callback to avoid interrupting the gRPC request cycle
+ super.sendHeaders(headers);
+ }
+ }
+
+ @Override
+ public void close(Status status, Metadata trailers) {
+ try {
+ if (!sentHeaders) {
+ // gRPC doesn't always send response headers if the call errors or completes immediately
+ final MetadataAdapter headerAdapter = new MetadataAdapter(trailers);
+ middleware.forEach(m -> m.onBeforeSendingHeaders(headerAdapter));
+ }
+ } finally {
+ // Make sure to always call the gRPC callback to avoid interrupting the gRPC request cycle
+ super.close(status, trailers);
+ }
+
+ final CallStatus flightStatus = StatusUtils.fromGrpcStatus(status);
+ middleware.forEach(m -> m.onCallCompleted(flightStatus));
+ }
+ };
+ return Contexts.interceptCall(contextWithMiddlewareAndRequestsOptions, forwardingServerCall, headers, next);
+
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java
new file mode 100644
index 000000000..55e841864
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.grpc;
+
+import java.util.Iterator;
+import java.util.Objects;
+import java.util.function.Function;
+
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.ErrorFlightMetadata;
+import org.apache.arrow.flight.FlightRuntimeException;
+import org.apache.arrow.flight.FlightStatusCode;
+
+import io.grpc.InternalMetadata;
+import io.grpc.Metadata;
+import io.grpc.Status;
+import io.grpc.Status.Code;
+import io.grpc.StatusException;
+import io.grpc.StatusRuntimeException;
+
+/**
+ * Utilities to adapt gRPC and Flight status objects.
+ *
+ * <p>NOT A PUBLIC CLASS, interface is not guaranteed to remain stable.
+ */
+public class StatusUtils {
+
+ private StatusUtils() {
+ throw new AssertionError("Do not instantiate this class.");
+ }
+
+ /**
+ * Convert from a Flight status code to a gRPC status code.
+ */
+ public static Status.Code toGrpcStatusCode(FlightStatusCode code) {
+ switch (code) {
+ case OK:
+ return Code.OK;
+ case UNKNOWN:
+ return Code.UNKNOWN;
+ case INTERNAL:
+ return Code.INTERNAL;
+ case INVALID_ARGUMENT:
+ return Code.INVALID_ARGUMENT;
+ case TIMED_OUT:
+ return Code.DEADLINE_EXCEEDED;
+ case NOT_FOUND:
+ return Code.NOT_FOUND;
+ case ALREADY_EXISTS:
+ return Code.ALREADY_EXISTS;
+ case CANCELLED:
+ return Code.CANCELLED;
+ case UNAUTHENTICATED:
+ return Code.UNAUTHENTICATED;
+ case UNAUTHORIZED:
+ return Code.PERMISSION_DENIED;
+ case UNIMPLEMENTED:
+ return Code.UNIMPLEMENTED;
+ case UNAVAILABLE:
+ return Code.UNAVAILABLE;
+ default:
+ return Code.UNKNOWN;
+ }
+ }
+
+ /**
+ * Convert from a gRPC status code to a Flight status code.
+ */
+ public static FlightStatusCode fromGrpcStatusCode(Status.Code code) {
+ switch (code) {
+ case OK:
+ return FlightStatusCode.OK;
+ case CANCELLED:
+ return FlightStatusCode.CANCELLED;
+ case UNKNOWN:
+ return FlightStatusCode.UNKNOWN;
+ case INVALID_ARGUMENT:
+ return FlightStatusCode.INVALID_ARGUMENT;
+ case DEADLINE_EXCEEDED:
+ return FlightStatusCode.TIMED_OUT;
+ case NOT_FOUND:
+ return FlightStatusCode.NOT_FOUND;
+ case ALREADY_EXISTS:
+ return FlightStatusCode.ALREADY_EXISTS;
+ case PERMISSION_DENIED:
+ return FlightStatusCode.UNAUTHORIZED;
+ case RESOURCE_EXHAUSTED:
+ return FlightStatusCode.INVALID_ARGUMENT;
+ case FAILED_PRECONDITION:
+ return FlightStatusCode.INVALID_ARGUMENT;
+ case ABORTED:
+ return FlightStatusCode.INTERNAL;
+ case OUT_OF_RANGE:
+ return FlightStatusCode.INVALID_ARGUMENT;
+ case UNIMPLEMENTED:
+ return FlightStatusCode.UNIMPLEMENTED;
+ case INTERNAL:
+ return FlightStatusCode.INTERNAL;
+ case UNAVAILABLE:
+ return FlightStatusCode.UNAVAILABLE;
+ case DATA_LOSS:
+ return FlightStatusCode.INTERNAL;
+ case UNAUTHENTICATED:
+ return FlightStatusCode.UNAUTHENTICATED;
+ default:
+ return FlightStatusCode.UNKNOWN;
+ }
+ }
+
+ /** Create Metadata Key for binary metadata. */
+ static Metadata.Key<byte[]> keyOfBinary(String name) {
+ return Metadata.Key.of(name, Metadata.BINARY_BYTE_MARSHALLER);
+ }
+
+ /** Create Metadata Key for ascii metadata. */
+ static Metadata.Key<String> keyOfAscii(String name) {
+ // Use InternalMetadata for keys that start with ":", e.g. ":status". See ARROW-14014.
+ return InternalMetadata.keyOf(name, Metadata.ASCII_STRING_MARSHALLER);
+ }
+
+ /** Convert from a gRPC Status & trailers to a Flight status. */
+ public static CallStatus fromGrpcStatusAndTrailers(Status status, Metadata trailers) {
+ // gRPC may not always have trailers - this happens when the server internally generates an error, which is rare,
+ // but can happen.
+ final ErrorFlightMetadata errorMetadata = trailers == null ? null : parseTrailers(trailers);
+ return new CallStatus(
+ fromGrpcStatusCode(status.getCode()),
+ status.getCause(),
+ status.getDescription(),
+ errorMetadata);
+ }
+
+ /** Convert from a gRPC status to a Flight status. */
+ public static CallStatus fromGrpcStatus(Status status) {
+ return new CallStatus(
+ fromGrpcStatusCode(status.getCode()),
+ status.getCause(),
+ status.getDescription(),
+ null);
+ }
+
+ /** Convert from a Flight status to a gRPC status. */
+ public static Status toGrpcStatus(CallStatus status) {
+ return toGrpcStatusCode(status.code()).toStatus().withDescription(status.description()).withCause(status.cause());
+ }
+
+ /** Convert from a gRPC exception to a Flight exception. */
+ public static FlightRuntimeException fromGrpcRuntimeException(StatusRuntimeException sre) {
+ return fromGrpcStatusAndTrailers(sre.getStatus(), sre.getTrailers()).toRuntimeException();
+ }
+
+ /** Convert gRPC trailers into Flight error metadata. */
+ private static ErrorFlightMetadata parseTrailers(Metadata trailers) {
+ ErrorFlightMetadata metadata = new ErrorFlightMetadata();
+ for (String key : trailers.keys()) {
+ if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
+ metadata.insert(key, trailers.get(keyOfBinary(key)));
+ } else {
+ metadata.insert(key, Objects.requireNonNull(trailers.get(keyOfAscii(key))).getBytes());
+ }
+ }
+ return metadata;
+ }
+
+ /**
+ * Convert arbitrary exceptions to a {@link FlightRuntimeException}.
+ */
+ public static FlightRuntimeException fromThrowable(Throwable t) {
+ if (t instanceof StatusRuntimeException) {
+ return fromGrpcRuntimeException((StatusRuntimeException) t);
+ } else if (t instanceof FlightRuntimeException) {
+ return (FlightRuntimeException) t;
+ }
+ return CallStatus.UNKNOWN.withCause(t).withDescription(t.getMessage()).toRuntimeException();
+ }
+
+ /**
+ * Convert arbitrary exceptions to a {@link StatusRuntimeException} or {@link StatusException}.
+ *
+ * <p>Such exceptions can be passed to {@link io.grpc.stub.StreamObserver#onError(Throwable)} and will give the client
+ * a reasonable error message.
+ */
+ public static Throwable toGrpcException(Throwable ex) {
+ if (ex instanceof StatusRuntimeException) {
+ return ex;
+ } else if (ex instanceof StatusException) {
+ return ex;
+ } else if (ex instanceof FlightRuntimeException) {
+ final FlightRuntimeException fre = (FlightRuntimeException) ex;
+ if (fre.status().metadata() != null) {
+ Metadata trailers = toGrpcMetadata(fre.status().metadata());
+ return new StatusRuntimeException(toGrpcStatus(fre.status()), trailers);
+ }
+ return toGrpcStatus(fre.status()).asRuntimeException();
+ }
+ return Status.INTERNAL.withCause(ex).withDescription("There was an error servicing your request.")
+ .asRuntimeException();
+ }
+
+ private static Metadata toGrpcMetadata(ErrorFlightMetadata metadata) {
+ final Metadata trailers = new Metadata();
+ for (final String key : metadata.keys()) {
+ if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
+ trailers.put(keyOfBinary(key), metadata.getByte(key));
+ } else {
+ trailers.put(keyOfAscii(key), metadata.get(key));
+ }
+ }
+ return trailers;
+ }
+
+ /**
+ * Maps a transformation function to the elements of an iterator, while wrapping exceptions in {@link
+ * FlightRuntimeException}.
+ */
+ public static <FROM, TO> Iterator<TO> wrapIterator(Iterator<FROM> fromIterator,
+ Function<? super FROM, ? extends TO> transformer) {
+ Objects.requireNonNull(fromIterator);
+ Objects.requireNonNull(transformer);
+ return new Iterator<TO>() {
+ @Override
+ public boolean hasNext() {
+ try {
+ return fromIterator.hasNext();
+ } catch (StatusRuntimeException e) {
+ throw fromGrpcRuntimeException(e);
+ }
+ }
+
+ @Override
+ public TO next() {
+ try {
+ return transformer.apply(fromIterator.next());
+ } catch (StatusRuntimeException e) {
+ throw fromGrpcRuntimeException(e);
+ }
+ }
+ };
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/FlightTestUtil.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/FlightTestUtil.java
new file mode 100644
index 000000000..cd043b639
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/FlightTestUtil.java
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.io.File;
+import java.io.IOException;
+import java.lang.reflect.InvocationTargetException;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+import java.util.Random;
+import java.util.function.Function;
+
+import org.junit.Assert;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.function.Executable;
+
+/**
+ * Utility methods and constants for testing flight servers.
+ */
+public class FlightTestUtil {
+
+ private static final Random RANDOM = new Random();
+
+ public static final String LOCALHOST = "localhost";
+ public static final String TEST_DATA_ENV_VAR = "ARROW_TEST_DATA";
+ public static final String TEST_DATA_PROPERTY = "arrow.test.dataRoot";
+
+ /**
+ * Returns a a FlightServer (actually anything that is startable)
+ * that has been started bound to a random port.
+ */
+ public static <T> T getStartedServer(Function<Location, T> newServerFromLocation) throws IOException {
+ IOException lastThrown = null;
+ T server = null;
+ for (int x = 0; x < 3; x++) {
+ final int port = 49152 + RANDOM.nextInt(5000);
+ final Location location = Location.forGrpcInsecure(LOCALHOST, port);
+ lastThrown = null;
+ try {
+ server = newServerFromLocation.apply(location);
+ try {
+ server.getClass().getMethod("start").invoke(server);
+ } catch (NoSuchMethodException | IllegalAccessException e) {
+ throw new IllegalArgumentException("Couldn't call start method on object.", e);
+ }
+ break;
+ } catch (InvocationTargetException e) {
+ if (e.getTargetException() instanceof IOException) {
+ lastThrown = (IOException) e.getTargetException();
+ } else {
+ throw (RuntimeException) e.getTargetException();
+ }
+ }
+ }
+ if (lastThrown != null) {
+ throw lastThrown;
+ }
+ return server;
+ }
+
+ static Path getTestDataRoot() {
+ String path = System.getenv(TEST_DATA_ENV_VAR);
+ if (path == null) {
+ path = System.getProperty(TEST_DATA_PROPERTY);
+ }
+ return Paths.get(Objects.requireNonNull(path,
+ String.format("Could not find test data path. Set the environment variable %s or the JVM property %s.",
+ TEST_DATA_ENV_VAR, TEST_DATA_PROPERTY)));
+ }
+
+ static Path getFlightTestDataRoot() {
+ return getTestDataRoot().resolve("flight");
+ }
+
+ static Path exampleTlsRootCert() {
+ return getFlightTestDataRoot().resolve("root-ca.pem");
+ }
+
+ static List<CertKeyPair> exampleTlsCerts() {
+ final Path root = getFlightTestDataRoot();
+ return Arrays.asList(new CertKeyPair(root.resolve("cert0.pem").toFile(), root.resolve("cert0.pkcs1").toFile()),
+ new CertKeyPair(root.resolve("cert1.pem").toFile(), root.resolve("cert1.pkcs1").toFile()));
+ }
+
+ static boolean isEpollAvailable() {
+ try {
+ Class<?> epoll = Class.forName("io.netty.channel.epoll.Epoll");
+ return (Boolean) epoll.getMethod("isAvailable").invoke(null);
+ } catch (ClassNotFoundException | NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
+ return false;
+ }
+ }
+
+ static boolean isKqueueAvailable() {
+ try {
+ Class<?> kqueue = Class.forName("io.netty.channel.kqueue.KQueue");
+ return (Boolean) kqueue.getMethod("isAvailable").invoke(null);
+ } catch (ClassNotFoundException | NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
+ return false;
+ }
+ }
+
+ static boolean isNativeTransportAvailable() {
+ return isEpollAvailable() || isKqueueAvailable();
+ }
+
+ /**
+ * Assert that the given runnable fails with a Flight exception of the given code.
+ * @param code The expected Flight status code.
+ * @param r The code to run.
+ * @return The thrown status.
+ */
+ public static CallStatus assertCode(FlightStatusCode code, Executable r) {
+ final FlightRuntimeException ex = Assertions.assertThrows(FlightRuntimeException.class, r);
+ Assert.assertEquals(code, ex.status().code());
+ return ex.status();
+ }
+
+ public static class CertKeyPair {
+
+ public final File cert;
+ public final File key;
+
+ public CertKeyPair(File cert, File key) {
+ this.cert = cert;
+ this.key = key;
+ }
+ }
+
+ private FlightTestUtil() {
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java
new file mode 100644
index 000000000..c7b3321af
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java
@@ -0,0 +1,329 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.function.BiConsumer;
+
+import org.apache.arrow.flight.FlightClient.PutListener;
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+
+/**
+ * Tests for application-specific metadata support in Flight.
+ */
+public class TestApplicationMetadata {
+
+ // The command used to trigger the test for ARROW-6136.
+ private static final byte[] COMMAND_ARROW_6136 = "ARROW-6136".getBytes();
+ // The expected error message.
+ private static final String MESSAGE_ARROW_6136 = "The stream should not be double-closed.";
+
+ /**
+ * Ensure that a client can read the metadata sent from the server.
+ */
+ @Test
+ // This test is consistently flaky on CI, unfortunately.
+ @Ignore
+ public void retrieveMetadata() {
+ test((allocator, client) -> {
+ try (final FlightStream stream = client.getStream(new Ticket(new byte[0]))) {
+ byte i = 0;
+ while (stream.next()) {
+ final IntVector vector = (IntVector) stream.getRoot().getVector("a");
+ Assert.assertEquals(1, vector.getValueCount());
+ Assert.assertEquals(10, vector.get(0));
+ Assert.assertEquals(i, stream.getLatestMetadata().getByte(0));
+ i++;
+ }
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+
+ /** ARROW-6136: make sure that the Flight implementation doesn't double-close the server-to-client stream. */
+ @Test
+ public void arrow6136() {
+ final Schema schema = new Schema(Collections.emptyList());
+ test((allocator, client) -> {
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ final FlightDescriptor descriptor = FlightDescriptor.command(COMMAND_ARROW_6136);
+
+ final PutListener listener = new SyncPutListener();
+ final FlightClient.ClientStreamListener writer = client.startPut(descriptor, root, listener);
+ // Must attempt to retrieve the result to get any server-side errors.
+ final CallStatus status = FlightTestUtil.assertCode(FlightStatusCode.INTERNAL, writer::getResult);
+ Assert.assertEquals(MESSAGE_ARROW_6136, status.description());
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+
+ /**
+ * Ensure that a client can send metadata to the server.
+ */
+ @Test
+ @Ignore
+ public void uploadMetadataAsync() {
+ final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true))));
+ test((allocator, client) -> {
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ final FlightDescriptor descriptor = FlightDescriptor.path("test");
+
+ final PutListener listener = new AsyncPutListener() {
+ int counter = 0;
+
+ @Override
+ public void onNext(PutResult val) {
+ Assert.assertNotNull(val);
+ Assert.assertEquals(counter, val.getApplicationMetadata().getByte(0));
+ counter++;
+ }
+ };
+ final FlightClient.ClientStreamListener writer = client.startPut(descriptor, root, listener);
+
+ root.allocateNew();
+ for (byte i = 0; i < 10; i++) {
+ final IntVector vector = (IntVector) root.getVector("a");
+ final ArrowBuf metadata = allocator.buffer(1);
+ metadata.writeByte(i);
+ vector.set(0, 10);
+ vector.setValueCount(1);
+ root.setRowCount(1);
+ writer.putNext(metadata);
+ }
+ writer.completed();
+ // Must attempt to retrieve the result to get any server-side errors.
+ writer.getResult();
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+
+ /**
+ * Ensure that a client can send metadata to the server. Uses the synchronous API.
+ */
+ @Test
+ @Ignore
+ public void uploadMetadataSync() {
+ final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true))));
+ test((allocator, client) -> {
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
+ final SyncPutListener listener = new SyncPutListener()) {
+ final FlightDescriptor descriptor = FlightDescriptor.path("test");
+ final FlightClient.ClientStreamListener writer = client.startPut(descriptor, root, listener);
+
+ root.allocateNew();
+ for (byte i = 0; i < 10; i++) {
+ final IntVector vector = (IntVector) root.getVector("a");
+ final ArrowBuf metadata = allocator.buffer(1);
+ metadata.writeByte(i);
+ vector.set(0, 10);
+ vector.setValueCount(1);
+ root.setRowCount(1);
+ writer.putNext(metadata);
+ try (final PutResult message = listener.poll(5000, TimeUnit.SECONDS)) {
+ Assert.assertNotNull(message);
+ Assert.assertEquals(i, message.getApplicationMetadata().getByte(0));
+ } catch (InterruptedException | ExecutionException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ writer.completed();
+ // Must attempt to retrieve the result to get any server-side errors.
+ writer.getResult();
+ }
+ });
+ }
+
+ /**
+ * Make sure that a {@link SyncPutListener} properly reclaims memory if ignored.
+ */
+ @Test
+ @Ignore
+ public void syncMemoryReclaimed() {
+ final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true))));
+ test((allocator, client) -> {
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
+ final SyncPutListener listener = new SyncPutListener()) {
+ final FlightDescriptor descriptor = FlightDescriptor.path("test");
+ final FlightClient.ClientStreamListener writer = client.startPut(descriptor, root, listener);
+
+ root.allocateNew();
+ for (byte i = 0; i < 10; i++) {
+ final IntVector vector = (IntVector) root.getVector("a");
+ final ArrowBuf metadata = allocator.buffer(1);
+ metadata.writeByte(i);
+ vector.set(0, 10);
+ vector.setValueCount(1);
+ root.setRowCount(1);
+ writer.putNext(metadata);
+ }
+ writer.completed();
+ // Must attempt to retrieve the result to get any server-side errors.
+ writer.getResult();
+ }
+ });
+ }
+
+ /**
+ * ARROW-9221: Flight copies metadata from the byte buffer of a Protobuf ByteString,
+ * which is in big-endian by default, thus mangling metadata.
+ */
+ @Test
+ public void testMetadataEndianness() throws Exception {
+ try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ final BufferAllocator serverAllocator = allocator.newChildAllocator("flight-server", 0, Long.MAX_VALUE);
+ final FlightServer server = FlightTestUtil.getStartedServer(
+ (location) -> FlightServer
+ .builder(serverAllocator, location, new EndianFlightProducer(serverAllocator))
+ .build());
+ final FlightClient client = FlightClient.builder(allocator, server.getLocation()).build()) {
+ final Schema schema = new Schema(Collections.emptyList());
+ final FlightDescriptor descriptor = FlightDescriptor.command(new byte[0]);
+ try (final SyncPutListener reader = new SyncPutListener();
+ final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ final FlightClient.ClientStreamListener writer = client.startPut(descriptor, root, reader);
+ writer.completed();
+ try (final PutResult metadata = reader.read()) {
+ Assert.assertEquals(16, metadata.getApplicationMetadata().readableBytes());
+ byte[] bytes = new byte[16];
+ metadata.getApplicationMetadata().readBytes(bytes);
+ Assert.assertArrayEquals(EndianFlightProducer.EXPECTED_BYTES, bytes);
+ }
+ writer.getResult();
+ }
+ }
+ }
+
+ private void test(BiConsumer<BufferAllocator, FlightClient> fun) {
+ try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ final FlightServer s =
+ FlightTestUtil.getStartedServer(
+ (location) -> FlightServer.builder(allocator, location, new MetadataFlightProducer(allocator)).build());
+ final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) {
+ fun.accept(allocator, client);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ /**
+ * A FlightProducer that always produces a fixed data stream with metadata on the side.
+ */
+ private static class MetadataFlightProducer extends NoOpFlightProducer {
+
+ private final BufferAllocator allocator;
+
+ public MetadataFlightProducer(BufferAllocator allocator) {
+ this.allocator = allocator;
+ }
+
+ @Override
+ public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
+ final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true))));
+ try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ root.allocateNew();
+ listener.start(root);
+ for (byte i = 0; i < 10; i++) {
+ final IntVector vector = (IntVector) root.getVector("a");
+ vector.set(0, 10);
+ vector.setValueCount(1);
+ root.setRowCount(1);
+ final ArrowBuf metadata = allocator.buffer(1);
+ metadata.writeByte(i);
+ listener.putNext(metadata);
+ }
+ listener.completed();
+ }
+ }
+
+ @Override
+ public Runnable acceptPut(CallContext context, FlightStream stream, StreamListener<PutResult> ackStream) {
+ return () -> {
+ // Wait for the descriptor to be sent
+ stream.getRoot();
+ if (stream.getDescriptor().isCommand() &&
+ Arrays.equals(stream.getDescriptor().getCommand(), COMMAND_ARROW_6136)) {
+ // ARROW-6136: Try closing the stream
+ ackStream.onError(
+ CallStatus.INTERNAL.withDescription(MESSAGE_ARROW_6136).toRuntimeException());
+ return;
+ }
+ try {
+ byte current = 0;
+ while (stream.next()) {
+ final ArrowBuf metadata = stream.getLatestMetadata();
+ if (current != metadata.getByte(0)) {
+ ackStream.onError(CallStatus.INVALID_ARGUMENT.withDescription(String
+ .format("Metadata does not match expected value; got %d but expected %d.", metadata.getByte(0),
+ current)).toRuntimeException());
+ return;
+ }
+ ackStream.onNext(PutResult.metadata(metadata));
+ current++;
+ }
+ if (current != 10) {
+ throw CallStatus.INVALID_ARGUMENT.withDescription("Wrong number of messages sent.").toRuntimeException();
+ }
+ } catch (Exception e) {
+ throw CallStatus.INTERNAL.withCause(e).withDescription(e.toString()).toRuntimeException();
+ }
+ };
+ }
+ }
+
+ private static class EndianFlightProducer extends NoOpFlightProducer {
+ static final byte[] EXPECTED_BYTES = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
+ private final BufferAllocator allocator;
+
+ private EndianFlightProducer(BufferAllocator allocator) {
+ this.allocator = allocator;
+ }
+
+ @Override
+ public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream) {
+ return () -> {
+ while (flightStream.next()) {
+ // Ignore any data
+ }
+
+ try (final ArrowBuf buf = allocator.buffer(16)) {
+ buf.writeBytes(EXPECTED_BYTES);
+ ackStream.onNext(PutResult.metadata(buf));
+ }
+ ackStream.onCompleted();
+ };
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestAuth.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestAuth.java
new file mode 100644
index 000000000..6f0ec9f02
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestAuth.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.Iterator;
+import java.util.Optional;
+
+import org.apache.arrow.flight.auth.ClientAuthHandler;
+import org.apache.arrow.flight.auth.ServerAuthHandler;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.junit.Test;
+
+public class TestAuth {
+
+ /** An auth handler that does not send messages should not block the server forever. */
+ @Test(expected = RuntimeException.class)
+ public void noMessages() throws Exception {
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+ final FlightServer s = FlightTestUtil
+ .getStartedServer(
+ location -> FlightServer.builder(allocator, location, new NoOpFlightProducer()).authHandler(
+ new OneshotAuthHandler()).build());
+ final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) {
+ client.authenticate(new ClientAuthHandler() {
+ @Override
+ public void authenticate(ClientAuthSender outgoing, Iterator<byte[]> incoming) {
+ }
+
+ @Override
+ public byte[] getCallToken() {
+ return new byte[0];
+ }
+ });
+ }
+ }
+
+ /** An auth handler that sends an error should not block the server forever. */
+ @Test(expected = RuntimeException.class)
+ public void clientError() throws Exception {
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+ final FlightServer s = FlightTestUtil
+ .getStartedServer(
+ location -> FlightServer.builder(allocator, location, new NoOpFlightProducer()).authHandler(
+ new OneshotAuthHandler()).build());
+ final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) {
+ client.authenticate(new ClientAuthHandler() {
+ @Override
+ public void authenticate(ClientAuthSender outgoing, Iterator<byte[]> incoming) {
+ outgoing.send(new byte[0]);
+ // Ensure the server-side runs
+ incoming.next();
+ outgoing.onError(new RuntimeException("test"));
+ }
+
+ @Override
+ public byte[] getCallToken() {
+ return new byte[0];
+ }
+ });
+ }
+ }
+
+ private static class OneshotAuthHandler implements ServerAuthHandler {
+
+ @Override
+ public Optional<String> isValid(byte[] token) {
+ return Optional.of("test");
+ }
+
+ @Override
+ public boolean authenticate(ServerAuthSender outgoing, Iterator<byte[]> incoming) {
+ incoming.next();
+ outgoing.send(new byte[0]);
+ return false;
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBackPressure.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBackPressure.java
new file mode 100644
index 000000000..1a71c363e
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBackPressure.java
@@ -0,0 +1,262 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.Function;
+
+import org.apache.arrow.flight.perf.PerformanceTestServer;
+import org.apache.arrow.flight.perf.TestPerf;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.types.Types.MinorType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import com.google.common.collect.ImmutableList;
+
+public class TestBackPressure {
+
+ private static final int BATCH_SIZE = 4095;
+
+ /**
+ * Make sure that failing to consume one stream doesn't block other streams.
+ */
+ @Ignore
+ @Test
+ public void ensureIndependentSteams() throws Exception {
+ ensureIndependentSteams((b) -> (location -> new PerformanceTestServer(b, location)));
+ }
+
+ /**
+ * Make sure that failing to consume one stream doesn't block other streams.
+ */
+ @Ignore
+ @Test
+ public void ensureIndependentSteamsWithCallbacks() throws Exception {
+ ensureIndependentSteams((b) -> (location -> new PerformanceTestServer(b, location,
+ new BackpressureStrategy.CallbackBackpressureStrategy(), true)));
+ }
+
+ /**
+ * Test to make sure stream doesn't go faster than the consumer is consuming.
+ */
+ @Ignore
+ @Test
+ public void ensureWaitUntilProceed() throws Exception {
+ ensureWaitUntilProceed(new PollingBackpressureStrategy(), false);
+ }
+
+ /**
+ * Test to make sure stream doesn't go faster than the consumer is consuming using a callback-based
+ * backpressure strategy.
+ */
+ @Ignore
+ @Test
+ public void ensureWaitUntilProceedWithCallbacks() throws Exception {
+ ensureWaitUntilProceed(new RecordingCallbackBackpressureStrategy(), true);
+ }
+
+ /**
+ * Make sure that failing to consume one stream doesn't block other streams.
+ */
+ private static void ensureIndependentSteams(Function<BufferAllocator, Function<Location, PerformanceTestServer>>
+ serverConstructor) throws Exception {
+ try (
+ final BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
+ final PerformanceTestServer server = FlightTestUtil.getStartedServer(
+ (location) -> (serverConstructor.apply(a).apply(location)));
+ final FlightClient client = FlightClient.builder(a, server.getLocation()).build()
+ ) {
+ try (FlightStream fs1 = client.getStream(client.getInfo(
+ TestPerf.getPerfFlightDescriptor(110L * BATCH_SIZE, BATCH_SIZE, 1))
+ .getEndpoints().get(0).getTicket())) {
+ consume(fs1, 10);
+
+ // stop consuming fs1 but make sure we can consume a large amount of fs2.
+ try (FlightStream fs2 = client.getStream(client.getInfo(
+ TestPerf.getPerfFlightDescriptor(200L * BATCH_SIZE, BATCH_SIZE, 1))
+ .getEndpoints().get(0).getTicket())) {
+ consume(fs2, 100);
+
+ consume(fs1, 100);
+ consume(fs2, 100);
+
+ consume(fs1);
+ consume(fs2);
+ }
+ }
+ }
+ }
+
+ /**
+ * Make sure that a stream doesn't go faster than the consumer is consuming.
+ */
+ private static void ensureWaitUntilProceed(SleepTimeRecordingBackpressureStrategy bpStrategy, boolean isNonBlocking)
+ throws Exception {
+ // request some values.
+ final long wait = 3000;
+ final long epsilon = 1000;
+
+ try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) {
+
+ final FlightProducer producer = new NoOpFlightProducer() {
+
+ @Override
+ public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
+ bpStrategy.register(listener);
+ final Runnable loadData = () -> {
+ int batches = 0;
+ final Schema pojoSchema = new Schema(ImmutableList.of(Field.nullable("a", MinorType.BIGINT.getType())));
+ try (VectorSchemaRoot root = VectorSchemaRoot.create(pojoSchema, allocator)) {
+ listener.start(root);
+ while (true) {
+ bpStrategy.waitForListener(0);
+ if (batches > 100) {
+ root.clear();
+ listener.completed();
+ return;
+ }
+
+ root.allocateNew();
+ root.setRowCount(4095);
+ listener.putNext();
+ batches++;
+ }
+ }
+ };
+
+ if (!isNonBlocking) {
+ loadData.run();
+ } else {
+ final ExecutorService service = Executors.newSingleThreadExecutor();
+ service.submit(loadData);
+ service.shutdown();
+ }
+ }
+ };
+
+
+ try (
+ BufferAllocator serverAllocator = allocator.newChildAllocator("server", 0, Long.MAX_VALUE);
+ FlightServer server =
+ FlightTestUtil.getStartedServer((location) -> FlightServer.builder(serverAllocator, location, producer)
+ .build());
+ BufferAllocator clientAllocator = allocator.newChildAllocator("client", 0, Long.MAX_VALUE);
+ FlightClient client =
+ FlightClient
+ .builder(clientAllocator, server.getLocation())
+ .build();
+ FlightStream stream = client.getStream(new Ticket(new byte[1]))
+ ) {
+ VectorSchemaRoot root = stream.getRoot();
+ root.clear();
+ Thread.sleep(wait);
+ while (stream.next()) {
+ root.clear();
+ }
+ long expected = wait - epsilon;
+ Assert.assertTrue(
+ String.format("Expected a sleep of at least %dms but only slept for %d", expected,
+ bpStrategy.getSleepTime()), bpStrategy.getSleepTime() > expected);
+
+ }
+ }
+ }
+
+ private static void consume(FlightStream stream) {
+ VectorSchemaRoot root = stream.getRoot();
+ while (stream.next()) {
+ root.clear();
+ }
+ }
+
+ private static void consume(FlightStream stream, int batches) {
+ VectorSchemaRoot root = stream.getRoot();
+ while (batches > 0 && stream.next()) {
+ root.clear();
+ batches--;
+ }
+ }
+
+ private interface SleepTimeRecordingBackpressureStrategy extends BackpressureStrategy {
+ /**
+ * Returns the total time spent waiting on the listener to be ready.
+ * @return the total time spent waiting on the listener to be ready.
+ */
+ long getSleepTime();
+ }
+
+ /**
+ * Implementation of a backpressure strategy that polls on isReady and records amount of time spent in Thread.sleep().
+ */
+ private static class PollingBackpressureStrategy implements SleepTimeRecordingBackpressureStrategy {
+ private final AtomicLong sleepTime = new AtomicLong(0);
+ private FlightProducer.ServerStreamListener listener;
+
+ @Override
+ public long getSleepTime() {
+ return sleepTime.get();
+ }
+
+ @Override
+ public void register(FlightProducer.ServerStreamListener listener) {
+ this.listener = listener;
+ }
+
+ @Override
+ public WaitResult waitForListener(long timeout) {
+ while (!listener.isReady()) {
+ try {
+ Thread.sleep(1);
+ sleepTime.addAndGet(1L);
+ } catch (InterruptedException ignore) {
+ }
+ }
+ return WaitResult.READY;
+ }
+ }
+
+ /**
+ * Implementation of a backpressure strategy that uses callbacks to detect changes in client readiness state
+ * and records spent time waiting.
+ */
+ private static class RecordingCallbackBackpressureStrategy extends BackpressureStrategy.CallbackBackpressureStrategy
+ implements SleepTimeRecordingBackpressureStrategy {
+ private final AtomicLong sleepTime = new AtomicLong(0);
+
+ @Override
+ public long getSleepTime() {
+ return sleepTime.get();
+ }
+
+ @Override
+ public WaitResult waitForListener(long timeout) {
+ final long startTime = System.currentTimeMillis();
+ final WaitResult result = super.waitForListener(timeout);
+ sleepTime.addAndGet(System.currentTimeMillis() - startTime);
+ return result;
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
new file mode 100644
index 000000000..e29cd07ce
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java
@@ -0,0 +1,567 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.function.BiConsumer;
+import java.util.function.Consumer;
+
+import org.apache.arrow.flight.FlightClient.ClientStreamListener;
+import org.apache.arrow.flight.impl.Flight;
+import org.apache.arrow.flight.impl.Flight.FlightDescriptor.DescriptorType;
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.BigIntVector;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.VectorUnloader;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.ipc.message.IpcOption;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.Assert;
+import org.junit.Test;
+
+import com.google.common.base.Charsets;
+import com.google.protobuf.ByteString;
+
+import io.grpc.MethodDescriptor;
+
+/**
+ * Test the operations of a basic flight service.
+ */
+public class TestBasicOperation {
+
+ @Test
+ public void fastPathDefaults() {
+ Assert.assertTrue(ArrowMessage.ENABLE_ZERO_COPY_READ);
+ Assert.assertFalse(ArrowMessage.ENABLE_ZERO_COPY_WRITE);
+ }
+
+ /**
+ * ARROW-6017: we should be able to construct locations for unknown schemes.
+ */
+ @Test
+ public void unknownScheme() throws URISyntaxException {
+ final Location location = new Location("s3://unknown");
+ Assert.assertEquals("s3", location.getUri().getScheme());
+ }
+
+ @Test
+ public void unknownSchemeRemote() throws Exception {
+ test(c -> {
+ try {
+ final FlightInfo info = c.getInfo(FlightDescriptor.path("test"));
+ Assert.assertEquals(new URI("https://example.com"), info.getEndpoints().get(0).getLocations().get(0).getUri());
+ } catch (URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+
+ @Test
+ public void roundTripTicket() throws Exception {
+ final Ticket ticket = new Ticket(new byte[]{0, 1, 2, 3, 4, 5});
+ Assert.assertEquals(ticket, Ticket.deserialize(ticket.serialize()));
+ }
+
+ @Test
+ public void roundTripInfo() throws Exception {
+ final Map<String, String> metadata = new HashMap<>();
+ metadata.put("foo", "bar");
+ final Schema schema = new Schema(Arrays.asList(
+ Field.nullable("a", new ArrowType.Int(32, true)),
+ Field.nullable("b", new ArrowType.FixedSizeBinary(32))
+ ), metadata);
+ final FlightInfo info1 = new FlightInfo(schema, FlightDescriptor.path(), Collections.emptyList(), -1, -1);
+ final FlightInfo info2 = new FlightInfo(schema, FlightDescriptor.command(new byte[2]),
+ Collections.singletonList(new FlightEndpoint(
+ new Ticket(new byte[10]), Location.forGrpcDomainSocket("/tmp/test.sock"))), 200, 500);
+ final FlightInfo info3 = new FlightInfo(schema, FlightDescriptor.path("a", "b"),
+ Arrays.asList(new FlightEndpoint(
+ new Ticket(new byte[10]), Location.forGrpcDomainSocket("/tmp/test.sock")),
+ new FlightEndpoint(
+ new Ticket(new byte[10]), Location.forGrpcDomainSocket("/tmp/test.sock"),
+ Location.forGrpcInsecure("localhost", 50051))
+ ), 200, 500);
+
+ Assert.assertEquals(info1, FlightInfo.deserialize(info1.serialize()));
+ Assert.assertEquals(info2, FlightInfo.deserialize(info2.serialize()));
+ Assert.assertEquals(info3, FlightInfo.deserialize(info3.serialize()));
+ }
+
+ @Test
+ public void roundTripDescriptor() throws Exception {
+ final FlightDescriptor cmd = FlightDescriptor.command("test command".getBytes(StandardCharsets.UTF_8));
+ Assert.assertEquals(cmd, FlightDescriptor.deserialize(cmd.serialize()));
+ final FlightDescriptor path = FlightDescriptor.path("foo", "bar", "test.arrow");
+ Assert.assertEquals(path, FlightDescriptor.deserialize(path.serialize()));
+ }
+
+ @Test
+ public void getDescriptors() throws Exception {
+ test(c -> {
+ int count = 0;
+ for (FlightInfo i : c.listFlights(Criteria.ALL)) {
+ count += 1;
+ }
+ Assert.assertEquals(1, count);
+ });
+ }
+
+ @Test
+ public void getDescriptorsWithCriteria() throws Exception {
+ test(c -> {
+ int count = 0;
+ for (FlightInfo i : c.listFlights(new Criteria(new byte[]{1}))) {
+ count += 1;
+ }
+ Assert.assertEquals(0, count);
+ });
+ }
+
+ @Test
+ public void getDescriptor() throws Exception {
+ test(c -> {
+ System.out.println(c.getInfo(FlightDescriptor.path("hello")).getDescriptor());
+ });
+ }
+
+ @Test
+ public void getSchema() throws Exception {
+ test(c -> {
+ System.out.println(c.getSchema(FlightDescriptor.path("hello")).getSchema());
+ });
+ }
+
+
+ @Test
+ public void listActions() throws Exception {
+ test(c -> {
+ for (ActionType at : c.listActions()) {
+ System.out.println(at.getType());
+ }
+ });
+ }
+
+ @Test
+ public void doAction() throws Exception {
+ test(c -> {
+ Iterator<Result> stream = c.doAction(new Action("hello"));
+
+ Assert.assertTrue(stream.hasNext());
+ Result r = stream.next();
+ Assert.assertArrayEquals("world".getBytes(Charsets.UTF_8), r.getBody());
+ });
+ test(c -> {
+ Iterator<Result> stream = c.doAction(new Action("hellooo"));
+
+ Assert.assertTrue(stream.hasNext());
+ Result r = stream.next();
+ Assert.assertArrayEquals("world".getBytes(Charsets.UTF_8), r.getBody());
+
+ Assert.assertTrue(stream.hasNext());
+ r = stream.next();
+ Assert.assertArrayEquals("!".getBytes(Charsets.UTF_8), r.getBody());
+ Assert.assertFalse(stream.hasNext());
+ });
+ }
+
+ @Test
+ public void putStream() throws Exception {
+ test((c, a) -> {
+ final int size = 10;
+
+ IntVector iv = new IntVector("c1", a);
+
+ try (VectorSchemaRoot root = VectorSchemaRoot.of(iv)) {
+ ClientStreamListener listener = c
+ .startPut(FlightDescriptor.path("hello"), root, new AsyncPutListener());
+
+ //batch 1
+ root.allocateNew();
+ for (int i = 0; i < size; i++) {
+ iv.set(i, i);
+ }
+ iv.setValueCount(size);
+ root.setRowCount(size);
+ listener.putNext();
+
+ // batch 2
+
+ root.allocateNew();
+ for (int i = 0; i < size; i++) {
+ iv.set(i, i + size);
+ }
+ iv.setValueCount(size);
+ root.setRowCount(size);
+ listener.putNext();
+ root.clear();
+ listener.completed();
+
+ // wait for ack to avoid memory leaks.
+ listener.getResult();
+ }
+ });
+ }
+
+ @Test
+ public void propagateErrors() throws Exception {
+ test(client -> {
+ FlightTestUtil.assertCode(FlightStatusCode.UNIMPLEMENTED, () -> {
+ client.doAction(new Action("invalid-action")).forEachRemaining(action -> Assert.fail());
+ });
+ });
+ }
+
+ @Test
+ public void getStream() throws Exception {
+ test(c -> {
+ try (final FlightStream stream = c.getStream(new Ticket(new byte[0]))) {
+ VectorSchemaRoot root = stream.getRoot();
+ IntVector iv = (IntVector) root.getVector("c1");
+ int value = 0;
+ while (stream.next()) {
+ for (int i = 0; i < root.getRowCount(); i++) {
+ Assert.assertEquals(value, iv.get(i));
+ value++;
+ }
+ }
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+
+ /** Ensure the client is configured to accept large messages. */
+ @Test
+ public void getStreamLargeBatch() throws Exception {
+ test(c -> {
+ try (final FlightStream stream = c.getStream(new Ticket(Producer.TICKET_LARGE_BATCH))) {
+ Assert.assertEquals(128, stream.getRoot().getFieldVectors().size());
+ Assert.assertTrue(stream.next());
+ Assert.assertEquals(65536, stream.getRoot().getRowCount());
+ Assert.assertTrue(stream.next());
+ Assert.assertEquals(65536, stream.getRoot().getRowCount());
+ Assert.assertFalse(stream.next());
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+
+ /** Ensure the server is configured to accept large messages. */
+ @Test
+ public void startPutLargeBatch() throws Exception {
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
+ final List<FieldVector> vectors = new ArrayList<>();
+ for (int col = 0; col < 128; col++) {
+ final BigIntVector vector = new BigIntVector("f" + col, allocator);
+ for (int row = 0; row < 65536; row++) {
+ vector.setSafe(row, row);
+ }
+ vectors.add(vector);
+ }
+ test(c -> {
+ try (final VectorSchemaRoot root = new VectorSchemaRoot(vectors)) {
+ root.setRowCount(65536);
+ final ClientStreamListener stream = c.startPut(FlightDescriptor.path(""), root, new SyncPutListener());
+ stream.putNext();
+ stream.putNext();
+ stream.completed();
+ stream.getResult();
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+ }
+
+ private void test(Consumer<FlightClient> consumer) throws Exception {
+ test((c, a) -> {
+ consumer.accept(c);
+ });
+ }
+
+ private void test(BiConsumer<FlightClient, BufferAllocator> consumer) throws Exception {
+ try (
+ BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
+ Producer producer = new Producer(a);
+ FlightServer s =
+ FlightTestUtil.getStartedServer(
+ (location) -> FlightServer.builder(a, location, producer).build()
+ )) {
+
+ try (
+ FlightClient c = FlightClient.builder(a, s.getLocation()).build()
+ ) {
+ try (BufferAllocator testAllocator = a.newChildAllocator("testcase", 0, Long.MAX_VALUE)) {
+ consumer.accept(c, testAllocator);
+ }
+ }
+ }
+ }
+
+ /** Helper method to convert an ArrowMessage into a Protobuf message. */
+ private Flight.FlightData arrowMessageToProtobuf(
+ MethodDescriptor.Marshaller<ArrowMessage> marshaller, ArrowMessage message) throws IOException {
+ final ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ try (final InputStream serialized = marshaller.stream(message)) {
+ final byte[] buf = new byte[1024];
+ while (true) {
+ int read = serialized.read(buf);
+ if (read < 0) {
+ break;
+ }
+ baos.write(buf, 0, read);
+ }
+ }
+ final byte[] serializedMessage = baos.toByteArray();
+ return Flight.FlightData.parseFrom(serializedMessage);
+ }
+
+ /** ARROW-10962: accept FlightData messages generated by Protobuf (which can omit empty fields). */
+ @Test
+ public void testProtobufRecordBatchCompatibility() throws Exception {
+ final Schema schema = new Schema(Collections.singletonList(Field.nullable("foo", new ArrowType.Int(32, true))));
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+ final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ final VectorUnloader unloader = new VectorUnloader(root);
+ root.setRowCount(0);
+ final MethodDescriptor.Marshaller<ArrowMessage> marshaller = ArrowMessage.createMarshaller(allocator);
+ try (final ArrowMessage message = new ArrowMessage(
+ unloader.getRecordBatch(), /* appMetadata */ null, /* tryZeroCopy */ false, IpcOption.DEFAULT)) {
+ Assert.assertEquals(ArrowMessage.HeaderType.RECORD_BATCH, message.getMessageType());
+ // Should have at least one empty body buffer (there may be multiple for e.g. data and validity)
+ Iterator<ArrowBuf> iterator = message.getBufs().iterator();
+ Assert.assertTrue(iterator.hasNext());
+ while (iterator.hasNext()) {
+ Assert.assertEquals(0, iterator.next().capacity());
+ }
+ final Flight.FlightData protobufData = arrowMessageToProtobuf(marshaller, message)
+ .toBuilder()
+ .clearDataBody()
+ .build();
+ Assert.assertEquals(0, protobufData.getDataBody().size());
+ ArrowMessage parsedMessage = marshaller.parse(new ByteArrayInputStream(protobufData.toByteArray()));
+ // Should have an empty body buffer
+ Iterator<ArrowBuf> parsedIterator = parsedMessage.getBufs().iterator();
+ Assert.assertTrue(parsedIterator.hasNext());
+ Assert.assertEquals(0, parsedIterator.next().capacity());
+ // Should have only one (the parser synthesizes exactly one); in the case of empty buffers, this is equivalent
+ Assert.assertFalse(parsedIterator.hasNext());
+ // Should not throw
+ final ArrowRecordBatch rb = parsedMessage.asRecordBatch();
+ Assert.assertEquals(rb.computeBodyLength(), 0);
+ }
+ }
+ }
+
+ /** ARROW-10962: accept FlightData messages generated by Protobuf (which can omit empty fields). */
+ @Test
+ public void testProtobufSchemaCompatibility() throws Exception {
+ final Schema schema = new Schema(Collections.singletonList(Field.nullable("foo", new ArrowType.Int(32, true))));
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
+ final MethodDescriptor.Marshaller<ArrowMessage> marshaller = ArrowMessage.createMarshaller(allocator);
+ Flight.FlightDescriptor descriptor = FlightDescriptor.command(new byte[0]).toProtocol();
+ try (final ArrowMessage message = new ArrowMessage(descriptor, schema, IpcOption.DEFAULT)) {
+ Assert.assertEquals(ArrowMessage.HeaderType.SCHEMA, message.getMessageType());
+ // Should have no body buffers
+ Assert.assertFalse(message.getBufs().iterator().hasNext());
+ final Flight.FlightData protobufData = arrowMessageToProtobuf(marshaller, message)
+ .toBuilder()
+ .setDataBody(ByteString.EMPTY)
+ .build();
+ Assert.assertEquals(0, protobufData.getDataBody().size());
+ final ArrowMessage parsedMessage = marshaller.parse(new ByteArrayInputStream(protobufData.toByteArray()));
+ // Should have no body buffers
+ Assert.assertFalse(parsedMessage.getBufs().iterator().hasNext());
+ // Should not throw
+ parsedMessage.asSchema();
+ }
+ }
+ }
+
+ /**
+ * An example FlightProducer for test purposes.
+ */
+ public static class Producer implements FlightProducer, AutoCloseable {
+ static final byte[] TICKET_LARGE_BATCH = "large-batch".getBytes(StandardCharsets.UTF_8);
+
+ private final BufferAllocator allocator;
+
+ public Producer(BufferAllocator allocator) {
+ super();
+ this.allocator = allocator;
+ }
+
+ @Override
+ public void listFlights(CallContext context, Criteria criteria,
+ StreamListener<FlightInfo> listener) {
+ if (criteria.getExpression().length > 0) {
+ // Don't send anything if criteria are set
+ listener.onCompleted();
+ }
+
+ Flight.FlightInfo getInfo = Flight.FlightInfo.newBuilder()
+ .setFlightDescriptor(Flight.FlightDescriptor.newBuilder()
+ .setType(DescriptorType.CMD)
+ .setCmd(ByteString.copyFrom("cool thing", Charsets.UTF_8)))
+ .build();
+ try {
+ listener.onNext(new FlightInfo(getInfo));
+ } catch (URISyntaxException e) {
+ listener.onError(e);
+ return;
+ }
+ listener.onCompleted();
+ }
+
+ @Override
+ public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream) {
+ return () -> {
+ while (flightStream.next()) {
+ // Drain the stream
+ }
+ };
+ }
+
+ @Override
+ public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
+ if (Arrays.equals(TICKET_LARGE_BATCH, ticket.getBytes())) {
+ getLargeBatch(listener);
+ return;
+ }
+ final int size = 10;
+
+ IntVector iv = new IntVector("c1", allocator);
+ VectorSchemaRoot root = VectorSchemaRoot.of(iv);
+ listener.start(root);
+
+ //batch 1
+ root.allocateNew();
+ for (int i = 0; i < size; i++) {
+ iv.set(i, i);
+ }
+ iv.setValueCount(size);
+ root.setRowCount(size);
+ listener.putNext();
+
+ // batch 2
+
+ root.allocateNew();
+ for (int i = 0; i < size; i++) {
+ iv.set(i, i + size);
+ }
+ iv.setValueCount(size);
+ root.setRowCount(size);
+ listener.putNext();
+ root.clear();
+ listener.completed();
+ }
+
+ private void getLargeBatch(ServerStreamListener listener) {
+ final List<FieldVector> vectors = new ArrayList<>();
+ for (int col = 0; col < 128; col++) {
+ final BigIntVector vector = new BigIntVector("f" + col, allocator);
+ for (int row = 0; row < 65536; row++) {
+ vector.setSafe(row, row);
+ }
+ vectors.add(vector);
+ }
+ try (final VectorSchemaRoot root = new VectorSchemaRoot(vectors)) {
+ root.setRowCount(65536);
+ listener.start(root);
+ listener.putNext();
+ listener.putNext();
+ listener.completed();
+ }
+ }
+
+ @Override
+ public void close() throws Exception {
+ allocator.close();
+ }
+
+ @Override
+ public FlightInfo getFlightInfo(CallContext context,
+ FlightDescriptor descriptor) {
+ try {
+ Flight.FlightInfo getInfo = Flight.FlightInfo.newBuilder()
+ .setFlightDescriptor(Flight.FlightDescriptor.newBuilder()
+ .setType(DescriptorType.CMD)
+ .setCmd(ByteString.copyFrom("cool thing", Charsets.UTF_8)))
+ .addEndpoint(
+ Flight.FlightEndpoint.newBuilder().addLocation(new Location("https://example.com").toProtocol()))
+ .build();
+ return new FlightInfo(getInfo);
+ } catch (URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public void doAction(CallContext context, Action action,
+ StreamListener<Result> listener) {
+ switch (action.getType()) {
+ case "hello": {
+ listener.onNext(new Result("world".getBytes(Charsets.UTF_8)));
+ listener.onCompleted();
+ break;
+ }
+ case "hellooo": {
+ listener.onNext(new Result("world".getBytes(Charsets.UTF_8)));
+ listener.onNext(new Result("!".getBytes(Charsets.UTF_8)));
+ listener.onCompleted();
+ break;
+ }
+ default:
+ listener.onError(CallStatus.UNIMPLEMENTED.withDescription("Action not implemented: " + action.getType())
+ .toRuntimeException());
+ }
+ }
+
+ @Override
+ public void listActions(CallContext context,
+ StreamListener<ActionType> listener) {
+ listener.onNext(new ActionType("get", ""));
+ listener.onNext(new ActionType("put", ""));
+ listener.onNext(new ActionType("hello", ""));
+ listener.onCompleted();
+ }
+
+ }
+
+
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestCallOptions.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestCallOptions.java
new file mode 100644
index 000000000..45e3e4960
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestCallOptions.java
@@ -0,0 +1,191 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Iterator;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import io.grpc.Metadata;
+
+public class TestCallOptions {
+
+ @Test
+ @Ignore
+ public void timeoutFires() {
+ // Ignored due to CI flakiness
+ test((client) -> {
+ Instant start = Instant.now();
+ Iterator<Result> results = client.doAction(new Action("hang"), CallOptions.timeout(1, TimeUnit.SECONDS));
+ try {
+ results.next();
+ Assert.fail("Call should have failed");
+ } catch (RuntimeException e) {
+ Assert.assertTrue(e.getMessage(), e.getMessage().contains("deadline exceeded"));
+ }
+ Instant end = Instant.now();
+ Assert.assertTrue("Call took over 1500 ms despite timeout", Duration.between(start, end).toMillis() < 1500);
+ });
+ }
+
+ @Test
+ @Ignore
+ public void underTimeout() {
+ // Ignored due to CI flakiness
+ test((client) -> {
+ Instant start = Instant.now();
+ // This shouldn't fail and it should complete within the timeout
+ Iterator<Result> results = client.doAction(new Action("fast"), CallOptions.timeout(2, TimeUnit.SECONDS));
+ Assert.assertArrayEquals(new byte[]{42, 42}, results.next().getBody());
+ Instant end = Instant.now();
+ Assert.assertTrue("Call took over 2500 ms despite timeout", Duration.between(start, end).toMillis() < 2500);
+ });
+ }
+
+ @Test
+ public void singleProperty() {
+ final FlightCallHeaders headers = new FlightCallHeaders();
+ headers.insert("key", "value");
+ testHeaders(headers);
+ }
+
+ @Test
+ public void multipleProperties() {
+ final FlightCallHeaders headers = new FlightCallHeaders();
+ headers.insert("key", "value");
+ headers.insert("key2", "value2");
+ testHeaders(headers);
+ }
+
+ @Test
+ public void binaryProperties() {
+ final FlightCallHeaders headers = new FlightCallHeaders();
+ headers.insert("key-bin", "value".getBytes());
+ headers.insert("key3-bin", "ëfßæ".getBytes());
+ testHeaders(headers);
+ }
+
+ @Test
+ public void mixedProperties() {
+ final FlightCallHeaders headers = new FlightCallHeaders();
+ headers.insert("key", "value");
+ headers.insert("key3-bin", "ëfßæ".getBytes());
+ testHeaders(headers);
+ }
+
+ private void testHeaders(CallHeaders headers) {
+ try (
+ BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
+ HeaderProducer producer = new HeaderProducer();
+ FlightServer s =
+ FlightTestUtil.getStartedServer((location) -> FlightServer.builder(a, location, producer).build());
+ FlightClient client = FlightClient.builder(a, s.getLocation()).build()) {
+ client.doAction(new Action(""), new HeaderCallOption(headers)).hasNext();
+
+ final CallHeaders incomingHeaders = producer.headers();
+ for (String key : headers.keys()) {
+ if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
+ Assert.assertArrayEquals(headers.getByte(key), incomingHeaders.getByte(key));
+ } else {
+ Assert.assertEquals(headers.get(key), incomingHeaders.get(key));
+ }
+ }
+ } catch (InterruptedException | IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ void test(Consumer<FlightClient> testFn) {
+ try (
+ BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
+ Producer producer = new Producer();
+ FlightServer s =
+ FlightTestUtil.getStartedServer((location) -> FlightServer.builder(a, location, producer).build());
+ FlightClient client = FlightClient.builder(a, s.getLocation()).build()) {
+ testFn.accept(client);
+ } catch (InterruptedException | IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ static class HeaderProducer extends NoOpFlightProducer implements AutoCloseable {
+ CallHeaders headers;
+
+ @Override
+ public void close() {
+ }
+
+ public CallHeaders headers() {
+ return headers;
+ }
+
+ @Override
+ public void doAction(CallContext context, Action action, StreamListener<Result> listener) {
+ this.headers = context.getMiddleware(FlightConstants.HEADER_KEY).headers();
+ listener.onCompleted();
+ }
+ }
+
+ static class Producer extends NoOpFlightProducer implements AutoCloseable {
+
+ Producer() {
+ }
+
+ @Override
+ public void close() {
+ }
+
+ @Override
+ public void doAction(CallContext context, Action action, StreamListener<Result> listener) {
+ switch (action.getType()) {
+ case "hang": {
+ try {
+ Thread.sleep(25000);
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ listener.onNext(new Result(new byte[]{}));
+ listener.onCompleted();
+ return;
+ }
+ case "fast": {
+ try {
+ Thread.sleep(500);
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ listener.onNext(new Result(new byte[]{42, 42}));
+ listener.onCompleted();
+ return;
+ }
+ default: {
+ throw new UnsupportedOperationException(action.getType());
+ }
+ }
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestClientMiddleware.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestClientMiddleware.java
new file mode 100644
index 000000000..ccfc9f2d1
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestClientMiddleware.java
@@ -0,0 +1,359 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.BiConsumer;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/**
+ * A basic test of client middleware using a simplified OpenTracing-like example.
+ */
+@RunWith(JUnit4.class)
+public class TestClientMiddleware {
+
+ /**
+ * Test that a client middleware can fail a call before it starts by throwing a {@link FlightRuntimeException}.
+ */
+ @Test
+ public void clientMiddleware_failCallBeforeSending() {
+ test(new NoOpFlightProducer(), null, Collections.singletonList(new CallRejector.Factory()),
+ (allocator, client) -> {
+ FlightTestUtil.assertCode(FlightStatusCode.UNAVAILABLE, client::listActions);
+ });
+ }
+
+ /**
+ * Test an OpenTracing-like scenario where client and server middleware work together to propagate a request ID
+ * without explicit intervention from the service implementation.
+ */
+ @Test
+ public void middleware_propagateHeader() {
+ final Context context = new Context("span id");
+ test(new NoOpFlightProducer(),
+ new TestServerMiddleware.ServerMiddlewarePair<>(
+ FlightServerMiddleware.Key.of("test"), new ServerSpanInjector.Factory()),
+ Collections.singletonList(new ClientSpanInjector.Factory(context)),
+ (allocator, client) -> {
+ FlightTestUtil.assertCode(FlightStatusCode.UNIMPLEMENTED, () -> client.listActions().forEach(actionType -> {
+ }));
+ });
+ Assert.assertEquals(context.outgoingSpanId, context.incomingSpanId);
+ Assert.assertNotNull(context.finalStatus);
+ Assert.assertEquals(FlightStatusCode.UNIMPLEMENTED, context.finalStatus.code());
+ }
+
+ /** Ensure both server and client can send and receive multi-valued headers (both binary and text values). */
+ @Test
+ public void testMultiValuedHeaders() {
+ final MultiHeaderClientMiddlewareFactory clientFactory = new MultiHeaderClientMiddlewareFactory();
+ test(new NoOpFlightProducer(),
+ new TestServerMiddleware.ServerMiddlewarePair<>(
+ FlightServerMiddleware.Key.of("test"), new MultiHeaderServerMiddlewareFactory()),
+ Collections.singletonList(clientFactory),
+ (allocator, client) -> {
+ FlightTestUtil.assertCode(FlightStatusCode.UNIMPLEMENTED, () -> client.listActions().forEach(actionType -> {
+ }));
+ });
+ // The server echoes the headers we send back to us, so ensure all the ones we sent are present with the correct
+ // values in the correct order.
+ for (final Map.Entry<String, List<byte[]>> entry : EXPECTED_BINARY_HEADERS.entrySet()) {
+ // Compare header values entry-by-entry because byte arrays don't compare via equals
+ final List<byte[]> receivedValues = clientFactory.lastBinaryHeaders.get(entry.getKey());
+ Assert.assertNotNull("Missing for header: " + entry.getKey(), receivedValues);
+ Assert.assertEquals(
+ "Missing or wrong value for header: " + entry.getKey(),
+ entry.getValue().size(), receivedValues.size());
+ for (int i = 0; i < entry.getValue().size(); i++) {
+ Assert.assertArrayEquals(entry.getValue().get(i), receivedValues.get(i));
+ }
+ }
+ for (final Map.Entry<String, List<String>> entry : EXPECTED_TEXT_HEADERS.entrySet()) {
+ Assert.assertEquals(
+ "Missing or wrong value for header: " + entry.getKey(),
+ entry.getValue(), clientFactory.lastTextHeaders.get(entry.getKey()));
+ }
+ }
+
+ private static <T extends FlightServerMiddleware> void test(FlightProducer producer,
+ TestServerMiddleware.ServerMiddlewarePair<T> serverMiddleware,
+ List<FlightClientMiddleware.Factory> clientMiddleware,
+ BiConsumer<BufferAllocator, FlightClient> body) {
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
+ final FlightServer server = FlightTestUtil
+ .getStartedServer(location -> {
+ final FlightServer.Builder builder = FlightServer.builder(allocator, location, producer);
+ if (serverMiddleware != null) {
+ builder.middleware(serverMiddleware.key, serverMiddleware.factory);
+ }
+ return builder.build();
+ });
+ FlightClient.Builder builder = FlightClient.builder(allocator, server.getLocation());
+ clientMiddleware.forEach(builder::intercept);
+ try (final FlightServer ignored = server;
+ final FlightClient client = builder.build()
+ ) {
+ body.accept(allocator, client);
+ }
+ } catch (InterruptedException | IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ /**
+ * A server middleware component that reads a request ID from incoming headers and sends the request ID back on
+ * outgoing headers.
+ */
+ static class ServerSpanInjector implements FlightServerMiddleware {
+
+ private final String spanId;
+
+ public ServerSpanInjector(String spanId) {
+ this.spanId = spanId;
+ }
+
+ @Override
+ public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
+ outgoingHeaders.insert("x-span", spanId);
+ }
+
+ @Override
+ public void onCallCompleted(CallStatus status) {
+
+ }
+
+ @Override
+ public void onCallErrored(Throwable err) {
+
+ }
+
+ static class Factory implements FlightServerMiddleware.Factory<ServerSpanInjector> {
+
+ @Override
+ public ServerSpanInjector onCallStarted(CallInfo info, CallHeaders incomingHeaders, RequestContext context) {
+ return new ServerSpanInjector(incomingHeaders.get("x-span"));
+ }
+ }
+ }
+
+ /**
+ * A client middleware component that, given a mock OpenTracing-like "request context", sends the request ID in the
+ * context on outgoing headers and reads it from incoming headers.
+ */
+ static class ClientSpanInjector implements FlightClientMiddleware {
+
+ private final Context context;
+
+ public ClientSpanInjector(Context context) {
+ this.context = context;
+ }
+
+ @Override
+ public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
+ outgoingHeaders.insert("x-span", context.outgoingSpanId);
+ }
+
+ @Override
+ public void onHeadersReceived(CallHeaders incomingHeaders) {
+ context.incomingSpanId = incomingHeaders.get("x-span");
+ }
+
+ @Override
+ public void onCallCompleted(CallStatus status) {
+ context.finalStatus = status;
+ }
+
+ static class Factory implements FlightClientMiddleware.Factory {
+
+ private final Context context;
+
+ Factory(Context context) {
+ this.context = context;
+ }
+
+ @Override
+ public FlightClientMiddleware onCallStarted(CallInfo info) {
+ return new ClientSpanInjector(context);
+ }
+ }
+ }
+
+ /**
+ * A mock OpenTracing-like "request context".
+ */
+ static class Context {
+
+ final String outgoingSpanId;
+ String incomingSpanId;
+ CallStatus finalStatus;
+
+ Context(String spanId) {
+ this.outgoingSpanId = spanId;
+ }
+ }
+
+ /**
+ * A client middleware that fails outgoing calls.
+ */
+ static class CallRejector implements FlightClientMiddleware {
+
+ @Override
+ public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
+ }
+
+ @Override
+ public void onHeadersReceived(CallHeaders incomingHeaders) {
+ }
+
+ @Override
+ public void onCallCompleted(CallStatus status) {
+ }
+
+ static class Factory implements FlightClientMiddleware.Factory {
+
+ @Override
+ public FlightClientMiddleware onCallStarted(CallInfo info) {
+ throw CallStatus.UNAVAILABLE.withDescription("Rejecting call.").toRuntimeException();
+ }
+ }
+ }
+
+ // Used to test that middleware can send and receive multi-valued text and binary headers.
+ static final Map<String, List<byte[]>> EXPECTED_BINARY_HEADERS = new HashMap<String, List<byte[]>>() {{
+ put("x-binary-bin", Arrays.asList(new byte[] {0}, new byte[]{1}));
+ }};
+ static final Map<String, List<String>> EXPECTED_TEXT_HEADERS = new HashMap<String, List<String>>() {{
+ put("x-text", Arrays.asList("foo", "bar"));
+ }};
+
+ static class MultiHeaderServerMiddlewareFactory implements
+ FlightServerMiddleware.Factory<MultiHeaderServerMiddleware> {
+ @Override
+ public MultiHeaderServerMiddleware onCallStarted(CallInfo info, CallHeaders incomingHeaders,
+ RequestContext context) {
+ // Echo the headers back to the client. Copy values out of CallHeaders since the underlying gRPC metadata
+ // object isn't safe to use after this function returns.
+ Map<String, List<byte[]>> binaryHeaders = new HashMap<>();
+ Map<String, List<String>> textHeaders = new HashMap<>();
+ for (final String key : incomingHeaders.keys()) {
+ if (key.endsWith("-bin")) {
+ binaryHeaders.compute(key, (ignored, values) -> {
+ if (values == null) {
+ values = new ArrayList<>();
+ }
+ incomingHeaders.getAllByte(key).forEach(values::add);
+ return values;
+ });
+ } else {
+ textHeaders.compute(key, (ignored, values) -> {
+ if (values == null) {
+ values = new ArrayList<>();
+ }
+ incomingHeaders.getAll(key).forEach(values::add);
+ return values;
+ });
+ }
+ }
+ return new MultiHeaderServerMiddleware(binaryHeaders, textHeaders);
+ }
+ }
+
+ static class MultiHeaderServerMiddleware implements FlightServerMiddleware {
+ private final Map<String, List<byte[]>> binaryHeaders;
+ private final Map<String, List<String>> textHeaders;
+
+ MultiHeaderServerMiddleware(Map<String, List<byte[]>> binaryHeaders, Map<String, List<String>> textHeaders) {
+ this.binaryHeaders = binaryHeaders;
+ this.textHeaders = textHeaders;
+ }
+
+ @Override
+ public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
+ binaryHeaders.forEach((key, values) -> values.forEach(value -> outgoingHeaders.insert(key, value)));
+ textHeaders.forEach((key, values) -> values.forEach(value -> outgoingHeaders.insert(key, value)));
+ }
+
+ @Override
+ public void onCallCompleted(CallStatus status) {}
+
+ @Override
+ public void onCallErrored(Throwable err) {}
+ }
+
+ static class MultiHeaderClientMiddlewareFactory implements FlightClientMiddleware.Factory {
+ Map<String, List<byte[]>> lastBinaryHeaders = null;
+ Map<String, List<String>> lastTextHeaders = null;
+
+ @Override
+ public FlightClientMiddleware onCallStarted(CallInfo info) {
+ return new MultiHeaderClientMiddleware(this);
+ }
+ }
+
+ static class MultiHeaderClientMiddleware implements FlightClientMiddleware {
+ private final MultiHeaderClientMiddlewareFactory factory;
+
+ public MultiHeaderClientMiddleware(MultiHeaderClientMiddlewareFactory factory) {
+ this.factory = factory;
+ }
+
+ @Override
+ public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
+ for (final Map.Entry<String, List<byte[]>> entry : EXPECTED_BINARY_HEADERS.entrySet()) {
+ entry.getValue().forEach((value) -> outgoingHeaders.insert(entry.getKey(), value));
+ Assert.assertTrue(outgoingHeaders.containsKey(entry.getKey()));
+ }
+ for (final Map.Entry<String, List<String>> entry : EXPECTED_TEXT_HEADERS.entrySet()) {
+ entry.getValue().forEach((value) -> outgoingHeaders.insert(entry.getKey(), value));
+ Assert.assertTrue(outgoingHeaders.containsKey(entry.getKey()));
+ }
+ }
+
+ @Override
+ public void onHeadersReceived(CallHeaders incomingHeaders) {
+ factory.lastBinaryHeaders = new HashMap<>();
+ factory.lastTextHeaders = new HashMap<>();
+ incomingHeaders.keys().forEach(header -> {
+ if (header.endsWith("-bin")) {
+ final List<byte[]> values = new ArrayList<>();
+ incomingHeaders.getAllByte(header).forEach(values::add);
+ factory.lastBinaryHeaders.put(header, values);
+ } else {
+ final List<String> values = new ArrayList<>();
+ incomingHeaders.getAll(header).forEach(values::add);
+ factory.lastTextHeaders.put(header, values);
+ }
+ });
+ }
+
+ @Override
+ public void onCallCompleted(CallStatus status) {}
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDictionaryUtils.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDictionaryUtils.java
new file mode 100644
index 000000000..b5bf117c6
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDictionaryUtils.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import java.util.TreeSet;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.dictionary.Dictionary;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.FieldType;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.Test;
+
+import com.google.common.collect.ImmutableList;
+
+/**
+ * Test cases for {@link DictionaryUtils}.
+ */
+public class TestDictionaryUtils {
+
+ @Test
+ public void testReuseSchema() {
+ FieldType varcharType = new FieldType(true, new ArrowType.Utf8(), null);
+ FieldType intType = new FieldType(true, new ArrowType.Int(32, true), null);
+
+ ImmutableList<Field> build = ImmutableList.of(
+ new Field("stringCol", varcharType, null),
+ new Field("intCol", intType, null));
+
+ Schema schema = new Schema(build);
+ Schema newSchema = DictionaryUtils.generateSchema(schema, null, new TreeSet<>());
+
+ // assert that no new schema is created.
+ assertTrue(schema == newSchema);
+ }
+
+ @Test
+ public void testCreateSchema() {
+ try (BufferAllocator allocator = new RootAllocator(1024)) {
+ DictionaryEncoding dictionaryEncoding =
+ new DictionaryEncoding(0, true, new ArrowType.Int(8, true));
+ VarCharVector dictVec = new VarCharVector("dict vector", allocator);
+ Dictionary dictionary = new Dictionary(dictVec, dictionaryEncoding);
+ DictionaryProvider dictProvider = new DictionaryProvider.MapDictionaryProvider(dictionary);
+ TreeSet<Long> dictionaryUsed = new TreeSet<>();
+
+ FieldType encodedVarcharType = new FieldType(true, new ArrowType.Int(8, true), dictionaryEncoding);
+ FieldType intType = new FieldType(true, new ArrowType.Int(32, true), null);
+
+ ImmutableList<Field> build = ImmutableList.of(
+ new Field("stringCol", encodedVarcharType, null),
+ new Field("intCol", intType, null));
+
+ Schema schema = new Schema(build);
+ Schema newSchema = DictionaryUtils.generateSchema(schema, dictProvider, dictionaryUsed);
+
+ // assert that a new schema is created.
+ assertTrue(schema != newSchema);
+
+ // assert the column is converted as expected
+ ArrowType newColType = newSchema.getFields().get(0).getType();
+ assertEquals(new ArrowType.Utf8(), newColType);
+
+ assertEquals(1, dictionaryUsed.size());
+ assertEquals(0, dictionaryUsed.first());
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java
new file mode 100644
index 000000000..70394e11e
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java
@@ -0,0 +1,536 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.stream.IntStream;
+
+import org.apache.arrow.memory.ArrowBuf;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VectorLoader;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.VectorUnloader;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.testing.ValueVectorDataPopulator;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+public class TestDoExchange {
+ static byte[] EXCHANGE_DO_GET = "do-get".getBytes(StandardCharsets.UTF_8);
+ static byte[] EXCHANGE_DO_PUT = "do-put".getBytes(StandardCharsets.UTF_8);
+ static byte[] EXCHANGE_ECHO = "echo".getBytes(StandardCharsets.UTF_8);
+ static byte[] EXCHANGE_METADATA_ONLY = "only-metadata".getBytes(StandardCharsets.UTF_8);
+ static byte[] EXCHANGE_TRANSFORM = "transform".getBytes(StandardCharsets.UTF_8);
+ static byte[] EXCHANGE_CANCEL = "cancel".getBytes(StandardCharsets.UTF_8);
+
+ private BufferAllocator allocator;
+ private FlightServer server;
+ private FlightClient client;
+
+ @Before
+ public void setUp() throws Exception {
+ allocator = new RootAllocator(Integer.MAX_VALUE);
+ final Location serverLocation = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, 0);
+ server = FlightServer.builder(allocator, serverLocation, new Producer(allocator)).build();
+ server.start();
+ final Location clientLocation = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort());
+ client = FlightClient.builder(allocator, clientLocation).build();
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ AutoCloseables.close(client, server, allocator);
+ }
+
+ /** Test a pure-metadata flow. */
+ @Test
+ public void testDoExchangeOnlyMetadata() throws Exception {
+ // Send a particular descriptor to the server and check for a particular response pattern.
+ try (final FlightClient.ExchangeReaderWriter stream =
+ client.doExchange(FlightDescriptor.command(EXCHANGE_METADATA_ONLY))) {
+ final FlightStream reader = stream.getReader();
+
+ // Server starts by sending a message without data (hence no VectorSchemaRoot should be present)
+ assertTrue(reader.next());
+ assertFalse(reader.hasRoot());
+ assertEquals(42, reader.getLatestMetadata().getInt(0));
+
+ // Write a metadata message to the server (without sending any data)
+ ArrowBuf buf = allocator.buffer(4);
+ buf.writeInt(84);
+ stream.getWriter().putMetadata(buf);
+
+ // Check that the server echoed the metadata back to us
+ assertTrue(reader.next());
+ assertFalse(reader.hasRoot());
+ assertEquals(84, reader.getLatestMetadata().getInt(0));
+
+ // Close our write channel and ensure the server also closes theirs
+ stream.getWriter().completed();
+ assertFalse(reader.next());
+ }
+ }
+
+ /** Emulate a DoGet with a DoExchange. */
+ @Test
+ public void testDoExchangeDoGet() throws Exception {
+ try (final FlightClient.ExchangeReaderWriter stream =
+ client.doExchange(FlightDescriptor.command(EXCHANGE_DO_GET))) {
+ final FlightStream reader = stream.getReader();
+ VectorSchemaRoot root = reader.getRoot();
+ IntVector iv = (IntVector) root.getVector("a");
+ int value = 0;
+ while (reader.next()) {
+ for (int i = 0; i < root.getRowCount(); i++) {
+ assertFalse(String.format("Row %d should not be null", value), iv.isNull(i));
+ assertEquals(value, iv.get(i));
+ value++;
+ }
+ }
+ assertEquals(100, value);
+ }
+ }
+
+ /** Emulate a DoPut with a DoExchange. */
+ @Test
+ public void testDoExchangeDoPut() throws Exception {
+ final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true))));
+ try (final FlightClient.ExchangeReaderWriter stream =
+ client.doExchange(FlightDescriptor.command(EXCHANGE_DO_PUT));
+ final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ IntVector iv = (IntVector) root.getVector("a");
+ iv.allocateNew();
+
+ stream.getWriter().start(root);
+ int counter = 0;
+ for (int i = 0; i < 10; i++) {
+ ValueVectorDataPopulator.setVector(iv, IntStream.range(0, i).boxed().toArray(Integer[]::new));
+ root.setRowCount(i);
+ counter += i;
+ stream.getWriter().putNext();
+
+ assertTrue(stream.getReader().next());
+ assertFalse(stream.getReader().hasRoot());
+ // For each write, the server sends back a metadata message containing the index of the last written batch
+ final ArrowBuf metadata = stream.getReader().getLatestMetadata();
+ assertEquals(counter, metadata.getInt(0));
+ }
+ stream.getWriter().completed();
+
+ while (stream.getReader().next()) {
+ // Drain the stream. Otherwise closing the stream sends a CANCEL which seriously screws with the server.
+ // CANCEL -> runs onCancel handler -> closes the FlightStream early
+ }
+ }
+ }
+
+ /** Test a DoExchange that echoes the client message. */
+ @Test
+ public void testDoExchangeEcho() throws Exception {
+ final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true))));
+ try (final FlightClient.ExchangeReaderWriter stream = client.doExchange(FlightDescriptor.command(EXCHANGE_ECHO));
+ final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ final FlightStream reader = stream.getReader();
+
+ // First try writing metadata without starting the Arrow data stream
+ ArrowBuf buf = allocator.buffer(4);
+ buf.writeInt(42);
+ stream.getWriter().putMetadata(buf);
+ buf = allocator.buffer(4);
+ buf.writeInt(84);
+ stream.getWriter().putMetadata(buf);
+
+ // Ensure that the server echoes the metadata back, also without starting its data stream
+ assertTrue(reader.next());
+ assertFalse(reader.hasRoot());
+ assertEquals(42, reader.getLatestMetadata().getInt(0));
+ assertTrue(reader.next());
+ assertFalse(reader.hasRoot());
+ assertEquals(84, reader.getLatestMetadata().getInt(0));
+
+ // Write data and check that it gets echoed back.
+ IntVector iv = (IntVector) root.getVector("a");
+ iv.allocateNew();
+ stream.getWriter().start(root);
+ for (int i = 0; i < 10; i++) {
+ iv.setSafe(0, i);
+ root.setRowCount(1);
+ stream.getWriter().putNext();
+
+ assertTrue(reader.next());
+ assertNull(reader.getLatestMetadata());
+ assertEquals(root.getSchema(), reader.getSchema());
+ assertEquals(i, ((IntVector) reader.getRoot().getVector("a")).get(0));
+ }
+
+ // Complete the stream so that the server knows not to expect any more messages from us.
+ stream.getWriter().completed();
+ // The server will end its side of the call, so this shouldn't block or indicate that
+ // there is more data.
+ assertFalse("We should not be waiting for any messages", reader.next());
+ }
+ }
+
+ /** Write some data, have it transformed, then read it back. */
+ @Test
+ public void testTransform() throws Exception {
+ final Schema schema = new Schema(Arrays.asList(
+ Field.nullable("a", new ArrowType.Int(32, true)),
+ Field.nullable("b", new ArrowType.Int(32, true))));
+ try (final FlightClient.ExchangeReaderWriter stream =
+ client.doExchange(FlightDescriptor.command(EXCHANGE_TRANSFORM))) {
+ // Write ten batches of data to the stream, where batch N contains N rows of data (N in [0, 10))
+ final FlightStream reader = stream.getReader();
+ final FlightClient.ClientStreamListener writer = stream.getWriter();
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ writer.start(root);
+ for (int batchIndex = 0; batchIndex < 10; batchIndex++) {
+ for (final FieldVector rawVec : root.getFieldVectors()) {
+ final IntVector vec = (IntVector) rawVec;
+ ValueVectorDataPopulator.setVector(vec, IntStream.range(0, batchIndex).boxed().toArray(Integer[]::new));
+ }
+ root.setRowCount(batchIndex);
+ writer.putNext();
+ }
+ }
+ // Indicate that we're done writing so that the server does not expect more data.
+ writer.completed();
+
+ // Read back data. We expect the server to double each value in each row of each batch.
+ assertEquals(schema, reader.getSchema());
+ final VectorSchemaRoot root = reader.getRoot();
+ for (int batchIndex = 0; batchIndex < 10; batchIndex++) {
+ assertTrue("Didn't receive batch #" + batchIndex, reader.next());
+ assertEquals(batchIndex, root.getRowCount());
+ for (final FieldVector rawVec : root.getFieldVectors()) {
+ final IntVector vec = (IntVector) rawVec;
+ for (int row = 0; row < batchIndex; row++) {
+ assertEquals(2 * row, vec.get(row));
+ }
+ }
+ }
+
+ // The server also sends back a metadata-only message containing the message count
+ assertTrue("There should be one extra message", reader.next());
+ assertEquals(10, reader.getLatestMetadata().getInt(0));
+ assertFalse("There should be no more data", reader.next());
+ }
+ }
+
+ /** Write some data, have it transformed, then read it back. Use the zero-copy optimization. */
+ @Test
+ public void testTransformZeroCopy() throws Exception {
+ final int rowsPerBatch = 4096;
+ final Schema schema = new Schema(Arrays.asList(
+ Field.nullable("a", new ArrowType.Int(32, true)),
+ Field.nullable("b", new ArrowType.Int(32, true))));
+ try (final FlightClient.ExchangeReaderWriter stream =
+ client.doExchange(FlightDescriptor.command(EXCHANGE_TRANSFORM))) {
+ // Write ten batches of data to the stream, where batch N contains 1024 rows of data (N in [0, 10))
+ final FlightStream reader = stream.getReader();
+ final FlightClient.ClientStreamListener writer = stream.getWriter();
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ writer.start(root);
+ // Enable the zero-copy optimization
+ writer.setUseZeroCopy(true);
+ for (int batchIndex = 0; batchIndex < 100; batchIndex++) {
+ for (final FieldVector rawVec : root.getFieldVectors()) {
+ final IntVector vec = (IntVector) rawVec;
+ for (int row = 0; row < rowsPerBatch; row++) {
+ // Use a value that'll be different per batch, so we can detect if we accidentally
+ // reuse a buffer (and overwrite a buffer that hasn't yet been sent over the network)
+ vec.setSafe(row, batchIndex + row);
+ }
+ }
+ root.setRowCount(rowsPerBatch);
+ writer.putNext();
+ // Allocate new buffers every time since we don't know if gRPC has written the buffer
+ // to the network yet
+ root.allocateNew();
+ }
+ }
+ // Indicate that we're done writing so that the server does not expect more data.
+ writer.completed();
+
+ // Read back data. We expect the server to double each value in each row of each batch.
+ assertEquals(schema, reader.getSchema());
+ final VectorSchemaRoot root = reader.getRoot();
+ for (int batchIndex = 0; batchIndex < 100; batchIndex++) {
+ assertTrue("Didn't receive batch #" + batchIndex, reader.next());
+ assertEquals(rowsPerBatch, root.getRowCount());
+ for (final FieldVector rawVec : root.getFieldVectors()) {
+ final IntVector vec = (IntVector) rawVec;
+ for (int row = 0; row < rowsPerBatch; row++) {
+ assertEquals(2 * (batchIndex + row), vec.get(row));
+ }
+ }
+ }
+
+ // The server also sends back a metadata-only message containing the message count
+ assertTrue("There should be one extra message", reader.next());
+ assertEquals(100, reader.getLatestMetadata().getInt(0));
+ assertFalse("There should be no more data", reader.next());
+ }
+ }
+
+ /** Have the server immediately cancel; ensure the client doesn't hang. */
+ @Test
+ public void testServerCancel() throws Exception {
+ try (final FlightClient.ExchangeReaderWriter stream =
+ client.doExchange(FlightDescriptor.command(EXCHANGE_CANCEL))) {
+ final FlightStream reader = stream.getReader();
+ final FlightClient.ClientStreamListener writer = stream.getWriter();
+
+ final FlightRuntimeException fre = assertThrows(FlightRuntimeException.class, reader::next);
+ assertEquals(FlightStatusCode.CANCELLED, fre.status().code());
+ assertEquals("expected", fre.status().description());
+
+ // Before, this would hang forever, because the writer checks if the stream is ready and not cancelled.
+ // However, the cancellation flag (was) only updated by reading, and the stream is never ready once the call ends.
+ // The test looks weird since normally, an application shouldn't try to write after the read fails. However,
+ // an application that isn't reading data wouldn't notice, and would instead get stuck on the write.
+ // Here, we read first to avoid a race condition in the test itself.
+ writer.putMetadata(allocator.getEmpty());
+ }
+ }
+
+ /** Have the server immediately cancel; ensure the server cleans up the FlightStream. */
+ @Test
+ public void testServerCancelLeak() throws Exception {
+ try (final FlightClient.ExchangeReaderWriter stream =
+ client.doExchange(FlightDescriptor.command(EXCHANGE_CANCEL))) {
+ final FlightStream reader = stream.getReader();
+ final FlightClient.ClientStreamListener writer = stream.getWriter();
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(Producer.SCHEMA, allocator)) {
+ writer.start(root);
+ final IntVector ints = (IntVector) root.getVector("a");
+ for (int i = 0; i < 128; i++) {
+ for (int row = 0; row < 1024; row++) {
+ ints.setSafe(row, row);
+ }
+ root.setRowCount(1024);
+ writer.putNext();
+ }
+ }
+
+ final FlightRuntimeException fre = assertThrows(FlightRuntimeException.class, reader::next);
+ assertEquals(FlightStatusCode.CANCELLED, fre.status().code());
+ assertEquals("expected", fre.status().description());
+ }
+ }
+
+ /** Have the client cancel without reading; ensure memory is not leaked. */
+ @Test
+ public void testClientCancel() throws Exception {
+ try (final FlightClient.ExchangeReaderWriter stream =
+ client.doExchange(FlightDescriptor.command(EXCHANGE_DO_GET))) {
+ final FlightStream reader = stream.getReader();
+ reader.cancel("", null);
+ // Cancel should be idempotent
+ reader.cancel("", null);
+ }
+ }
+
+ /** Have the client close the stream without reading; ensure memory is not leaked. */
+ @Test
+ public void testClientClose() throws Exception {
+ try (final FlightClient.ExchangeReaderWriter stream =
+ client.doExchange(FlightDescriptor.command(EXCHANGE_DO_GET))) {
+ assertEquals(Producer.SCHEMA, stream.getReader().getSchema());
+ }
+ // Intentionally leak the allocator in this test. gRPC has a bug where it does not wait for all calls to complete
+ // when shutting down the server, so this test will fail otherwise because it closes the allocator while the
+ // server-side call still has memory allocated.
+ // TODO(ARROW-9586): fix this once we track outstanding RPCs outside of gRPC.
+ // https://stackoverflow.com/questions/46716024/
+ allocator = null;
+ client = null;
+ }
+
+ static class Producer extends NoOpFlightProducer {
+ static final Schema SCHEMA = new Schema(
+ Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true))));
+ private final BufferAllocator allocator;
+
+ Producer(BufferAllocator allocator) {
+ this.allocator = allocator;
+ }
+
+ @Override
+ public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) {
+ if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_METADATA_ONLY)) {
+ metadataOnly(context, reader, writer);
+ } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_DO_GET)) {
+ doGet(context, reader, writer);
+ } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_DO_PUT)) {
+ doPut(context, reader, writer);
+ } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_ECHO)) {
+ echo(context, reader, writer);
+ } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_TRANSFORM)) {
+ transform(context, reader, writer);
+ } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_CANCEL)) {
+ cancel(context, reader, writer);
+ } else {
+ writer.error(CallStatus.UNIMPLEMENTED.withDescription("Command not implemented").toRuntimeException());
+ }
+ }
+
+ /** Emulate DoGet. */
+ private void doGet(CallContext context, FlightStream reader, ServerStreamListener writer) {
+ try (VectorSchemaRoot root = VectorSchemaRoot.create(SCHEMA, allocator)) {
+ writer.start(root);
+ root.allocateNew();
+ IntVector iv = (IntVector) root.getVector("a");
+
+ for (int i = 0; i < 100; i += 2) {
+ iv.set(0, i);
+ iv.set(1, i + 1);
+ root.setRowCount(2);
+ writer.putNext();
+ }
+ }
+ writer.completed();
+ }
+
+ /** Emulate DoPut. */
+ private void doPut(CallContext context, FlightStream reader, ServerStreamListener writer) {
+ int counter = 0;
+ while (reader.next()) {
+ if (!reader.hasRoot()) {
+ writer.error(CallStatus.INVALID_ARGUMENT.withDescription("Message has no data").toRuntimeException());
+ return;
+ }
+ counter += reader.getRoot().getRowCount();
+
+ final ArrowBuf pong = allocator.buffer(4);
+ pong.writeInt(counter);
+ writer.putMetadata(pong);
+ }
+ writer.completed();
+ }
+
+ /** Exchange metadata without ever exchanging data. */
+ private void metadataOnly(CallContext context, FlightStream reader, ServerStreamListener writer) {
+ final ArrowBuf buf = allocator.buffer(4);
+ buf.writeInt(42);
+ writer.putMetadata(buf);
+ assertTrue(reader.next());
+ assertNotNull(reader.getLatestMetadata());
+ reader.getLatestMetadata().getReferenceManager().retain();
+ writer.putMetadata(reader.getLatestMetadata());
+ writer.completed();
+ }
+
+ /** Echo the client's response back to it. */
+ private void echo(CallContext context, FlightStream reader, ServerStreamListener writer) {
+ VectorSchemaRoot root = null;
+ VectorLoader loader = null;
+ while (reader.next()) {
+ if (reader.hasRoot()) {
+ if (root == null) {
+ root = VectorSchemaRoot.create(reader.getSchema(), allocator);
+ loader = new VectorLoader(root);
+ writer.start(root);
+ }
+ VectorUnloader unloader = new VectorUnloader(reader.getRoot());
+ try (final ArrowRecordBatch arb = unloader.getRecordBatch()) {
+ loader.load(arb);
+ }
+ if (reader.getLatestMetadata() != null) {
+ reader.getLatestMetadata().getReferenceManager().retain();
+ writer.putNext(reader.getLatestMetadata());
+ } else {
+ writer.putNext();
+ }
+ } else {
+ // Pure metadata
+ reader.getLatestMetadata().getReferenceManager().retain();
+ writer.putMetadata(reader.getLatestMetadata());
+ }
+ }
+ if (root != null) {
+ root.close();
+ }
+ writer.completed();
+ }
+
+ /** Accept a set of messages, then return some result. */
+ private void transform(CallContext context, FlightStream reader, ServerStreamListener writer) {
+ final Schema schema = reader.getSchema();
+ for (final Field field : schema.getFields()) {
+ if (!(field.getType() instanceof ArrowType.Int)) {
+ writer.error(CallStatus.INVALID_ARGUMENT.withDescription("Invalid type: " + field).toRuntimeException());
+ return;
+ }
+ final ArrowType.Int intType = (ArrowType.Int) field.getType();
+ if (!intType.getIsSigned() || intType.getBitWidth() != 32) {
+ writer.error(CallStatus.INVALID_ARGUMENT.withDescription("Must be i32: " + field).toRuntimeException());
+ return;
+ }
+ }
+ int batches = 0;
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ writer.start(root);
+ writer.setUseZeroCopy(true);
+ final VectorLoader loader = new VectorLoader(root);
+ final VectorUnloader unloader = new VectorUnloader(reader.getRoot());
+ while (reader.next()) {
+ try (final ArrowRecordBatch batch = unloader.getRecordBatch()) {
+ loader.load(batch);
+ }
+ batches++;
+ for (final FieldVector rawVec : root.getFieldVectors()) {
+ final IntVector vec = (IntVector) rawVec;
+ for (int i = 0; i < root.getRowCount(); i++) {
+ if (!vec.isNull(i)) {
+ vec.set(i, vec.get(i) * 2);
+ }
+ }
+ }
+ writer.putNext();
+ }
+ }
+ final ArrowBuf count = allocator.buffer(4);
+ count.writeInt(batches);
+ writer.putMetadata(count);
+ writer.completed();
+ }
+
+ /** Immediately cancel the call. */
+ private void cancel(CallContext context, FlightStream reader, ServerStreamListener writer) {
+ writer.error(CallStatus.CANCELLED.withDescription("expected").toRuntimeException());
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestErrorMetadata.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestErrorMetadata.java
new file mode 100644
index 000000000..2c62bc7fa
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestErrorMetadata.java
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import org.apache.arrow.flight.perf.impl.PerfOuterClass;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.junit.Assert;
+import org.junit.Test;
+
+import com.google.protobuf.Any;
+import com.google.protobuf.InvalidProtocolBufferException;
+import com.google.rpc.Status;
+
+import io.grpc.Metadata;
+import io.grpc.StatusRuntimeException;
+import io.grpc.protobuf.ProtoUtils;
+import io.grpc.protobuf.StatusProto;
+
+public class TestErrorMetadata {
+ private static final Metadata.BinaryMarshaller<Status> marshaller =
+ ProtoUtils.metadataMarshaller(Status.getDefaultInstance());
+
+ /** Ensure metadata attached to a gRPC error is propagated. */
+ @Test
+ public void testGrpcMetadata() throws Exception {
+ PerfOuterClass.Perf perf = PerfOuterClass.Perf.newBuilder()
+ .setStreamCount(12)
+ .setRecordsPerBatch(1000)
+ .setRecordsPerStream(1000000L)
+ .build();
+ StatusRuntimeExceptionProducer producer = new StatusRuntimeExceptionProducer(perf);
+ try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ final FlightServer s =
+ FlightTestUtil.getStartedServer(
+ (location) -> {
+ return FlightServer.builder(allocator, location, producer).build();
+ });
+ final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) {
+ final CallStatus flightStatus = FlightTestUtil.assertCode(FlightStatusCode.CANCELLED, () -> {
+ FlightStream stream = client.getStream(new Ticket("abs".getBytes()));
+ stream.next();
+ });
+ PerfOuterClass.Perf newPerf = null;
+ ErrorFlightMetadata metadata = flightStatus.metadata();
+ Assert.assertNotNull(metadata);
+ Assert.assertEquals(2, metadata.keys().size());
+ Assert.assertTrue(metadata.containsKey("grpc-status-details-bin"));
+ Status status = marshaller.parseBytes(metadata.getByte("grpc-status-details-bin"));
+ for (Any details : status.getDetailsList()) {
+ if (details.is(PerfOuterClass.Perf.class)) {
+ try {
+ newPerf = details.unpack(PerfOuterClass.Perf.class);
+ } catch (InvalidProtocolBufferException e) {
+ Assert.fail();
+ }
+ }
+ }
+ Assert.assertNotNull(newPerf);
+ Assert.assertEquals(perf, newPerf);
+ }
+ }
+
+ /** Ensure metadata attached to a Flight error is propagated. */
+ @Test
+ public void testFlightMetadata() throws Exception {
+ try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ final FlightServer s =
+ FlightTestUtil.getStartedServer(
+ (location) -> FlightServer.builder(allocator, location, new CallStatusProducer()).build());
+ final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) {
+ CallStatus flightStatus = FlightTestUtil.assertCode(FlightStatusCode.INVALID_ARGUMENT, () -> {
+ FlightStream stream = client.getStream(new Ticket(new byte[0]));
+ stream.next();
+ });
+ ErrorFlightMetadata metadata = flightStatus.metadata();
+ Assert.assertNotNull(metadata);
+ Assert.assertEquals("foo", metadata.get("x-foo"));
+ Assert.assertArrayEquals(new byte[]{1}, metadata.getByte("x-bar-bin"));
+
+ flightStatus = FlightTestUtil.assertCode(FlightStatusCode.INVALID_ARGUMENT, () -> {
+ client.getInfo(FlightDescriptor.command(new byte[0]));
+ });
+ metadata = flightStatus.metadata();
+ Assert.assertNotNull(metadata);
+ Assert.assertEquals("foo", metadata.get("x-foo"));
+ Assert.assertArrayEquals(new byte[]{1}, metadata.getByte("x-bar-bin"));
+ }
+ }
+
+ private static class StatusRuntimeExceptionProducer extends NoOpFlightProducer {
+ private final PerfOuterClass.Perf perf;
+
+ private StatusRuntimeExceptionProducer(PerfOuterClass.Perf perf) {
+ this.perf = perf;
+ }
+
+ @Override
+ public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
+ StatusRuntimeException sre = StatusProto.toStatusRuntimeException(Status.newBuilder()
+ .setCode(1)
+ .setMessage("Testing 1 2 3")
+ .addDetails(Any.pack(perf, "arrow/meta/types"))
+ .build());
+ listener.error(sre);
+ }
+ }
+
+ private static class CallStatusProducer extends NoOpFlightProducer {
+ ErrorFlightMetadata metadata;
+
+ CallStatusProducer() {
+ this.metadata = new ErrorFlightMetadata();
+ metadata.insert("x-foo", "foo");
+ metadata.insert("x-bar-bin", new byte[]{1});
+ }
+
+ @Override
+ public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
+ listener.error(CallStatus.INVALID_ARGUMENT.withDescription("Failed").withMetadata(metadata).toRuntimeException());
+ }
+
+ @Override
+ public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
+ throw CallStatus.INVALID_ARGUMENT.withDescription("Failed").withMetadata(metadata).toRuntimeException();
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightClient.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightClient.java
new file mode 100644
index 000000000..30e351e94
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightClient.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.nio.charset.StandardCharsets;
+import java.util.Collections;
+import java.util.List;
+
+import org.apache.arrow.flight.FlightClient.ClientStreamListener;
+import org.apache.arrow.flight.TestBasicOperation.Producer;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.ValueVector;
+import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.VectorLoader;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.VectorUnloader;
+import org.apache.arrow.vector.dictionary.Dictionary;
+import org.apache.arrow.vector.dictionary.DictionaryEncoder;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.FieldType;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.jupiter.api.Assertions;
+
+public class TestFlightClient {
+ /**
+ * ARROW-5063: make sure two clients to the same location can be closed independently.
+ */
+ @Test
+ public void independentShutdown() throws Exception {
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+ final FlightServer server = FlightTestUtil.getStartedServer(
+ location -> FlightServer.builder(allocator, location,
+ new Producer(allocator)).build())) {
+ final Location location = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort());
+ final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true))));
+ try (final FlightClient client1 = FlightClient.builder(allocator, location).build();
+ final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ // Use startPut as this ensures the RPC won't finish until we want it to
+ final ClientStreamListener listener = client1.startPut(FlightDescriptor.path("test"), root,
+ new AsyncPutListener());
+ try (final FlightClient client2 = FlightClient.builder(allocator, location).build()) {
+ client2.listActions().forEach(actionType -> Assert.assertNotNull(actionType.getType()));
+ }
+ listener.completed();
+ listener.getResult();
+ }
+ }
+ }
+
+ /**
+ * ARROW-5978: make sure that we can properly close a client/stream after requesting dictionaries.
+ */
+ @Ignore // Unfortunately this test is flaky in CI.
+ @Test
+ public void freeDictionaries() throws Exception {
+ final Schema expectedSchema = new Schema(Collections
+ .singletonList(new Field("encoded",
+ new FieldType(true, new ArrowType.Int(32, true), new DictionaryEncoding(1L, false, null)), null)));
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+ final BufferAllocator serverAllocator = allocator.newChildAllocator("flight-server", 0, Integer.MAX_VALUE);
+ final FlightServer server = FlightTestUtil.getStartedServer(
+ location -> FlightServer.builder(serverAllocator, location,
+ new DictionaryProducer(serverAllocator)).build())) {
+ final Location location = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort());
+ try (final FlightClient client = FlightClient.builder(allocator, location).build()) {
+ try (final FlightStream stream = client.getStream(new Ticket(new byte[0]))) {
+ Assert.assertTrue(stream.next());
+ Assert.assertNotNull(stream.getDictionaryProvider().lookup(1));
+ final VectorSchemaRoot root = stream.getRoot();
+ Assert.assertEquals(expectedSchema, root.getSchema());
+ Assert.assertEquals(6, root.getVector("encoded").getValueCount());
+ try (final ValueVector decoded = DictionaryEncoder
+ .decode(root.getVector("encoded"), stream.getDictionaryProvider().lookup(1))) {
+ Assert.assertFalse(decoded.isNull(1));
+ Assert.assertTrue(decoded instanceof VarCharVector);
+ Assert.assertArrayEquals("one".getBytes(StandardCharsets.UTF_8), ((VarCharVector) decoded).get(1));
+ }
+ Assert.assertFalse(stream.next());
+ }
+ // Closing stream fails if it doesn't free dictionaries; closing dictionaries fails (refcount goes negative)
+ // if reference isn't retained in ArrowMessage
+ }
+ }
+ }
+
+ /**
+ * ARROW-5978: make sure that dictionary ownership can't be claimed twice.
+ */
+ @Ignore // Unfortunately this test is flaky in CI.
+ @Test
+ public void ownDictionaries() throws Exception {
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+ final BufferAllocator serverAllocator = allocator.newChildAllocator("flight-server", 0, Integer.MAX_VALUE);
+ final FlightServer server = FlightTestUtil.getStartedServer(
+ location -> FlightServer.builder(serverAllocator, location,
+ new DictionaryProducer(serverAllocator)).build())) {
+ final Location location = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort());
+ try (final FlightClient client = FlightClient.builder(allocator, location).build()) {
+ try (final FlightStream stream = client.getStream(new Ticket(new byte[0]))) {
+ Assert.assertTrue(stream.next());
+ Assert.assertFalse(stream.next());
+ final DictionaryProvider provider = stream.takeDictionaryOwnership();
+ Assertions.assertThrows(IllegalStateException.class, stream::takeDictionaryOwnership);
+ Assertions.assertThrows(IllegalStateException.class, stream::getDictionaryProvider);
+ DictionaryUtils.closeDictionaries(stream.getSchema(), provider);
+ }
+ }
+ }
+ }
+
+ /**
+ * ARROW-5978: make sure that dictionaries can be used after closing the stream.
+ */
+ @Ignore // Unfortunately this test is flaky in CI.
+ @Test
+ public void useDictionariesAfterClose() throws Exception {
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+ final BufferAllocator serverAllocator = allocator.newChildAllocator("flight-server", 0, Integer.MAX_VALUE);
+ final FlightServer server = FlightTestUtil.getStartedServer(
+ location -> FlightServer.builder(serverAllocator, location, new DictionaryProducer(serverAllocator))
+ .build())) {
+ final Location location = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort());
+ try (final FlightClient client = FlightClient.builder(allocator, location).build()) {
+ final VectorSchemaRoot root;
+ final DictionaryProvider provider;
+ try (final FlightStream stream = client.getStream(new Ticket(new byte[0]))) {
+ final VectorUnloader unloader = new VectorUnloader(stream.getRoot());
+ root = VectorSchemaRoot.create(stream.getSchema(), allocator);
+ final VectorLoader loader = new VectorLoader(root);
+ while (stream.next()) {
+ try (final ArrowRecordBatch arb = unloader.getRecordBatch()) {
+ loader.load(arb);
+ }
+ }
+ provider = stream.takeDictionaryOwnership();
+ }
+ try (final ValueVector decoded = DictionaryEncoder
+ .decode(root.getVector("encoded"), provider.lookup(1))) {
+ Assert.assertFalse(decoded.isNull(1));
+ Assert.assertTrue(decoded instanceof VarCharVector);
+ Assert.assertArrayEquals("one".getBytes(StandardCharsets.UTF_8), ((VarCharVector) decoded).get(1));
+ }
+ root.close();
+ DictionaryUtils.closeDictionaries(root.getSchema(), provider);
+ }
+ }
+ }
+
+ static class DictionaryProducer extends NoOpFlightProducer {
+
+ private final BufferAllocator allocator;
+
+ public DictionaryProducer(BufferAllocator allocator) {
+ this.allocator = allocator;
+ }
+
+ @Override
+ public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
+ final byte[] zero = "zero".getBytes(StandardCharsets.UTF_8);
+ final byte[] one = "one".getBytes(StandardCharsets.UTF_8);
+ final byte[] two = "two".getBytes(StandardCharsets.UTF_8);
+ try (final VarCharVector dictionaryVector = newVarCharVector("dictionary", allocator)) {
+ final DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider();
+
+ dictionaryVector.allocateNew(512, 3);
+ dictionaryVector.setSafe(0, zero, 0, zero.length);
+ dictionaryVector.setSafe(1, one, 0, one.length);
+ dictionaryVector.setSafe(2, two, 0, two.length);
+ dictionaryVector.setValueCount(3);
+
+ final Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null));
+ provider.put(dictionary);
+
+ final FieldVector encodedVector;
+ try (final VarCharVector unencoded = newVarCharVector("encoded", allocator)) {
+ unencoded.allocateNewSafe();
+ unencoded.set(1, one);
+ unencoded.set(2, two);
+ unencoded.set(3, zero);
+ unencoded.set(4, two);
+ unencoded.setValueCount(6);
+ encodedVector = (FieldVector) DictionaryEncoder.encode(unencoded, dictionary);
+ }
+
+ final List<Field> fields = Collections.singletonList(encodedVector.getField());
+ final List<FieldVector> vectors = Collections.singletonList(encodedVector);
+
+ try (final VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors, encodedVector.getValueCount())) {
+ listener.start(root, provider);
+ listener.putNext();
+ listener.completed();
+ }
+ }
+ }
+
+ private static VarCharVector newVarCharVector(String name, BufferAllocator allocator) {
+ return (VarCharVector)
+ FieldType.nullable(new ArrowType.Utf8()).createNewSingleVector(name, allocator, null);
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java
new file mode 100644
index 000000000..65ef12a8a
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import static org.junit.jupiter.api.Assertions.fail;
+
+import org.apache.arrow.flight.impl.Flight;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import io.grpc.stub.ServerCallStreamObserver;
+
+public class TestFlightService {
+
+ private BufferAllocator allocator;
+
+ @Before
+ public void setup() {
+ allocator = new RootAllocator(Long.MAX_VALUE);
+ }
+
+ @After
+ public void cleanup() throws Exception {
+ AutoCloseables.close(allocator);
+ }
+
+ @Test
+ public void testFlightServiceWithNoAuthHandlerOrInterceptors() {
+ // This test is for ARROW-10491. There was a bug where FlightService would try to access the RequestContext,
+ // but the RequestContext was getting set to null because no interceptors were active to initialize it
+ // when using FlightService directly rather than starting up a FlightServer.
+
+ // Arrange
+ final FlightProducer producer = new NoOpFlightProducer() {
+ @Override
+ public void getStream(CallContext context, Ticket ticket,
+ ServerStreamListener listener) {
+ listener.completed();
+ }
+ };
+
+ // This response observer notifies that the test failed if onError() is called.
+ final ServerCallStreamObserver<ArrowMessage> observer = new ServerCallStreamObserver<ArrowMessage>() {
+ @Override
+ public boolean isCancelled() {
+ return false;
+ }
+
+ @Override
+ public void setOnCancelHandler(Runnable runnable) {
+
+ }
+
+ @Override
+ public void setCompression(String s) {
+
+ }
+
+ @Override
+ public boolean isReady() {
+ return false;
+ }
+
+ @Override
+ public void setOnReadyHandler(Runnable runnable) {
+
+ }
+
+ @Override
+ public void disableAutoInboundFlowControl() {
+
+ }
+
+ @Override
+ public void request(int i) {
+
+ }
+
+ @Override
+ public void setMessageCompression(boolean b) {
+
+ }
+
+ @Override
+ public void onNext(ArrowMessage arrowMessage) {
+
+ }
+
+ @Override
+ public void onError(Throwable throwable) {
+ fail(throwable);
+ }
+
+ @Override
+ public void onCompleted() {
+
+ }
+ };
+ final FlightService flightService = new FlightService(allocator, producer, null, null);
+
+ // Act
+ flightService.doGetCustom(Flight.Ticket.newBuilder().build(), observer);
+
+ // fail() would have been called if an error happened during doGetCustom(), so this test passed.
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLargeMessage.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLargeMessage.java
new file mode 100644
index 000000000..629b6f5eb
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLargeMessage.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Stream;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.FieldType;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestLargeMessage {
+ /**
+ * Make sure a Flight client accepts large message payloads by default.
+ */
+ @Test
+ public void getLargeMessage() throws Exception {
+ try (final BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
+ final Producer producer = new Producer(a);
+ final FlightServer s =
+ FlightTestUtil.getStartedServer((location) -> FlightServer.builder(a, location, producer).build())) {
+
+ try (FlightClient client = FlightClient.builder(a, s.getLocation()).build()) {
+ try (FlightStream stream = client.getStream(new Ticket(new byte[]{}));
+ VectorSchemaRoot root = stream.getRoot()) {
+ while (stream.next()) {
+ for (final Field field : root.getSchema().getFields()) {
+ int value = 0;
+ final IntVector iv = (IntVector) root.getVector(field.getName());
+ for (int i = 0; i < root.getRowCount(); i++) {
+ Assert.assertEquals(value, iv.get(i));
+ value++;
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Make sure a Flight server accepts large message payloads by default.
+ */
+ @Test
+ public void putLargeMessage() throws Exception {
+ try (final BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
+ final Producer producer = new Producer(a);
+ final FlightServer s =
+ FlightTestUtil.getStartedServer((location) -> FlightServer.builder(a, location, producer).build()
+ )) {
+
+ try (FlightClient client = FlightClient.builder(a, s.getLocation()).build();
+ BufferAllocator testAllocator = a.newChildAllocator("testcase", 0, Long.MAX_VALUE);
+ VectorSchemaRoot root = generateData(testAllocator)) {
+ final FlightClient.ClientStreamListener listener = client.startPut(FlightDescriptor.path("hello"), root,
+ new AsyncPutListener());
+ listener.putNext();
+ listener.completed();
+ listener.getResult();
+ }
+ }
+ }
+
+ private static VectorSchemaRoot generateData(BufferAllocator allocator) {
+ final int size = 128 * 1024;
+ final List<String> fieldNames = Arrays.asList("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10");
+ final Stream<Field> fields = fieldNames
+ .stream()
+ .map(fieldName -> new Field(fieldName, FieldType.nullable(new ArrowType.Int(32, true)), null));
+ final Schema schema = new Schema(fields::iterator, null);
+
+ final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
+ root.allocateNew();
+ for (final String fieldName : fieldNames) {
+ final IntVector iv = (IntVector) root.getVector(fieldName);
+ iv.setValueCount(size);
+ for (int i = 0; i < size; i++) {
+ iv.set(i, i);
+ }
+ }
+ root.setRowCount(size);
+ return root;
+ }
+
+ private static class Producer implements FlightProducer, AutoCloseable {
+ private final BufferAllocator allocator;
+
+ Producer(BufferAllocator allocator) {
+ this.allocator = allocator;
+ }
+
+ @Override
+ public void getStream(CallContext context, Ticket ticket,
+ ServerStreamListener listener) {
+ try (VectorSchemaRoot root = generateData(allocator)) {
+ listener.start(root);
+ listener.putNext();
+ listener.completed();
+ }
+ }
+
+ @Override
+ public void listFlights(CallContext context, Criteria criteria,
+ StreamListener<FlightInfo> listener) {
+
+ }
+
+ @Override
+ public FlightInfo getFlightInfo(CallContext context,
+ FlightDescriptor descriptor) {
+ return null;
+ }
+
+ @Override
+ public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream) {
+ return () -> {
+ try (VectorSchemaRoot root = flightStream.getRoot()) {
+ while (flightStream.next()) {
+ ;
+ }
+ }
+ };
+ }
+
+ @Override
+ public void doAction(CallContext context, Action action,
+ StreamListener<Result> listener) {
+ listener.onCompleted();
+ }
+
+ @Override
+ public void listActions(CallContext context,
+ StreamListener<ActionType> listener) {
+
+ }
+
+ @Override
+ public void close() throws Exception {
+ allocator.close();
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLeak.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLeak.java
new file mode 100644
index 000000000..6e2870499
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLeak.java
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.Arrays;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.Float8Vector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.types.FloatingPointPrecision;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.Test;
+
+/**
+ * Tests for scenarios where Flight could leak memory.
+ */
+public class TestLeak {
+
+ private static final int ROWS = 2048;
+
+ private static Schema getSchema() {
+ return new Schema(Arrays.asList(
+ Field.nullable("0", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
+ Field.nullable("1", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
+ Field.nullable("2", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
+ Field.nullable("3", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
+ Field.nullable("4", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
+ Field.nullable("5", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
+ Field.nullable("6", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
+ Field.nullable("7", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
+ Field.nullable("8", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
+ Field.nullable("9", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
+ Field.nullable("10", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))
+ ));
+ }
+
+ /**
+ * Ensure that if the client cancels, the server does not leak memory.
+ *
+ * <p>In gRPC, canceling the stream from the client sends an event to the server. Once processed, gRPC will start
+ * silently rejecting messages sent by the server. However, Flight depends on gRPC processing these messages in order
+ * to free the associated memory.
+ */
+ @Test
+ public void testCancelingDoGetDoesNotLeak() throws Exception {
+ final CountDownLatch callFinished = new CountDownLatch(1);
+ try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ final FlightServer s =
+ FlightTestUtil.getStartedServer(
+ (location) -> FlightServer.builder(allocator, location, new LeakFlightProducer(allocator, callFinished))
+ .build());
+ final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) {
+
+ final FlightStream stream = client.getStream(new Ticket(new byte[0]));
+ stream.getRoot();
+ stream.cancel("Cancel", null);
+
+ // Wait for the call to finish. (Closing the allocator while a call is ongoing is a guaranteed leak.)
+ callFinished.await(60, TimeUnit.SECONDS);
+
+ s.shutdown();
+ s.awaitTermination();
+ }
+ }
+
+ @Test
+ public void testCancelingDoPutDoesNotBlock() throws Exception {
+ final CountDownLatch callFinished = new CountDownLatch(1);
+ try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ final FlightServer s =
+ FlightTestUtil.getStartedServer(
+ (location) -> FlightServer.builder(allocator, location, new LeakFlightProducer(allocator, callFinished))
+ .build());
+ final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) {
+
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(getSchema(), allocator)) {
+ final FlightDescriptor descriptor = FlightDescriptor.command(new byte[0]);
+ final SyncPutListener listener = new SyncPutListener();
+ final FlightClient.ClientStreamListener stream = client.startPut(descriptor, root, listener);
+ // Wait for the server to cancel
+ callFinished.await(60, TimeUnit.SECONDS);
+
+ for (int col = 0; col < 11; col++) {
+ final Float8Vector vector = (Float8Vector) root.getVector(Integer.toString(col));
+ vector.allocateNew();
+ for (int row = 0; row < ROWS; row++) {
+ vector.setSafe(row, 10.);
+ }
+ }
+ root.setRowCount(ROWS);
+ // Unlike DoGet, this method fairly reliably will write the message to the stream, so even without the fix
+ // for ARROW-7343, this won't leak memory.
+ // However, it will block if FlightClient doesn't check for cancellation.
+ stream.putNext();
+ stream.completed();
+ }
+
+ s.shutdown();
+ s.awaitTermination();
+ }
+ }
+
+ /**
+ * A FlightProducer that always produces a fixed data stream with metadata on the side.
+ */
+ private static class LeakFlightProducer extends NoOpFlightProducer {
+
+ private final BufferAllocator allocator;
+ private final CountDownLatch callFinished;
+
+ public LeakFlightProducer(BufferAllocator allocator, CountDownLatch callFinished) {
+ this.allocator = allocator;
+ this.callFinished = callFinished;
+ }
+
+ @Override
+ public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
+ BufferAllocator childAllocator = allocator.newChildAllocator("foo", 0, Long.MAX_VALUE);
+ VectorSchemaRoot root = VectorSchemaRoot.create(TestLeak.getSchema(), childAllocator);
+ root.allocateNew();
+ listener.start(root);
+
+ // We can't poll listener#isCancelled since gRPC has two distinct "is cancelled" flags.
+ // TODO: should we continue leaking gRPC semantics? Can we even avoid this?
+ listener.setOnCancelHandler(() -> {
+ try {
+ for (int col = 0; col < 11; col++) {
+ final Float8Vector vector = (Float8Vector) root.getVector(Integer.toString(col));
+ vector.allocateNew();
+ for (int row = 0; row < ROWS; row++) {
+ vector.setSafe(row, 10.);
+ }
+ }
+ root.setRowCount(ROWS);
+ // Once the call is "really cancelled" (setOnCancelListener has run/is running), this call is actually a
+ // no-op on the gRPC side and will leak the ArrowMessage unless Flight checks for this.
+ listener.putNext();
+ listener.completed();
+ } finally {
+ try {
+ root.close();
+ childAllocator.close();
+ } finally {
+ // Don't let the test hang if we throw above
+ callFinished.countDown();
+ }
+ }
+ });
+ }
+
+ @Override
+ public Runnable acceptPut(CallContext context,
+ FlightStream flightStream, StreamListener<PutResult> ackStream) {
+ return () -> {
+ flightStream.getRoot();
+ ackStream.onError(CallStatus.CANCELLED.withDescription("CANCELLED").toRuntimeException());
+ callFinished.countDown();
+ ackStream.onCompleted();
+ };
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java
new file mode 100644
index 000000000..83a694bf3
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java
@@ -0,0 +1,319 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+import java.util.Collections;
+
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.ipc.message.IpcOption;
+import org.apache.arrow.vector.types.MetadataVersion;
+import org.apache.arrow.vector.types.UnionMode;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+/**
+ * Test clients/servers with different metadata versions.
+ */
+public class TestMetadataVersion {
+ private static BufferAllocator allocator;
+ private static Schema schema;
+ private static IpcOption optionV4;
+ private static IpcOption optionV5;
+ private static Schema unionSchema;
+
+ @BeforeClass
+ public static void setUpClass() {
+ allocator = new RootAllocator(Integer.MAX_VALUE);
+ schema = new Schema(Collections.singletonList(Field.nullable("foo", new ArrowType.Int(32, true))));
+ unionSchema = new Schema(
+ Collections.singletonList(Field.nullable("union", new ArrowType.Union(UnionMode.Dense, new int[]{0}))));
+
+ // avoid writing legacy ipc format by default
+ optionV4 = new IpcOption(false, MetadataVersion.V4);
+ optionV5 = IpcOption.DEFAULT;
+ }
+
+ @AfterClass
+ public static void tearDownClass() {
+ allocator.close();
+ }
+
+ @Test
+ public void testGetFlightInfoV4() throws Exception {
+ try (final FlightServer server = startServer(optionV4);
+ final FlightClient client = connect(server)) {
+ final FlightInfo result = client.getInfo(FlightDescriptor.command(new byte[0]));
+ assertEquals(schema, result.getSchema());
+ }
+ }
+
+ @Test
+ public void testGetSchemaV4() throws Exception {
+ try (final FlightServer server = startServer(optionV4);
+ final FlightClient client = connect(server)) {
+ final SchemaResult result = client.getSchema(FlightDescriptor.command(new byte[0]));
+ assertEquals(schema, result.getSchema());
+ }
+ }
+
+ @Test
+ public void testUnionCheck() throws Exception {
+ assertThrows(IllegalArgumentException.class, () -> new SchemaResult(unionSchema, optionV4));
+ assertThrows(IllegalArgumentException.class, () ->
+ new FlightInfo(unionSchema, FlightDescriptor.command(new byte[0]), Collections.emptyList(), -1, -1, optionV4));
+ try (final FlightServer server = startServer(optionV4);
+ final FlightClient client = connect(server);
+ final FlightStream stream = client.getStream(new Ticket("union".getBytes(StandardCharsets.UTF_8)))) {
+ final FlightRuntimeException err = assertThrows(FlightRuntimeException.class, stream::next);
+ assertTrue(err.getMessage(), err.getMessage().contains("Cannot write union with V4 metadata"));
+ }
+
+ try (final FlightServer server = startServer(optionV4);
+ final FlightClient client = connect(server);
+ final VectorSchemaRoot root = VectorSchemaRoot.create(unionSchema, allocator)) {
+ final FlightDescriptor descriptor = FlightDescriptor.command(new byte[0]);
+ final SyncPutListener reader = new SyncPutListener();
+ final FlightClient.ClientStreamListener listener = client.startPut(descriptor, reader);
+ final IllegalArgumentException err = assertThrows(IllegalArgumentException.class,
+ () -> listener.start(root, null, optionV4));
+ assertTrue(err.getMessage(), err.getMessage().contains("Cannot write union with V4 metadata"));
+ }
+ }
+
+ @Test
+ public void testPutV4() throws Exception {
+ try (final FlightServer server = startServer(optionV4);
+ final FlightClient client = connect(server);
+ final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ generateData(root);
+ final FlightDescriptor descriptor = FlightDescriptor.command(new byte[0]);
+ final SyncPutListener reader = new SyncPutListener();
+ final FlightClient.ClientStreamListener listener = client.startPut(descriptor, reader);
+ listener.start(root, null, optionV4);
+ listener.putNext();
+ listener.completed();
+ listener.getResult();
+ }
+ }
+
+ @Test
+ public void testGetV4() throws Exception {
+ try (final FlightServer server = startServer(optionV4);
+ final FlightClient client = connect(server);
+ final FlightStream stream = client.getStream(new Ticket(new byte[0]))) {
+ assertTrue(stream.next());
+ assertEquals(optionV4.metadataVersion, stream.metadataVersion);
+ validateRoot(stream.getRoot());
+ assertFalse(stream.next());
+ }
+ }
+
+ @Test
+ public void testExchangeV4ToV5() throws Exception {
+ try (final FlightServer server = startServer(optionV5);
+ final FlightClient client = connect(server);
+ final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
+ final FlightClient.ExchangeReaderWriter stream = client.doExchange(FlightDescriptor.command(new byte[0]))) {
+ stream.getWriter().start(root, null, optionV4);
+ generateData(root);
+ stream.getWriter().putNext();
+ stream.getWriter().completed();
+ assertTrue(stream.getReader().next());
+ assertEquals(optionV5.metadataVersion, stream.getReader().metadataVersion);
+ validateRoot(stream.getReader().getRoot());
+ assertFalse(stream.getReader().next());
+ }
+ }
+
+ @Test
+ public void testExchangeV5ToV4() throws Exception {
+ try (final FlightServer server = startServer(optionV4);
+ final FlightClient client = connect(server);
+ final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
+ final FlightClient.ExchangeReaderWriter stream = client.doExchange(FlightDescriptor.command(new byte[0]))) {
+ stream.getWriter().start(root, null, optionV5);
+ generateData(root);
+ stream.getWriter().putNext();
+ stream.getWriter().completed();
+ assertTrue(stream.getReader().next());
+ assertEquals(optionV4.metadataVersion, stream.getReader().metadataVersion);
+ validateRoot(stream.getReader().getRoot());
+ assertFalse(stream.getReader().next());
+ }
+ }
+
+ @Test
+ public void testExchangeV4ToV4() throws Exception {
+ try (final FlightServer server = startServer(optionV4);
+ final FlightClient client = connect(server);
+ final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
+ final FlightClient.ExchangeReaderWriter stream = client.doExchange(FlightDescriptor.command(new byte[0]))) {
+ stream.getWriter().start(root, null, optionV4);
+ generateData(root);
+ stream.getWriter().putNext();
+ stream.getWriter().completed();
+ assertTrue(stream.getReader().next());
+ assertEquals(optionV4.metadataVersion, stream.getReader().metadataVersion);
+ validateRoot(stream.getReader().getRoot());
+ assertFalse(stream.getReader().next());
+ }
+ }
+
+ private static void generateData(VectorSchemaRoot root) {
+ assertEquals(schema, root.getSchema());
+ final IntVector vector = (IntVector) root.getVector("foo");
+ vector.setSafe(0, 0);
+ vector.setSafe(1, 1);
+ vector.setSafe(2, 4);
+ root.setRowCount(3);
+ }
+
+ private static void validateRoot(VectorSchemaRoot root) {
+ assertEquals(schema, root.getSchema());
+ assertEquals(3, root.getRowCount());
+ final IntVector vector = (IntVector) root.getVector("foo");
+ assertEquals(0, vector.get(0));
+ assertEquals(1, vector.get(1));
+ assertEquals(4, vector.get(2));
+ }
+
+ FlightServer startServer(IpcOption option) throws Exception {
+ Location location = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, 0);
+ VersionFlightProducer producer = new VersionFlightProducer(allocator, option);
+ final FlightServer server = FlightServer.builder(allocator, location, producer).build();
+ server.start();
+ return server;
+ }
+
+ FlightClient connect(FlightServer server) {
+ Location location = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort());
+ return FlightClient.builder(allocator, location).build();
+ }
+
+ static final class VersionFlightProducer extends NoOpFlightProducer {
+ private final BufferAllocator allocator;
+ private final IpcOption option;
+
+ VersionFlightProducer(BufferAllocator allocator, IpcOption option) {
+ this.allocator = allocator;
+ this.option = option;
+ }
+
+ @Override
+ public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
+ return new FlightInfo(schema, descriptor, Collections.emptyList(), -1, -1, option);
+ }
+
+ @Override
+ public SchemaResult getSchema(CallContext context, FlightDescriptor descriptor) {
+ return new SchemaResult(schema, option);
+ }
+
+ @Override
+ public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
+ if (Arrays.equals("union".getBytes(StandardCharsets.UTF_8), ticket.getBytes())) {
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(unionSchema, allocator)) {
+ listener.start(root, null, option);
+ } catch (IllegalArgumentException e) {
+ listener.error(CallStatus.INTERNAL.withCause(e).withDescription(e.getMessage()).toRuntimeException());
+ return;
+ }
+ listener.error(CallStatus.INTERNAL.withDescription("Expected exception not raised").toRuntimeException());
+ return;
+ }
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ listener.start(root, null, option);
+ generateData(root);
+ listener.putNext();
+ listener.completed();
+ }
+ }
+
+ @Override
+ public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream) {
+ return () -> {
+ try {
+ assertTrue(flightStream.next());
+ assertEquals(option.metadataVersion, flightStream.metadataVersion);
+ validateRoot(flightStream.getRoot());
+ } catch (AssertionError err) {
+ // gRPC doesn't propagate stack traces across the wire.
+ err.printStackTrace();
+ ackStream.onError(CallStatus.INVALID_ARGUMENT
+ .withCause(err)
+ .withDescription("Server assertion failed: " + err)
+ .toRuntimeException());
+ return;
+ } catch (RuntimeException err) {
+ err.printStackTrace();
+ ackStream.onError(CallStatus.INTERNAL
+ .withCause(err)
+ .withDescription("Server assertion failed: " + err)
+ .toRuntimeException());
+ return;
+ }
+ ackStream.onCompleted();
+ };
+ }
+
+ @Override
+ public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) {
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ try {
+ assertTrue(reader.next());
+ validateRoot(reader.getRoot());
+ assertFalse(reader.next());
+ } catch (AssertionError err) {
+ // gRPC doesn't propagate stack traces across the wire.
+ err.printStackTrace();
+ writer.error(CallStatus.INVALID_ARGUMENT
+ .withCause(err)
+ .withDescription("Server assertion failed: " + err)
+ .toRuntimeException());
+ return;
+ } catch (RuntimeException err) {
+ err.printStackTrace();
+ writer.error(CallStatus.INTERNAL
+ .withCause(err)
+ .withDescription("Server assertion failed: " + err)
+ .toRuntimeException());
+ return;
+ }
+
+ writer.start(root, null, option);
+ generateData(root);
+ writer.putNext();
+ writer.completed();
+ }
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerMiddleware.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerMiddleware.java
new file mode 100644
index 000000000..1f3e35ca3
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerMiddleware.java
@@ -0,0 +1,360 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.BiConsumer;
+
+import org.apache.arrow.flight.FlightClient.ClientStreamListener;
+import org.apache.arrow.flight.FlightServerMiddleware.Factory;
+import org.apache.arrow.flight.FlightServerMiddleware.Key;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class TestServerMiddleware {
+
+ private static final RuntimeException EXPECTED_EXCEPTION = new RuntimeException("test");
+
+ /**
+ * Make sure errors in DoPut are intercepted.
+ */
+ @Test
+ public void doPutErrors() {
+ test(
+ new ErrorProducer(EXPECTED_EXCEPTION),
+ (allocator, client) -> {
+ final FlightDescriptor descriptor = FlightDescriptor.path("test");
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(new Schema(Collections.emptyList()), allocator)) {
+ final ClientStreamListener listener = client.startPut(descriptor, root, new SyncPutListener());
+ listener.completed();
+ FlightTestUtil.assertCode(FlightStatusCode.INTERNAL, listener::getResult);
+ }
+ }, (recorder) -> {
+ final CallStatus status = recorder.statusFuture.get();
+ Assert.assertNotNull(status);
+ Assert.assertNotNull(status.cause());
+ Assert.assertEquals(FlightStatusCode.INTERNAL, status.code());
+ });
+ // Check the status after server shutdown (to make sure gRPC finishes pending calls on the server side)
+ }
+
+ /**
+ * Make sure custom error codes in DoPut are intercepted.
+ */
+ @Test
+ public void doPutCustomCode() {
+ test(
+ new ErrorProducer(CallStatus.UNAVAILABLE.withDescription("description").toRuntimeException()),
+ (allocator, client) -> {
+ final FlightDescriptor descriptor = FlightDescriptor.path("test");
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(new Schema(Collections.emptyList()), allocator)) {
+ final ClientStreamListener listener = client.startPut(descriptor, root, new SyncPutListener());
+ listener.completed();
+ FlightTestUtil.assertCode(FlightStatusCode.UNAVAILABLE, listener::getResult);
+ }
+ }, (recorder) -> {
+ final CallStatus status = recorder.statusFuture.get();
+ Assert.assertNotNull(status);
+ Assert.assertNull(status.cause());
+ Assert.assertEquals(FlightStatusCode.UNAVAILABLE, status.code());
+ Assert.assertEquals("description", status.description());
+ });
+ }
+
+ /**
+ * Make sure uncaught exceptions in DoPut are intercepted.
+ */
+ @Test
+ public void doPutUncaught() {
+ test(new ServerErrorProducer(EXPECTED_EXCEPTION),
+ (allocator, client) -> {
+ final FlightDescriptor descriptor = FlightDescriptor.path("test");
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(new Schema(Collections.emptyList()), allocator)) {
+ final ClientStreamListener listener = client.startPut(descriptor, root, new SyncPutListener());
+ listener.completed();
+ listener.getResult();
+ }
+ }, (recorder) -> {
+ final CallStatus status = recorder.statusFuture.get();
+ final Throwable err = recorder.errFuture.get();
+ Assert.assertNotNull(status);
+ Assert.assertEquals(FlightStatusCode.OK, status.code());
+ Assert.assertNull(status.cause());
+ Assert.assertNotNull(err);
+ Assert.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage());
+ });
+ }
+
+ @Test
+ public void listFlightsUncaught() {
+ test(new ServerErrorProducer(EXPECTED_EXCEPTION),
+ (allocator, client) -> client.listFlights(new Criteria(new byte[0])).forEach((action) -> {
+ }), (recorder) -> {
+ final CallStatus status = recorder.statusFuture.get();
+ final Throwable err = recorder.errFuture.get();
+ Assert.assertNotNull(status);
+ Assert.assertEquals(FlightStatusCode.OK, status.code());
+ Assert.assertNull(status.cause());
+ Assert.assertNotNull(err);
+ Assert.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage());
+ });
+ }
+
+ @Test
+ public void doActionUncaught() {
+ test(new ServerErrorProducer(EXPECTED_EXCEPTION),
+ (allocator, client) -> client.doAction(new Action("test")).forEachRemaining(result -> {
+ }), (recorder) -> {
+ final CallStatus status = recorder.statusFuture.get();
+ final Throwable err = recorder.errFuture.get();
+ Assert.assertNotNull(status);
+ Assert.assertEquals(FlightStatusCode.OK, status.code());
+ Assert.assertNull(status.cause());
+ Assert.assertNotNull(err);
+ Assert.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage());
+ });
+ }
+
+ @Test
+ public void listActionsUncaught() {
+ test(new ServerErrorProducer(EXPECTED_EXCEPTION),
+ (allocator, client) -> client.listActions().forEach(result -> {
+ }), (recorder) -> {
+ final CallStatus status = recorder.statusFuture.get();
+ final Throwable err = recorder.errFuture.get();
+ Assert.assertNotNull(status);
+ Assert.assertEquals(FlightStatusCode.OK, status.code());
+ Assert.assertNull(status.cause());
+ Assert.assertNotNull(err);
+ Assert.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage());
+ });
+ }
+
+ @Test
+ public void getFlightInfoUncaught() {
+ test(new ServerErrorProducer(EXPECTED_EXCEPTION),
+ (allocator, client) -> {
+ FlightTestUtil.assertCode(FlightStatusCode.INTERNAL, () -> client.getInfo(FlightDescriptor.path("test")));
+ }, (recorder) -> {
+ final CallStatus status = recorder.statusFuture.get();
+ Assert.assertNotNull(status);
+ Assert.assertEquals(FlightStatusCode.INTERNAL, status.code());
+ Assert.assertNotNull(status.cause());
+ Assert.assertEquals(EXPECTED_EXCEPTION.getMessage(), status.cause().getMessage());
+ });
+ }
+
+ @Test
+ public void doGetUncaught() {
+ test(new ServerErrorProducer(EXPECTED_EXCEPTION),
+ (allocator, client) -> {
+ try (final FlightStream stream = client.getStream(new Ticket(new byte[0]))) {
+ while (stream.next()) {
+ }
+ } catch (Exception e) {
+ Assert.fail(e.toString());
+ }
+ }, (recorder) -> {
+ final CallStatus status = recorder.statusFuture.get();
+ final Throwable err = recorder.errFuture.get();
+ Assert.assertNotNull(status);
+ Assert.assertEquals(FlightStatusCode.OK, status.code());
+ Assert.assertNull(status.cause());
+ Assert.assertNotNull(err);
+ Assert.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage());
+ });
+ }
+
+ /**
+ * A middleware that records the last error on any call.
+ */
+ static class ErrorRecorder implements FlightServerMiddleware {
+
+ CompletableFuture<CallStatus> statusFuture = new CompletableFuture<>();
+ CompletableFuture<Throwable> errFuture = new CompletableFuture<>();
+
+ @Override
+ public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
+ }
+
+ @Override
+ public void onCallCompleted(CallStatus status) {
+ statusFuture.complete(status);
+ }
+
+ @Override
+ public void onCallErrored(Throwable err) {
+ errFuture.complete(err);
+ }
+
+ static class Factory implements FlightServerMiddleware.Factory<ErrorRecorder> {
+
+ ErrorRecorder instance = new ErrorRecorder();
+
+ @Override
+ public ErrorRecorder onCallStarted(CallInfo info, CallHeaders incomingHeaders, RequestContext context) {
+ return instance;
+ }
+ }
+ }
+
+ /**
+ * A producer that throws the given exception on a call.
+ */
+ static class ErrorProducer extends NoOpFlightProducer {
+
+ final RuntimeException error;
+
+ ErrorProducer(RuntimeException t) {
+ error = t;
+ }
+
+ @Override
+ public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream) {
+ return () -> {
+ // Drain queue to avoid FlightStream#close cancelling the call
+ while (flightStream.next()) {
+ }
+ throw error;
+ };
+ }
+ }
+
+ /**
+ * A producer that throws the given exception on a call, but only after sending a success to the client.
+ */
+ static class ServerErrorProducer extends NoOpFlightProducer {
+
+ final RuntimeException error;
+
+ ServerErrorProducer(RuntimeException t) {
+ error = t;
+ }
+
+ @Override
+ public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+ final VectorSchemaRoot root = VectorSchemaRoot.create(new Schema(Collections.emptyList()), allocator)) {
+ listener.start(root);
+ listener.completed();
+ }
+ throw error;
+ }
+
+ @Override
+ public void listFlights(CallContext context, Criteria criteria, StreamListener<FlightInfo> listener) {
+ listener.onCompleted();
+ throw error;
+ }
+
+ @Override
+ public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) {
+ throw error;
+ }
+
+ @Override
+ public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream) {
+ return () -> {
+ while (flightStream.next()) {
+ }
+ ackStream.onCompleted();
+ throw error;
+ };
+ }
+
+ @Override
+ public void doAction(CallContext context, Action action, StreamListener<Result> listener) {
+ listener.onCompleted();
+ throw error;
+ }
+
+ @Override
+ public void listActions(CallContext context, StreamListener<ActionType> listener) {
+ listener.onCompleted();
+ throw error;
+ }
+ }
+
+ static class ServerMiddlewarePair<T extends FlightServerMiddleware> {
+
+ final FlightServerMiddleware.Key<T> key;
+ final FlightServerMiddleware.Factory<T> factory;
+
+ ServerMiddlewarePair(Key<T> key, Factory<T> factory) {
+ this.key = key;
+ this.factory = factory;
+ }
+ }
+
+ /**
+ * Spin up a service with the given middleware and producer.
+ *
+ * @param producer The Flight producer to use.
+ * @param middleware A list of middleware to register.
+ * @param body A function to run as the body of the test.
+ * @param <T> The middleware type.
+ */
+ static <T extends FlightServerMiddleware> void test(FlightProducer producer, List<ServerMiddlewarePair<T>> middleware,
+ BiConsumer<BufferAllocator, FlightClient> body) {
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
+ final FlightServer server = FlightTestUtil
+ .getStartedServer(location -> {
+ final FlightServer.Builder builder = FlightServer.builder(allocator, location, producer);
+ middleware.forEach(pair -> builder.middleware(pair.key, pair.factory));
+ return builder.build();
+ });
+ try (final FlightServer ignored = server;
+ final FlightClient client = FlightClient.builder(allocator, server.getLocation()).build()
+ ) {
+ body.accept(allocator, client);
+ }
+ } catch (InterruptedException | IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ static void test(FlightProducer producer, BiConsumer<BufferAllocator, FlightClient> body,
+ ErrorConsumer<ErrorRecorder> verify) {
+ final ErrorRecorder.Factory factory = new ErrorRecorder.Factory();
+ final List<ServerMiddlewarePair<ErrorRecorder>> middleware = Collections
+ .singletonList(new ServerMiddlewarePair<>(Key.of("m"), factory));
+ test(producer, middleware, (allocator, client) -> {
+ body.accept(allocator, client);
+ try {
+ verify.accept(factory.instance);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+
+ @FunctionalInterface
+ interface ErrorConsumer<T> {
+ void accept(T obj) throws Exception;
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerOptions.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerOptions.java
new file mode 100644
index 000000000..363ad443e
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerOptions.java
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
+import java.io.File;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.function.Consumer;
+
+import org.apache.arrow.flight.TestBasicOperation.Producer;
+import org.apache.arrow.flight.auth.ServerAuthHandler;
+import org.apache.arrow.flight.impl.FlightServiceGrpc;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.junit.Assert;
+import org.junit.Assume;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import io.grpc.MethodDescriptor;
+import io.grpc.ServerServiceDefinition;
+import io.grpc.netty.NettyServerBuilder;
+
+@RunWith(JUnit4.class)
+public class TestServerOptions {
+
+ @Test
+ public void builderConsumer() throws Exception {
+ final AtomicBoolean consumerCalled = new AtomicBoolean();
+ final Consumer<NettyServerBuilder> consumer = (builder) -> consumerCalled.set(true);
+
+ try (
+ BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
+ Producer producer = new Producer(a);
+ FlightServer s =
+ FlightTestUtil.getStartedServer(
+ (location) -> FlightServer.builder(a, location, producer)
+ .transportHint("grpc.builderConsumer", consumer).build()
+ )) {
+ Assert.assertTrue(consumerCalled.get());
+ }
+ }
+
+ /**
+ * Make sure that if Flight supplies a default executor to gRPC, then it is closed along with the server.
+ */
+ @Test
+ public void defaultExecutorClosed() throws Exception {
+ final ExecutorService executor;
+ try (
+ BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
+ FlightServer server =
+ FlightTestUtil.getStartedServer(
+ (location) -> FlightServer.builder(a, location, new NoOpFlightProducer())
+ .build()
+ )) {
+ assertNotNull(server.grpcExecutor);
+ executor = server.grpcExecutor;
+ }
+ Assert.assertTrue(executor.isShutdown());
+ }
+
+ /**
+ * Make sure that if the user provides an executor to gRPC, then Flight does not close it.
+ */
+ @Test
+ public void suppliedExecutorNotClosed() throws Exception {
+ final ExecutorService executor = Executors.newSingleThreadExecutor();
+ try {
+ try (
+ BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
+ FlightServer server =
+ FlightTestUtil.getStartedServer(
+ (location) -> FlightServer.builder(a, location, new NoOpFlightProducer())
+ .executor(executor)
+ .build()
+ )) {
+ Assert.assertNull(server.grpcExecutor);
+ }
+ Assert.assertFalse(executor.isShutdown());
+ } finally {
+ executor.shutdown();
+ }
+ }
+
+ @Test
+ public void domainSocket() throws Exception {
+ Assume.assumeTrue("We have a native transport available", FlightTestUtil.isNativeTransportAvailable());
+ final File domainSocket = File.createTempFile("flight-unit-test-", ".sock");
+ Assert.assertTrue(domainSocket.delete());
+ // Domain socket paths have a platform-dependent limit. Set a conservative limit and skip the test if the temporary
+ // file name is too long. (We do not assume a particular platform-dependent temporary directory path.)
+ Assume.assumeTrue("The domain socket path is not too long", domainSocket.getAbsolutePath().length() < 100);
+ final Location location = Location.forGrpcDomainSocket(domainSocket.getAbsolutePath());
+ try (
+ BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
+ Producer producer = new Producer(a);
+ FlightServer s =
+ FlightTestUtil.getStartedServer(
+ (port) -> FlightServer.builder(a, location, producer).build()
+ )) {
+ try (FlightClient c = FlightClient.builder(a, location).build()) {
+ try (FlightStream stream = c.getStream(new Ticket(new byte[0]))) {
+ VectorSchemaRoot root = stream.getRoot();
+ IntVector iv = (IntVector) root.getVector("c1");
+ int value = 0;
+ while (stream.next()) {
+ for (int i = 0; i < root.getRowCount(); i++) {
+ Assert.assertEquals(value, iv.get(i));
+ value++;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ @Test
+ public void checkReflectionMetadata() {
+ // This metadata is needed for gRPC reflection to work.
+ final ExecutorService executorService = Executors.newSingleThreadExecutor();
+ try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
+ final FlightBindingService service = new FlightBindingService(allocator, new NoOpFlightProducer(),
+ ServerAuthHandler.NO_OP, executorService);
+ final ServerServiceDefinition definition = service.bindService();
+ assertEquals(FlightServiceGrpc.getServiceDescriptor().getSchemaDescriptor(),
+ definition.getServiceDescriptor().getSchemaDescriptor());
+
+ final Map<String, MethodDescriptor<?, ?>> definedMethods = new HashMap<>();
+ final Map<String, MethodDescriptor<?, ?>> serviceMethods = new HashMap<>();
+
+ // Make sure that the reflection metadata object is identical across all the places where it's accessible
+ definition.getMethods().forEach(
+ method -> definedMethods.put(method.getMethodDescriptor().getFullMethodName(), method.getMethodDescriptor()));
+ definition.getServiceDescriptor().getMethods().forEach(
+ method -> serviceMethods.put(method.getFullMethodName(), method));
+
+ for (final MethodDescriptor<?, ?> descriptor : FlightServiceGrpc.getServiceDescriptor().getMethods()) {
+ final String methodName = descriptor.getFullMethodName();
+ Assert.assertTrue("Method is missing from ServerServiceDefinition: " + methodName,
+ definedMethods.containsKey(methodName));
+ Assert.assertTrue("Method is missing from ServiceDescriptor: " + methodName,
+ definedMethods.containsKey(methodName));
+
+ assertEquals(descriptor.getSchemaDescriptor(), definedMethods.get(methodName).getSchemaDescriptor());
+ assertEquals(descriptor.getSchemaDescriptor(), serviceMethods.get(methodName).getSchemaDescriptor());
+ }
+ } finally {
+ executorService.shutdown();
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestTls.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestTls.java
new file mode 100644
index 000000000..c5cd871e2
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestTls.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.charset.StandardCharsets;
+import java.util.Iterator;
+import java.util.function.Consumer;
+
+import org.apache.arrow.flight.FlightClient.Builder;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Tests for TLS in Flight.
+ */
+public class TestTls {
+
+ /**
+ * Test a basic request over TLS.
+ */
+ @Test
+ public void connectTls() {
+ test((builder) -> {
+ try (final InputStream roots = new FileInputStream(FlightTestUtil.exampleTlsRootCert().toFile());
+ final FlightClient client = builder.trustedCertificates(roots).build()) {
+ final Iterator<Result> responses = client.doAction(new Action("hello-world"));
+ final byte[] response = responses.next().getBody();
+ Assert.assertEquals("Hello, world!", new String(response, StandardCharsets.UTF_8));
+ Assert.assertFalse(responses.hasNext());
+ } catch (InterruptedException | IOException e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+
+ /**
+ * Make sure that connections are rejected when the root certificate isn't trusted.
+ */
+ @Test
+ public void rejectInvalidCert() {
+ test((builder) -> {
+ try (final FlightClient client = builder.build()) {
+ final Iterator<Result> responses = client.doAction(new Action("hello-world"));
+ FlightTestUtil.assertCode(FlightStatusCode.UNAVAILABLE, () -> responses.next().getBody());
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+
+ /**
+ * Make sure that connections are rejected when the hostname doesn't match.
+ */
+ @Test
+ public void rejectHostname() {
+ test((builder) -> {
+ try (final InputStream roots = new FileInputStream(FlightTestUtil.exampleTlsRootCert().toFile());
+ final FlightClient client = builder.trustedCertificates(roots).overrideHostname("fakehostname")
+ .build()) {
+ final Iterator<Result> responses = client.doAction(new Action("hello-world"));
+ FlightTestUtil.assertCode(FlightStatusCode.UNAVAILABLE, () -> responses.next().getBody());
+ } catch (InterruptedException | IOException e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+
+ /**
+ * Test a basic request over TLS.
+ */
+ @Test
+ public void connectTlsDisableServerVerification() {
+ test((builder) -> {
+ try (final FlightClient client = builder.verifyServer(false).build()) {
+ final Iterator<Result> responses = client.doAction(new Action("hello-world"));
+ final byte[] response = responses.next().getBody();
+ Assert.assertEquals("Hello, world!", new String(response, StandardCharsets.UTF_8));
+ Assert.assertFalse(responses.hasNext());
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ });
+ }
+
+ void test(Consumer<Builder> testFn) {
+ final FlightTestUtil.CertKeyPair certKey = FlightTestUtil.exampleTlsCerts().get(0);
+ try (
+ BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
+ Producer producer = new Producer();
+ FlightServer s =
+ FlightTestUtil.getStartedServer(
+ (location) -> {
+ try {
+ return FlightServer.builder(a, location, producer)
+ .useTls(certKey.cert, certKey.key)
+ .build();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ })) {
+ final Builder builder = FlightClient.builder(a, Location.forGrpcTls(FlightTestUtil.LOCALHOST, s.getPort()));
+ testFn.accept(builder);
+ } catch (InterruptedException | IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ static class Producer extends NoOpFlightProducer implements AutoCloseable {
+
+ @Override
+ public void doAction(CallContext context, Action action, StreamListener<Result> listener) {
+ if (action.getType().equals("hello-world")) {
+ listener.onNext(new Result("Hello, world!".getBytes(StandardCharsets.UTF_8)));
+ listener.onCompleted();
+ return;
+ }
+ listener
+ .onError(CallStatus.UNIMPLEMENTED.withDescription("Invalid action " + action.getType()).toRuntimeException());
+ }
+
+ @Override
+ public void close() {
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java
new file mode 100644
index 000000000..c18f5709b
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+import java.util.Optional;
+
+import org.apache.arrow.flight.Criteria;
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightInfo;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.FlightStatusCode;
+import org.apache.arrow.flight.FlightStream;
+import org.apache.arrow.flight.FlightTestUtil;
+import org.apache.arrow.flight.NoOpFlightProducer;
+import org.apache.arrow.flight.Ticket;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.types.Types;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import com.google.common.collect.ImmutableList;
+
+public class TestBasicAuth {
+
+ private static final String USERNAME = "flight";
+ private static final String PASSWORD = "woohoo";
+ private static final byte[] VALID_TOKEN = "my_token".getBytes(StandardCharsets.UTF_8);
+
+ private FlightClient client;
+ private FlightServer server;
+ private BufferAllocator allocator;
+
+ @Test
+ public void validAuth() {
+ client.authenticateBasic(USERNAME, PASSWORD);
+ Assert.assertTrue(ImmutableList.copyOf(client.listFlights(Criteria.ALL)).size() == 0);
+ }
+
+ // ARROW-7722: this test occasionally leaks memory
+ @Ignore
+ @Test
+ public void asyncCall() throws Exception {
+ client.authenticateBasic(USERNAME, PASSWORD);
+ client.listFlights(Criteria.ALL);
+ try (final FlightStream s = client.getStream(new Ticket(new byte[1]))) {
+ while (s.next()) {
+ Assert.assertEquals(4095, s.getRoot().getRowCount());
+ }
+ }
+ }
+
+ @Test
+ public void invalidAuth() {
+ FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () -> {
+ client.authenticateBasic(USERNAME, "WRONG");
+ });
+
+ FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () -> {
+ client.listFlights(Criteria.ALL).forEach(action -> Assert.fail());
+ });
+ }
+
+ @Test
+ public void didntAuth() {
+ FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () -> {
+ client.listFlights(Criteria.ALL).forEach(action -> Assert.fail());
+ });
+ }
+
+ @Before
+ public void setup() throws IOException {
+ allocator = new RootAllocator(Long.MAX_VALUE);
+ final BasicServerAuthHandler.BasicAuthValidator validator = new BasicServerAuthHandler.BasicAuthValidator() {
+
+ @Override
+ public Optional<String> isValid(byte[] token) {
+ if (Arrays.equals(token, VALID_TOKEN)) {
+ return Optional.of(USERNAME);
+ }
+ return Optional.empty();
+ }
+
+ @Override
+ public byte[] getToken(String username, String password) {
+ if (USERNAME.equals(username) && PASSWORD.equals(password)) {
+ return VALID_TOKEN;
+ } else {
+ throw new IllegalArgumentException("invalid credentials");
+ }
+ }
+ };
+
+ server = FlightTestUtil.getStartedServer((location) -> FlightServer.builder(
+ allocator,
+ location,
+ new NoOpFlightProducer() {
+ @Override
+ public void listFlights(CallContext context, Criteria criteria,
+ StreamListener<FlightInfo> listener) {
+ if (!context.peerIdentity().equals(USERNAME)) {
+ listener.onError(new IllegalArgumentException("Invalid username"));
+ return;
+ }
+ listener.onCompleted();
+ }
+
+ @Override
+ public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
+ if (!context.peerIdentity().equals(USERNAME)) {
+ listener.error(new IllegalArgumentException("Invalid username"));
+ return;
+ }
+ final Schema pojoSchema = new Schema(ImmutableList.of(Field.nullable("a",
+ Types.MinorType.BIGINT.getType())));
+ try (VectorSchemaRoot root = VectorSchemaRoot.create(pojoSchema, allocator)) {
+ listener.start(root);
+ root.allocateNew();
+ root.setRowCount(4095);
+ listener.putNext();
+ listener.completed();
+ }
+ }
+ }).authHandler(new BasicServerAuthHandler(validator)).build());
+ client = FlightClient.builder(allocator, server.getLocation()).build();
+ }
+
+ @After
+ public void shutdown() throws Exception {
+ AutoCloseables.close(client, server, allocator);
+ }
+
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth2/TestBasicAuth2.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth2/TestBasicAuth2.java
new file mode 100644
index 000000000..9bec32f1b
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth2/TestBasicAuth2.java
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.auth2;
+
+import java.io.IOException;
+
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.Criteria;
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightInfo;
+import org.apache.arrow.flight.FlightProducer;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.FlightStatusCode;
+import org.apache.arrow.flight.FlightStream;
+import org.apache.arrow.flight.FlightTestUtil;
+import org.apache.arrow.flight.NoOpFlightProducer;
+import org.apache.arrow.flight.Ticket;
+import org.apache.arrow.flight.grpc.CredentialCallOption;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.types.Types;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import com.google.common.base.Strings;
+import com.google.common.collect.ImmutableList;
+
+public class TestBasicAuth2 {
+
+ private static final String USERNAME_1 = "flight1";
+ private static final String USERNAME_2 = "flight2";
+ private static final String NO_USERNAME = "";
+ private static final String PASSWORD_1 = "woohoo1";
+ private static final String PASSWORD_2 = "woohoo2";
+ private BufferAllocator allocator;
+ private FlightServer server;
+ private FlightClient client;
+ private FlightClient client2;
+
+ @Before
+ public void setup() throws Exception {
+ allocator = new RootAllocator(Long.MAX_VALUE);
+ startServerAndClient();
+ }
+
+ private FlightProducer getFlightProducer() {
+ return new NoOpFlightProducer() {
+ @Override
+ public void listFlights(CallContext context, Criteria criteria,
+ StreamListener<FlightInfo> listener) {
+ if (!context.peerIdentity().equals(USERNAME_1) && !context.peerIdentity().equals(USERNAME_2)) {
+ listener.onError(new IllegalArgumentException("Invalid username"));
+ return;
+ }
+ listener.onCompleted();
+ }
+
+ @Override
+ public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) {
+ if (!context.peerIdentity().equals(USERNAME_1) && !context.peerIdentity().equals(USERNAME_2)) {
+ listener.error(new IllegalArgumentException("Invalid username"));
+ return;
+ }
+ final Schema pojoSchema = new Schema(ImmutableList.of(Field.nullable("a",
+ Types.MinorType.BIGINT.getType())));
+ try (VectorSchemaRoot root = VectorSchemaRoot.create(pojoSchema, allocator)) {
+ listener.start(root);
+ root.allocateNew();
+ root.setRowCount(4095);
+ listener.putNext();
+ listener.completed();
+ }
+ }
+ };
+ }
+
+ private void startServerAndClient() throws IOException {
+ final FlightProducer flightProducer = getFlightProducer();
+ this.server = FlightTestUtil.getStartedServer((location) -> FlightServer
+ .builder(allocator, location, flightProducer)
+ .headerAuthenticator(new GeneratedBearerTokenAuthenticator(
+ new BasicCallHeaderAuthenticator(this::validate)))
+ .build());
+
+ this.client = FlightClient.builder(allocator, server.getLocation())
+ .build();
+ }
+
+ @After
+ public void shutdown() throws Exception {
+ AutoCloseables.close(client, client2, server, allocator);
+ client = null;
+ client2 = null;
+ server = null;
+ allocator = null;
+ }
+
+ private void startClient2() throws IOException {
+ client2 = FlightClient.builder(allocator, server.getLocation())
+ .build();
+ }
+
+ private CallHeaderAuthenticator.AuthResult validate(String username, String password) {
+ if (Strings.isNullOrEmpty(username)) {
+ throw CallStatus.UNAUTHENTICATED.withDescription("Credentials not supplied.").toRuntimeException();
+ }
+ final String identity;
+ if (USERNAME_1.equals(username) && PASSWORD_1.equals(password)) {
+ identity = USERNAME_1;
+ } else if (USERNAME_2.equals(username) && PASSWORD_2.equals(password)) {
+ identity = USERNAME_2;
+ } else {
+ throw CallStatus.UNAUTHENTICATED.withDescription("Username or password is invalid.").toRuntimeException();
+ }
+ return () -> identity;
+ }
+
+ @Test
+ public void validAuthWithBearerAuthServer() throws IOException {
+ testValidAuth(client);
+ }
+
+ @Test
+ public void validAuthWithMultipleClientsWithSameCredentialsWithBearerAuthServer() throws IOException {
+ startClient2();
+ testValidAuthWithMultipleClientsWithSameCredentials(client, client2);
+ }
+
+ @Test
+ public void validAuthWithMultipleClientsWithDifferentCredentialsWithBearerAuthServer() throws IOException {
+ startClient2();
+ testValidAuthWithMultipleClientsWithDifferentCredentials(client, client2);
+ }
+
+ // ARROW-7722: this test occasionally leaks memory
+ @Ignore
+ @Test
+ public void asyncCall() throws Exception {
+ final CredentialCallOption bearerToken = client
+ .authenticateBasicToken(USERNAME_1, PASSWORD_1).get();
+ client.listFlights(Criteria.ALL, bearerToken);
+ try (final FlightStream s = client.getStream(new Ticket(new byte[1]))) {
+ while (s.next()) {
+ Assert.assertEquals(4095, s.getRoot().getRowCount());
+ }
+ }
+ }
+
+ @Test
+ public void invalidAuthWithBearerAuthServer() throws IOException {
+ testInvalidAuth(client);
+ }
+
+ @Test
+ public void didntAuthWithBearerAuthServer() throws IOException {
+ didntAuth(client);
+ }
+
+ private void testValidAuth(FlightClient client) {
+ final CredentialCallOption bearerToken = client
+ .authenticateBasicToken(USERNAME_1, PASSWORD_1).get();
+ Assert.assertTrue(ImmutableList.copyOf(client
+ .listFlights(Criteria.ALL, bearerToken))
+ .isEmpty());
+ }
+
+ private void testValidAuthWithMultipleClientsWithSameCredentials(
+ FlightClient client1, FlightClient client2) {
+ final CredentialCallOption bearerToken1 = client1
+ .authenticateBasicToken(USERNAME_1, PASSWORD_1).get();
+ final CredentialCallOption bearerToken2 = client2
+ .authenticateBasicToken(USERNAME_1, PASSWORD_1).get();
+ Assert.assertTrue(ImmutableList.copyOf(client1
+ .listFlights(Criteria.ALL, bearerToken1))
+ .isEmpty());
+ Assert.assertTrue(ImmutableList.copyOf(client2
+ .listFlights(Criteria.ALL, bearerToken2))
+ .isEmpty());
+ }
+
+ private void testValidAuthWithMultipleClientsWithDifferentCredentials(
+ FlightClient client1, FlightClient client2) {
+ final CredentialCallOption bearerToken1 = client1
+ .authenticateBasicToken(USERNAME_1, PASSWORD_1).get();
+ final CredentialCallOption bearerToken2 = client2
+ .authenticateBasicToken(USERNAME_2, PASSWORD_2).get();
+ Assert.assertTrue(ImmutableList.copyOf(client1
+ .listFlights(Criteria.ALL, bearerToken1))
+ .isEmpty());
+ Assert.assertTrue(ImmutableList.copyOf(client2
+ .listFlights(Criteria.ALL, bearerToken2))
+ .isEmpty());
+ }
+
+ private void testInvalidAuth(FlightClient client) {
+ FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () ->
+ client.authenticateBasicToken(USERNAME_1, "WRONG"));
+
+ FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () ->
+ client.authenticateBasicToken(NO_USERNAME, PASSWORD_1));
+
+ FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () ->
+ client.listFlights(Criteria.ALL).forEach(action -> Assert.fail()));
+ }
+
+ private void didntAuth(FlightClient client) {
+ FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () ->
+ client.listFlights(Criteria.ALL).forEach(action -> Assert.fail()));
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/TestCookieHandling.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/TestCookieHandling.java
new file mode 100644
index 000000000..f205f9a3b
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/TestCookieHandling.java
@@ -0,0 +1,267 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.client;
+
+import java.io.IOException;
+
+import org.apache.arrow.flight.CallHeaders;
+import org.apache.arrow.flight.CallInfo;
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.Criteria;
+import org.apache.arrow.flight.ErrorFlightMetadata;
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightInfo;
+import org.apache.arrow.flight.FlightMethod;
+import org.apache.arrow.flight.FlightProducer;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.FlightServerMiddleware;
+import org.apache.arrow.flight.FlightTestUtil;
+import org.apache.arrow.flight.NoOpFlightProducer;
+import org.apache.arrow.flight.RequestContext;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Ignore;
+import org.junit.Test;
+
+/**
+ * Tests for correct handling of cookies from the FlightClient using {@link ClientCookieMiddleware}.
+ */
+public class TestCookieHandling {
+ private static final String SET_COOKIE_HEADER = "Set-Cookie";
+ private static final String COOKIE_HEADER = "Cookie";
+ private BufferAllocator allocator;
+ private FlightServer server;
+ private FlightClient client;
+
+ private ClientCookieMiddlewareTestFactory testFactory = new ClientCookieMiddlewareTestFactory();
+ private ClientCookieMiddleware cookieMiddleware = new ClientCookieMiddleware(testFactory);
+
+ @Before
+ public void setup() throws Exception {
+ allocator = new RootAllocator(Long.MAX_VALUE);
+ startServerAndClient();
+ }
+
+ @After
+ public void cleanup() throws Exception {
+ testFactory = new ClientCookieMiddlewareTestFactory();
+ cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
+ AutoCloseables.close(client, server, allocator);
+ client = null;
+ server = null;
+ allocator = null;
+ }
+
+ @Test
+ public void basicCookie() {
+ CallHeaders headersToSend = new ErrorFlightMetadata();
+ headersToSend.insert(SET_COOKIE_HEADER, "k=v");
+ cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
+ cookieMiddleware.onHeadersReceived(headersToSend);
+ Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString());
+ }
+
+ @Test
+ public void cookieStaysAfterMultipleRequests() {
+ CallHeaders headersToSend = new ErrorFlightMetadata();
+ headersToSend.insert(SET_COOKIE_HEADER, "k=v");
+ cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
+ cookieMiddleware.onHeadersReceived(headersToSend);
+ Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString());
+
+ headersToSend = new ErrorFlightMetadata();
+ cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
+ cookieMiddleware.onHeadersReceived(headersToSend);
+ Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString());
+
+ headersToSend = new ErrorFlightMetadata();
+ cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
+ cookieMiddleware.onHeadersReceived(headersToSend);
+ Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString());
+ }
+
+ @Ignore
+ @Test
+ public void cookieAutoExpires() {
+ CallHeaders headersToSend = new ErrorFlightMetadata();
+ headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=2");
+ cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
+ cookieMiddleware.onHeadersReceived(headersToSend);
+ // Note: using max-age changes cookie version from 0->1, which quotes values.
+ Assert.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString());
+
+ headersToSend = new ErrorFlightMetadata();
+ cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
+ cookieMiddleware.onHeadersReceived(headersToSend);
+ Assert.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString());
+
+ try {
+ Thread.sleep(5000);
+ } catch (InterruptedException ignored) {
+ }
+
+ // Verify that the k cookie was discarded because it expired.
+ Assert.assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty());
+ }
+
+ @Test
+ public void cookieExplicitlyExpires() {
+ CallHeaders headersToSend = new ErrorFlightMetadata();
+ headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=2");
+ cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
+ cookieMiddleware.onHeadersReceived(headersToSend);
+ // Note: using max-age changes cookie version from 0->1, which quotes values.
+ Assert.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString());
+
+ // Note: The JDK treats Max-Age < 0 as not expired and treats 0 as expired.
+ // This violates the RFC, which states that less than zero and zero should both be expired.
+ headersToSend = new ErrorFlightMetadata();
+ headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=0");
+ cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
+ cookieMiddleware.onHeadersReceived(headersToSend);
+
+ // Verify that the k cookie was discarded because the server told the client it is expired.
+ Assert.assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty());
+ }
+
+ @Ignore
+ @Test
+ public void cookieExplicitlyExpiresWithMaxAgeMinusOne() {
+ CallHeaders headersToSend = new ErrorFlightMetadata();
+ headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=2");
+ cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
+ cookieMiddleware.onHeadersReceived(headersToSend);
+ // Note: using max-age changes cookie version from 0->1, which quotes values.
+ Assert.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString());
+
+ headersToSend = new ErrorFlightMetadata();
+
+ // The Java HttpCookie class has a bug where it uses a -1 maxAge to indicate
+ // a persistent cookie, when the RFC spec says this should mean the cookie expires immediately.
+ headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=-1");
+ cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION));
+ cookieMiddleware.onHeadersReceived(headersToSend);
+
+ // Verify that the k cookie was discarded because the server told the client it is expired.
+ Assert.assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty());
+ }
+
+ @Test
+ public void changeCookieValue() {
+ CallHeaders headersToSend = new ErrorFlightMetadata();
+ headersToSend.insert(SET_COOKIE_HEADER, "k=v");
+ cookieMiddleware.onHeadersReceived(headersToSend);
+ Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString());
+
+ headersToSend = new ErrorFlightMetadata();
+ headersToSend.insert(SET_COOKIE_HEADER, "k=v2");
+ cookieMiddleware.onHeadersReceived(headersToSend);
+ Assert.assertEquals("k=v2", cookieMiddleware.getValidCookiesAsString());
+ }
+
+ @Test
+ public void multipleCookiesWithSetCookie() {
+ CallHeaders headersToSend = new ErrorFlightMetadata();
+ headersToSend.insert(SET_COOKIE_HEADER, "firstKey=firstVal");
+ headersToSend.insert(SET_COOKIE_HEADER, "secondKey=secondVal");
+ cookieMiddleware.onHeadersReceived(headersToSend);
+ Assert.assertEquals("firstKey=firstVal; secondKey=secondVal", cookieMiddleware.getValidCookiesAsString());
+ }
+
+ @Test
+ public void cookieStaysAfterMultipleRequestsEndToEnd() {
+ client.handshake();
+ Assert.assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString());
+ client.handshake();
+ Assert.assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString());
+ client.listFlights(Criteria.ALL);
+ Assert.assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString());
+ }
+
+ /**
+ * A server middleware component that injects SET_COOKIE_HEADER into the outgoing headers.
+ */
+ static class SetCookieHeaderInjector implements FlightServerMiddleware {
+ private final Factory factory;
+
+ public SetCookieHeaderInjector(Factory factory) {
+ this.factory = factory;
+ }
+
+ @Override
+ public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
+ if (!factory.receivedCookieHeader) {
+ outgoingHeaders.insert(SET_COOKIE_HEADER, "k=v");
+ }
+ }
+
+ @Override
+ public void onCallCompleted(CallStatus status) {
+
+ }
+
+ @Override
+ public void onCallErrored(Throwable err) {
+
+ }
+
+ static class Factory implements FlightServerMiddleware.Factory<SetCookieHeaderInjector> {
+ private boolean receivedCookieHeader = false;
+
+ @Override
+ public SetCookieHeaderInjector onCallStarted(CallInfo info, CallHeaders incomingHeaders,
+ RequestContext context) {
+ receivedCookieHeader = null != incomingHeaders.get(COOKIE_HEADER);
+ return new SetCookieHeaderInjector(this);
+ }
+ }
+ }
+
+ public static class ClientCookieMiddlewareTestFactory extends ClientCookieMiddleware.Factory {
+
+ private ClientCookieMiddleware clientCookieMiddleware;
+
+ @Override
+ public ClientCookieMiddleware onCallStarted(CallInfo info) {
+ this.clientCookieMiddleware = new ClientCookieMiddleware(this);
+ return this.clientCookieMiddleware;
+ }
+ }
+
+ private void startServerAndClient() throws IOException {
+ final FlightProducer flightProducer = new NoOpFlightProducer() {
+ public void listFlights(CallContext context, Criteria criteria,
+ StreamListener<FlightInfo> listener) {
+ listener.onCompleted();
+ }
+ };
+
+ this.server = FlightTestUtil.getStartedServer((location) -> FlightServer
+ .builder(allocator, location, flightProducer)
+ .middleware(FlightServerMiddleware.Key.of("test"), new SetCookieHeaderInjector.Factory())
+ .build());
+
+ this.client = FlightClient.builder(allocator, server.getLocation())
+ .intercept(testFactory)
+ .build();
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java
new file mode 100644
index 000000000..fb157f45e
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.example;
+
+import java.io.IOException;
+
+import org.apache.arrow.flight.AsyncPutListener;
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightClient.ClientStreamListener;
+import org.apache.arrow.flight.FlightDescriptor;
+import org.apache.arrow.flight.FlightInfo;
+import org.apache.arrow.flight.FlightStream;
+import org.apache.arrow.flight.FlightTestUtil;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Ignore;
+import org.junit.Test;
+
+/**
+ * Ensure that example server supports get and put.
+ */
+public class TestExampleServer {
+
+ private BufferAllocator allocator;
+ private BufferAllocator caseAllocator;
+ private ExampleFlightServer server;
+ private FlightClient client;
+
+ @Before
+ public void start() throws IOException {
+ allocator = new RootAllocator(Long.MAX_VALUE);
+
+ Location l = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, 12233);
+ if (!Boolean.getBoolean("disableServer")) {
+ System.out.println("Starting server.");
+ server = new ExampleFlightServer(allocator, l);
+ server.start();
+ } else {
+ System.out.println("Skipping server startup.");
+ }
+ client = FlightClient.builder(allocator, l).build();
+ caseAllocator = allocator.newChildAllocator("test-case", 0, Long.MAX_VALUE);
+ }
+
+ @After
+ public void after() throws Exception {
+ AutoCloseables.close(server, client, caseAllocator, allocator);
+ }
+
+ @Test
+ @Ignore
+ public void putStream() {
+ BufferAllocator a = caseAllocator;
+ final int size = 10;
+
+ IntVector iv = new IntVector("c1", a);
+
+ VectorSchemaRoot root = VectorSchemaRoot.of(iv);
+ ClientStreamListener listener = client.startPut(FlightDescriptor.path("hello"), root,
+ new AsyncPutListener());
+
+ //batch 1
+ root.allocateNew();
+ for (int i = 0; i < size; i++) {
+ iv.set(i, i);
+ }
+ iv.setValueCount(size);
+ root.setRowCount(size);
+ listener.putNext();
+
+ // batch 2
+
+ root.allocateNew();
+ for (int i = 0; i < size; i++) {
+ iv.set(i, i + size);
+ }
+ iv.setValueCount(size);
+ root.setRowCount(size);
+ listener.putNext();
+ root.clear();
+ listener.completed();
+
+ // wait for ack to avoid memory leaks.
+ listener.getResult();
+
+ FlightInfo info = client.getInfo(FlightDescriptor.path("hello"));
+ try (final FlightStream stream = client.getStream(info.getEndpoints().get(0).getTicket())) {
+ VectorSchemaRoot newRoot = stream.getRoot();
+ while (stream.next()) {
+ newRoot.clear();
+ }
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/grpc/TestStatusUtils.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/grpc/TestStatusUtils.java
new file mode 100644
index 000000000..5d76e8ae1
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/grpc/TestStatusUtils.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.grpc;
+
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.FlightStatusCode;
+import org.junit.Assert;
+import org.junit.Test;
+
+import io.grpc.Metadata;
+import io.grpc.Status;
+
+public class TestStatusUtils {
+
+ @Test
+ public void testParseTrailers() {
+ Status status = Status.CANCELLED;
+ Metadata trailers = new Metadata();
+
+ // gRPC can have trailers with certain metadata keys beginning with ":", such as ":status".
+ // See https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
+ trailers.put(StatusUtils.keyOfAscii(":status"), "502");
+ trailers.put(StatusUtils.keyOfAscii("date"), "Fri, 13 Sep 2015 11:23:58 GMT");
+ trailers.put(StatusUtils.keyOfAscii("content-type"), "text/html");
+
+ CallStatus callStatus = StatusUtils.fromGrpcStatusAndTrailers(status, trailers);
+
+ Assert.assertEquals(FlightStatusCode.CANCELLED, callStatus.code());
+ Assert.assertTrue(callStatus.metadata().containsKey(":status"));
+ Assert.assertEquals("502", callStatus.metadata().get(":status"));
+ Assert.assertTrue(callStatus.metadata().containsKey("date"));
+ Assert.assertEquals("Fri, 13 Sep 2015 11:23:58 GMT", callStatus.metadata().get("date"));
+ Assert.assertTrue(callStatus.metadata().containsKey("content-type"));
+ Assert.assertEquals("text/html", callStatus.metadata().get("content-type"));
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java
new file mode 100644
index 000000000..7794ed748
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.perf;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+import org.apache.arrow.flight.BackpressureStrategy;
+import org.apache.arrow.flight.FlightDescriptor;
+import org.apache.arrow.flight.FlightEndpoint;
+import org.apache.arrow.flight.FlightInfo;
+import org.apache.arrow.flight.FlightProducer;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.flight.NoOpFlightProducer;
+import org.apache.arrow.flight.Ticket;
+import org.apache.arrow.flight.perf.impl.PerfOuterClass.Perf;
+import org.apache.arrow.flight.perf.impl.PerfOuterClass.Token;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.util.AutoCloseables;
+import org.apache.arrow.util.Preconditions;
+import org.apache.arrow.vector.BigIntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.types.Types.MinorType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+
+import com.google.common.collect.ImmutableList;
+import com.google.protobuf.InvalidProtocolBufferException;
+
+public class PerformanceTestServer implements AutoCloseable {
+
+ private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(PerformanceTestServer.class);
+
+ private final FlightServer flightServer;
+ private final Location location;
+ private final BufferAllocator allocator;
+ private final PerfProducer producer;
+ private final boolean isNonBlocking;
+
+ public PerformanceTestServer(BufferAllocator incomingAllocator, Location location) {
+ this(incomingAllocator, location, new BackpressureStrategy() {
+ private FlightProducer.ServerStreamListener listener;
+
+ @Override
+ public void register(FlightProducer.ServerStreamListener listener) {
+ this.listener = listener;
+ }
+
+ @Override
+ public WaitResult waitForListener(long timeout) {
+ while (!listener.isReady() && !listener.isCancelled()) {
+ // busy wait
+ }
+ return WaitResult.READY;
+ }
+ }, false);
+ }
+
+ public PerformanceTestServer(BufferAllocator incomingAllocator, Location location, BackpressureStrategy bpStrategy,
+ boolean isNonBlocking) {
+ this.allocator = incomingAllocator.newChildAllocator("perf-server", 0, Long.MAX_VALUE);
+ this.location = location;
+ this.producer = new PerfProducer(bpStrategy);
+ this.flightServer = FlightServer.builder(this.allocator, location, producer).build();
+ this.isNonBlocking = isNonBlocking;
+ }
+
+ public Location getLocation() {
+ return location;
+ }
+
+ public void start() throws IOException {
+ flightServer.start();
+ }
+
+ @Override
+ public void close() throws Exception {
+ AutoCloseables.close(flightServer, allocator);
+ }
+
+ private final class PerfProducer extends NoOpFlightProducer {
+ private final BackpressureStrategy bpStrategy;
+
+ private PerfProducer(BackpressureStrategy bpStrategy) {
+ this.bpStrategy = bpStrategy;
+ }
+
+ @Override
+ public void getStream(CallContext context, Ticket ticket,
+ ServerStreamListener listener) {
+ bpStrategy.register(listener);
+ final Runnable loadData = () -> {
+ VectorSchemaRoot root = null;
+ try {
+ Token token = Token.parseFrom(ticket.getBytes());
+ Perf perf = token.getDefinition();
+ Schema schema = Schema.deserialize(ByteBuffer.wrap(perf.getSchema().toByteArray()));
+ root = VectorSchemaRoot.create(schema, allocator);
+ BigIntVector a = (BigIntVector) root.getVector("a");
+ BigIntVector b = (BigIntVector) root.getVector("b");
+ BigIntVector c = (BigIntVector) root.getVector("c");
+ BigIntVector d = (BigIntVector) root.getVector("d");
+ listener.setUseZeroCopy(true);
+ listener.start(root);
+ root.allocateNew();
+
+ int current = 0;
+ long i = token.getStart();
+ while (i < token.getEnd()) {
+ if (listener.isCancelled()) {
+ root.clear();
+ return;
+ }
+
+ if (TestPerf.VALIDATE) {
+ a.setSafe(current, i);
+ }
+
+ i++;
+ current++;
+ if (i % perf.getRecordsPerBatch() == 0) {
+ root.setRowCount(current);
+
+ bpStrategy.waitForListener(0);
+ if (listener.isCancelled()) {
+ root.clear();
+ return;
+ }
+ listener.putNext();
+ current = 0;
+ root.allocateNew();
+ }
+ }
+
+ // send last partial batch.
+ if (current != 0) {
+ root.setRowCount(current);
+ listener.putNext();
+ }
+ listener.completed();
+ } catch (InvalidProtocolBufferException e) {
+ throw new RuntimeException(e);
+ } finally {
+ try {
+ AutoCloseables.close(root);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+ };
+
+ if (!isNonBlocking) {
+ loadData.run();
+ } else {
+ final ExecutorService service = Executors.newSingleThreadExecutor();
+ service.submit(loadData);
+ service.shutdown();
+ }
+ }
+
+ @Override
+ public FlightInfo getFlightInfo(CallContext context,
+ FlightDescriptor descriptor) {
+ try {
+ Preconditions.checkArgument(descriptor.isCommand());
+ Perf exec = Perf.parseFrom(descriptor.getCommand());
+
+ final Schema pojoSchema = new Schema(ImmutableList.of(
+ Field.nullable("a", MinorType.BIGINT.getType()),
+ Field.nullable("b", MinorType.BIGINT.getType()),
+ Field.nullable("c", MinorType.BIGINT.getType()),
+ Field.nullable("d", MinorType.BIGINT.getType())
+ ));
+
+ Token token = Token.newBuilder().setDefinition(exec)
+ .setStart(0)
+ .setEnd(exec.getRecordsPerStream())
+ .build();
+ final Ticket ticket = new Ticket(token.toByteArray());
+
+ List<FlightEndpoint> endpoints = new ArrayList<>();
+ for (int i = 0; i < exec.getStreamCount(); i++) {
+ endpoints.add(new FlightEndpoint(ticket, getLocation()));
+ }
+
+ return new FlightInfo(pojoSchema, descriptor, endpoints, -1,
+ exec.getRecordsPerStream() * exec.getStreamCount());
+ } catch (InvalidProtocolBufferException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+}
+
+
+
diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/TestPerf.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/TestPerf.java
new file mode 100644
index 000000000..9e2d7cc54
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/TestPerf.java
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight.perf;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Executors;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightDescriptor;
+import org.apache.arrow.flight.FlightInfo;
+import org.apache.arrow.flight.FlightStream;
+import org.apache.arrow.flight.FlightTestUtil;
+import org.apache.arrow.flight.Ticket;
+import org.apache.arrow.flight.perf.impl.PerfOuterClass.Perf;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.BigIntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.types.Types.MinorType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+import org.junit.Test;
+
+import com.google.common.base.MoreObjects;
+import com.google.common.base.Stopwatch;
+import com.google.common.collect.ImmutableList;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import com.google.common.util.concurrent.MoreExecutors;
+import com.google.protobuf.ByteString;
+
+@org.junit.Ignore
+public class TestPerf {
+
+ public static final boolean VALIDATE = false;
+
+ public static FlightDescriptor getPerfFlightDescriptor(long recordCount, int recordsPerBatch, int streamCount) {
+ final Schema pojoSchema = new Schema(ImmutableList.of(
+ Field.nullable("a", MinorType.BIGINT.getType()),
+ Field.nullable("b", MinorType.BIGINT.getType()),
+ Field.nullable("c", MinorType.BIGINT.getType()),
+ Field.nullable("d", MinorType.BIGINT.getType())
+ ));
+
+ ByteString serializedSchema = ByteString.copyFrom(pojoSchema.toByteArray());
+
+ return FlightDescriptor.command(Perf.newBuilder()
+ .setRecordsPerStream(recordCount)
+ .setRecordsPerBatch(recordsPerBatch)
+ .setSchema(serializedSchema)
+ .setStreamCount(streamCount)
+ .build()
+ .toByteArray());
+ }
+
+ public static void main(String[] args) throws Exception {
+ new TestPerf().throughput();
+ }
+
+ @Test
+ public void throughput() throws Exception {
+ final int numRuns = 10;
+ ListeningExecutorService pool = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(4));
+ double [] throughPuts = new double[numRuns];
+
+ for (int i = 0; i < numRuns; i++) {
+ try (
+ final BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
+ final PerformanceTestServer server =
+ FlightTestUtil.getStartedServer((location) -> new PerformanceTestServer(a, location));
+ final FlightClient client = FlightClient.builder(a, server.getLocation()).build();
+ ) {
+ final FlightInfo info = client.getInfo(getPerfFlightDescriptor(50_000_000L, 4095, 2));
+ List<ListenableFuture<Result>> results = info.getEndpoints()
+ .stream()
+ .map(t -> new Consumer(client, t.getTicket()))
+ .map(t -> pool.submit(t))
+ .collect(Collectors.toList());
+
+ final Result r = Futures.whenAllSucceed(results).call(() -> {
+ Result res = new Result();
+ for (ListenableFuture<Result> f : results) {
+ res.add(f.get());
+ }
+ return res;
+ }, pool).get();
+
+ double seconds = r.nanos * 1.0d / 1000 / 1000 / 1000;
+ throughPuts[i] = (r.bytes * 1.0d / 1024 / 1024) / seconds;
+ System.out.println(String.format(
+ "Transferred %d records totaling %s bytes at %f MiB/s. %f record/s. %f batch/s.",
+ r.rows,
+ r.bytes,
+ throughPuts[i],
+ (r.rows * 1.0d) / seconds,
+ (r.batches * 1.0d) / seconds
+ ));
+ }
+ }
+ pool.shutdown();
+
+ System.out.println("Summary: ");
+ double average = Arrays.stream(throughPuts).sum() / numRuns;
+ double sqrSum = Arrays.stream(throughPuts).map(val -> val - average).map(val -> val * val).sum();
+ double stddev = Math.sqrt(sqrSum / numRuns);
+ System.out.println(String.format("Average throughput: %f MiB/s, standard deviation: %f MiB/s",
+ average, stddev));
+ }
+
+ private final class Consumer implements Callable<Result> {
+
+ private final FlightClient client;
+ private final Ticket ticket;
+
+ public Consumer(FlightClient client, Ticket ticket) {
+ super();
+ this.client = client;
+ this.ticket = ticket;
+ }
+
+ @Override
+ public Result call() throws Exception {
+ final Result r = new Result();
+ Stopwatch watch = Stopwatch.createStarted();
+ try (final FlightStream stream = client.getStream(ticket)) {
+ final VectorSchemaRoot root = stream.getRoot();
+ try {
+ BigIntVector a = (BigIntVector) root.getVector("a");
+ while (stream.next()) {
+ int rows = root.getRowCount();
+ long aSum = r.aSum;
+ for (int i = 0; i < rows; i++) {
+ if (VALIDATE) {
+ aSum += a.get(i);
+ }
+ }
+ r.bytes += rows * 32;
+ r.rows += rows;
+ r.aSum = aSum;
+ r.batches++;
+ }
+
+ r.nanos = watch.elapsed(TimeUnit.NANOSECONDS);
+ return r;
+ } finally {
+ root.clear();
+ }
+ }
+ }
+
+ }
+
+ private final class Result {
+ private long rows;
+ private long aSum;
+ private long bytes;
+ private long nanos;
+ private long batches;
+
+ public void add(Result r) {
+ rows += r.rows;
+ aSum += r.aSum;
+ bytes += r.bytes;
+ batches += r.batches;
+ nanos = Math.max(nanos, r.nanos);
+ }
+
+ @Override
+ public String toString() {
+ return MoreObjects.toStringHelper(this)
+ .add("rows", rows)
+ .add("aSum", aSum)
+ .add("batches", batches)
+ .add("bytes", bytes)
+ .add("nanos", nanos)
+ .toString();
+ }
+ }
+}
diff --git a/src/arrow/java/flight/flight-core/src/test/protobuf/perf.proto b/src/arrow/java/flight/flight-core/src/test/protobuf/perf.proto
new file mode 100644
index 000000000..99f35a9e6
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/protobuf/perf.proto
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ * <p>
+ * http://www.apache.org/licenses/LICENSE-2.0
+ * <p>
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+
+option java_package = "org.apache.arrow.flight.perf.impl";
+
+message Perf {
+ bytes schema = 1;
+ int32 stream_count = 2;
+ int64 records_per_stream = 3;
+ int32 records_per_batch = 4;
+}
+
+/*
+ * Payload of ticket
+ */
+message Token {
+
+ // definition of entire flight.
+ Perf definition = 1;
+
+ // inclusive start
+ int64 start = 2;
+
+ // exclusive end
+ int64 end = 3;
+
+}
+
diff --git a/src/arrow/java/flight/flight-core/src/test/resources/logback.xml b/src/arrow/java/flight/flight-core/src/test/resources/logback.xml
new file mode 100644
index 000000000..444b2ed6d
--- /dev/null
+++ b/src/arrow/java/flight/flight-core/src/test/resources/logback.xml
@@ -0,0 +1,28 @@
+<?xml version="1.0" encoding="UTF-8" ?>
+<!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor
+ license agreements. See the NOTICE file distributed with this work for additional
+ information regarding copyright ownership. The ASF licenses this file to
+ You under the Apache License, Version 2.0 (the "License"); you may not use
+ this file except in compliance with the License. You may obtain a copy of
+ the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required
+ by applicable law or agreed to in writing, software distributed under the
+ License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS
+ OF ANY KIND, either express or implied. See the License for the specific
+ language governing permissions and limitations under the License. -->
+
+<configuration>
+ <statusListener class="ch.qos.logback.core.status.NopStatusListener"/>
+ <appender name="SOCKET"
+ class="de.huxhorn.lilith.logback.appender.ClassicMultiplexSocketAppender">
+ <Compressing>true</Compressing>
+ <ReconnectionDelay>10000</ReconnectionDelay>
+ <IncludeCallerData>true</IncludeCallerData>
+ <RemoteHosts>${LILITH_HOSTNAME:-localhost}</RemoteHosts>
+ </appender>
+
+ <logger name="org.apache.arrow" additivity="false">
+ <level value="info" />
+ <appender-ref ref="FILE" />
+ </logger>
+
+</configuration>
diff --git a/src/arrow/java/flight/flight-grpc/pom.xml b/src/arrow/java/flight/flight-grpc/pom.xml
new file mode 100644
index 000000000..1968484a1
--- /dev/null
+++ b/src/arrow/java/flight/flight-grpc/pom.xml
@@ -0,0 +1,132 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor
+ license agreements. See the NOTICE file distributed with this work for additional
+ information regarding copyright ownership. The ASF licenses this file to
+ You under the Apache License, Version 2.0 (the "License"); you may not use
+ this file except in compliance with the License. You may obtain a copy of
+ the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required
+ by applicable law or agreed to in writing, software distributed under the
+ License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS
+ OF ANY KIND, either express or implied. See the License for the specific
+ language governing permissions and limitations under the License. -->
+<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
+ <parent>
+ <artifactId>arrow-java-root</artifactId>
+ <groupId>org.apache.arrow</groupId>
+ <version>6.0.1</version>
+ <relativePath>../../pom.xml</relativePath>
+ </parent>
+ <modelVersion>4.0.0</modelVersion>
+
+ <artifactId>flight-grpc</artifactId>
+ <name>Arrow Flight GRPC</name>
+ <description>(Experimental)Contains utility class to expose Flight gRPC service and client</description>
+ <packaging>jar</packaging>
+
+ <properties>
+ <dep.grpc.version>1.41.0</dep.grpc.version>
+ <dep.protobuf.version>3.7.1</dep.protobuf.version>
+ <forkCount>1</forkCount>
+ </properties>
+
+ <dependencies>
+ <dependency>
+ <groupId>org.apache.arrow</groupId>
+ <artifactId>flight-core</artifactId>
+ <version>${project.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-transport-native-unix-common</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-transport-native-kqueue</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>io.netty</groupId>
+ <artifactId>netty-transport-native-epoll</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
+ <groupId>io.grpc</groupId>
+ <artifactId>grpc-core</artifactId>
+ <version>${dep.grpc.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>io.grpc</groupId>
+ <artifactId>grpc-stub</artifactId>
+ <version>${dep.grpc.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.arrow</groupId>
+ <artifactId>arrow-memory-core</artifactId>
+ <version>${project.version}</version>
+ <scope>compile</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.arrow</groupId>
+ <artifactId>arrow-memory-netty</artifactId>
+ <version>${project.version}</version>
+ <scope>runtime</scope>
+ </dependency>
+ <dependency>
+ <groupId>io.grpc</groupId>
+ <artifactId>grpc-protobuf</artifactId>
+ <version>${dep.grpc.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ </dependency>
+ <dependency>
+ <groupId>com.google.protobuf</groupId>
+ <artifactId>protobuf-java</artifactId>
+ <version>${dep.protobuf.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>io.grpc</groupId>
+ <artifactId>grpc-api</artifactId>
+ <version>${dep.grpc.version}</version>
+ </dependency>
+ </dependencies>
+
+ <build>
+ <extensions>
+ <!-- provides os.detected.classifier (i.e. linux-x86_64, osx-x86_64) property -->
+ <extension>
+ <groupId>kr.motd.maven</groupId>
+ <artifactId>os-maven-plugin</artifactId>
+ <version>1.5.0.Final</version>
+ </extension>
+ </extensions>
+ <plugins>
+ <plugin>
+ <groupId>org.xolstice.maven.plugins</groupId>
+ <artifactId>protobuf-maven-plugin</artifactId>
+ <version>0.5.0</version>
+ <configuration>
+ <protocArtifact>com.google.protobuf:protoc:${dep.protobuf.version}:exe:${os.detected.classifier}</protocArtifact>
+ <clearOutputDirectory>false</clearOutputDirectory>
+ <pluginId>grpc-java</pluginId>
+ <pluginArtifact>io.grpc:protoc-gen-grpc-java:${dep.grpc.version}:exe:${os.detected.classifier}</pluginArtifact>
+ </configuration>
+ <executions>
+ <execution>
+ <id>test</id>
+ <configuration>
+ <protoSourceRoot>${basedir}/src/test/protobuf</protoSourceRoot>
+ <outputDirectory>${project.build.directory}/generated-test-sources//protobuf</outputDirectory>
+ </configuration>
+ <goals>
+ <goal>compile</goal>
+ <goal>compile-custom</goal>
+ </goals>
+ </execution>
+ </executions>
+ </plugin>
+ </plugins>
+ </build>
+
+</project>
diff --git a/src/arrow/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java b/src/arrow/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java
new file mode 100644
index 000000000..eb5e492b4
--- /dev/null
+++ b/src/arrow/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import java.util.Collections;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.arrow.flight.auth.ServerAuthHandler;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.util.VisibleForTesting;
+
+import io.grpc.BindableService;
+import io.grpc.CallOptions;
+import io.grpc.ClientCall;
+import io.grpc.ConnectivityState;
+import io.grpc.ManagedChannel;
+import io.grpc.MethodDescriptor;
+
+/**
+ * Exposes Flight GRPC service & client.
+ */
+public class FlightGrpcUtils {
+ /**
+ * Proxy class for ManagedChannel that makes closure a no-op.
+ */
+ @VisibleForTesting
+ static class NonClosingProxyManagedChannel extends ManagedChannel {
+ private final ManagedChannel channel;
+ private boolean isShutdown;
+
+ NonClosingProxyManagedChannel(ManagedChannel channel) {
+ this.channel = channel;
+ this.isShutdown = channel.isShutdown();
+ }
+
+ @Override
+ public ManagedChannel shutdown() {
+ isShutdown = true;
+ return this;
+ }
+
+ @Override
+ public boolean isShutdown() {
+ if (this.channel.isShutdown()) {
+ // If the underlying channel is shut down, ensure we're updated to match.
+ shutdown();
+ }
+ return isShutdown;
+ }
+
+ @Override
+ public boolean isTerminated() {
+ return this.isShutdown();
+ }
+
+ @Override
+ public ManagedChannel shutdownNow() {
+ return shutdown();
+ }
+
+ @Override
+ public boolean awaitTermination(long l, TimeUnit timeUnit) {
+ // Don't actually await termination, since it'll be a no-op, so simply return whether or not
+ // the channel has been shut down already.
+ return this.isShutdown();
+ }
+
+ @Override
+ public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(
+ MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) {
+ if (this.isShutdown()) {
+ throw new IllegalStateException("Channel has been shut down.");
+ }
+
+ return this.channel.newCall(methodDescriptor, callOptions);
+ }
+
+ @Override
+ public String authority() {
+ return this.channel.authority();
+ }
+
+ @Override
+ public ConnectivityState getState(boolean requestConnection) {
+ if (this.isShutdown()) {
+ return ConnectivityState.SHUTDOWN;
+ }
+
+ return this.channel.getState(requestConnection);
+ }
+
+ @Override
+ public void notifyWhenStateChanged(ConnectivityState source, Runnable callback) {
+ // The proxy has no insight into the underlying channel state changes, so we'll have to leak the abstraction
+ // a bit here and simply pass to the underlying channel, even though it will never transition to shutdown via
+ // the proxy. This should be fine, since it's mainly targeted at the FlightClient and there's no getter for
+ // the channel.
+ this.channel.notifyWhenStateChanged(source, callback);
+ }
+
+ @Override
+ public void resetConnectBackoff() {
+ this.channel.resetConnectBackoff();
+ }
+
+ @Override
+ public void enterIdle() {
+ this.channel.enterIdle();
+ }
+ }
+
+ private FlightGrpcUtils() {}
+
+ /**
+ * Creates a Flight service.
+ * @param allocator Memory allocator
+ * @param producer Specifies the service api
+ * @param authHandler Authentication handler
+ * @param executor Executor service
+ * @return FlightBindingService
+ */
+ public static BindableService createFlightService(BufferAllocator allocator, FlightProducer producer,
+ ServerAuthHandler authHandler, ExecutorService executor) {
+ return new FlightBindingService(allocator, producer, authHandler, executor);
+ }
+
+ /**
+ * Creates a Flight client.
+ * @param incomingAllocator Memory allocator
+ * @param channel provides a connection to a gRPC server.
+ */
+ public static FlightClient createFlightClient(BufferAllocator incomingAllocator, ManagedChannel channel) {
+ return new FlightClient(incomingAllocator, channel, Collections.emptyList());
+ }
+
+ /**
+ * Creates a Flight client.
+ * @param incomingAllocator Memory allocator
+ * @param channel provides a connection to a gRPC server. Will not be closed on closure of the returned FlightClient.
+ */
+ public static FlightClient createFlightClientWithSharedChannel(
+ BufferAllocator incomingAllocator, ManagedChannel channel) {
+ return new FlightClient(incomingAllocator, new NonClosingProxyManagedChannel(channel), Collections.emptyList());
+ }
+}
diff --git a/src/arrow/java/flight/flight-grpc/src/test/java/org/apache/arrow/flight/TestFlightGrpcUtils.java b/src/arrow/java/flight/flight-grpc/src/test/java/org/apache/arrow/flight/TestFlightGrpcUtils.java
new file mode 100644
index 000000000..142a0f937
--- /dev/null
+++ b/src/arrow/java/flight/flight-grpc/src/test/java/org/apache/arrow/flight/TestFlightGrpcUtils.java
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.arrow.flight;
+
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+import java.io.IOException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+import org.apache.arrow.flight.auth.ServerAuthHandler;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import com.google.protobuf.Empty;
+
+import io.grpc.BindableService;
+import io.grpc.ConnectivityState;
+import io.grpc.ManagedChannel;
+import io.grpc.Server;
+import io.grpc.inprocess.InProcessChannelBuilder;
+import io.grpc.inprocess.InProcessServerBuilder;
+import io.grpc.stub.StreamObserver;
+
+/**
+ * Unit test which adds 2 services to same server end point.
+ */
+public class TestFlightGrpcUtils {
+ private Server server;
+ private BufferAllocator allocator;
+ private String serverName;
+
+ @Before
+ public void setup() throws IOException {
+ //Defines flight service
+ allocator = new RootAllocator(Integer.MAX_VALUE);
+ final NoOpFlightProducer producer = new NoOpFlightProducer();
+ final ServerAuthHandler authHandler = ServerAuthHandler.NO_OP;
+ final ExecutorService exec = Executors.newCachedThreadPool();
+ final BindableService flightBindingService = FlightGrpcUtils.createFlightService(allocator, producer,
+ authHandler, exec);
+
+ //initializes server with 2 services - FlightBindingService & TestService
+ serverName = InProcessServerBuilder.generateName();
+ server = InProcessServerBuilder.forName(serverName)
+ .directExecutor()
+ .addService(flightBindingService)
+ .addService(new TestServiceAdapter())
+ .build();
+ server.start();
+ }
+
+ @After
+ public void cleanup() {
+ server.shutdownNow();
+ }
+
+ /**
+ * This test checks if multiple gRPC services can be added to the same
+ * server endpoint and if they can be used by different clients via the same channel.
+ * @throws IOException If server fails to start.
+ */
+ @Test
+ public void testMultipleGrpcServices() throws IOException {
+ //Initializes channel so that multiple clients can communicate with server
+ final ManagedChannel managedChannel = InProcessChannelBuilder.forName(serverName)
+ .directExecutor()
+ .build();
+
+ //Defines flight client and calls service method. Since we use a NoOpFlightProducer we expect the service
+ //to throw a RunTimeException
+ final FlightClient flightClient = FlightGrpcUtils.createFlightClient(allocator, managedChannel);
+ final Iterable<ActionType> actionTypes = flightClient.listActions();
+ assertThrows(FlightRuntimeException.class, () -> actionTypes.forEach(
+ actionType -> System.out.println(actionType.toString())));
+
+ //Define Test client as a blocking stub and call test method which correctly returns an empty protobuf object
+ final TestServiceGrpc.TestServiceBlockingStub blockingStub = TestServiceGrpc.newBlockingStub(managedChannel);
+ Assert.assertEquals(Empty.newBuilder().build(), blockingStub.test(Empty.newBuilder().build()));
+ }
+
+ @Test
+ public void testShutdown() throws IOException, InterruptedException {
+ //Initializes channel so that multiple clients can communicate with server
+ final ManagedChannel managedChannel = InProcessChannelBuilder.forName(serverName)
+ .directExecutor()
+ .build();
+
+ //Defines flight client and calls service method. Since we use a NoOpFlightProducer we expect the service
+ //to throw a RunTimeException
+ final FlightClient flightClient = FlightGrpcUtils.createFlightClientWithSharedChannel(allocator, managedChannel);
+
+ // Should be a no-op.
+ flightClient.close();
+ Assert.assertFalse(managedChannel.isShutdown());
+ Assert.assertFalse(managedChannel.isTerminated());
+ Assert.assertEquals(ConnectivityState.IDLE, managedChannel.getState(false));
+ managedChannel.shutdownNow();
+ }
+
+ @Test
+ public void testProxyChannel() throws IOException, InterruptedException {
+ //Initializes channel so that multiple clients can communicate with server
+ final ManagedChannel managedChannel = InProcessChannelBuilder.forName(serverName)
+ .directExecutor()
+ .build();
+
+ final FlightGrpcUtils.NonClosingProxyManagedChannel proxyChannel =
+ new FlightGrpcUtils.NonClosingProxyManagedChannel(managedChannel);
+ Assert.assertFalse(proxyChannel.isShutdown());
+ Assert.assertFalse(proxyChannel.isTerminated());
+ proxyChannel.shutdown();
+ Assert.assertTrue(proxyChannel.isShutdown());
+ Assert.assertTrue(proxyChannel.isTerminated());
+ Assert.assertEquals(ConnectivityState.SHUTDOWN, proxyChannel.getState(false));
+ try {
+ proxyChannel.newCall(null, null);
+ Assert.fail();
+ } catch (IllegalStateException e) {
+ // This is expected, since the proxy channel is shut down.
+ }
+
+ Assert.assertFalse(managedChannel.isShutdown());
+ Assert.assertFalse(managedChannel.isTerminated());
+ Assert.assertEquals(ConnectivityState.IDLE, managedChannel.getState(false));
+
+ managedChannel.shutdownNow();
+ }
+
+ @Test
+ public void testProxyChannelWithClosedChannel() throws IOException, InterruptedException {
+ //Initializes channel so that multiple clients can communicate with server
+ final ManagedChannel managedChannel = InProcessChannelBuilder.forName(serverName)
+ .directExecutor()
+ .build();
+
+ final FlightGrpcUtils.NonClosingProxyManagedChannel proxyChannel =
+ new FlightGrpcUtils.NonClosingProxyManagedChannel(managedChannel);
+ Assert.assertFalse(proxyChannel.isShutdown());
+ Assert.assertFalse(proxyChannel.isTerminated());
+ managedChannel.shutdownNow();
+ Assert.assertTrue(proxyChannel.isShutdown());
+ Assert.assertTrue(proxyChannel.isTerminated());
+ Assert.assertEquals(ConnectivityState.SHUTDOWN, proxyChannel.getState(false));
+ try {
+ proxyChannel.newCall(null, null);
+ Assert.fail();
+ } catch (IllegalStateException e) {
+ // This is expected, since the proxy channel is shut down.
+ }
+
+ Assert.assertTrue(managedChannel.isShutdown());
+ Assert.assertTrue(managedChannel.isTerminated());
+ Assert.assertEquals(ConnectivityState.SHUTDOWN, managedChannel.getState(false));
+ }
+
+ /**
+ * Private class used for testing purposes that overrides service behavior.
+ */
+ private class TestServiceAdapter extends TestServiceGrpc.TestServiceImplBase {
+
+ /**
+ * gRPC service that receives an empty object & returns and empty protobuf object.
+ * @param request google.protobuf.Empty
+ * @param responseObserver google.protobuf.Empty
+ */
+ @Override
+ public void test(Empty request, StreamObserver<Empty> responseObserver) {
+ responseObserver.onNext(Empty.newBuilder().build());
+ responseObserver.onCompleted();
+ }
+ }
+}
+
diff --git a/src/arrow/java/flight/flight-grpc/src/test/protobuf/test.proto b/src/arrow/java/flight/flight-grpc/src/test/protobuf/test.proto
new file mode 100644
index 000000000..6fa1890b2
--- /dev/null
+++ b/src/arrow/java/flight/flight-grpc/src/test/protobuf/test.proto
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = "proto3";
+
+option java_package = "org.apache.arrow.flight";
+
+import "google/protobuf/empty.proto";
+
+service TestService {
+ rpc Test(google.protobuf.Empty) returns (google.protobuf.Empty) {}
+}