diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-21 11:54:28 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-21 11:54:28 +0000 |
commit | e6918187568dbd01842d8d1d2c808ce16a894239 (patch) | |
tree | 64f88b554b444a49f656b6c656111a145cbbaa28 /src/arrow/java/flight | |
parent | Initial commit. (diff) | |
download | ceph-e6918187568dbd01842d8d1d2c808ce16a894239.tar.xz ceph-e6918187568dbd01842d8d1d2c808ce16a894239.zip |
Adding upstream version 18.2.2.upstream/18.2.2
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/arrow/java/flight')
122 files changed, 17321 insertions, 0 deletions
diff --git a/src/arrow/java/flight/flight-core/README.md b/src/arrow/java/flight/flight-core/README.md new file mode 100644 index 000000000..37b41ede2 --- /dev/null +++ b/src/arrow/java/flight/flight-core/README.md @@ -0,0 +1,95 @@ +<!--- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + +# Arrow Flight Java Package + +Exposing Apache Arrow data on the wire. + +[Protocol Description Slides](https://www.slideshare.net/JacquesNadeau5/apache-arrow-flight-overview) + +[GRPC Protocol Definition](https://github.com/apache/arrow/blob/master/format/Flight.proto) + +## Example usage + +* Compile the java tree: + + ``` + cd java + mvn clean install -DskipTests + ``` + +* Go Into the Flight tree + + ``` + cd flight/flight-core + ``` + + +* Start the ExampleFlightServer (supports get/put of streams and listing these streams) + + ``` + mvn exec:exec + ``` + +* In new terminal, run the TestExampleServer to populate the server with example data + + ``` + cd arrow/java/flight/flight-core + mvn surefire:test -DdisableServer=true -Dtest=TestExampleServer + ``` + +## Python Example Usage + +* Compile example python headers + + ``` + mkdir target/generated-python + pip install grpcio-tools # or conda install grpcio + python -m grpc_tools.protoc -I./src/main/protobuf/ --python_out=./target/generated-python --grpc_python_out=./target/generated-python ../../format/Flight.proto + ``` + +* Connect to the Flight Service + + ``` + cd target/generated-python + python + ``` + + + ``` + import grpc + import flight_pb2 + import flight_pb2_grpc as flightrpc + channel = grpc.insecure_channel('localhost:12233') + service = flightrpc.FlightServiceStub(channel) + ``` + +* List the Flight from Python + + ``` + for f in service.ListFlights(flight_pb2.Criteria()): f + ``` + +* Try to Drop + + ``` + action = flight_pb2.Action() + action.type="drop" + service.DoAction(action) + ``` diff --git a/src/arrow/java/flight/flight-core/pom.xml b/src/arrow/java/flight/flight-core/pom.xml new file mode 100644 index 000000000..669c6b744 --- /dev/null +++ b/src/arrow/java/flight/flight-core/pom.xml @@ -0,0 +1,392 @@ +<?xml version="1.0"?> +<!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor + license agreements. See the NOTICE file distributed with this work for additional + information regarding copyright ownership. The ASF licenses this file to + You under the Apache License, Version 2.0 (the "License"); you may not use + this file except in compliance with the License. You may obtain a copy of + the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required + by applicable law or agreed to in writing, software distributed under the + License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS + OF ANY KIND, either express or implied. See the License for the specific + language governing permissions and limitations under the License. --> +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <modelVersion>4.0.0</modelVersion> + <parent> + <groupId>org.apache.arrow</groupId> + <artifactId>arrow-java-root</artifactId> + <version>6.0.1</version> + <relativePath>../../pom.xml</relativePath> + </parent> + + <artifactId>flight-core</artifactId> + <name>Arrow Flight Core</name> + <description>(Experimental)An RPC mechanism for transferring ValueVectors.</description> + <packaging>jar</packaging> + + <properties> + <dep.grpc.version>1.41.0</dep.grpc.version> + <dep.protobuf.version>3.7.1</dep.protobuf.version> + <forkCount>1</forkCount> + </properties> + + <dependencies> + <dependency> + <groupId>org.apache.arrow</groupId> + <artifactId>arrow-format</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> + <groupId>org.apache.arrow</groupId> + <artifactId>arrow-vector</artifactId> + <version>${project.version}</version> + <classifier>${arrow.vector.classifier}</classifier> + </dependency> + <dependency> + <groupId>org.apache.arrow</groupId> + <artifactId>arrow-memory-core</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> + <groupId>org.apache.arrow</groupId> + <artifactId>arrow-memory-netty</artifactId> + <version>${project.version}</version> + <scope>runtime</scope> + </dependency> + <dependency> + <groupId>io.grpc</groupId> + <artifactId>grpc-netty</artifactId> + <version>${dep.grpc.version}</version> + </dependency> + <dependency> + <groupId>io.grpc</groupId> + <artifactId>grpc-core</artifactId> + <version>${dep.grpc.version}</version> + </dependency> + <dependency> + <groupId>io.grpc</groupId> + <artifactId>grpc-context</artifactId> + <version>${dep.grpc.version}</version> + </dependency> + <dependency> + <groupId>io.grpc</groupId> + <artifactId>grpc-protobuf</artifactId> + <version>${dep.grpc.version}</version> + </dependency> + <dependency> + <groupId>io.netty</groupId> + <artifactId>netty-tcnative-boringssl-static</artifactId> + <version>2.0.43.Final</version> + </dependency> + <dependency> + <groupId>io.netty</groupId> + <artifactId>netty-buffer</artifactId> + </dependency> + <dependency> + <groupId>io.netty</groupId> + <artifactId>netty-handler</artifactId> + <version>${dep.netty.version}</version> + </dependency> + <dependency> + <groupId>io.netty</groupId> + <artifactId>netty-transport</artifactId> + <version>${dep.netty.version}</version> + </dependency> + <dependency> + <groupId>com.google.guava</groupId> + <artifactId>guava</artifactId> + </dependency> + <dependency> + <groupId>commons-cli</groupId> + <artifactId>commons-cli</artifactId> + <version>1.4</version> + </dependency> + <dependency> + <groupId>io.grpc</groupId> + <artifactId>grpc-stub</artifactId> + <version>${dep.grpc.version}</version> + </dependency> + <dependency> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + <version>${dep.protobuf.version}</version> + </dependency> + <dependency> + <groupId>io.grpc</groupId> + <artifactId>grpc-api</artifactId> + <version>${dep.grpc.version}</version> + </dependency> + + <dependency> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-core</artifactId> + </dependency> + <dependency> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-annotations</artifactId> + </dependency> + <dependency> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-databind</artifactId> + </dependency> + <dependency> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-api</artifactId> + </dependency> + <dependency> + <groupId>javax.annotation</groupId> + <artifactId>javax.annotation-api</artifactId> + </dependency> + + <dependency> + <groupId>com.google.api.grpc</groupId> + <artifactId>proto-google-common-protos</artifactId> + <version>1.12.0</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.apache.arrow</groupId> + <artifactId>arrow-vector</artifactId> + <version>${project.version}</version> + <classifier>tests</classifier> + <type>test-jar</type> + <scope>test</scope> + </dependency> + </dependencies> + <build> + <extensions> + <extension> + <groupId>kr.motd.maven</groupId> + <artifactId>os-maven-plugin</artifactId> + <version>1.5.0.Final</version> + </extension> + </extensions> + <plugins> + <plugin> + <artifactId>maven-surefire-plugin</artifactId> + <configuration> + <enableAssertions>false</enableAssertions> + <systemPropertyVariables> + <arrow.test.dataRoot>${project.basedir}/../../../testing/data</arrow.test.dataRoot> + </systemPropertyVariables> + </configuration> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-shade-plugin</artifactId> + <version>3.1.1</version> + <executions> + <execution> + <id>shade-main</id> + <phase>package</phase> + <goals> + <goal>shade</goal> + </goals> + <configuration> + <shadedArtifactAttached>true</shadedArtifactAttached> + <shadedClassifierName>shaded</shadedClassifierName> + <artifactSet> + <includes> + <include>io.grpc:*</include> + <include>com.google.protobuf:*</include> + </includes> + </artifactSet> + <relocations> + <relocation> + <pattern>com.google.protobuf</pattern> + <shadedPattern>arrow.flight.com.google.protobuf</shadedPattern> + </relocation> + </relocations> + <transformers> + <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer" /> + </transformers> + </configuration> + </execution> + <execution> + <id>shade-ext</id> + <phase>package</phase> + <goals> + <goal>shade</goal> + </goals> + <configuration> + <shadedArtifactAttached>true</shadedArtifactAttached> + <shadedClassifierName>shaded-ext</shadedClassifierName> + <artifactSet> + <includes> + <include>io.grpc:*</include> + <include>com.google.protobuf:*</include> + <include>com.google.guava:*</include> + </includes> + </artifactSet> + <relocations> + <relocation> + <pattern>com.google.protobuf</pattern> + <shadedPattern>arrow.flight.com.google.protobuf</shadedPattern> + </relocation> + <relocation> + <pattern>com.google.common</pattern> + <shadedPattern>arrow.flight.com.google.common</shadedPattern> + </relocation> + </relocations> + <transformers> + <transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer" /> + </transformers> + </configuration> + </execution> + </executions> + </plugin> + <plugin> + <groupId>org.xolstice.maven.plugins</groupId> + <artifactId>protobuf-maven-plugin</artifactId> + <version>0.5.0</version> + <configuration> + <protocArtifact>com.google.protobuf:protoc:${dep.protobuf.version}:exe:${os.detected.classifier}</protocArtifact> + <clearOutputDirectory>false</clearOutputDirectory> + <pluginId>grpc-java</pluginId> + <pluginArtifact>io.grpc:protoc-gen-grpc-java:${dep.grpc.version}:exe:${os.detected.classifier}</pluginArtifact> + </configuration> + <executions> + <execution> + <id>src</id> + <configuration> + <protoSourceRoot>${basedir}/../../../format/</protoSourceRoot> + <outputDirectory>${project.build.directory}/generated-sources/protobuf</outputDirectory> + </configuration> + <goals> + <goal>compile</goal> + <goal>compile-custom</goal> + </goals> + </execution> + <execution> + <id>test</id> + <configuration> + <protoSourceRoot>${basedir}/src/test/protobuf</protoSourceRoot> + <outputDirectory>${project.build.directory}/generated-test-sources//protobuf</outputDirectory> + </configuration> + <goals> + <goal>compile</goal> + <goal>compile-custom</goal> + </goals> + </execution> + </executions> + </plugin> + <plugin> + <groupId>org.codehaus.mojo</groupId> + <artifactId>exec-maven-plugin</artifactId> + <version>1.6.0</version> + <configuration> + <executable>java</executable> + <classpathScope>test</classpathScope> + <arguments> + <argument>-classpath</argument> + <classpath /> + <argument>-Xms64m</argument> + <argument>-Xmx64m</argument> + <argument>-XX:MaxDirectMemorySize=4g</argument> + <argument>org.apache.arrow.flight.example.ExampleFlightServer</argument> + </arguments> + </configuration> + </plugin> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-dependency-plugin</artifactId> + <executions> + <execution> + <id>analyze</id> + <phase>verify</phase> + <goals> + <goal>analyze-only</goal> + </goals> + <configuration> + <ignoredDependencies combine.children="append"> + <ignoredDependency>io.netty:netty-tcnative-boringssl-static:*</ignoredDependency> + </ignoredDependencies> + </configuration> + </execution> + </executions> + </plugin> + <plugin> <!-- add generated sources to classpath --> + <groupId>org.codehaus.mojo</groupId> + <artifactId>build-helper-maven-plugin</artifactId> + <version>1.9.1</version> + <executions> + <execution> + <id>add-generated-sources-to-classpath</id> + <phase>generate-sources</phase> + <goals> + <goal>add-source</goal> + </goals> + <configuration> + <sources> + <source>${project.build.directory}/generated-sources/protobuf</source> + </sources> + </configuration> + </execution> + </executions> + </plugin> + <plugin> + <artifactId>maven-assembly-plugin</artifactId> + <version>3.0.0</version> + <configuration> + <descriptorRefs> + <descriptorRef>jar-with-dependencies</descriptorRef> + </descriptorRefs> + </configuration> + <executions> + <execution> + <id>make-assembly</id> + <phase>package</phase> + <goals> + <goal>single</goal> + </goals> + </execution> + </executions> + </plugin> + </plugins> + </build> + <profiles> + <profile> + <id>linux-netty-native</id> + <activation> + <os> + <family>linux</family> + </os> + </activation> + <dependencies> + <dependency> + <groupId>io.netty</groupId> + <artifactId>netty-transport-native-unix-common</artifactId> + <version>${dep.netty.version}</version> + <classifier>${os.detected.name}-${os.detected.arch}</classifier> + </dependency> + <dependency> + <groupId>io.netty</groupId> + <artifactId>netty-transport-native-epoll</artifactId> + <version>${dep.netty.version}</version> + <classifier>${os.detected.name}-${os.detected.arch}</classifier> + </dependency> + </dependencies> + </profile> + <profile> + <id>mac-netty-native</id> + <activation> + <os> + <family>mac</family> + </os> + </activation> + <dependencies> + <dependency> + <groupId>io.netty</groupId> + <artifactId>netty-transport-native-unix-common</artifactId> + <version>${dep.netty.version}</version> + <classifier>${os.detected.name}-${os.detected.arch}</classifier> + </dependency> + <dependency> + <groupId>io.netty</groupId> + <artifactId>netty-transport-native-kqueue</artifactId> + <version>${dep.netty.version}</version> + <classifier>${os.detected.name}-${os.detected.arch}</classifier> + </dependency> + </dependencies> + </profile> + </profiles> +</project> diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Action.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Action.java new file mode 100644 index 000000000..524ffcab9 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Action.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import org.apache.arrow.flight.impl.Flight; + +import com.google.protobuf.ByteString; + +/** + * An opaque action for the service to perform. + * + * <p>This is a POJO wrapper around the message of the same name in Flight.proto. + */ +public class Action { + + private final String type; + private final byte[] body; + + public Action(String type) { + this(type, null); + } + + public Action(String type, byte[] body) { + this.type = type; + this.body = body == null ? new byte[0] : body; + } + + Action(Flight.Action action) { + this(action.getType(), action.getBody().toByteArray()); + } + + public String getType() { + return type; + } + + public byte[] getBody() { + return body; + } + + Flight.Action toProtocol() { + return Flight.Action.newBuilder() + .setType(getType()) + .setBody(ByteString.copyFrom(getBody())) + .build(); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ActionType.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ActionType.java new file mode 100644 index 000000000..d89365612 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ActionType.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import org.apache.arrow.flight.impl.Flight; + +/** + * POJO wrapper around protocol specifics for Flight actions. + */ +public class ActionType { + private final String type; + private final String description; + + /** + * Construct a new instance. + * + * @param type The type of action to perform + * @param description The description of the type. + */ + public ActionType(String type, String description) { + super(); + this.type = type; + this.description = description; + } + + /** + * Constructs a new instance from the corresponding protocol buffer object. + */ + ActionType(Flight.ActionType type) { + this.type = type.getType(); + this.description = type.getDescription(); + } + + public String getType() { + return type; + } + + /** + * Converts the POJO to the corresponding protocol buffer type. + */ + Flight.ActionType toProtocol() { + return Flight.ActionType.newBuilder() + .setType(type) + .setDescription(description) + .build(); + } + + @Override + public String toString() { + return "ActionType{" + + "type='" + type + '\'' + + ", description='" + description + '\'' + + '}'; + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java new file mode 100644 index 000000000..b4ee835de --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ArrowMessage.java @@ -0,0 +1,560 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import org.apache.arrow.flight.grpc.AddWritableBuffer; +import org.apache.arrow.flight.grpc.GetReadableBuffer; +import org.apache.arrow.flight.impl.Flight.FlightData; +import org.apache.arrow.flight.impl.Flight.FlightDescriptor; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.compression.NoCompressionCodec; +import org.apache.arrow.vector.ipc.message.ArrowBodyCompression; +import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.ipc.message.MessageMetadataResult; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.MetadataVersion; +import org.apache.arrow.vector.types.pojo.Schema; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.common.io.ByteStreams; +import com.google.protobuf.ByteString; +import com.google.protobuf.CodedInputStream; +import com.google.protobuf.CodedOutputStream; +import com.google.protobuf.WireFormat; + +import io.grpc.Drainable; +import io.grpc.MethodDescriptor.Marshaller; +import io.grpc.protobuf.ProtoUtils; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; + +/** + * The in-memory representation of FlightData used to manage a stream of Arrow messages. + */ +class ArrowMessage implements AutoCloseable { + + // If true, deserialize Arrow data by giving Arrow a reference to the underlying gRPC buffer + // instead of copying the data. Defaults to true. + public static final boolean ENABLE_ZERO_COPY_READ; + // If true, serialize Arrow data by giving gRPC a reference to the underlying Arrow buffer + // instead of copying the data. Defaults to false. + public static final boolean ENABLE_ZERO_COPY_WRITE; + + static { + String zeroCopyReadFlag = System.getProperty("arrow.flight.enable_zero_copy_read"); + if (zeroCopyReadFlag == null) { + zeroCopyReadFlag = System.getenv("ARROW_FLIGHT_ENABLE_ZERO_COPY_READ"); + } + String zeroCopyWriteFlag = System.getProperty("arrow.flight.enable_zero_copy_write"); + if (zeroCopyWriteFlag == null) { + zeroCopyWriteFlag = System.getenv("ARROW_FLIGHT_ENABLE_ZERO_COPY_WRITE"); + } + ENABLE_ZERO_COPY_READ = !"false".equalsIgnoreCase(zeroCopyReadFlag); + ENABLE_ZERO_COPY_WRITE = "true".equalsIgnoreCase(zeroCopyWriteFlag); + } + + private static final int DESCRIPTOR_TAG = + (FlightData.FLIGHT_DESCRIPTOR_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; + private static final int BODY_TAG = + (FlightData.DATA_BODY_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; + private static final int HEADER_TAG = + (FlightData.DATA_HEADER_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; + private static final int APP_METADATA_TAG = + (FlightData.APP_METADATA_FIELD_NUMBER << 3) | WireFormat.WIRETYPE_LENGTH_DELIMITED; + + private static final Marshaller<FlightData> NO_BODY_MARSHALLER = + ProtoUtils.marshaller(FlightData.getDefaultInstance()); + + /** Get the application-specific metadata in this message. The ArrowMessage retains ownership of the buffer. */ + public ArrowBuf getApplicationMetadata() { + return appMetadata; + } + + /** Types of messages that can be sent. */ + public enum HeaderType { + NONE, + SCHEMA, + DICTIONARY_BATCH, + RECORD_BATCH, + TENSOR + ; + + public static HeaderType getHeader(byte b) { + switch (b) { + case 0: return NONE; + case 1: return SCHEMA; + case 2: return DICTIONARY_BATCH; + case 3: return RECORD_BATCH; + case 4: return TENSOR; + default: + throw new UnsupportedOperationException("unknown type: " + b); + } + } + + } + + // Pre-allocated buffers for padding serialized ArrowMessages. + private static final List<ByteBuf> PADDING_BUFFERS = Arrays.asList( + null, + Unpooled.copiedBuffer(new byte[] { 0 }), + Unpooled.copiedBuffer(new byte[] { 0, 0 }), + Unpooled.copiedBuffer(new byte[] { 0, 0, 0 }), + Unpooled.copiedBuffer(new byte[] { 0, 0, 0, 0 }), + Unpooled.copiedBuffer(new byte[] { 0, 0, 0, 0, 0 }), + Unpooled.copiedBuffer(new byte[] { 0, 0, 0, 0, 0, 0 }), + Unpooled.copiedBuffer(new byte[] { 0, 0, 0, 0, 0, 0, 0 }) + ); + + private final IpcOption writeOption; + private final FlightDescriptor descriptor; + private final MessageMetadataResult message; + private final ArrowBuf appMetadata; + private final List<ArrowBuf> bufs; + private final ArrowBodyCompression bodyCompression; + private final boolean tryZeroCopyWrite; + + public ArrowMessage(FlightDescriptor descriptor, Schema schema, IpcOption option) { + this.writeOption = option; + ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(schema, writeOption); + this.message = MessageMetadataResult.create(serializedMessage.slice(), + serializedMessage.remaining()); + bufs = ImmutableList.of(); + this.descriptor = descriptor; + this.appMetadata = null; + this.bodyCompression = NoCompressionCodec.DEFAULT_BODY_COMPRESSION; + this.tryZeroCopyWrite = false; + } + + /** + * Create an ArrowMessage from a record batch and app metadata. + * @param batch The record batch. + * @param appMetadata The app metadata. May be null. Takes ownership of the buffer otherwise. + * @param tryZeroCopy Whether to enable the zero-copy optimization. + */ + public ArrowMessage(ArrowRecordBatch batch, ArrowBuf appMetadata, boolean tryZeroCopy, IpcOption option) { + this.writeOption = option; + ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(batch, writeOption); + this.message = MessageMetadataResult.create(serializedMessage.slice(), serializedMessage.remaining()); + this.bufs = ImmutableList.copyOf(batch.getBuffers()); + this.descriptor = null; + this.appMetadata = appMetadata; + this.bodyCompression = batch.getBodyCompression(); + this.tryZeroCopyWrite = tryZeroCopy; + } + + public ArrowMessage(ArrowDictionaryBatch batch, IpcOption option) { + this.writeOption = option; + ByteBuffer serializedMessage = MessageSerializer.serializeMetadata(batch, writeOption); + serializedMessage = serializedMessage.slice(); + this.message = MessageMetadataResult.create(serializedMessage, serializedMessage.remaining()); + // asInputStream will free the buffers implicitly, so increment the reference count + batch.getDictionary().getBuffers().forEach(buf -> buf.getReferenceManager().retain()); + this.bufs = ImmutableList.copyOf(batch.getDictionary().getBuffers()); + this.descriptor = null; + this.appMetadata = null; + this.bodyCompression = batch.getDictionary().getBodyCompression(); + this.tryZeroCopyWrite = false; + } + + /** + * Create an ArrowMessage containing only application metadata. + * @param appMetadata The application-provided metadata buffer. + */ + public ArrowMessage(ArrowBuf appMetadata) { + // No need to take IpcOption as it's not used to serialize this kind of message. + this.writeOption = IpcOption.DEFAULT; + this.message = null; + this.bufs = ImmutableList.of(); + this.descriptor = null; + this.appMetadata = appMetadata; + this.bodyCompression = NoCompressionCodec.DEFAULT_BODY_COMPRESSION; + this.tryZeroCopyWrite = false; + } + + public ArrowMessage(FlightDescriptor descriptor) { + // No need to take IpcOption as it's not used to serialize this kind of message. + this.writeOption = IpcOption.DEFAULT; + this.message = null; + this.bufs = ImmutableList.of(); + this.descriptor = descriptor; + this.appMetadata = null; + this.bodyCompression = NoCompressionCodec.DEFAULT_BODY_COMPRESSION; + this.tryZeroCopyWrite = false; + } + + private ArrowMessage(FlightDescriptor descriptor, MessageMetadataResult message, ArrowBuf appMetadata, + ArrowBuf buf) { + // No need to take IpcOption as this is used for deserialized ArrowMessage coming from the wire. + this.writeOption = message != null ? + // avoid writing legacy ipc format by default + new IpcOption(false, MetadataVersion.fromFlatbufID(message.getMessage().version())) : + IpcOption.DEFAULT; + this.message = message; + this.descriptor = descriptor; + this.appMetadata = appMetadata; + this.bufs = buf == null ? ImmutableList.of() : ImmutableList.of(buf); + this.bodyCompression = NoCompressionCodec.DEFAULT_BODY_COMPRESSION; + this.tryZeroCopyWrite = false; + } + + public MessageMetadataResult asSchemaMessage() { + return message; + } + + public FlightDescriptor getDescriptor() { + return descriptor; + } + + public HeaderType getMessageType() { + if (message == null) { + // Null message occurs for metadata-only messages (in DoExchange) + return HeaderType.NONE; + } + return HeaderType.getHeader(message.headerType()); + } + + public Schema asSchema() { + Preconditions.checkArgument(bufs.size() == 0); + Preconditions.checkArgument(getMessageType() == HeaderType.SCHEMA); + return MessageSerializer.deserializeSchema(message); + } + + public ArrowRecordBatch asRecordBatch() throws IOException { + Preconditions.checkArgument(bufs.size() == 1, "A batch can only be consumed if it contains a single ArrowBuf."); + Preconditions.checkArgument(getMessageType() == HeaderType.RECORD_BATCH); + + ArrowBuf underlying = bufs.get(0); + + underlying.getReferenceManager().retain(); + return MessageSerializer.deserializeRecordBatch(message, underlying); + } + + public ArrowDictionaryBatch asDictionaryBatch() throws IOException { + Preconditions.checkArgument(bufs.size() == 1, "A batch can only be consumed if it contains a single ArrowBuf."); + Preconditions.checkArgument(getMessageType() == HeaderType.DICTIONARY_BATCH); + ArrowBuf underlying = bufs.get(0); + // Retain a reference to keep the batch alive when the message is closed + underlying.getReferenceManager().retain(); + // Do not set drained - we still want to release our reference + return MessageSerializer.deserializeDictionaryBatch(message, underlying); + } + + public Iterable<ArrowBuf> getBufs() { + return Iterables.unmodifiableIterable(bufs); + } + + private static ArrowMessage frame(BufferAllocator allocator, final InputStream stream) { + + try { + FlightDescriptor descriptor = null; + MessageMetadataResult header = null; + ArrowBuf body = null; + ArrowBuf appMetadata = null; + while (stream.available() > 0) { + int tag = readRawVarint32(stream); + switch (tag) { + + case DESCRIPTOR_TAG: { + int size = readRawVarint32(stream); + byte[] bytes = new byte[size]; + ByteStreams.readFully(stream, bytes); + descriptor = FlightDescriptor.parseFrom(bytes); + break; + } + case HEADER_TAG: { + int size = readRawVarint32(stream); + byte[] bytes = new byte[size]; + ByteStreams.readFully(stream, bytes); + header = MessageMetadataResult.create(ByteBuffer.wrap(bytes), size); + break; + } + case APP_METADATA_TAG: { + int size = readRawVarint32(stream); + appMetadata = allocator.buffer(size); + GetReadableBuffer.readIntoBuffer(stream, appMetadata, size, ENABLE_ZERO_COPY_READ); + break; + } + case BODY_TAG: + if (body != null) { + // only read last body. + body.getReferenceManager().release(); + body = null; + } + int size = readRawVarint32(stream); + body = allocator.buffer(size); + GetReadableBuffer.readIntoBuffer(stream, body, size, ENABLE_ZERO_COPY_READ); + break; + + default: + // ignore unknown fields. + } + } + // Protobuf implementations can omit empty fields, such as body; for some message types, like RecordBatch, + // this will fail later as we still expect an empty buffer. In those cases only, fill in an empty buffer here - + // in other cases, like Schema, having an unexpected empty buffer will also cause failures. + // We don't fill in defaults for fields like header, for which there is no reasonable default, or for appMetadata + // or descriptor, which are intended to be empty in some cases. + if (header != null) { + switch (HeaderType.getHeader(header.headerType())) { + case SCHEMA: + // Ignore 0-length buffers in case a Protobuf implementation wrote it out + if (body != null && body.capacity() == 0) { + body.close(); + body = null; + } + break; + case DICTIONARY_BATCH: + case RECORD_BATCH: + // A Protobuf implementation can skip 0-length bodies, so ensure we fill it in here + if (body == null) { + body = allocator.getEmpty(); + } + break; + case NONE: + case TENSOR: + default: + // Do nothing + break; + } + } + return new ArrowMessage(descriptor, header, appMetadata, body); + } catch (Exception ioe) { + throw new RuntimeException(ioe); + } + + } + + private static int readRawVarint32(InputStream is) throws IOException { + int firstByte = is.read(); + return CodedInputStream.readRawVarint32(firstByte, is); + } + + /** + * Convert the ArrowMessage to an InputStream. + * + * <p>Implicitly, this transfers ownership of the contained buffers to the InputStream. + * + * @return InputStream + */ + private InputStream asInputStream(BufferAllocator allocator) { + if (message == null) { + // If we have no IPC message, it's a pure-metadata message + final FlightData.Builder builder = FlightData.newBuilder(); + if (descriptor != null) { + builder.setFlightDescriptor(descriptor); + } + if (appMetadata != null) { + builder.setAppMetadata(ByteString.copyFrom(appMetadata.nioBuffer())); + } + return NO_BODY_MARSHALLER.stream(builder.build()); + } + + try { + final ByteString bytes = ByteString.copyFrom(message.getMessageBuffer(), + message.bytesAfterMessage()); + + if (getMessageType() == HeaderType.SCHEMA) { + + final FlightData.Builder builder = FlightData.newBuilder() + .setDataHeader(bytes); + + if (descriptor != null) { + builder.setFlightDescriptor(descriptor); + } + + Preconditions.checkArgument(bufs.isEmpty()); + return NO_BODY_MARSHALLER.stream(builder.build()); + } + + Preconditions.checkArgument(getMessageType() == HeaderType.RECORD_BATCH || + getMessageType() == HeaderType.DICTIONARY_BATCH); + // There may be no buffers in the case that we write only a null array + Preconditions.checkArgument(descriptor == null, "Descriptor should only be included in the schema message."); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + CodedOutputStream cos = CodedOutputStream.newInstance(baos); + cos.writeBytes(FlightData.DATA_HEADER_FIELD_NUMBER, bytes); + + if (appMetadata != null && appMetadata.capacity() > 0) { + // Must call slice() as CodedOutputStream#writeByteBuffer writes -capacity- bytes, not -limit- bytes + cos.writeByteBuffer(FlightData.APP_METADATA_FIELD_NUMBER, appMetadata.nioBuffer().slice()); + } + + cos.writeTag(FlightData.DATA_BODY_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED); + int size = 0; + List<ByteBuf> allBufs = new ArrayList<>(); + for (ArrowBuf b : bufs) { + // [ARROW-11066] This creates a Netty buffer whose refcnt is INDEPENDENT of the backing + // Arrow buffer. This is susceptible to use-after-free, so we subclass CompositeByteBuf + // below to tie the Arrow buffer refcnt to the Netty buffer refcnt + allBufs.add(Unpooled.wrappedBuffer(b.nioBuffer()).retain()); + size += b.readableBytes(); + // [ARROW-4213] These buffers must be aligned to an 8-byte boundary in order to be readable from C++. + if (b.readableBytes() % 8 != 0) { + int paddingBytes = (int) (8 - (b.readableBytes() % 8)); + assert paddingBytes > 0 && paddingBytes < 8; + size += paddingBytes; + allBufs.add(PADDING_BUFFERS.get(paddingBytes).retain()); + } + } + // rawvarint is used for length definition. + cos.writeUInt32NoTag(size); + cos.flush(); + + ByteBuf initialBuf = Unpooled.buffer(baos.size()); + initialBuf.writeBytes(baos.toByteArray()); + final CompositeByteBuf bb; + final int maxNumComponents = Math.max(2, bufs.size() + 1); + final ImmutableList<ByteBuf> byteBufs = ImmutableList.<ByteBuf>builder() + .add(initialBuf) + .addAll(allBufs) + .build(); + if (tryZeroCopyWrite) { + bb = new ArrowBufRetainingCompositeByteBuf(maxNumComponents, byteBufs, bufs); + } else { + // Don't retain the buffers in the non-zero-copy path since we're copying them + bb = new CompositeByteBuf(UnpooledByteBufAllocator.DEFAULT, /* direct */ true, maxNumComponents, byteBufs); + } + return new DrainableByteBufInputStream(bb, tryZeroCopyWrite); + } catch (Exception ex) { + throw new RuntimeException("Unexpected IO Exception", ex); + } + + } + + /** + * ARROW-11066: enable the zero-copy optimization and protect against use-after-free. + * + * When you send a message through gRPC, the following happens: + * 1. gRPC immediately serializes the message, eventually calling asInputStream above. + * 2. gRPC buffers the serialized message for sending. + * 3. Later, gRPC will actually write out the message. + * + * The problem with this is that when the zero-copy optimization is enabled, Flight + * "serializes" the message by handing gRPC references to Arrow data. That means we need + * a way to keep the Arrow buffers valid until gRPC actually writes them, else, we'll read + * invalid data or segfault. gRPC doesn't know anything about Arrow buffers, either. + * + * This class solves that issue by bridging Arrow and Netty/gRPC. We increment the refcnt + * on a set of Arrow backing buffers and decrement them once the Netty buffers are freed + * by gRPC. + */ + private static final class ArrowBufRetainingCompositeByteBuf extends CompositeByteBuf { + // Arrow buffers that back the Netty ByteBufs here; ByteBufs held by this class are + // either slices of one of the ArrowBufs or independently allocated. + final List<ArrowBuf> backingBuffers; + boolean freed; + + ArrowBufRetainingCompositeByteBuf(int maxNumComponents, Iterable<ByteBuf> buffers, List<ArrowBuf> backingBuffers) { + super(UnpooledByteBufAllocator.DEFAULT, /* direct */ true, maxNumComponents, buffers); + this.backingBuffers = backingBuffers; + this.freed = false; + // N.B. the Netty superclass avoids enhanced-for to reduce GC pressure, so follow that here + for (int i = 0; i < backingBuffers.size(); i++) { + backingBuffers.get(i).getReferenceManager().retain(); + } + } + + @Override + protected void deallocate() { + super.deallocate(); + if (freed) { + return; + } + freed = true; + for (int i = 0; i < backingBuffers.size(); i++) { + backingBuffers.get(i).getReferenceManager().release(); + } + } + } + + private static class DrainableByteBufInputStream extends ByteBufInputStream implements Drainable { + + private final CompositeByteBuf buf; + private final boolean isZeroCopy; + + public DrainableByteBufInputStream(CompositeByteBuf buffer, boolean isZeroCopy) { + super(buffer, buffer.readableBytes(), true); + this.buf = buffer; + this.isZeroCopy = isZeroCopy; + } + + @Override + public int drainTo(OutputStream target) throws IOException { + int size = buf.readableBytes(); + AddWritableBuffer.add(buf, target, isZeroCopy); + return size; + } + + @Override + public void close() { + buf.release(); + } + + + + } + + public static Marshaller<ArrowMessage> createMarshaller(BufferAllocator allocator) { + return new ArrowMessageHolderMarshaller(allocator); + } + + private static class ArrowMessageHolderMarshaller implements Marshaller<ArrowMessage> { + + private final BufferAllocator allocator; + + public ArrowMessageHolderMarshaller(BufferAllocator allocator) { + this.allocator = allocator; + } + + @Override + public InputStream stream(ArrowMessage value) { + return value.asInputStream(allocator); + } + + @Override + public ArrowMessage parse(InputStream stream) { + return ArrowMessage.frame(allocator, stream); + } + + } + + @Override + public void close() throws Exception { + AutoCloseables.close(Iterables.concat(bufs, Collections.singletonList(appMetadata))); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/AsyncPutListener.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/AsyncPutListener.java new file mode 100644 index 000000000..a45463225 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/AsyncPutListener.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.apache.arrow.flight.grpc.StatusUtils; + +/** + * A handler for server-sent application metadata messages during a Flight DoPut operation. + * + * <p>To handle messages, create an instance of this class overriding {@link #onNext(PutResult)}. The other methods + * should not be overridden. + */ +public class AsyncPutListener implements FlightClient.PutListener { + + private CompletableFuture<Void> completed; + + public AsyncPutListener() { + completed = new CompletableFuture<>(); + } + + /** + * Wait for the stream to finish on the server side. You must call this to be notified of any errors that may have + * happened during the upload. + */ + @Override + public final void getResult() { + try { + completed.get(); + } catch (ExecutionException e) { + throw StatusUtils.fromThrowable(e.getCause()); + } catch (InterruptedException e) { + throw StatusUtils.fromThrowable(e); + } + } + + @Override + public void onNext(PutResult val) { + } + + @Override + public final void onError(Throwable t) { + completed.completeExceptionally(StatusUtils.fromThrowable(t)); + } + + @Override + public final void onCompleted() { + completed.complete(null); + } + + @Override + public boolean isCancelled() { + return completed.isDone(); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/BackpressureStrategy.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/BackpressureStrategy.java new file mode 100644 index 000000000..de34643a7 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/BackpressureStrategy.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import org.apache.arrow.vector.VectorSchemaRoot; + +import com.google.common.base.Preconditions; + +/** + * Helper interface to dynamically handle backpressure when implementing FlightProducers. + * This must only be used in FlightProducer implementations that are non-blocking. + */ +public interface BackpressureStrategy { + /** + * The state of the client after a call to waitForListener. + */ + enum WaitResult { + /** + * Listener is ready. + */ + READY, + + /** + * Listener was cancelled by the client. + */ + CANCELLED, + + /** + * Timed out waiting for the listener to change state. + */ + TIMEOUT, + + /** + * Indicates that the wait was interrupted for a reason + * unrelated to the listener itself. + */ + OTHER + } + + /** + * Set up operations to work against the given listener. + * + * This must be called exactly once and before any calls to {@link #waitForListener(long)} and + * {@link OutboundStreamListener#start(VectorSchemaRoot)} + * @param listener The listener this strategy applies to. + */ + void register(FlightProducer.ServerStreamListener listener); + + /** + * Waits for the listener to be ready or cancelled up to the given timeout. + * + * @param timeout The timeout in milliseconds. Infinite if timeout is <= 0. + * @return The result of the wait. + */ + WaitResult waitForListener(long timeout); + + /** + * A back pressure strategy that uses callbacks to notify when the client is ready or cancelled. + */ + class CallbackBackpressureStrategy implements BackpressureStrategy { + private final Object lock = new Object(); + private FlightProducer.ServerStreamListener listener; + + @Override + public void register(FlightProducer.ServerStreamListener listener) { + this.listener = listener; + listener.setOnReadyHandler(this::onReady); + listener.setOnCancelHandler(this::onCancel); + } + + @Override + public WaitResult waitForListener(long timeout) { + Preconditions.checkNotNull(listener); + long remainingTimeout = timeout; + final long startTime = System.currentTimeMillis(); + synchronized (lock) { + while (!listener.isReady() && !listener.isCancelled()) { + try { + lock.wait(remainingTimeout); + if (timeout != 0) { // If timeout was zero explicitly, we should never report timeout. + remainingTimeout = startTime + timeout - System.currentTimeMillis(); + if (remainingTimeout <= 0) { + return WaitResult.TIMEOUT; + } + } + if (!shouldContinueWaiting(listener, remainingTimeout)) { + return WaitResult.OTHER; + } + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + return WaitResult.OTHER; + } + } + + if (listener.isReady()) { + return WaitResult.READY; + } else if (listener.isCancelled()) { + return WaitResult.CANCELLED; + } else if (System.currentTimeMillis() > startTime + timeout) { + return WaitResult.TIMEOUT; + } + throw new RuntimeException("Invalid state when waiting for listener."); + } + } + + /** + * Interrupt waiting on the listener to change state. + * + * This method can be used in conjunction with + * {@link #shouldContinueWaiting(FlightProducer.ServerStreamListener, long)} to allow FlightProducers to + * terminate streams internally and notify clients. + */ + public void interruptWait() { + synchronized (lock) { + lock.notifyAll(); + } + } + + /** + * Callback function to run to check if the listener should continue + * to be waited on if it leaves the waiting state without being cancelled, + * ready, or timed out. + * + * This method should be used to determine if the wait on the listener was interrupted explicitly using a + * call to {@link #interruptWait()} or if it was woken up due to a spurious wake. + */ + protected boolean shouldContinueWaiting(FlightProducer.ServerStreamListener listener, long remainingTimeout) { + return true; + } + + /** + * Callback to execute when the listener becomes ready. + */ + protected void readyCallback() { + } + + /** + * Callback to execute when the listener is cancelled. + */ + protected void cancelCallback() { + } + + private void onReady() { + synchronized (lock) { + readyCallback(); + lock.notifyAll(); + } + } + + private void onCancel() { + synchronized (lock) { + cancelCallback(); + lock.notifyAll(); + } + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallHeaders.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallHeaders.java new file mode 100644 index 000000000..32f9a8430 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallHeaders.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.Set; + +/** + * A set of metadata key value pairs for a call (request or response). + */ +public interface CallHeaders { + /** + * Get the value of a metadata key. If multiple values are present, then get the last one. + */ + String get(String key); + + /** + * Get the value of a metadata key. If multiple values are present, then get the last one. + */ + byte[] getByte(String key); + + /** + * Get all values present for the given metadata key. + */ + Iterable<String> getAll(String key); + + /** + * Get all values present for the given metadata key. + */ + Iterable<byte[]> getAllByte(String key); + + /** + * Insert a metadata pair with the given value. + * + * <p>Duplicate metadata are permitted. + */ + void insert(String key, String value); + + /** + * Insert a metadata pair with the given value. + * + * <p>Duplicate metadata are permitted. + */ + void insert(String key, byte[] value); + + /** Get a set of all the metadata keys. */ + Set<String> keys(); + + /** Check whether the given metadata key is present. */ + boolean containsKey(String key); +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallInfo.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallInfo.java new file mode 100644 index 000000000..744584bdf --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallInfo.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +/** + * A description of a Flight call for middleware to inspect. + */ +public final class CallInfo { + private final FlightMethod method; + + public CallInfo(FlightMethod method) { + this.method = method; + } + + public FlightMethod method() { + return method; + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallOption.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallOption.java new file mode 100644 index 000000000..d3ee3ab4c --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallOption.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +/** + * Per-call RPC options. These are hints to the underlying RPC layer and may not be respected. + */ +public interface CallOption { +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallOptions.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallOptions.java new file mode 100644 index 000000000..bbb4edef9 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallOptions.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.concurrent.TimeUnit; + +import io.grpc.stub.AbstractStub; + +/** + * Common call options. + */ +public class CallOptions { + public static CallOption timeout(long duration, TimeUnit unit) { + return new Timeout(duration, unit); + } + + static <T extends AbstractStub<T>> T wrapStub(T stub, CallOption[] options) { + for (CallOption option : options) { + if (option instanceof GrpcCallOption) { + stub = ((GrpcCallOption) option).wrapStub(stub); + } + } + return stub; + } + + private static class Timeout implements GrpcCallOption { + long timeout; + TimeUnit timeoutUnit; + + Timeout(long timeout, TimeUnit timeoutUnit) { + this.timeout = timeout; + this.timeoutUnit = timeoutUnit; + } + + @Override + public <T extends AbstractStub<T>> T wrapStub(T stub) { + return stub.withDeadlineAfter(timeout, timeoutUnit); + } + } + + /** + * CallOptions specific to GRPC stubs. + */ + public interface GrpcCallOption extends CallOption { + <T extends AbstractStub<T>> T wrapStub(T stub); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallStatus.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallStatus.java new file mode 100644 index 000000000..991d0ed6a --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/CallStatus.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.Objects; + +import org.apache.arrow.flight.FlightProducer.ServerStreamListener; +import org.apache.arrow.flight.FlightProducer.StreamListener; + +/** + * The result of a Flight RPC, consisting of a status code with an optional description and/or exception that led + * to the status. + * + * <p>If raised or sent through {@link StreamListener#onError(Throwable)} or + * {@link ServerStreamListener#error(Throwable)}, the client call will raise the same error (a + * {@link FlightRuntimeException} with the same {@link FlightStatusCode} and description). The exception within, if + * present, will not be sent to the client. + */ +public class CallStatus { + + private final FlightStatusCode code; + private final Throwable cause; + private final String description; + private final ErrorFlightMetadata metadata; + + public static final CallStatus UNKNOWN = FlightStatusCode.UNKNOWN.toStatus(); + public static final CallStatus INTERNAL = FlightStatusCode.INTERNAL.toStatus(); + public static final CallStatus INVALID_ARGUMENT = FlightStatusCode.INVALID_ARGUMENT.toStatus(); + public static final CallStatus TIMED_OUT = FlightStatusCode.TIMED_OUT.toStatus(); + public static final CallStatus NOT_FOUND = FlightStatusCode.NOT_FOUND.toStatus(); + public static final CallStatus ALREADY_EXISTS = FlightStatusCode.ALREADY_EXISTS.toStatus(); + public static final CallStatus CANCELLED = FlightStatusCode.CANCELLED.toStatus(); + public static final CallStatus UNAUTHENTICATED = FlightStatusCode.UNAUTHENTICATED.toStatus(); + public static final CallStatus UNAUTHORIZED = FlightStatusCode.UNAUTHORIZED.toStatus(); + public static final CallStatus UNIMPLEMENTED = FlightStatusCode.UNIMPLEMENTED.toStatus(); + public static final CallStatus UNAVAILABLE = FlightStatusCode.UNAVAILABLE.toStatus(); + + /** + * Create a new status. + * + * @param code The status code. + * @param cause An exception that resulted in this status (or null). + * @param description A description of the status (or null). + */ + public CallStatus(FlightStatusCode code, Throwable cause, String description, ErrorFlightMetadata metadata) { + this.code = Objects.requireNonNull(code); + this.cause = cause; + this.description = description == null ? "" : description; + this.metadata = metadata == null ? new ErrorFlightMetadata() : metadata; + } + + /** + * Create a new status with no cause or description. + * + * @param code The status code. + */ + public CallStatus(FlightStatusCode code) { + this(code, /* no cause */ null, /* no description */ null, /* no metadata */ null); + } + + /** + * The status code describing the result of the RPC. + */ + public FlightStatusCode code() { + return code; + } + + /** + * The exception that led to this result. May be null. + */ + public Throwable cause() { + return cause; + } + + /** + * A description of the result. + */ + public String description() { + return description; + } + + /** + * Metadata associated with the exception. + * + * May be null. + */ + public ErrorFlightMetadata metadata() { + return metadata; + } + + /** + * Return a copy of this status with an error message. + */ + public CallStatus withDescription(String message) { + return new CallStatus(code, cause, message, metadata); + } + + /** + * Return a copy of this status with the given exception as the cause. This will not be sent over the wire. + */ + public CallStatus withCause(Throwable t) { + return new CallStatus(code, t, description, metadata); + } + + /** + * Return a copy of this status with associated exception metadata. + */ + public CallStatus withMetadata(ErrorFlightMetadata metadata) { + return new CallStatus(code, cause, description, metadata); + } + + /** + * Convert the status to an equivalent exception. + */ + public FlightRuntimeException toRuntimeException() { + return new FlightRuntimeException(this); + } + + @Override + public String toString() { + return "CallStatus{" + + "code=" + code + + ", cause=" + cause + + ", description='" + description + + "', metadata='" + metadata + '\'' + + '}'; + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Criteria.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Criteria.java new file mode 100644 index 000000000..989cd6581 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Criteria.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import org.apache.arrow.flight.impl.Flight; + +import com.google.protobuf.ByteString; + +/** + * An opaque object that can be used to filter a list of streams available from a server. + * + * <p>This is a POJO wrapper around the protobuf Criteria message. + */ +public class Criteria { + + public static Criteria ALL = new Criteria((byte[]) null); + + private final byte[] bytes; + + public Criteria(byte[] bytes) { + this.bytes = bytes; + } + + Criteria(Flight.Criteria criteria) { + this.bytes = criteria.getExpression().toByteArray(); + } + + /** + * Get the contained filter criteria. + */ + public byte[] getExpression() { + return bytes; + } + + Flight.Criteria asCriteria() { + Flight.Criteria.Builder b = Flight.Criteria.newBuilder(); + if (bytes != null) { + b.setExpression(ByteString.copyFrom(bytes)); + } + + return b.build(); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java new file mode 100644 index 000000000..516dab01d --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/DictionaryUtils.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.function.Consumer; +import java.util.stream.Collectors; + +import org.apache.arrow.flight.impl.Flight; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.DictionaryUtility; +import org.apache.arrow.vector.validate.MetadataV4UnionChecker; + +/** + * Utilities to work with dictionaries in Flight. + */ +final class DictionaryUtils { + + private DictionaryUtils() { + throw new UnsupportedOperationException("Do not instantiate this class."); + } + + /** + * Generate all the necessary Flight messages to send a schema and associated dictionaries. + * + * @throws Exception if there was an error closing {@link ArrowMessage} objects. This is not generally expected. + */ + static Schema generateSchemaMessages(final Schema originalSchema, final FlightDescriptor descriptor, + final DictionaryProvider provider, final IpcOption option, + final Consumer<ArrowMessage> messageCallback) throws Exception { + final Set<Long> dictionaryIds = new HashSet<>(); + final Schema schema = generateSchema(originalSchema, provider, dictionaryIds); + MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option.metadataVersion); + // Send the schema message + final Flight.FlightDescriptor protoDescriptor = descriptor == null ? null : descriptor.toProtocol(); + try (final ArrowMessage message = new ArrowMessage(protoDescriptor, schema, option)) { + messageCallback.accept(message); + } + // Create and write dictionary batches + for (Long id : dictionaryIds) { + final Dictionary dictionary = provider.lookup(id); + final FieldVector vector = dictionary.getVector(); + final int count = vector.getValueCount(); + // Do NOT close this root, as it does not actually own the vector. + final VectorSchemaRoot dictRoot = new VectorSchemaRoot( + Collections.singletonList(vector.getField()), + Collections.singletonList(vector), + count); + final VectorUnloader unloader = new VectorUnloader(dictRoot); + try (final ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch( + id, unloader.getRecordBatch()); + final ArrowMessage message = new ArrowMessage(dictionaryBatch, option)) { + messageCallback.accept(message); + } + } + return schema; + } + + static void closeDictionaries(final Schema schema, final DictionaryProvider provider) throws Exception { + // Close dictionaries + final Set<Long> dictionaryIds = new HashSet<>(); + schema.getFields().forEach(field -> DictionaryUtility.toMessageFormat(field, provider, dictionaryIds)); + + final List<AutoCloseable> dictionaryVectors = dictionaryIds.stream() + .map(id -> (AutoCloseable) provider.lookup(id).getVector()).collect(Collectors.toList()); + AutoCloseables.close(dictionaryVectors); + } + + /** + * Generates the schema to send with flight messages. + * If the schema contains no field with a dictionary, it will return the schema as is. + * Otherwise, it will return a newly created a new schema after converting the fields. + * @param originalSchema the original schema. + * @param provider the dictionary provider. + * @param dictionaryIds dictionary IDs that are used. + * @return the schema to send with the flight messages. + */ + static Schema generateSchema( + final Schema originalSchema, final DictionaryProvider provider, Set<Long> dictionaryIds) { + // first determine if a new schema needs to be created. + boolean createSchema = false; + for (Field field : originalSchema.getFields()) { + if (DictionaryUtility.needConvertToMessageFormat(field)) { + createSchema = true; + break; + } + } + + if (!createSchema) { + return originalSchema; + } else { + final List<Field> fields = new ArrayList<>(originalSchema.getFields().size()); + for (final Field field : originalSchema.getFields()) { + fields.add(DictionaryUtility.toMessageFormat(field, provider, dictionaryIds)); + } + return new Schema(fields, originalSchema.getCustomMetadata()); + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ErrorFlightMetadata.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ErrorFlightMetadata.java new file mode 100644 index 000000000..6669ce465 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ErrorFlightMetadata.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.nio.charset.StandardCharsets; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import com.google.common.collect.Iterables; +import com.google.common.collect.LinkedListMultimap; +import com.google.common.collect.Multimap; + +/** + * metadata container specific to the binary metadata held in the grpc trailer. + */ +public class ErrorFlightMetadata implements CallHeaders { + private final Multimap<String, byte[]> metadata = LinkedListMultimap.create(); + + public ErrorFlightMetadata() { + } + + + @Override + public String get(String key) { + return new String(getByte(key), StandardCharsets.US_ASCII); + } + + @Override + public byte[] getByte(String key) { + return Iterables.getLast(metadata.get(key)); + } + + @Override + public Iterable<String> getAll(String key) { + return StreamSupport.stream( + getAllByte(key).spliterator(), false) + .map(b -> new String(b, StandardCharsets.US_ASCII)) + .collect(Collectors.toList()); + } + + @Override + public Iterable<byte[]> getAllByte(String key) { + return metadata.get(key); + } + + @Override + public void insert(String key, String value) { + metadata.put(key, value.getBytes()); + } + + @Override + public void insert(String key, byte[] value) { + metadata.put(key, value); + } + + @Override + public Set<String> keys() { + return metadata.keySet(); + } + + @Override + public boolean containsKey(String key) { + return metadata.containsKey(key); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightBindingService.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightBindingService.java new file mode 100644 index 000000000..ba5249b4a --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightBindingService.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.Set; +import java.util.concurrent.ExecutorService; + +import org.apache.arrow.flight.auth.ServerAuthHandler; +import org.apache.arrow.flight.impl.Flight; +import org.apache.arrow.flight.impl.Flight.PutResult; +import org.apache.arrow.flight.impl.FlightServiceGrpc; +import org.apache.arrow.memory.BufferAllocator; + +import com.google.common.collect.ImmutableSet; + +import io.grpc.BindableService; +import io.grpc.MethodDescriptor; +import io.grpc.MethodDescriptor.MethodType; +import io.grpc.ServerMethodDefinition; +import io.grpc.ServerServiceDefinition; +import io.grpc.ServiceDescriptor; +import io.grpc.protobuf.ProtoUtils; +import io.grpc.stub.ServerCalls; +import io.grpc.stub.StreamObserver; + +/** + * Extends the basic flight service to override some methods for more efficient implementations. + */ +class FlightBindingService implements BindableService { + + private static final String DO_GET = MethodDescriptor.generateFullMethodName(FlightConstants.SERVICE, "DoGet"); + private static final String DO_PUT = MethodDescriptor.generateFullMethodName(FlightConstants.SERVICE, "DoPut"); + private static final String DO_EXCHANGE = MethodDescriptor.generateFullMethodName( + FlightConstants.SERVICE, "DoExchange"); + private static final Set<String> OVERRIDE_METHODS = ImmutableSet.of(DO_GET, DO_PUT, DO_EXCHANGE); + + private final FlightService delegate; + private final BufferAllocator allocator; + + public FlightBindingService(BufferAllocator allocator, FlightProducer producer, + ServerAuthHandler authHandler, ExecutorService executor) { + this.allocator = allocator; + this.delegate = new FlightService(allocator, producer, authHandler, executor); + } + + public static MethodDescriptor<Flight.Ticket, ArrowMessage> getDoGetDescriptor(BufferAllocator allocator) { + return MethodDescriptor.<Flight.Ticket, ArrowMessage>newBuilder() + .setType(io.grpc.MethodDescriptor.MethodType.SERVER_STREAMING) + .setFullMethodName(DO_GET) + .setSampledToLocalTracing(false) + .setRequestMarshaller(ProtoUtils.marshaller(Flight.Ticket.getDefaultInstance())) + .setResponseMarshaller(ArrowMessage.createMarshaller(allocator)) + .setSchemaDescriptor(FlightServiceGrpc.getDoGetMethod().getSchemaDescriptor()) + .build(); + } + + public static MethodDescriptor<ArrowMessage, Flight.PutResult> getDoPutDescriptor(BufferAllocator allocator) { + return MethodDescriptor.<ArrowMessage, Flight.PutResult>newBuilder() + .setType(MethodType.BIDI_STREAMING) + .setFullMethodName(DO_PUT) + .setSampledToLocalTracing(false) + .setRequestMarshaller(ArrowMessage.createMarshaller(allocator)) + .setResponseMarshaller(ProtoUtils.marshaller(Flight.PutResult.getDefaultInstance())) + .setSchemaDescriptor(FlightServiceGrpc.getDoPutMethod().getSchemaDescriptor()) + .build(); + } + + public static MethodDescriptor<ArrowMessage, ArrowMessage> getDoExchangeDescriptor(BufferAllocator allocator) { + return MethodDescriptor.<ArrowMessage, ArrowMessage>newBuilder() + .setType(MethodType.BIDI_STREAMING) + .setFullMethodName(DO_EXCHANGE) + .setSampledToLocalTracing(false) + .setRequestMarshaller(ArrowMessage.createMarshaller(allocator)) + .setResponseMarshaller(ArrowMessage.createMarshaller(allocator)) + .setSchemaDescriptor(FlightServiceGrpc.getDoExchangeMethod().getSchemaDescriptor()) + .build(); + } + + @Override + public ServerServiceDefinition bindService() { + final ServerServiceDefinition baseDefinition = delegate.bindService(); + + final MethodDescriptor<Flight.Ticket, ArrowMessage> doGetDescriptor = getDoGetDescriptor(allocator); + final MethodDescriptor<ArrowMessage, Flight.PutResult> doPutDescriptor = getDoPutDescriptor(allocator); + final MethodDescriptor<ArrowMessage, ArrowMessage> doExchangeDescriptor = getDoExchangeDescriptor(allocator); + + // Make sure we preserve SchemaDescriptor fields on methods so that gRPC reflection still works. + final ServiceDescriptor.Builder serviceDescriptorBuilder = ServiceDescriptor.newBuilder(FlightConstants.SERVICE) + .setSchemaDescriptor(baseDefinition.getServiceDescriptor().getSchemaDescriptor()); + serviceDescriptorBuilder.addMethod(doGetDescriptor); + serviceDescriptorBuilder.addMethod(doPutDescriptor); + serviceDescriptorBuilder.addMethod(doExchangeDescriptor); + for (MethodDescriptor<?, ?> definition : baseDefinition.getServiceDescriptor().getMethods()) { + if (OVERRIDE_METHODS.contains(definition.getFullMethodName())) { + continue; + } + + serviceDescriptorBuilder.addMethod(definition); + } + + final ServiceDescriptor serviceDescriptor = serviceDescriptorBuilder.build(); + ServerServiceDefinition.Builder serviceBuilder = ServerServiceDefinition.builder(serviceDescriptor); + serviceBuilder.addMethod(doGetDescriptor, ServerCalls.asyncServerStreamingCall(new DoGetMethod(delegate))); + serviceBuilder.addMethod(doPutDescriptor, ServerCalls.asyncBidiStreamingCall(new DoPutMethod(delegate))); + serviceBuilder.addMethod(doExchangeDescriptor, ServerCalls.asyncBidiStreamingCall(new DoExchangeMethod(delegate))); + + // copy over not-overridden methods. + for (ServerMethodDefinition<?, ?> definition : baseDefinition.getMethods()) { + if (OVERRIDE_METHODS.contains(definition.getMethodDescriptor().getFullMethodName())) { + continue; + } + + serviceBuilder.addMethod(definition); + } + + return serviceBuilder.build(); + } + + private static class DoGetMethod implements ServerCalls.ServerStreamingMethod<Flight.Ticket, ArrowMessage> { + + private final FlightService delegate; + + public DoGetMethod(FlightService delegate) { + this.delegate = delegate; + } + + @Override + public void invoke(Flight.Ticket request, StreamObserver<ArrowMessage> responseObserver) { + delegate.doGetCustom(request, responseObserver); + } + } + + private static class DoPutMethod implements ServerCalls.BidiStreamingMethod<ArrowMessage, PutResult> { + private final FlightService delegate; + + public DoPutMethod(FlightService delegate) { + this.delegate = delegate; + } + + @Override + public StreamObserver<ArrowMessage> invoke(StreamObserver<PutResult> responseObserver) { + return delegate.doPutCustom(responseObserver); + } + } + + private static class DoExchangeMethod implements ServerCalls.BidiStreamingMethod<ArrowMessage, ArrowMessage> { + private final FlightService delegate; + + public DoExchangeMethod(FlightService delegate) { + this.delegate = delegate; + } + + @Override + public StreamObserver<ArrowMessage> invoke(StreamObserver<ArrowMessage> responseObserver) { + return delegate.doExchangeCustom(responseObserver); + } + } + +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightCallHeaders.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightCallHeaders.java new file mode 100644 index 000000000..dd26d1908 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightCallHeaders.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.Collection; +import java.util.Set; +import java.util.stream.Collectors; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.Iterables; +import com.google.common.collect.Multimap; + +import io.grpc.Metadata; + +/** + * An implementation of the Flight headers interface for headers. + */ +public class FlightCallHeaders implements CallHeaders { + private final Multimap<String, Object> keysAndValues; + + public FlightCallHeaders() { + this.keysAndValues = ArrayListMultimap.create(); + } + + @Override + public String get(String key) { + final Collection<Object> values = this.keysAndValues.get(key); + if (values.isEmpty()) { + return null; + } + + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + return new String((byte[]) Iterables.get(values, 0)); + } + + return (String) Iterables.get(values, 0); + } + + @Override + public byte[] getByte(String key) { + final Collection<Object> values = this.keysAndValues.get(key); + if (values.isEmpty()) { + return null; + } + + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + return (byte[]) Iterables.get(values, 0); + } + + return ((String) Iterables.get(values, 0)).getBytes(); + } + + @Override + public Iterable<String> getAll(String key) { + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + return this.keysAndValues.get(key).stream().map(o -> new String((byte[]) o)).collect(Collectors.toList()); + } + return (Collection<String>) (Collection<?>) this.keysAndValues.get(key); + } + + @Override + public Iterable<byte[]> getAllByte(String key) { + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + return (Collection<byte[]>) (Collection<?>) this.keysAndValues.get(key); + } + return this.keysAndValues.get(key).stream().map(o -> ((String) o).getBytes()).collect(Collectors.toList()); + } + + @Override + public void insert(String key, String value) { + this.keysAndValues.put(key, value); + } + + @Override + public void insert(String key, byte[] value) { + Preconditions.checkArgument(key.endsWith("-bin"), "Binary header is named %s. It must end with %s", key, "-bin"); + Preconditions.checkArgument(key.length() > "-bin".length(), "empty key name"); + + this.keysAndValues.put(key, value); + } + + @Override + public Set<String> keys() { + return this.keysAndValues.keySet(); + } + + @Override + public boolean containsKey(String key) { + return this.keysAndValues.containsKey(key); + } + + public String toString() { + return this.keysAndValues.toString(); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java new file mode 100644 index 000000000..762b37859 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -0,0 +1,721 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.InputStream; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.function.BooleanSupplier; + +import javax.net.ssl.SSLException; + +import org.apache.arrow.flight.FlightProducer.StreamListener; +import org.apache.arrow.flight.auth.BasicClientAuthHandler; +import org.apache.arrow.flight.auth.ClientAuthHandler; +import org.apache.arrow.flight.auth.ClientAuthInterceptor; +import org.apache.arrow.flight.auth.ClientAuthWrapper; +import org.apache.arrow.flight.auth2.BasicAuthCredentialWriter; +import org.apache.arrow.flight.auth2.ClientBearerHeaderHandler; +import org.apache.arrow.flight.auth2.ClientHandshakeWrapper; +import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware; +import org.apache.arrow.flight.grpc.ClientInterceptorAdapter; +import org.apache.arrow.flight.grpc.CredentialCallOption; +import org.apache.arrow.flight.grpc.StatusUtils; +import org.apache.arrow.flight.impl.Flight; +import org.apache.arrow.flight.impl.Flight.Empty; +import org.apache.arrow.flight.impl.FlightServiceGrpc; +import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceBlockingStub; +import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceStub; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider; + +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ClientInterceptors; +import io.grpc.ManagedChannel; +import io.grpc.MethodDescriptor; +import io.grpc.StatusRuntimeException; +import io.grpc.netty.GrpcSslContexts; +import io.grpc.netty.NettyChannelBuilder; +import io.grpc.stub.ClientCallStreamObserver; +import io.grpc.stub.ClientCalls; +import io.grpc.stub.ClientResponseObserver; +import io.grpc.stub.StreamObserver; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.ServerChannel; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; + +/** + * Client for Flight services. + */ +public class FlightClient implements AutoCloseable { + private static final int PENDING_REQUESTS = 5; + /** The maximum number of trace events to keep on the gRPC Channel. This value disables channel tracing. */ + private static final int MAX_CHANNEL_TRACE_EVENTS = 0; + private final BufferAllocator allocator; + private final ManagedChannel channel; + private final Channel interceptedChannel; + private final FlightServiceBlockingStub blockingStub; + private final FlightServiceStub asyncStub; + private final ClientAuthInterceptor authInterceptor = new ClientAuthInterceptor(); + private final MethodDescriptor<Flight.Ticket, ArrowMessage> doGetDescriptor; + private final MethodDescriptor<ArrowMessage, Flight.PutResult> doPutDescriptor; + private final MethodDescriptor<ArrowMessage, ArrowMessage> doExchangeDescriptor; + private final List<FlightClientMiddleware.Factory> middleware; + + /** + * Create a Flight client from an allocator and a gRPC channel. + */ + FlightClient(BufferAllocator incomingAllocator, ManagedChannel channel, + List<FlightClientMiddleware.Factory> middleware) { + this.allocator = incomingAllocator.newChildAllocator("flight-client", 0, Long.MAX_VALUE); + this.channel = channel; + this.middleware = middleware; + + final ClientInterceptor[] interceptors; + interceptors = new ClientInterceptor[]{authInterceptor, new ClientInterceptorAdapter(middleware)}; + + // Create a channel with interceptors pre-applied for DoGet and DoPut + this.interceptedChannel = ClientInterceptors.intercept(channel, interceptors); + + blockingStub = FlightServiceGrpc.newBlockingStub(interceptedChannel); + asyncStub = FlightServiceGrpc.newStub(interceptedChannel); + doGetDescriptor = FlightBindingService.getDoGetDescriptor(allocator); + doPutDescriptor = FlightBindingService.getDoPutDescriptor(allocator); + doExchangeDescriptor = FlightBindingService.getDoExchangeDescriptor(allocator); + } + + /** + * Get a list of available flights. + * + * @param criteria Criteria for selecting flights + * @param options RPC-layer hints for the call. + * @return FlightInfo Iterable + */ + public Iterable<FlightInfo> listFlights(Criteria criteria, CallOption... options) { + final Iterator<Flight.FlightInfo> flights; + try { + flights = CallOptions.wrapStub(blockingStub, options) + .listFlights(criteria.asCriteria()); + } catch (StatusRuntimeException sre) { + throw StatusUtils.fromGrpcRuntimeException(sre); + } + return () -> StatusUtils.wrapIterator(flights, t -> { + try { + return new FlightInfo(t); + } catch (URISyntaxException e) { + // We don't expect this will happen for conforming Flight implementations. For instance, a Java server + // itself wouldn't be able to construct an invalid Location. + throw new RuntimeException(e); + } + }); + } + + /** + * Lists actions available on the Flight service. + * + * @param options RPC-layer hints for the call. + */ + public Iterable<ActionType> listActions(CallOption... options) { + final Iterator<Flight.ActionType> actions; + try { + actions = CallOptions.wrapStub(blockingStub, options) + .listActions(Empty.getDefaultInstance()); + } catch (StatusRuntimeException sre) { + throw StatusUtils.fromGrpcRuntimeException(sre); + } + return () -> StatusUtils.wrapIterator(actions, ActionType::new); + } + + /** + * Performs an action on the Flight service. + * + * @param action The action to perform. + * @param options RPC-layer hints for this call. + * @return An iterator of results. + */ + public Iterator<Result> doAction(Action action, CallOption... options) { + return StatusUtils + .wrapIterator(CallOptions.wrapStub(blockingStub, options).doAction(action.toProtocol()), Result::new); + } + + /** + * Authenticates with a username and password. + */ + public void authenticateBasic(String username, String password) { + BasicClientAuthHandler basicClient = new BasicClientAuthHandler(username, password); + authenticate(basicClient); + } + + /** + * Authenticates against the Flight service. + * + * @param options RPC-layer hints for this call. + * @param handler The auth mechanism to use. + */ + public void authenticate(ClientAuthHandler handler, CallOption... options) { + Preconditions.checkArgument(!authInterceptor.hasAuthHandler(), "Auth already completed."); + ClientAuthWrapper.doClientAuth(handler, CallOptions.wrapStub(asyncStub, options)); + authInterceptor.setAuthHandler(handler); + } + + /** + * Authenticates with a username and password. + * + * @param username the username. + * @param password the password. + * @return a CredentialCallOption containing a bearer token if the server emitted one, or + * empty if no bearer token was returned. This can be used in subsequent API calls. + */ + public Optional<CredentialCallOption> authenticateBasicToken(String username, String password) { + final ClientIncomingAuthHeaderMiddleware.Factory clientAuthMiddleware = + new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler()); + middleware.add(clientAuthMiddleware); + handshake(new CredentialCallOption(new BasicAuthCredentialWriter(username, password))); + + return Optional.ofNullable(clientAuthMiddleware.getCredentialCallOption()); + } + + /** + * Executes the handshake against the Flight service. + * + * @param options RPC-layer hints for this call. + */ + public void handshake(CallOption... options) { + ClientHandshakeWrapper.doClientHandshake(CallOptions.wrapStub(asyncStub, options)); + } + + /** + * Create or append a descriptor with another stream. + * + * @param descriptor FlightDescriptor the descriptor for the data + * @param root VectorSchemaRoot the root containing data + * @param metadataListener A handler for metadata messages from the server. This will be passed buffers that will be + * freed after {@link StreamListener#onNext(Object)} is called! + * @param options RPC-layer hints for this call. + * @return ClientStreamListener an interface to control uploading data + */ + public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRoot root, + PutListener metadataListener, CallOption... options) { + return startPut(descriptor, root, new MapDictionaryProvider(), metadataListener, options); + } + + /** + * Create or append a descriptor with another stream. + * @param descriptor FlightDescriptor the descriptor for the data + * @param root VectorSchemaRoot the root containing data + * @param metadataListener A handler for metadata messages from the server. + * @param options RPC-layer hints for this call. + * @return ClientStreamListener an interface to control uploading data. + * {@link ClientStreamListener#start(VectorSchemaRoot, DictionaryProvider)} will already have been called. + */ + public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRoot root, DictionaryProvider provider, + PutListener metadataListener, CallOption... options) { + Preconditions.checkNotNull(root, "root must not be null"); + Preconditions.checkNotNull(provider, "provider must not be null"); + final ClientStreamListener writer = startPut(descriptor, metadataListener, options); + writer.start(root, provider); + return writer; + } + + /** + * Create or append a descriptor with another stream. + * @param descriptor FlightDescriptor the descriptor for the data + * @param metadataListener A handler for metadata messages from the server. + * @param options RPC-layer hints for this call. + * @return ClientStreamListener an interface to control uploading data. + * {@link ClientStreamListener#start(VectorSchemaRoot, DictionaryProvider)} will NOT already have been called. + */ + public ClientStreamListener startPut(FlightDescriptor descriptor, PutListener metadataListener, + CallOption... options) { + Preconditions.checkNotNull(descriptor, "descriptor must not be null"); + Preconditions.checkNotNull(metadataListener, "metadataListener must not be null"); + final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions(); + + try { + final SetStreamObserver resultObserver = new SetStreamObserver(allocator, metadataListener); + ClientCallStreamObserver<ArrowMessage> observer = (ClientCallStreamObserver<ArrowMessage>) + ClientCalls.asyncBidiStreamingCall( + interceptedChannel.newCall(doPutDescriptor, callOptions), resultObserver); + return new PutObserver( + descriptor, observer, metadataListener::isCancelled, metadataListener::getResult); + } catch (StatusRuntimeException sre) { + throw StatusUtils.fromGrpcRuntimeException(sre); + } + } + + /** + * Get info on a stream. + * @param descriptor The descriptor for the stream. + * @param options RPC-layer hints for this call. + */ + public FlightInfo getInfo(FlightDescriptor descriptor, CallOption... options) { + try { + return new FlightInfo(CallOptions.wrapStub(blockingStub, options).getFlightInfo(descriptor.toProtocol())); + } catch (URISyntaxException e) { + // We don't expect this will happen for conforming Flight implementations. For instance, a Java server + // itself wouldn't be able to construct an invalid Location. + throw new RuntimeException(e); + } catch (StatusRuntimeException sre) { + throw StatusUtils.fromGrpcRuntimeException(sre); + } + } + + /** + * Get schema for a stream. + * @param descriptor The descriptor for the stream. + * @param options RPC-layer hints for this call. + */ + public SchemaResult getSchema(FlightDescriptor descriptor, CallOption... options) { + return SchemaResult.fromProtocol(CallOptions.wrapStub(blockingStub, options).getSchema(descriptor.toProtocol())); + } + + /** + * Retrieve a stream from the server. + * @param ticket The ticket granting access to the data stream. + * @param options RPC-layer hints for this call. + */ + public FlightStream getStream(Ticket ticket, CallOption... options) { + final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions(); + ClientCall<Flight.Ticket, ArrowMessage> call = interceptedChannel.newCall(doGetDescriptor, callOptions); + FlightStream stream = new FlightStream( + allocator, + PENDING_REQUESTS, + (String message, Throwable cause) -> call.cancel(message, cause), + (count) -> call.request(count)); + + final StreamObserver<ArrowMessage> delegate = stream.asObserver(); + ClientResponseObserver<Flight.Ticket, ArrowMessage> clientResponseObserver = + new ClientResponseObserver<Flight.Ticket, ArrowMessage>() { + + @Override + public void beforeStart(ClientCallStreamObserver<org.apache.arrow.flight.impl.Flight.Ticket> requestStream) { + requestStream.disableAutoInboundFlowControl(); + } + + @Override + public void onNext(ArrowMessage value) { + delegate.onNext(value); + } + + @Override + public void onError(Throwable t) { + delegate.onError(StatusUtils.toGrpcException(t)); + } + + @Override + public void onCompleted() { + delegate.onCompleted(); + } + + }; + + ClientCalls.asyncServerStreamingCall(call, ticket.toProtocol(), clientResponseObserver); + return stream; + } + + /** + * Initiate a bidirectional data exchange with the server. + * + * @param descriptor A descriptor for the data stream. + * @param options RPC call options. + * @return A pair of a readable stream and a writable stream. + */ + public ExchangeReaderWriter doExchange(FlightDescriptor descriptor, CallOption... options) { + Preconditions.checkNotNull(descriptor, "descriptor must not be null"); + final io.grpc.CallOptions callOptions = CallOptions.wrapStub(asyncStub, options).getCallOptions(); + + try { + final ClientCall<ArrowMessage, ArrowMessage> call = interceptedChannel.newCall(doExchangeDescriptor, callOptions); + final FlightStream stream = new FlightStream(allocator, PENDING_REQUESTS, call::cancel, call::request); + final ClientCallStreamObserver<ArrowMessage> observer = (ClientCallStreamObserver<ArrowMessage>) + ClientCalls.asyncBidiStreamingCall(call, stream.asObserver()); + final ClientStreamListener writer = new PutObserver( + descriptor, observer, stream.cancelled::isDone, + () -> { + try { + stream.completed.get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw CallStatus.INTERNAL + .withDescription("Client error: interrupted while completing call") + .withCause(e) + .toRuntimeException(); + } catch (ExecutionException e) { + throw CallStatus.INTERNAL + .withDescription("Client error: internal while completing call") + .withCause(e) + .toRuntimeException(); + } + }); + // Send the descriptor to start. + try (final ArrowMessage message = new ArrowMessage(descriptor.toProtocol())) { + observer.onNext(message); + } catch (Exception e) { + throw CallStatus.INTERNAL + .withCause(e) + .withDescription("Could not write descriptor " + descriptor) + .toRuntimeException(); + } + return new ExchangeReaderWriter(stream, writer); + } catch (StatusRuntimeException sre) { + throw StatusUtils.fromGrpcRuntimeException(sre); + } + } + + /** A pair of a reader and a writer for a DoExchange call. */ + public static class ExchangeReaderWriter implements AutoCloseable { + private final FlightStream reader; + private final ClientStreamListener writer; + + ExchangeReaderWriter(FlightStream reader, ClientStreamListener writer) { + this.reader = reader; + this.writer = writer; + } + + /** Get the reader for the call. */ + public FlightStream getReader() { + return reader; + } + + /** Get the writer for the call. */ + public ClientStreamListener getWriter() { + return writer; + } + + /** Shut down the streams in this call. */ + @Override + public void close() throws Exception { + reader.close(); + } + } + + /** + * A stream observer for Flight.PutResult + */ + private static class SetStreamObserver implements StreamObserver<Flight.PutResult> { + private final BufferAllocator allocator; + private final StreamListener<PutResult> listener; + + SetStreamObserver(BufferAllocator allocator, StreamListener<PutResult> listener) { + super(); + this.allocator = allocator; + this.listener = listener == null ? NoOpStreamListener.getInstance() : listener; + } + + @Override + public void onNext(Flight.PutResult value) { + try (final PutResult message = PutResult.fromProtocol(allocator, value)) { + listener.onNext(message); + } + } + + @Override + public void onError(Throwable t) { + listener.onError(StatusUtils.fromThrowable(t)); + } + + @Override + public void onCompleted() { + listener.onCompleted(); + } + } + + /** + * The implementation of a {@link ClientStreamListener} for writing data to a Flight server. + */ + static class PutObserver extends OutboundStreamListenerImpl implements ClientStreamListener { + private final BooleanSupplier isCancelled; + private final Runnable getResult; + + /** + * Create a new client stream listener. + * + * @param descriptor The descriptor for the stream. + * @param observer The write-side gRPC StreamObserver. + * @param isCancelled A flag to check if the call has been cancelled. + * @param getResult A flag that blocks until the overall call completes. + */ + PutObserver(FlightDescriptor descriptor, ClientCallStreamObserver<ArrowMessage> observer, + BooleanSupplier isCancelled, Runnable getResult) { + super(descriptor, observer); + Preconditions.checkNotNull(descriptor, "descriptor must be provided"); + Preconditions.checkNotNull(isCancelled, "isCancelled must be provided"); + Preconditions.checkNotNull(getResult, "getResult must be provided"); + this.isCancelled = isCancelled; + this.getResult = getResult; + this.unloader = null; + } + + @Override + protected void waitUntilStreamReady() { + // Check isCancelled as well to avoid inadvertently blocking forever + // (so long as PutListener properly implements it) + while (!responseObserver.isReady() && !isCancelled.getAsBoolean()) { + /* busy wait */ + } + } + + @Override + public void getResult() { + getResult.run(); + } + } + + /** + * Interface for writers to an Arrow data stream. + */ + public interface ClientStreamListener extends OutboundStreamListener { + + /** + * Wait for the stream to finish on the server side. You must call this to be notified of any errors that may have + * happened during the upload. + */ + void getResult(); + } + + /** + * A handler for server-sent application metadata messages during a Flight DoPut operation. + * + * <p>Generally, instead of implementing this yourself, you should use {@link AsyncPutListener} or {@link + * SyncPutListener}. + */ + public interface PutListener extends StreamListener<PutResult> { + + /** + * Wait for the stream to finish on the server side. You must call this to be notified of any errors that may have + * happened during the upload. + */ + void getResult(); + + /** + * Called when a message from the server is received. + * + * @param val The application metadata. This buffer will be reclaimed once onNext returns; you must retain a + * reference to use it outside this method. + */ + @Override + void onNext(PutResult val); + + /** + * Check if the call has been cancelled. + * + * <p>By default, this always returns false. Implementations should provide an appropriate implementation, as + * otherwise, a DoPut operation may inadvertently block forever. + */ + default boolean isCancelled() { + return false; + } + } + + /** + * Shut down this client. + */ + public void close() throws InterruptedException { + channel.shutdown().awaitTermination(5, TimeUnit.SECONDS); + allocator.close(); + } + + /** + * Create a builder for a Flight client. + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Create a builder for a Flight client. + * @param allocator The allocator to use for the client. + * @param location The location to connect to. + */ + public static Builder builder(BufferAllocator allocator, Location location) { + return new Builder(allocator, location); + } + + /** + * A builder for Flight clients. + */ + public static final class Builder { + private BufferAllocator allocator; + private Location location; + private boolean forceTls = false; + private int maxInboundMessageSize = FlightServer.MAX_GRPC_MESSAGE_SIZE; + private InputStream trustedCertificates = null; + private InputStream clientCertificate = null; + private InputStream clientKey = null; + private String overrideHostname = null; + private List<FlightClientMiddleware.Factory> middleware = new ArrayList<>(); + private boolean verifyServer = true; + + private Builder() { + } + + private Builder(BufferAllocator allocator, Location location) { + this.allocator = Preconditions.checkNotNull(allocator); + this.location = Preconditions.checkNotNull(location); + } + + /** + * Force the client to connect over TLS. + */ + public Builder useTls() { + this.forceTls = true; + return this; + } + + /** Override the hostname checked for TLS. Use with caution in production. */ + public Builder overrideHostname(final String hostname) { + this.overrideHostname = hostname; + return this; + } + + /** Set the maximum inbound message size. */ + public Builder maxInboundMessageSize(int maxSize) { + Preconditions.checkArgument(maxSize > 0); + this.maxInboundMessageSize = maxSize; + return this; + } + + /** Set the trusted TLS certificates. */ + public Builder trustedCertificates(final InputStream stream) { + this.trustedCertificates = Preconditions.checkNotNull(stream); + return this; + } + + /** Set the trusted TLS certificates. */ + public Builder clientCertificate(final InputStream clientCertificate, final InputStream clientKey) { + Preconditions.checkNotNull(clientKey); + this.clientCertificate = Preconditions.checkNotNull(clientCertificate); + this.clientKey = Preconditions.checkNotNull(clientKey); + return this; + } + + public Builder allocator(BufferAllocator allocator) { + this.allocator = Preconditions.checkNotNull(allocator); + return this; + } + + public Builder location(Location location) { + this.location = Preconditions.checkNotNull(location); + return this; + } + + public Builder intercept(FlightClientMiddleware.Factory factory) { + middleware.add(factory); + return this; + } + + public Builder verifyServer(boolean verifyServer) { + this.verifyServer = verifyServer; + return this; + } + + /** + * Create the client from this builder. + */ + public FlightClient build() { + final NettyChannelBuilder builder; + + switch (location.getUri().getScheme()) { + case LocationSchemes.GRPC: + case LocationSchemes.GRPC_INSECURE: + case LocationSchemes.GRPC_TLS: { + builder = NettyChannelBuilder.forAddress(location.toSocketAddress()); + break; + } + case LocationSchemes.GRPC_DOMAIN_SOCKET: { + // The implementation is platform-specific, so we have to find the classes at runtime + builder = NettyChannelBuilder.forAddress(location.toSocketAddress()); + try { + try { + // Linux + builder.channelType( + (Class<? extends ServerChannel>) Class.forName("io.netty.channel.epoll.EpollDomainSocketChannel")); + final EventLoopGroup elg = (EventLoopGroup) Class.forName("io.netty.channel.epoll.EpollEventLoopGroup") + .newInstance(); + builder.eventLoopGroup(elg); + } catch (ClassNotFoundException e) { + // BSD + builder.channelType( + (Class<? extends ServerChannel>) Class.forName("io.netty.channel.kqueue.KQueueDomainSocketChannel")); + final EventLoopGroup elg = (EventLoopGroup) Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup") + .newInstance(); + builder.eventLoopGroup(elg); + } + } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) { + throw new UnsupportedOperationException( + "Could not find suitable Netty native transport implementation for domain socket address."); + } + break; + } + default: + throw new IllegalArgumentException("Scheme is not supported: " + location.getUri().getScheme()); + } + + if (this.forceTls || LocationSchemes.GRPC_TLS.equals(location.getUri().getScheme())) { + builder.useTransportSecurity(); + + final boolean hasTrustedCerts = this.trustedCertificates != null; + final boolean hasKeyCertPair = this.clientCertificate != null && this.clientKey != null; + if (!this.verifyServer && (hasTrustedCerts || hasKeyCertPair)) { + throw new IllegalArgumentException("FlightClient has been configured to disable server verification, " + + "but certificate options have been specified."); + } + + final SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); + + if (!this.verifyServer) { + sslContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE); + } else if (this.trustedCertificates != null || this.clientCertificate != null || this.clientKey != null) { + if (this.trustedCertificates != null) { + sslContextBuilder.trustManager(this.trustedCertificates); + } + if (this.clientCertificate != null && this.clientKey != null) { + sslContextBuilder.keyManager(this.clientCertificate, this.clientKey); + } + } + try { + builder.sslContext(sslContextBuilder.build()); + } catch (SSLException e) { + throw new RuntimeException(e); + } + + if (this.overrideHostname != null) { + builder.overrideAuthority(this.overrideHostname); + } + } else { + builder.usePlaintext(); + } + + builder + .maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS) + .maxInboundMessageSize(maxInboundMessageSize); + return new FlightClient(allocator, builder.build(), middleware); + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClientMiddleware.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClientMiddleware.java new file mode 100644 index 000000000..1528ca6c6 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClientMiddleware.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +/** + * Client-side middleware for Flight. + * + * <p>Middleware are instantiated per-call and should store state in the middleware instance. + */ +public interface FlightClientMiddleware { + /** + * A callback used before request headers are sent. The headers may be manipulated. + */ + void onBeforeSendingHeaders(CallHeaders outgoingHeaders); + + /** + * A callback called after response headers are received. The headers may be manipulated. + */ + void onHeadersReceived(CallHeaders incomingHeaders); + + /** + * A callback called after the call completes. + */ + void onCallCompleted(CallStatus status); + + /** + * A factory for client middleware instances. + */ + interface Factory { + /** + * Create a new middleware instance for the given call. + * + * @throws FlightRuntimeException if the middleware wants to reject the call with the given status + */ + FlightClientMiddleware onCallStarted(CallInfo info); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightConstants.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightConstants.java new file mode 100644 index 000000000..2d039c9d2 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightConstants.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +/** + * String constants relevant to flight implementations. + */ +public interface FlightConstants { + + String SERVICE = "arrow.flight.protocol.FlightService"; + + FlightServerMiddleware.Key<ServerHeaderMiddleware> HEADER_KEY = + FlightServerMiddleware.Key.of("org.apache.arrow.flight.ServerHeaderMiddleware"); +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDescriptor.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDescriptor.java new file mode 100644 index 000000000..3eff011d9 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightDescriptor.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.List; + +import org.apache.arrow.flight.impl.Flight; +import org.apache.arrow.flight.impl.Flight.FlightDescriptor.DescriptorType; +import org.apache.arrow.util.Preconditions; + +import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; + +/** + * An identifier for a particular set of data. This can either be an opaque command that generates + * the data or a static "path" to the data. This is a POJO wrapper around the protobuf message with + * the same name. + */ +public class FlightDescriptor { + + private boolean isCmd; + private List<String> path; + private byte[] cmd; + + private FlightDescriptor(boolean isCmd, List<String> path, byte[] cmd) { + super(); + this.isCmd = isCmd; + this.path = path; + this.cmd = cmd; + } + + public static FlightDescriptor command(byte[] cmd) { + return new FlightDescriptor(true, null, cmd); + } + + public static FlightDescriptor path(Iterable<String> path) { + return new FlightDescriptor(false, ImmutableList.copyOf(path), null); + } + + public static FlightDescriptor path(String...path) { + return new FlightDescriptor(false, ImmutableList.copyOf(path), null); + } + + FlightDescriptor(Flight.FlightDescriptor descriptor) { + if (descriptor.getType() == DescriptorType.CMD) { + isCmd = true; + cmd = descriptor.getCmd().toByteArray(); + } else if (descriptor.getType() == DescriptorType.PATH) { + isCmd = false; + path = descriptor.getPathList(); + } else { + throw new UnsupportedOperationException(); + } + } + + public boolean isCommand() { + return isCmd; + } + + public List<String> getPath() { + Preconditions.checkArgument(!isCmd); + return path; + } + + public byte[] getCommand() { + Preconditions.checkArgument(isCmd); + return cmd; + } + + Flight.FlightDescriptor toProtocol() { + Flight.FlightDescriptor.Builder b = Flight.FlightDescriptor.newBuilder(); + + if (isCmd) { + return b.setType(DescriptorType.CMD).setCmd(ByteString.copyFrom(cmd)).build(); + } + return b.setType(DescriptorType.PATH).addAllPath(path).build(); + } + + /** + * Get the serialized form of this protocol message. + * + * <p>Intended to help interoperability by allowing non-Flight services to still return Flight types. + */ + public ByteBuffer serialize() { + return ByteBuffer.wrap(toProtocol().toByteArray()); + } + + /** + * Parse the serialized form of this protocol message. + * + * <p>Intended to help interoperability by allowing Flight clients to obtain stream info from non-Flight services. + * + * @param serialized The serialized form of the FlightDescriptor, as returned by {@link #serialize()}. + * @return The deserialized FlightDescriptor. + * @throws IOException if the serialized form is invalid. + */ + public static FlightDescriptor deserialize(ByteBuffer serialized) throws IOException { + return new FlightDescriptor(Flight.FlightDescriptor.parseFrom(serialized)); + } + + @Override + public String toString() { + if (isCmd) { + return toHex(cmd); + } else { + return Joiner.on('.').join(path); + } + } + + private String toHex(byte[] bytes) { + StringBuilder sb = new StringBuilder(); + for (byte b : bytes) { + sb.append(String.format("%02X ", b)); + } + return sb.toString(); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((cmd == null) ? 0 : Arrays.hashCode(cmd)); + result = prime * result + (isCmd ? 1231 : 1237); + result = prime * result + ((path == null) ? 0 : path.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + FlightDescriptor other = (FlightDescriptor) obj; + if (cmd == null) { + if (other.cmd != null) { + return false; + } + } else if (!Arrays.equals(cmd, other.cmd)) { + return false; + } + if (isCmd != other.isCmd) { + return false; + } + if (path == null) { + if (other.path != null) { + return false; + } + } else if (!path.equals(other.path)) { + return false; + } + return true; + } + + +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightEndpoint.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightEndpoint.java new file mode 100644 index 000000000..2e46b694d --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightEndpoint.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import org.apache.arrow.flight.impl.Flight; + +import com.google.common.collect.ImmutableList; + +/** + * POJO to convert to/from the underlying protobuf FlightEndpoint. + */ +public class FlightEndpoint { + private List<Location> locations; + private Ticket ticket; + + /** + * Constructs a new instance. + * + * @param ticket A ticket that describe the key of a data stream. + * @param locations The possible locations the stream can be retrieved from. + */ + public FlightEndpoint(Ticket ticket, Location... locations) { + super(); + Objects.requireNonNull(ticket); + this.locations = ImmutableList.copyOf(locations); + this.ticket = ticket; + } + + /** + * Constructs from the protocol buffer representation. + */ + FlightEndpoint(Flight.FlightEndpoint flt) throws URISyntaxException { + locations = new ArrayList<>(); + for (final Flight.Location location : flt.getLocationList()) { + locations.add(new Location(location.getUri())); + } + ticket = new Ticket(flt.getTicket()); + } + + public List<Location> getLocations() { + return locations; + } + + public Ticket getTicket() { + return ticket; + } + + /** + * Converts to the protocol buffer representation. + */ + Flight.FlightEndpoint toProtocol() { + Flight.FlightEndpoint.Builder b = Flight.FlightEndpoint.newBuilder() + .setTicket(ticket.toProtocol()); + + for (Location l : locations) { + b.addLocation(l.toProtocol()); + } + return b.build(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FlightEndpoint that = (FlightEndpoint) o; + return locations.equals(that.locations) && + ticket.equals(that.ticket); + } + + @Override + public int hashCode() { + return Objects.hash(locations, ticket); + } + + @Override + public String toString() { + return "FlightEndpoint{" + + "locations=" + locations + + ", ticket=" + ticket + + '}'; + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java new file mode 100644 index 000000000..e57b311c2 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightInfo.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import org.apache.arrow.flight.impl.Flight; +import org.apache.arrow.vector.ipc.ReadChannel; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.validate.MetadataV4UnionChecker; + +import com.fasterxml.jackson.databind.util.ByteBufferBackedInputStream; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; + +/** + * A POJO representation of a FlightInfo, metadata associated with a set of data records. + */ +public class FlightInfo { + private final Schema schema; + private final FlightDescriptor descriptor; + private final List<FlightEndpoint> endpoints; + private final long bytes; + private final long records; + private final IpcOption option; + + /** + * Constructs a new instance. + * + * @param schema The schema of the Flight + * @param descriptor An identifier for the Flight. + * @param endpoints A list of endpoints that have the flight available. + * @param bytes The number of bytes in the flight + * @param records The number of records in the flight. + */ + public FlightInfo(Schema schema, FlightDescriptor descriptor, List<FlightEndpoint> endpoints, long bytes, + long records) { + this(schema, descriptor, endpoints, bytes, records, IpcOption.DEFAULT); + } + + /** + * Constructs a new instance. + * + * @param schema The schema of the Flight + * @param descriptor An identifier for the Flight. + * @param endpoints A list of endpoints that have the flight available. + * @param bytes The number of bytes in the flight + * @param records The number of records in the flight. + * @param option IPC write options. + */ + public FlightInfo(Schema schema, FlightDescriptor descriptor, List<FlightEndpoint> endpoints, long bytes, + long records, IpcOption option) { + Objects.requireNonNull(schema); + Objects.requireNonNull(descriptor); + Objects.requireNonNull(endpoints); + MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option.metadataVersion); + this.schema = schema; + this.descriptor = descriptor; + this.endpoints = endpoints; + this.bytes = bytes; + this.records = records; + this.option = option; + } + + /** + * Constructs from the protocol buffer representation. + */ + FlightInfo(Flight.FlightInfo pbFlightInfo) throws URISyntaxException { + try { + final ByteBuffer schemaBuf = pbFlightInfo.getSchema().asReadOnlyByteBuffer(); + schema = pbFlightInfo.getSchema().size() > 0 ? + MessageSerializer.deserializeSchema( + new ReadChannel(Channels.newChannel(new ByteBufferBackedInputStream(schemaBuf)))) + : new Schema(ImmutableList.of()); + } catch (IOException e) { + throw new RuntimeException(e); + } + descriptor = new FlightDescriptor(pbFlightInfo.getFlightDescriptor()); + endpoints = new ArrayList<>(); + for (final Flight.FlightEndpoint endpoint : pbFlightInfo.getEndpointList()) { + endpoints.add(new FlightEndpoint(endpoint)); + } + bytes = pbFlightInfo.getTotalBytes(); + records = pbFlightInfo.getTotalRecords(); + option = IpcOption.DEFAULT; + } + + public Schema getSchema() { + return schema; + } + + public long getBytes() { + return bytes; + } + + public long getRecords() { + return records; + } + + public FlightDescriptor getDescriptor() { + return descriptor; + } + + public List<FlightEndpoint> getEndpoints() { + return endpoints; + } + + /** + * Converts to the protocol buffer representation. + */ + Flight.FlightInfo toProtocol() { + // Encode schema in a Message payload + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(baos)), schema, option); + } catch (IOException e) { + throw new RuntimeException(e); + } + return Flight.FlightInfo.newBuilder() + .addAllEndpoint(endpoints.stream().map(t -> t.toProtocol()).collect(Collectors.toList())) + .setSchema(ByteString.copyFrom(baos.toByteArray())) + .setFlightDescriptor(descriptor.toProtocol()) + .setTotalBytes(FlightInfo.this.bytes) + .setTotalRecords(records) + .build(); + } + + /** + * Get the serialized form of this protocol message. + * + * <p>Intended to help interoperability by allowing non-Flight services to still return Flight types. + */ + public ByteBuffer serialize() { + return ByteBuffer.wrap(toProtocol().toByteArray()); + } + + /** + * Parse the serialized form of this protocol message. + * + * <p>Intended to help interoperability by allowing Flight clients to obtain stream info from non-Flight services. + * + * @param serialized The serialized form of the FlightInfo, as returned by {@link #serialize()}. + * @return The deserialized FlightInfo. + * @throws IOException if the serialized form is invalid. + * @throws URISyntaxException if the serialized form contains an unsupported URI format. + */ + public static FlightInfo deserialize(ByteBuffer serialized) throws IOException, URISyntaxException { + return new FlightInfo(Flight.FlightInfo.parseFrom(serialized)); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FlightInfo that = (FlightInfo) o; + return bytes == that.bytes && + records == that.records && + schema.equals(that.schema) && + descriptor.equals(that.descriptor) && + endpoints.equals(that.endpoints); + } + + @Override + public int hashCode() { + return Objects.hash(schema, descriptor, endpoints, bytes, records); + } + + @Override + public String toString() { + return "FlightInfo{" + + "schema=" + schema + + ", descriptor=" + descriptor + + ", endpoints=" + endpoints + + ", bytes=" + bytes + + ", records=" + records + + '}'; + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightMethod.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightMethod.java new file mode 100644 index 000000000..5d2915bb6 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightMethod.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import org.apache.arrow.flight.impl.FlightServiceGrpc; + +/** + * All the RPC methods available in Flight. + */ +public enum FlightMethod { + HANDSHAKE, + LIST_FLIGHTS, + GET_FLIGHT_INFO, + GET_SCHEMA, + DO_GET, + DO_PUT, + DO_ACTION, + LIST_ACTIONS, + DO_EXCHANGE, + ; + + /** + * Convert a method name string into a {@link FlightMethod}. + * + * @throws IllegalArgumentException if the method name is not valid. + */ + public static FlightMethod fromProtocol(final String methodName) { + if (FlightServiceGrpc.getHandshakeMethod().getFullMethodName().equals(methodName)) { + return HANDSHAKE; + } else if (FlightServiceGrpc.getListFlightsMethod().getFullMethodName().equals(methodName)) { + return LIST_FLIGHTS; + } else if (FlightServiceGrpc.getGetFlightInfoMethod().getFullMethodName().equals(methodName)) { + return GET_FLIGHT_INFO; + } else if (FlightServiceGrpc.getGetSchemaMethod().getFullMethodName().equals(methodName)) { + return GET_SCHEMA; + } else if (FlightServiceGrpc.getDoGetMethod().getFullMethodName().equals(methodName)) { + return DO_GET; + } else if (FlightServiceGrpc.getDoPutMethod().getFullMethodName().equals(methodName)) { + return DO_PUT; + } else if (FlightServiceGrpc.getDoActionMethod().getFullMethodName().equals(methodName)) { + return DO_ACTION; + } else if (FlightServiceGrpc.getListActionsMethod().getFullMethodName().equals(methodName)) { + return LIST_ACTIONS; + } else if (FlightServiceGrpc.getDoExchangeMethod().getFullMethodName().equals(methodName)) { + return DO_EXCHANGE; + } + throw new IllegalArgumentException("Not a Flight method name in gRPC: " + methodName); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java new file mode 100644 index 000000000..5e5b26505 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightProducer.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.Map; + +/** + * API to Implement an Arrow Flight producer. + */ +public interface FlightProducer { + + /** + * Return data for a stream. + * + * @param context Per-call context. + * @param ticket The application-defined ticket identifying this stream. + * @param listener An interface for sending data back to the client. + */ + void getStream(CallContext context, Ticket ticket, ServerStreamListener listener); + + /** + * List available data streams on this service. + * + * @param context Per-call context. + * @param criteria Application-defined criteria for filtering streams. + * @param listener An interface for sending data back to the client. + */ + void listFlights(CallContext context, Criteria criteria, + StreamListener<FlightInfo> listener); + + /** + * Get information about a particular data stream. + * + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Metadata about the stream. + */ + FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor); + + /** + * Get schema for a particular data stream. + * + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Schema for the stream. + */ + default SchemaResult getSchema(CallContext context, FlightDescriptor descriptor) { + FlightInfo info = getFlightInfo(context, descriptor); + return new SchemaResult(info.getSchema()); + } + + + /** + * Accept uploaded data for a particular stream. + * + * @param context Per-call context. + * @param flightStream The data stream being uploaded. + */ + Runnable acceptPut(CallContext context, + FlightStream flightStream, StreamListener<PutResult> ackStream); + + default void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) { + throw CallStatus.UNIMPLEMENTED.withDescription("DoExchange is unimplemented").toRuntimeException(); + } + + /** + * Generic handler for application-defined RPCs. + * + * @param context Per-call context. + * @param action Client-supplied parameters. + * @param listener A stream of responses. + */ + void doAction(CallContext context, Action action, + StreamListener<Result> listener); + + /** + * List available application-defined RPCs. + * @param context Per-call context. + * @param listener An interface for sending data back to the client. + */ + void listActions(CallContext context, StreamListener<ActionType> listener); + + /** + * An interface for sending Arrow data back to a client. + */ + interface ServerStreamListener extends OutboundStreamListener { + + /** + * Check whether the call has been cancelled. If so, stop sending data. + */ + boolean isCancelled(); + + /** + * Set a callback for when the client cancels a call, i.e. {@link #isCancelled()} has become true. + * + * <p>Note that this callback may only be called some time after {@link #isCancelled()} becomes true, and may never + * be called if all executor threads on the server are busy, or the RPC method body is implemented in a blocking + * fashion. + */ + void setOnCancelHandler(Runnable handler); + } + + /** + * Callbacks for pushing objects to a receiver. + * + * @param <T> Type of the values in the stream. + */ + interface StreamListener<T> { + + /** + * Send the next value to the client. + */ + void onNext(T val); + + /** + * Indicate an error to the client. + * + * <p>Terminates the stream; do not call {@link #onCompleted()}. + */ + void onError(Throwable t); + + /** + * Indicate that the transmission is finished. + */ + void onCompleted(); + + } + + /** + * Call-specific context. + */ + interface CallContext { + /** The identity of the authenticated peer. May be the empty string if unknown. */ + String peerIdentity(); + + /** Whether the call has been cancelled by the client. */ + boolean isCancelled(); + + /** + * Get the middleware instance of the given type for this call. + * + * <p>Returns null if not found. + */ + <T extends FlightServerMiddleware> T getMiddleware(FlightServerMiddleware.Key<T> key); + + /** Get an immutable map of middleware for this call. */ + Map<FlightServerMiddleware.Key<?>, FlightServerMiddleware> getMiddleware(); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightRuntimeException.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightRuntimeException.java new file mode 100644 index 000000000..76d3349a2 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightRuntimeException.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +/** + * An exception raised from a Flight RPC. + * + * <p>In service implementations, raising an instance of this exception will provide clients with a more detailed + * message and error code. + */ +public class FlightRuntimeException extends RuntimeException { + private final CallStatus status; + + /** + * Create a new exception from the given status. + */ + FlightRuntimeException(CallStatus status) { + super(status.description(), status.cause()); + this.status = status; + } + + public CallStatus status() { + return status; + } + + @Override + public String toString() { + String s = getClass().getName(); + return String.format("%s: %s: %s", s, status.code(), status.description()); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java new file mode 100644 index 000000000..d59480bfb --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java @@ -0,0 +1,399 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +import org.apache.arrow.flight.auth.ServerAuthHandler; +import org.apache.arrow.flight.auth.ServerAuthInterceptor; +import org.apache.arrow.flight.auth2.Auth2Constants; +import org.apache.arrow.flight.auth2.CallHeaderAuthenticator; +import org.apache.arrow.flight.auth2.ServerCallHeaderAuthMiddleware; +import org.apache.arrow.flight.grpc.ServerInterceptorAdapter; +import org.apache.arrow.flight.grpc.ServerInterceptorAdapter.KeyFactory; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.util.VisibleForTesting; + +import com.google.common.util.concurrent.ThreadFactoryBuilder; + +import io.grpc.Server; +import io.grpc.ServerInterceptors; +import io.grpc.netty.NettyServerBuilder; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.ServerChannel; + +/** + * Generic server of flight data that is customized via construction with delegate classes for the + * actual logic. The server currently uses GRPC as its transport mechanism. + */ +public class FlightServer implements AutoCloseable { + + private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(FlightServer.class); + + private final Location location; + private final Server server; + // The executor used by the gRPC server. We don't use it here, but we do need to clean it up with the server. + // May be null, if a user-supplied executor was provided (as we do not want to clean that up) + @VisibleForTesting + final ExecutorService grpcExecutor; + + /** The maximum size of an individual gRPC message. This effectively disables the limit. */ + static final int MAX_GRPC_MESSAGE_SIZE = Integer.MAX_VALUE; + + /** Create a new instance from a gRPC server. For internal use only. */ + private FlightServer(Location location, Server server, ExecutorService grpcExecutor) { + this.location = location; + this.server = server; + this.grpcExecutor = grpcExecutor; + } + + /** Start the server. */ + public FlightServer start() throws IOException { + server.start(); + return this; + } + + /** Get the port the server is running on (if applicable). */ + public int getPort() { + return server.getPort(); + } + + /** Get the location for this server. */ + public Location getLocation() { + if (location.getUri().getPort() == 0) { + // If the server was bound to port 0, replace the port in the location with the real port. + final URI uri = location.getUri(); + try { + return new Location(new URI(uri.getScheme(), uri.getUserInfo(), uri.getHost(), getPort(), + uri.getPath(), uri.getQuery(), uri.getFragment())); + } catch (URISyntaxException e) { + // We don't expect this to happen + throw new RuntimeException(e); + } + } + return location; + } + + /** Block until the server shuts down. */ + public void awaitTermination() throws InterruptedException { + server.awaitTermination(); + } + + /** Request that the server shut down. */ + public void shutdown() { + server.shutdown(); + if (grpcExecutor != null) { + grpcExecutor.shutdown(); + } + } + + /** + * Wait for the server to shut down with a timeout. + * @return true if the server shut down successfully. + */ + public boolean awaitTermination(final long timeout, final TimeUnit unit) throws InterruptedException { + return server.awaitTermination(timeout, unit); + } + + /** Shutdown the server, waits for up to 6 seconds for successful shutdown before returning. */ + public void close() throws InterruptedException { + shutdown(); + final boolean terminated = awaitTermination(3000, TimeUnit.MILLISECONDS); + if (terminated) { + logger.debug("Server was terminated within 3s"); + return; + } + + // get more aggressive in termination. + server.shutdownNow(); + + int count = 0; + while (!server.isTerminated() & count < 30) { + count++; + logger.debug("Waiting for termination"); + Thread.sleep(100); + } + + if (!server.isTerminated()) { + logger.warn("Couldn't shutdown server, resources likely will be leaked."); + } + } + + /** Create a builder for a Flight server. */ + public static Builder builder() { + return new Builder(); + } + + /** Create a builder for a Flight server. */ + public static Builder builder(BufferAllocator allocator, Location location, FlightProducer producer) { + return new Builder(allocator, location, producer); + } + + /** A builder for Flight servers. */ + public static final class Builder { + private BufferAllocator allocator; + private Location location; + private FlightProducer producer; + private final Map<String, Object> builderOptions; + private ServerAuthHandler authHandler = ServerAuthHandler.NO_OP; + private CallHeaderAuthenticator headerAuthenticator = CallHeaderAuthenticator.NO_OP; + private ExecutorService executor = null; + private int maxInboundMessageSize = MAX_GRPC_MESSAGE_SIZE; + private InputStream certChain; + private InputStream key; + private final List<KeyFactory<?>> interceptors; + // Keep track of inserted interceptors + private final Set<String> interceptorKeys; + + Builder() { + builderOptions = new HashMap<>(); + interceptors = new ArrayList<>(); + interceptorKeys = new HashSet<>(); + } + + Builder(BufferAllocator allocator, Location location, FlightProducer producer) { + this(); + this.allocator = Preconditions.checkNotNull(allocator); + this.location = Preconditions.checkNotNull(location); + this.producer = Preconditions.checkNotNull(producer); + } + + /** Create the server for this builder. */ + public FlightServer build() { + // Add the auth middleware if applicable. + if (headerAuthenticator != CallHeaderAuthenticator.NO_OP) { + this.middleware(FlightServerMiddleware.Key.of(Auth2Constants.AUTHORIZATION_HEADER), + new ServerCallHeaderAuthMiddleware.Factory(headerAuthenticator)); + } + + this.middleware(FlightConstants.HEADER_KEY, new ServerHeaderMiddleware.Factory()); + + final NettyServerBuilder builder; + switch (location.getUri().getScheme()) { + case LocationSchemes.GRPC_DOMAIN_SOCKET: { + // The implementation is platform-specific, so we have to find the classes at runtime + builder = NettyServerBuilder.forAddress(location.toSocketAddress()); + try { + try { + // Linux + builder.channelType( + (Class<? extends ServerChannel>) Class + .forName("io.netty.channel.epoll.EpollServerDomainSocketChannel")); + final EventLoopGroup elg = (EventLoopGroup) Class.forName("io.netty.channel.epoll.EpollEventLoopGroup") + .newInstance(); + builder.bossEventLoopGroup(elg).workerEventLoopGroup(elg); + } catch (ClassNotFoundException e) { + // BSD + builder.channelType( + (Class<? extends ServerChannel>) Class + .forName("io.netty.channel.kqueue.KQueueServerDomainSocketChannel")); + final EventLoopGroup elg = (EventLoopGroup) Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup") + .newInstance(); + builder.bossEventLoopGroup(elg).workerEventLoopGroup(elg); + } + } catch (ClassNotFoundException | InstantiationException | IllegalAccessException e) { + throw new UnsupportedOperationException( + "Could not find suitable Netty native transport implementation for domain socket address."); + } + break; + } + case LocationSchemes.GRPC: + case LocationSchemes.GRPC_INSECURE: { + builder = NettyServerBuilder.forAddress(location.toSocketAddress()); + break; + } + case LocationSchemes.GRPC_TLS: { + if (certChain == null) { + throw new IllegalArgumentException("Must provide a certificate and key to serve gRPC over TLS"); + } + builder = NettyServerBuilder.forAddress(location.toSocketAddress()); + break; + } + default: + throw new IllegalArgumentException("Scheme is not supported: " + location.getUri().getScheme()); + } + + if (certChain != null) { + builder.useTransportSecurity(certChain, key); + } + + // Share one executor between the gRPC service, DoPut, and Handshake + final ExecutorService exec; + // We only want to have FlightServer close the gRPC executor if we created it here. We should not close + // user-supplied executors. + final ExecutorService grpcExecutor; + if (executor != null) { + exec = executor; + grpcExecutor = null; + } else { + exec = Executors.newCachedThreadPool( + // Name threads for better debuggability + new ThreadFactoryBuilder().setNameFormat("flight-server-default-executor-%d").build()); + grpcExecutor = exec; + } + final FlightBindingService flightService = new FlightBindingService(allocator, producer, authHandler, exec); + builder + .executor(exec) + .maxInboundMessageSize(maxInboundMessageSize) + .addService( + ServerInterceptors.intercept( + flightService, + new ServerAuthInterceptor(authHandler))); + + // Allow hooking into the gRPC builder. This is not guaranteed to be available on all Arrow versions or + // Flight implementations. + builderOptions.computeIfPresent("grpc.builderConsumer", (key, builderConsumer) -> { + final Consumer<NettyServerBuilder> consumer = (Consumer<NettyServerBuilder>) builderConsumer; + consumer.accept(builder); + return null; + }); + + // Allow explicitly setting some Netty-specific options + builderOptions.computeIfPresent("netty.channelType", (key, channelType) -> { + builder.channelType((Class<? extends ServerChannel>) channelType); + return null; + }); + builderOptions.computeIfPresent("netty.bossEventLoopGroup", (key, elg) -> { + builder.bossEventLoopGroup((EventLoopGroup) elg); + return null; + }); + builderOptions.computeIfPresent("netty.workerEventLoopGroup", (key, elg) -> { + builder.workerEventLoopGroup((EventLoopGroup) elg); + return null; + }); + + builder.intercept(new ServerInterceptorAdapter(interceptors)); + return new FlightServer(location, builder.build(), grpcExecutor); + } + + /** + * Set the maximum size of a message. Defaults to "unlimited", depending on the underlying transport. + */ + public Builder maxInboundMessageSize(int maxMessageSize) { + this.maxInboundMessageSize = maxMessageSize; + return this; + } + + /** + * Enable TLS on the server. + * @param certChain The certificate chain to use. + * @param key The private key to use. + */ + public Builder useTls(final File certChain, final File key) throws IOException { + this.certChain = new FileInputStream(certChain); + this.key = new FileInputStream(key); + return this; + } + + /** + * Enable TLS on the server. + * @param certChain The certificate chain to use. + * @param key The private key to use. + */ + public Builder useTls(final InputStream certChain, final InputStream key) { + this.certChain = certChain; + this.key = key; + return this; + } + + /** + * Set the executor used by the server. + * + * <p>Flight will NOT take ownership of the executor. The application must clean it up if one is provided. (If not + * provided, Flight will use a default executor which it will clean up.) + */ + public Builder executor(ExecutorService executor) { + this.executor = executor; + return this; + } + + /** + * Set the authentication handler. + */ + public Builder authHandler(ServerAuthHandler authHandler) { + this.authHandler = authHandler; + return this; + } + + /** + * Set the header-based authentication mechanism. + */ + public Builder headerAuthenticator(CallHeaderAuthenticator headerAuthenticator) { + this.headerAuthenticator = headerAuthenticator; + return this; + } + + /** + * Provide a transport-specific option. Not guaranteed to have any effect. + */ + public Builder transportHint(final String key, Object option) { + builderOptions.put(key, option); + return this; + } + + /** + * Add a Flight middleware component to inspect and modify requests to this service. + * + * @param key An identifier for this middleware component. Service implementations can retrieve the middleware + * instance for the current call using {@link org.apache.arrow.flight.FlightProducer.CallContext}. + * @param factory A factory for the middleware. + * @param <T> The middleware type. + * @throws IllegalArgumentException if the key already exists + */ + public <T extends FlightServerMiddleware> Builder middleware(final FlightServerMiddleware.Key<T> key, + final FlightServerMiddleware.Factory<T> factory) { + if (interceptorKeys.contains(key.key)) { + throw new IllegalArgumentException("Key already exists: " + key.key); + } + interceptors.add(new KeyFactory<>(key, factory)); + interceptorKeys.add(key.key); + return this; + } + + public Builder allocator(BufferAllocator allocator) { + this.allocator = Preconditions.checkNotNull(allocator); + return this; + } + + public Builder location(Location location) { + this.location = Preconditions.checkNotNull(location); + return this; + } + + public Builder producer(FlightProducer producer) { + this.producer = Preconditions.checkNotNull(producer); + return this; + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServerMiddleware.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServerMiddleware.java new file mode 100644 index 000000000..9bc8bbfe7 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServerMiddleware.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.Objects; + +/** + * Server-side middleware for Flight calls. + * + * <p>Middleware are instantiated per-call. + * + * <p>Methods are not guaranteed to be called on any particular thread, relative to the thread that Flight requests are + * executed on. Do not depend on thread-local storage; instead, use state on the middleware instance. Service + * implementations may communicate with middleware implementations through + * {@link org.apache.arrow.flight.FlightProducer.CallContext#getMiddleware(Key)}. Methods on the middleware instance + * are non-reentrant, that is, a particular RPC will not make multiple concurrent calls to methods on a single + * middleware instance. However, methods on the factory instance are expected to be thread-safe, and if the factory + * instance returns the same middleware object more than once, then that middleware object must be thread-safe. + */ +public interface FlightServerMiddleware { + + /** + * A factory for Flight server middleware. + * @param <T> The middleware type. + */ + interface Factory<T extends FlightServerMiddleware> { + /** + * A callback for when the call starts. + * + * @param info Details about the call. + * @param incomingHeaders A mutable set of request headers. + * @param context Context about the current request. + * + * @throws FlightRuntimeException if the middleware wants to reject the call with the given status + */ + T onCallStarted(CallInfo info, CallHeaders incomingHeaders, RequestContext context); + } + + /** + * A key for Flight server middleware. On a server, middleware instances are identified by this key. + * + * <p>Keys use reference equality, so instances should be shared. + * + * @param <T> The middleware class stored in this key. This provides a compile-time check when retrieving instances. + */ + class Key<T extends FlightServerMiddleware> { + final String key; + + Key(String key) { + this.key = Objects.requireNonNull(key, "Key must not be null."); + } + + /** + * Create a new key for the given type. + */ + public static <T extends FlightServerMiddleware> Key<T> of(String key) { + return new Key<>(key); + } + } + + /** + * Callback for when the underlying transport is about to send response headers. + * + * @param outgoingHeaders A mutable set of response headers. These can be manipulated to send different headers to the + * client. + */ + void onBeforeSendingHeaders(CallHeaders outgoingHeaders); + + /** + * Callback for when the underlying transport has completed a call. + * @param status Whether the call completed successfully or not. + */ + void onCallCompleted(CallStatus status); + + /** + * Callback for when an RPC method implementation throws an uncaught exception. + * + * <p>May be called multiple times, and may be called before or after {@link #onCallCompleted(CallStatus)}. + * Generally, an uncaught exception will end the call with a error {@link CallStatus}, and will be reported to {@link + * #onCallCompleted(CallStatus)}, but not necessarily this method. + * + * @param err The exception that was thrown. + */ + void onCallErrored(Throwable err); +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java new file mode 100644 index 000000000..4fb0dea2c --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java @@ -0,0 +1,427 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.function.BooleanSupplier; +import java.util.function.Consumer; + +import org.apache.arrow.flight.FlightProducer.ServerStreamListener; +import org.apache.arrow.flight.FlightServerMiddleware.Key; +import org.apache.arrow.flight.auth.AuthConstants; +import org.apache.arrow.flight.auth.ServerAuthHandler; +import org.apache.arrow.flight.auth.ServerAuthWrapper; +import org.apache.arrow.flight.auth2.Auth2Constants; +import org.apache.arrow.flight.grpc.ContextPropagatingExecutorService; +import org.apache.arrow.flight.grpc.RequestContextAdapter; +import org.apache.arrow.flight.grpc.ServerInterceptorAdapter; +import org.apache.arrow.flight.grpc.StatusUtils; +import org.apache.arrow.flight.impl.Flight; +import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceImplBase; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.base.Strings; + +import io.grpc.stub.ServerCallStreamObserver; +import io.grpc.stub.StreamObserver; + +/** + * GRPC service implementation for a flight server. + */ +class FlightService extends FlightServiceImplBase { + + private static final Logger logger = LoggerFactory.getLogger(FlightService.class); + private static final int PENDING_REQUESTS = 5; + + private final BufferAllocator allocator; + private final FlightProducer producer; + private final ServerAuthHandler authHandler; + private final ExecutorService executors; + + FlightService(BufferAllocator allocator, FlightProducer producer, ServerAuthHandler authHandler, + ExecutorService executors) { + this.allocator = allocator; + this.producer = producer; + this.authHandler = authHandler; + this.executors = new ContextPropagatingExecutorService(executors); + } + + private CallContext makeContext(ServerCallStreamObserver<?> responseObserver) { + // Try to get the peer identity from middleware first (using the auth2 interfaces). + final RequestContext context = RequestContextAdapter.REQUEST_CONTEXT_KEY.get(); + String peerIdentity = null; + if (context != null) { + peerIdentity = context.get(Auth2Constants.PEER_IDENTITY_KEY); + } + + if (Strings.isNullOrEmpty(peerIdentity)) { + // Try the legacy auth interface, which defaults to empty string. + peerIdentity = AuthConstants.PEER_IDENTITY_KEY.get(); + } + + return new CallContext(peerIdentity, responseObserver::isCancelled); + } + + @Override + public StreamObserver<Flight.HandshakeRequest> handshake(StreamObserver<Flight.HandshakeResponse> responseObserver) { + // This method is not meaningful with the auth2 interfaces. Authentication would already + // have happened by header/middleware with the auth2 classes. + return ServerAuthWrapper.wrapHandshake(authHandler, responseObserver, executors); + } + + @Override + public void listFlights(Flight.Criteria criteria, StreamObserver<Flight.FlightInfo> responseObserver) { + final StreamPipe<FlightInfo, Flight.FlightInfo> listener = StreamPipe + .wrap(responseObserver, FlightInfo::toProtocol, this::handleExceptionWithMiddleware); + try { + final CallContext context = makeContext((ServerCallStreamObserver<?>) responseObserver); + producer.listFlights(context, new Criteria(criteria), listener); + } catch (Exception ex) { + listener.onError(ex); + } + // Do NOT call StreamPipe#onCompleted, as the FlightProducer implementation may be asynchronous + } + + public void doGetCustom(Flight.Ticket ticket, StreamObserver<ArrowMessage> responseObserverSimple) { + final ServerCallStreamObserver<ArrowMessage> responseObserver = + (ServerCallStreamObserver<ArrowMessage>) responseObserverSimple; + + final GetListener listener = new GetListener(responseObserver, this::handleExceptionWithMiddleware); + try { + producer.getStream(makeContext(responseObserver), new Ticket(ticket), listener); + } catch (Exception ex) { + listener.error(ex); + } + // Do NOT call GetListener#completed, as the implementation of getStream may be asynchronous + } + + @Override + public void doAction(Flight.Action request, StreamObserver<Flight.Result> responseObserver) { + final StreamPipe<Result, Flight.Result> listener = StreamPipe + .wrap(responseObserver, Result::toProtocol, this::handleExceptionWithMiddleware); + try { + final CallContext context = makeContext((ServerCallStreamObserver<?>) responseObserver); + producer.doAction(context, new Action(request), listener); + } catch (Exception ex) { + listener.onError(ex); + } + // Do NOT call StreamPipe#onCompleted, as the FlightProducer implementation may be asynchronous + } + + @Override + public void listActions(Flight.Empty request, StreamObserver<Flight.ActionType> responseObserver) { + final StreamPipe<org.apache.arrow.flight.ActionType, Flight.ActionType> listener = StreamPipe + .wrap(responseObserver, ActionType::toProtocol, this::handleExceptionWithMiddleware); + try { + final CallContext context = makeContext((ServerCallStreamObserver<?>) responseObserver); + producer.listActions(context, listener); + } catch (Exception ex) { + listener.onError(ex); + } + // Do NOT call StreamPipe#onCompleted, as the FlightProducer implementation may be asynchronous + } + + private static class GetListener extends OutboundStreamListenerImpl implements ServerStreamListener { + private ServerCallStreamObserver<ArrowMessage> responseObserver; + private final Consumer<Throwable> errorHandler; + private Runnable onCancelHandler = null; + private Runnable onReadyHandler = null; + private boolean completed; + + public GetListener(ServerCallStreamObserver<ArrowMessage> responseObserver, Consumer<Throwable> errorHandler) { + super(null, responseObserver); + this.errorHandler = errorHandler; + this.completed = false; + this.responseObserver = responseObserver; + this.responseObserver.setOnCancelHandler(this::onCancel); + this.responseObserver.setOnReadyHandler(this::onReady); + this.responseObserver.disableAutoInboundFlowControl(); + } + + private void onCancel() { + logger.debug("Stream cancelled by client."); + if (onCancelHandler != null) { + onCancelHandler.run(); + } + } + + private void onReady() { + if (onReadyHandler != null) { + onReadyHandler.run(); + } + } + + @Override + public void setOnCancelHandler(Runnable handler) { + this.onCancelHandler = handler; + } + + @Override + public void setOnReadyHandler(Runnable handler) { + this.onReadyHandler = handler; + } + + @Override + public boolean isCancelled() { + return responseObserver.isCancelled(); + } + + @Override + protected void waitUntilStreamReady() { + // Don't do anything - service implementations are expected to manage backpressure themselves + } + + @Override + public void error(Throwable ex) { + if (!completed) { + completed = true; + super.error(ex); + } else { + errorHandler.accept(ex); + } + } + + @Override + public void completed() { + if (!completed) { + completed = true; + super.completed(); + } else { + errorHandler.accept(new IllegalStateException("Tried to complete already-completed call")); + } + } + } + + public StreamObserver<ArrowMessage> doPutCustom(final StreamObserver<Flight.PutResult> responseObserverSimple) { + ServerCallStreamObserver<Flight.PutResult> responseObserver = + (ServerCallStreamObserver<Flight.PutResult>) responseObserverSimple; + responseObserver.disableAutoInboundFlowControl(); + responseObserver.request(1); + + final StreamPipe<PutResult, Flight.PutResult> ackStream = StreamPipe + .wrap(responseObserver, PutResult::toProtocol, this::handleExceptionWithMiddleware); + final FlightStream fs = new FlightStream( + allocator, + PENDING_REQUESTS, + /* server-upload streams are not cancellable */null, + responseObserver::request); + // When the ackStream is completed, the FlightStream will be closed with it + ackStream.setAutoCloseable(fs); + final StreamObserver<ArrowMessage> observer = fs.asObserver(); + executors.submit(() -> { + try { + producer.acceptPut(makeContext(responseObserver), fs, ackStream).run(); + } catch (Exception ex) { + ackStream.onError(ex); + } finally { + // ARROW-6136: Close the stream if and only if acceptPut hasn't closed it itself + // We don't do this for other streams since the implementation may be asynchronous + ackStream.ensureCompleted(); + } + }); + + return observer; + } + + @Override + public void getFlightInfo(Flight.FlightDescriptor request, StreamObserver<Flight.FlightInfo> responseObserver) { + final FlightInfo info; + try { + info = producer + .getFlightInfo(makeContext((ServerCallStreamObserver<?>) responseObserver), new FlightDescriptor(request)); + } catch (Exception ex) { + // Don't capture exceptions from onNext or onCompleted with this block - because then we can't call onError + responseObserver.onError(StatusUtils.toGrpcException(ex)); + return; + } + responseObserver.onNext(info.toProtocol()); + responseObserver.onCompleted(); + } + + /** + * Broadcast the given exception to all registered middleware. + */ + private void handleExceptionWithMiddleware(Throwable t) { + final Map<Key<?>, FlightServerMiddleware> middleware = ServerInterceptorAdapter.SERVER_MIDDLEWARE_KEY.get(); + if (middleware == null || middleware.isEmpty()) { + logger.error("Uncaught exception in Flight method body", t); + return; + } + middleware.forEach((k, v) -> v.onCallErrored(t)); + } + + @Override + public void getSchema(Flight.FlightDescriptor request, StreamObserver<Flight.SchemaResult> responseObserver) { + try { + SchemaResult result = producer + .getSchema(makeContext((ServerCallStreamObserver<?>) responseObserver), + new FlightDescriptor(request)); + responseObserver.onNext(result.toProtocol()); + responseObserver.onCompleted(); + } catch (Exception ex) { + responseObserver.onError(StatusUtils.toGrpcException(ex)); + } + } + + /** Ensures that other resources are cleaned up when the service finishes its call. */ + private static class ExchangeListener extends GetListener { + + private AutoCloseable resource; + private boolean closed = false; + private Runnable onCancelHandler = null; + + public ExchangeListener(ServerCallStreamObserver<ArrowMessage> responseObserver, Consumer<Throwable> errorHandler) { + super(responseObserver, errorHandler); + this.resource = null; + super.setOnCancelHandler(() -> { + try { + if (onCancelHandler != null) { + onCancelHandler.run(); + } + } finally { + cleanup(); + } + }); + } + + private void cleanup() { + if (closed) { + // Prevent double-free. gRPC will call the OnCancelHandler even on a normal call end, which means that + // we'll double-free without this guard. + return; + } + closed = true; + try { + AutoCloseables.close(resource); + } catch (Exception e) { + throw CallStatus.INTERNAL + .withCause(e) + .withDescription("Server internal error cleaning up resources") + .toRuntimeException(); + } + } + + @Override + public void error(Throwable ex) { + try { + this.cleanup(); + } finally { + super.error(ex); + } + } + + @Override + public void completed() { + try { + this.cleanup(); + } finally { + super.completed(); + } + } + + @Override + public void setOnCancelHandler(Runnable handler) { + onCancelHandler = handler; + } + } + + public StreamObserver<ArrowMessage> doExchangeCustom(StreamObserver<ArrowMessage> responseObserverSimple) { + final ServerCallStreamObserver<ArrowMessage> responseObserver = + (ServerCallStreamObserver<ArrowMessage>) responseObserverSimple; + final ExchangeListener listener = new ExchangeListener( + responseObserver, + this::handleExceptionWithMiddleware); + final FlightStream fs = new FlightStream( + allocator, + PENDING_REQUESTS, + /* server-upload streams are not cancellable */null, + responseObserver::request); + // When service completes the call, this cleans up the FlightStream + listener.resource = fs; + responseObserver.disableAutoInboundFlowControl(); + responseObserver.request(1); + final StreamObserver<ArrowMessage> observer = fs.asObserver(); + try { + executors.submit(() -> { + try { + producer.doExchange(makeContext(responseObserver), fs, listener); + } catch (Exception ex) { + listener.error(ex); + } + // We do not clean up or close anything here, to allow long-running asynchronous implementations. + // It is the service's responsibility to call completed() or error(), which will then clean up the FlightStream. + }); + } catch (Exception ex) { + listener.error(ex); + } + return observer; + } + + /** + * Call context for the service. + */ + static class CallContext implements FlightProducer.CallContext { + + private final String peerIdentity; + private final BooleanSupplier isCancelled; + + CallContext(final String peerIdentity, BooleanSupplier isCancelled) { + this.peerIdentity = peerIdentity; + this.isCancelled = isCancelled; + } + + @Override + public String peerIdentity() { + return peerIdentity; + } + + @Override + public boolean isCancelled() { + return this.isCancelled.getAsBoolean(); + } + + @Override + public <T extends FlightServerMiddleware> T getMiddleware(Key<T> key) { + final Map<Key<?>, FlightServerMiddleware> middleware = ServerInterceptorAdapter.SERVER_MIDDLEWARE_KEY.get(); + if (middleware == null) { + return null; + } + final FlightServerMiddleware m = middleware.get(key); + if (m == null) { + return null; + } + @SuppressWarnings("unchecked") final T result = (T) m; + return result; + } + + @Override + public Map<Key<?>, FlightServerMiddleware> getMiddleware() { + final Map<Key<?>, FlightServerMiddleware> middleware = ServerInterceptorAdapter.SERVER_MIDDLEWARE_KEY.get(); + if (middleware == null) { + return Collections.emptyMap(); + } + // This is an unmodifiable map + return middleware; + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStatusCode.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStatusCode.java new file mode 100644 index 000000000..3d96877ba --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStatusCode.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +/** + * A status code describing the result of a Flight call. + */ +public enum FlightStatusCode { + /** + * The call completed successfully. Generally clients will not see this, but middleware may. + */ + OK, + /** + * An unknown error occurred. This may also be the result of an implementation error on the server-side; by default, + * unhandled server exceptions result in this code. + */ + UNKNOWN, + /** + * An internal/implementation error occurred. + */ + INTERNAL, + /** + * One or more of the given arguments was invalid. + */ + INVALID_ARGUMENT, + /** + * The operation timed out. + */ + TIMED_OUT, + /** + * The operation describes a resource that does not exist. + */ + NOT_FOUND, + /** + * The operation creates a resource that already exists. + */ + ALREADY_EXISTS, + /** + * The operation was cancelled. + */ + CANCELLED, + /** + * The client was not authenticated. + */ + UNAUTHENTICATED, + /** + * The client did not have permission to make the call. + */ + UNAUTHORIZED, + /** + * The requested operation is not implemented. + */ + UNIMPLEMENTED, + /** + * The server cannot currently handle the request. This should be used for retriable requests, i.e. the server + * should send this code only if it has not done any work. + */ + UNAVAILABLE, + ; + + /** + * Create a blank {@link CallStatus} with this code. + */ + public CallStatus toStatus() { + return new CallStatus(this); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java new file mode 100644 index 000000000..03ce13c97 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java @@ -0,0 +1,505 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.LinkedBlockingQueue; + +import org.apache.arrow.flight.ArrowMessage.HeaderType; +import org.apache.arrow.flight.grpc.StatusUtils; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.util.VisibleForTesting; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.MetadataVersion; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.DictionaryUtility; +import org.apache.arrow.vector.validate.MetadataV4UnionChecker; + +import com.google.common.util.concurrent.SettableFuture; + +import io.grpc.stub.StreamObserver; + +/** + * An adaptor between protobuf streams and flight data streams. + */ +public class FlightStream implements AutoCloseable { + // Use AutoCloseable sentinel objects to simplify logic in #close + private final AutoCloseable DONE = () -> { + }; + private final AutoCloseable DONE_EX = () -> { + }; + + private final BufferAllocator allocator; + private final Cancellable cancellable; + private final LinkedBlockingQueue<AutoCloseable> queue = new LinkedBlockingQueue<>(); + private final SettableFuture<VectorSchemaRoot> root = SettableFuture.create(); + private final SettableFuture<FlightDescriptor> descriptor = SettableFuture.create(); + private final int pendingTarget; + private final Requestor requestor; + // The completion flags. + // This flag is only updated as the user iterates through the data, i.e. it tracks whether the user has read all the + // data and closed the stream + final CompletableFuture<Void> completed; + // This flag is immediately updated when gRPC signals that the server has ended the call. This is used to make sure + // we don't block forever trying to write to a server that has rejected a call. + final CompletableFuture<Void> cancelled; + + private volatile int pending = 1; + private volatile VectorSchemaRoot fulfilledRoot; + private DictionaryProvider.MapDictionaryProvider dictionaries; + private volatile VectorLoader loader; + private volatile Throwable ex; + private volatile ArrowBuf applicationMetadata = null; + @VisibleForTesting + volatile MetadataVersion metadataVersion = null; + + /** + * Constructs a new instance. + * + * @param allocator The allocator to use for creating/reallocating buffers for Vectors. + * @param pendingTarget Target number of messages to receive. + * @param cancellable Used to cancel mid-stream requests. + * @param requestor A callback to determine how many pending items there are. + */ + public FlightStream(BufferAllocator allocator, int pendingTarget, Cancellable cancellable, Requestor requestor) { + Objects.requireNonNull(allocator); + Objects.requireNonNull(requestor); + this.allocator = allocator; + this.pendingTarget = pendingTarget; + this.cancellable = cancellable; + this.requestor = requestor; + this.dictionaries = new DictionaryProvider.MapDictionaryProvider(); + this.completed = new CompletableFuture<>(); + this.cancelled = new CompletableFuture<>(); + } + + /** + * Get the schema for this stream. Blocks until the schema is available. + */ + public Schema getSchema() { + return getRoot().getSchema(); + } + + /** + * Get the provider for dictionaries in this stream. + * + * <p>Does NOT retain a reference to the underlying dictionaries. Dictionaries may be updated as the stream is read. + * This method is intended for stream processing, where the application code will not retain references to values + * after the stream is closed. + * + * @throws IllegalStateException if {@link #takeDictionaryOwnership()} was called + * @see #takeDictionaryOwnership() + */ + public DictionaryProvider getDictionaryProvider() { + if (dictionaries == null) { + throw new IllegalStateException("Dictionary ownership was claimed by the application."); + } + return dictionaries; + } + + /** + * Get an owned reference to the dictionaries in this stream. Should be called after finishing reading the stream, + * but before closing. + * + * <p>If called, the client is responsible for closing the dictionaries in this provider. Can only be called once. + * + * @return The dictionary provider for the stream. + * @throws IllegalStateException if called more than once. + */ + public DictionaryProvider takeDictionaryOwnership() { + if (dictionaries == null) { + throw new IllegalStateException("Dictionary ownership was claimed by the application."); + } + // Swap out the provider so it is not closed + final DictionaryProvider provider = dictionaries; + dictionaries = null; + return provider; + } + + /** + * Get the descriptor for this stream. Only applicable on the server side of a DoPut operation. Will block until the + * client sends the descriptor. + */ + public FlightDescriptor getDescriptor() { + // This blocks until the first message from the client is received. + try { + return descriptor.get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw CallStatus.INTERNAL.withCause(e).withDescription("Interrupted").toRuntimeException(); + } catch (ExecutionException e) { + throw CallStatus.INTERNAL.withCause(e).withDescription("Error getting descriptor").toRuntimeException(); + } + } + + /** + * Closes the stream (freeing any existing resources). + * + * <p>If the stream isn't complete and is cancellable, this method will cancel and drain the stream first. + */ + public void close() throws Exception { + final List<AutoCloseable> closeables = new ArrayList<>(); + Throwable suppressor = null; + if (cancellable != null) { + // Client-side stream. Cancel the call, to help ensure gRPC doesn't deliver a message after close() ends. + // On the server side, we can't rely on draining the stream , because this gRPC bug means the completion callback + // may never run https://github.com/grpc/grpc-java/issues/5882 + try { + synchronized (cancellable) { + if (!cancelled.isDone()) { + // Only cancel if the call is not done on the gRPC side + cancellable.cancel("Stream closed before end", /* no exception to report */null); + } + } + // Drain the stream without the lock (as next() implicitly needs the lock) + while (next()) { } + } catch (FlightRuntimeException e) { + suppressor = e; + } + } + // Perform these operations under a lock. This way the observer can't enqueue new messages while we're in the + // middle of cleanup. This should only be a concern for server-side streams since client-side streams are drained + // by the lambda above. + synchronized (completed) { + try { + if (fulfilledRoot != null) { + closeables.add(fulfilledRoot); + } + closeables.add(applicationMetadata); + closeables.addAll(queue); + if (dictionaries != null) { + dictionaries.getDictionaryIds().forEach(id -> closeables.add(dictionaries.lookup(id).getVector())); + } + if (suppressor != null) { + AutoCloseables.close(suppressor, closeables); + } else { + AutoCloseables.close(closeables); + } + } finally { + // The value of this CompletableFuture is meaningless, only whether it's completed (or has an exception) + // No-op if already complete + completed.complete(null); + } + } + } + + /** + * Blocking request to load next item into list. + * @return Whether or not more data was found. + */ + public boolean next() { + try { + if (completed.isDone() && queue.isEmpty()) { + return false; + } + + pending--; + requestOutstanding(); + + Object data = queue.take(); + if (DONE == data) { + queue.put(DONE); + // Other code ignores the value of this CompletableFuture, only whether it's completed (or has an exception) + completed.complete(null); + return false; + } else if (DONE_EX == data) { + queue.put(DONE_EX); + if (ex instanceof Exception) { + throw (Exception) ex; + } else { + throw new Exception(ex); + } + } else { + try (ArrowMessage msg = ((ArrowMessage) data)) { + if (msg.getMessageType() == HeaderType.NONE) { + updateMetadata(msg); + // We received a message without data, so erase any leftover data + if (fulfilledRoot != null) { + fulfilledRoot.clear(); + } + } else if (msg.getMessageType() == HeaderType.RECORD_BATCH) { + checkMetadataVersion(msg); + // Ensure we have the root + root.get().clear(); + try (ArrowRecordBatch arb = msg.asRecordBatch()) { + loader.load(arb); + } + updateMetadata(msg); + } else if (msg.getMessageType() == HeaderType.DICTIONARY_BATCH) { + checkMetadataVersion(msg); + // Ensure we have the root + root.get().clear(); + try (ArrowDictionaryBatch arb = msg.asDictionaryBatch()) { + final long id = arb.getDictionaryId(); + if (dictionaries == null) { + throw new IllegalStateException("Dictionary ownership was claimed by the application."); + } + final Dictionary dictionary = dictionaries.lookup(id); + if (dictionary == null) { + throw new IllegalArgumentException("Dictionary not defined in schema: ID " + id); + } + + final FieldVector vector = dictionary.getVector(); + final VectorSchemaRoot dictionaryRoot = new VectorSchemaRoot(Collections.singletonList(vector.getField()), + Collections.singletonList(vector), 0); + final VectorLoader dictionaryLoader = new VectorLoader(dictionaryRoot); + dictionaryLoader.load(arb.getDictionary()); + } + return next(); + } else { + throw new UnsupportedOperationException("Message type is unsupported: " + msg.getMessageType()); + } + return true; + } + } + } catch (RuntimeException e) { + throw e; + } catch (ExecutionException e) { + throw StatusUtils.fromThrowable(e.getCause()); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** Update our metadata reference with a new one from this message. */ + private void updateMetadata(ArrowMessage msg) { + if (this.applicationMetadata != null) { + this.applicationMetadata.close(); + } + this.applicationMetadata = msg.getApplicationMetadata(); + if (this.applicationMetadata != null) { + this.applicationMetadata.getReferenceManager().retain(); + } + } + + /** Ensure the Arrow metadata version doesn't change mid-stream. */ + private void checkMetadataVersion(ArrowMessage msg) { + if (msg.asSchemaMessage() == null) { + return; + } + MetadataVersion receivedVersion = MetadataVersion.fromFlatbufID(msg.asSchemaMessage().getMessage().version()); + if (this.metadataVersion != receivedVersion) { + throw new IllegalStateException("Metadata version mismatch: stream started as " + + this.metadataVersion + " but got message with version " + receivedVersion); + } + } + + /** + * Get the current vector data from the stream. + * + * <p>The data in the root may change at any time. Clients should NOT modify the root, but instead unload the data + * into their own root. + * + * @throws FlightRuntimeException if there was an error reading the schema from the stream. + */ + public VectorSchemaRoot getRoot() { + try { + return root.get(); + } catch (InterruptedException e) { + throw CallStatus.INTERNAL.withCause(e).toRuntimeException(); + } catch (ExecutionException e) { + throw StatusUtils.fromThrowable(e.getCause()); + } + } + + /** + * Check if there is a root (i.e. whether the other end has started sending data). + * + * Updated by calls to {@link #next()}. + * + * @return true if and only if the other end has started sending data. + */ + public boolean hasRoot() { + return root.isDone(); + } + + /** + * Get the most recent metadata sent from the server. This may be cleared by calls to {@link #next()} if the server + * sends a message without metadata. This does NOT take ownership of the buffer - call retain() to create a reference + * if you need the buffer after a call to {@link #next()}. + * + * @return the application metadata. May be null. + */ + public ArrowBuf getLatestMetadata() { + return applicationMetadata; + } + + private synchronized void requestOutstanding() { + if (pending < pendingTarget) { + requestor.request(pendingTarget - pending); + pending = pendingTarget; + } + } + + private class Observer implements StreamObserver<ArrowMessage> { + + Observer() { + super(); + } + + /** Helper to add an item to the queue under the appropriate lock. */ + private void enqueue(AutoCloseable message) { + synchronized (completed) { + if (completed.isDone()) { + // The stream is already closed (RPC ended), discard the message + AutoCloseables.closeNoChecked(message); + } else { + queue.add(message); + } + } + } + + @Override + public void onNext(ArrowMessage msg) { + // Operations here have to be under a lock so that we don't add a message to the queue while in the middle of + // close(). + requestOutstanding(); + switch (msg.getMessageType()) { + case NONE: { + // No IPC message - pure metadata or descriptor + if (msg.getDescriptor() != null) { + descriptor.set(new FlightDescriptor(msg.getDescriptor())); + } + if (msg.getApplicationMetadata() != null) { + enqueue(msg); + } + break; + } + case SCHEMA: { + Schema schema = msg.asSchema(); + + // if there is app metadata in the schema message, make sure + // that we don't leak it. + ArrowBuf meta = msg.getApplicationMetadata(); + if (meta != null) { + meta.close(); + } + + final List<Field> fields = new ArrayList<>(); + final Map<Long, Dictionary> dictionaryMap = new HashMap<>(); + for (final Field originalField : schema.getFields()) { + final Field updatedField = DictionaryUtility.toMemoryFormat(originalField, allocator, dictionaryMap); + fields.add(updatedField); + } + for (final Map.Entry<Long, Dictionary> entry : dictionaryMap.entrySet()) { + dictionaries.put(entry.getValue()); + } + schema = new Schema(fields, schema.getCustomMetadata()); + metadataVersion = MetadataVersion.fromFlatbufID(msg.asSchemaMessage().getMessage().version()); + try { + MetadataV4UnionChecker.checkRead(schema, metadataVersion); + } catch (IOException e) { + ex = e; + enqueue(DONE_EX); + break; + } + + synchronized (completed) { + if (!completed.isDone()) { + fulfilledRoot = VectorSchemaRoot.create(schema, allocator); + loader = new VectorLoader(fulfilledRoot); + if (msg.getDescriptor() != null) { + descriptor.set(new FlightDescriptor(msg.getDescriptor())); + } + root.set(fulfilledRoot); + } + } + break; + } + case RECORD_BATCH: + case DICTIONARY_BATCH: + enqueue(msg); + break; + case TENSOR: + default: + ex = new UnsupportedOperationException("Unable to handle message of type: " + msg.getMessageType()); + enqueue(DONE_EX); + } + } + + @Override + public void onError(Throwable t) { + ex = StatusUtils.fromThrowable(t); + queue.add(DONE_EX); + cancelled.complete(null); + root.setException(ex); + } + + @Override + public void onCompleted() { + // Depends on gRPC calling onNext and onCompleted non-concurrently + cancelled.complete(null); + queue.add(DONE); + } + } + + /** + * Cancels sending the stream to a client. + * + * <p>Callers should drain the stream (with {@link #next()}) to ensure all messages sent before cancellation are + * received and to wait for the underlying transport to acknowledge cancellation. + */ + public void cancel(String message, Throwable exception) { + if (cancellable == null) { + throw new UnsupportedOperationException("Streams cannot be cancelled that are produced by client. " + + "Instead, server should reject incoming messages."); + } + cancellable.cancel(message, exception); + // Do not mark the stream as completed, as gRPC may still be delivering messages. + } + + StreamObserver<ArrowMessage> asObserver() { + return new Observer(); + } + + /** + * Provides a callback to cancel a process that is in progress. + */ + @FunctionalInterface + public interface Cancellable { + void cancel(String message, Throwable exception); + } + + /** + * Provides a interface to request more items from a stream producer. + */ + @FunctionalInterface + public interface Requestor { + /** + * Requests <code>count</code> more messages from the instance of this object. + */ + void request(int count); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/HeaderCallOption.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/HeaderCallOption.java new file mode 100644 index 000000000..e2fad1a40 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/HeaderCallOption.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import io.grpc.Metadata; +import io.grpc.stub.AbstractStub; +import io.grpc.stub.MetadataUtils; + +/** + * Method option for supplying headers to method calls. + */ +public class HeaderCallOption implements CallOptions.GrpcCallOption { + private final Metadata propertiesMetadata = new Metadata(); + + /** + * Header property constructor. + * + * @param headers the headers that should be sent across. If a header is a string, it should only be valid ASCII + * characters. Binary headers should end in "-bin". + */ + public HeaderCallOption(CallHeaders headers) { + for (String key : headers.keys()) { + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + final Metadata.Key<byte[]> metaKey = Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER); + headers.getAllByte(key).forEach(v -> propertiesMetadata.put(metaKey, v)); + } else { + final Metadata.Key<String> metaKey = Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER); + headers.getAll(key).forEach(v -> propertiesMetadata.put(metaKey, v)); + } + } + } + + @Override + public <T extends AbstractStub<T>> T wrapStub(T stub) { + return MetadataUtils.attachHeaders(stub, propertiesMetadata); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java new file mode 100644 index 000000000..1fbec7b5a --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Location.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.lang.reflect.InvocationTargetException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Objects; + +import org.apache.arrow.flight.impl.Flight; + +/** A URI where a Flight stream is available. */ +public class Location { + private final URI uri; + + /** + * Constructs a new instance. + * + * @param uri the URI of the Flight service + * @throws IllegalArgumentException if the URI scheme is unsupported + */ + public Location(String uri) throws URISyntaxException { + this(new URI(uri)); + } + + /** + * Construct a new instance from an existing URI. + * + * @param uri the URI of the Flight service + */ + public Location(URI uri) { + super(); + Objects.requireNonNull(uri); + this.uri = uri; + } + + public URI getUri() { + return uri; + } + + /** + * Helper method to turn this Location into a SocketAddress. + * + * @return null if could not be converted + */ + SocketAddress toSocketAddress() { + switch (uri.getScheme()) { + case LocationSchemes.GRPC: + case LocationSchemes.GRPC_TLS: + case LocationSchemes.GRPC_INSECURE: { + return new InetSocketAddress(uri.getHost(), uri.getPort()); + } + + case LocationSchemes.GRPC_DOMAIN_SOCKET: { + try { + // This dependency is not available on non-Unix platforms. + return (SocketAddress) Class.forName("io.netty.channel.unix.DomainSocketAddress") + .getConstructor(String.class) + .newInstance(uri.getPath()); + } catch (InstantiationException | ClassNotFoundException | InvocationTargetException | + NoSuchMethodException | IllegalAccessException e) { + return null; + } + } + + default: { + return null; + } + } + } + + /** + * Convert this Location into its protocol-level representation. + */ + Flight.Location toProtocol() { + return Flight.Location.newBuilder().setUri(uri.toString()).build(); + } + + /** + * Construct a URI for a Flight+gRPC server without transport security. + * + * @throws IllegalArgumentException if the constructed URI is invalid. + */ + public static Location forGrpcInsecure(String host, int port) { + try { + return new Location(new URI(LocationSchemes.GRPC_INSECURE, null, host, port, null, null, null)); + } catch (URISyntaxException e) { + throw new IllegalArgumentException(e); + } + } + + /** + * Construct a URI for a Flight+gRPC server with transport security. + * + * @throws IllegalArgumentException if the constructed URI is invalid. + */ + public static Location forGrpcTls(String host, int port) { + try { + return new Location(new URI(LocationSchemes.GRPC_TLS, null, host, port, null, null, null)); + } catch (URISyntaxException e) { + throw new IllegalArgumentException(e); + } + } + + /** + * Construct a URI for a Flight+gRPC server over a Unix domain socket. + * + * @throws IllegalArgumentException if the constructed URI is invalid. + */ + public static Location forGrpcDomainSocket(String path) { + try { + return new Location(new URI(LocationSchemes.GRPC_DOMAIN_SOCKET, null, path, null)); + } catch (URISyntaxException e) { + throw new IllegalArgumentException(e); + } + } + + @Override + public String toString() { + return "Location{" + + "uri=" + uri + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Location location = (Location) o; + return uri.equals(location.uri); + } + + @Override + public int hashCode() { + return Objects.hash(uri); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/LocationSchemes.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/LocationSchemes.java new file mode 100644 index 000000000..872e5b1c2 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/LocationSchemes.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +/** + * Constants representing well-known URI schemes for Flight services. + */ +public final class LocationSchemes { + public static final String GRPC = "grpc"; + public static final String GRPC_INSECURE = "grpc+tcp"; + public static final String GRPC_DOMAIN_SOCKET = "grpc+unix"; + public static final String GRPC_TLS = "grpc+tls"; + + private LocationSchemes() { + throw new AssertionError("Do not instantiate this class."); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java new file mode 100644 index 000000000..d1432f514 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/NoOpFlightProducer.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +/** + * A {@link FlightProducer} that throws on all operations. + */ +public class NoOpFlightProducer implements FlightProducer { + + @Override + public void getStream(CallContext context, Ticket ticket, + ServerStreamListener listener) { + listener.error(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException()); + } + + @Override + public void listFlights(CallContext context, Criteria criteria, + StreamListener<FlightInfo> listener) { + listener.onError(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException()); + } + + @Override + public FlightInfo getFlightInfo(CallContext context, + FlightDescriptor descriptor) { + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } + + @Override + public Runnable acceptPut(CallContext context, + FlightStream flightStream, StreamListener<PutResult> ackStream) { + throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException(); + } + + @Override + public void doAction(CallContext context, Action action, + StreamListener<Result> listener) { + listener.onError(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException()); + } + + @Override + public void listActions(CallContext context, + StreamListener<ActionType> listener) { + listener.onError(CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException()); + } + +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/NoOpStreamListener.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/NoOpStreamListener.java new file mode 100644 index 000000000..e06af1a10 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/NoOpStreamListener.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import org.apache.arrow.flight.FlightProducer.StreamListener; + +/** + * A {@link StreamListener} that does nothing for all callbacks. + * @param <T> The type of the callback object. + */ +public class NoOpStreamListener<T> implements StreamListener<T> { + private static NoOpStreamListener INSTANCE = new NoOpStreamListener(); + + /** Ignores the value received. */ + @Override + public void onNext(T val) { + } + + /** Ignores the error received. */ + @Override + public void onError(Throwable t) { + } + + /** Ignores the stream completion event. */ + @Override + public void onCompleted() { + } + + @SuppressWarnings("unchecked") + public static <T> StreamListener<T> getInstance() { + // Safe because we never use T + return (StreamListener<T>) INSTANCE; + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListener.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListener.java new file mode 100644 index 000000000..38a44d0e5 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListener.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.message.IpcOption; + +/** + * An interface for writing data to a peer, client or server. + */ +public interface OutboundStreamListener { + + /** + * A hint indicating whether the client is ready to receive data without excessive buffering. + * + * <p>Writers should poll this flag before sending data to respect backpressure from the client and + * avoid sending data faster than the client can handle. Ignoring this flag may mean that the server + * will start consuming excessive amounts of memory, as it may buffer messages in memory. + */ + boolean isReady(); + + /** + * Set a callback for when the listener is ready for new calls to putNext(), i.e. {@link #isReady()} + * has become true. + * + * <p>Note that this callback may only be called some time after {@link #isReady()} becomes true, and may never + * be called if all executor threads on the server are busy, or the RPC method body is implemented in a blocking + * fashion. Note that isReady() must still be checked after the callback is run as it may have been run + * spuriously. + */ + default void setOnReadyHandler(Runnable handler) { + throw new UnsupportedOperationException("Not yet implemented."); + } + + /** + * Start sending data, using the schema of the given {@link VectorSchemaRoot}. + * + * <p>This method must be called before all others, except {@link #putMetadata(ArrowBuf)}. + */ + default void start(VectorSchemaRoot root) { + start(root, null, IpcOption.DEFAULT); + } + + /** + * Start sending data, using the schema of the given {@link VectorSchemaRoot}. + * + * <p>This method must be called before all others, except {@link #putMetadata(ArrowBuf)}. + */ + default void start(VectorSchemaRoot root, DictionaryProvider dictionaries) { + start(root, dictionaries, IpcOption.DEFAULT); + } + + /** + * Start sending data, using the schema of the given {@link VectorSchemaRoot}. + * + * <p>This method must be called before all others, except {@link #putMetadata(ArrowBuf)}. + */ + void start(VectorSchemaRoot root, DictionaryProvider dictionaries, IpcOption option); + + /** + * Send the current contents of the associated {@link VectorSchemaRoot}. + * + * <p>This will not necessarily block until the message is actually sent; it may buffer messages + * in memory. Use {@link #isReady()} to check if there is backpressure and avoid excessive buffering. + */ + void putNext(); + + /** + * Send the current contents of the associated {@link VectorSchemaRoot} alongside application-defined metadata. + * @param metadata The metadata to send. Ownership of the buffer is transferred to the Flight implementation. + */ + void putNext(ArrowBuf metadata); + + /** + * Send a pure metadata message without any associated data. + * + * <p>This may be called without starting the stream. + */ + void putMetadata(ArrowBuf metadata); + + /** + * Indicate an error to the client. Terminates the stream; do not call {@link #completed()} afterwards. + */ + void error(Throwable ex); + + /** + * Indicate that transmission is finished. + */ + void completed(); + + /** + * Toggle whether to ues the zero-copy write optimization. + * + * <p>By default or when disabled, Arrow may copy data into a buffer for the underlying implementation to + * send. When enabled, Arrow will instead try to directly enqueue the Arrow buffer for sending. Not all + * implementations support this optimization, so even if enabled, you may not see a difference. + * + * <p>In this mode, buffers must not be reused after they are written with {@link #putNext()}. For example, + * you would have to call {@link VectorSchemaRoot#allocateNew()} after every call to {@link #putNext()}. + * Hence, this is not enabled by default. + * + * <p>The default value can be toggled globally by setting the JVM property arrow.flight.enable_zero_copy_write + * or the environment variable ARROW_FLIGHT_ENABLE_ZERO_COPY_WRITE. + */ + default void setUseZeroCopy(boolean enabled) {} +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListenerImpl.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListenerImpl.java new file mode 100644 index 000000000..8c1cfde3a --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/OutboundStreamListenerImpl.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import org.apache.arrow.flight.grpc.StatusUtils; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.message.IpcOption; + +import io.grpc.stub.CallStreamObserver; + +/** + * A base class for writing Arrow data to a Flight stream. + */ +abstract class OutboundStreamListenerImpl implements OutboundStreamListener { + private final FlightDescriptor descriptor; // nullable + protected final CallStreamObserver<ArrowMessage> responseObserver; + protected volatile VectorUnloader unloader; // null until stream started + protected IpcOption option; // null until stream started + protected boolean tryZeroCopy = ArrowMessage.ENABLE_ZERO_COPY_WRITE; + + OutboundStreamListenerImpl(FlightDescriptor descriptor, CallStreamObserver<ArrowMessage> responseObserver) { + Preconditions.checkNotNull(responseObserver, "responseObserver must be provided"); + this.descriptor = descriptor; + this.responseObserver = responseObserver; + this.unloader = null; + } + + @Override + public boolean isReady() { + return responseObserver.isReady(); + } + + @Override + public void setOnReadyHandler(Runnable handler) { + responseObserver.setOnReadyHandler(handler); + } + + @Override + public void start(VectorSchemaRoot root, DictionaryProvider dictionaries, IpcOption option) { + this.option = option; + try { + DictionaryUtils.generateSchemaMessages(root.getSchema(), descriptor, dictionaries, option, + responseObserver::onNext); + } catch (RuntimeException e) { + // Propagate runtime exceptions, like those raised when trying to write unions with V4 metadata + throw e; + } catch (Exception e) { + // Only happens if closing buffers somehow fails - indicates application is an unknown state so propagate + // the exception + throw new RuntimeException("Could not generate and send all schema messages", e); + } + // We include the null count and align buffers to be compatible with Flight/C++ + unloader = new VectorUnloader(root, /* includeNullCount */ true, /* alignBuffers */ true); + } + + @Override + public void putNext() { + putNext(null); + } + + /** + * Busy-wait until the stream is ready. + * + * <p>This is overridable as client/server have different behavior. + */ + protected abstract void waitUntilStreamReady(); + + @Override + public void putNext(ArrowBuf metadata) { + if (unloader == null) { + throw CallStatus.INTERNAL.withDescription("Stream was not started, call start()").toRuntimeException(); + } + + waitUntilStreamReady(); + // close is a no-op if the message has been written to gRPC, otherwise frees the associated buffers + // in some code paths (e.g. if the call is cancelled), gRPC does not write the message, so we need to clean up + // ourselves. Normally, writing the ArrowMessage will transfer ownership of the data to gRPC/Netty. + try (final ArrowMessage message = new ArrowMessage(unloader.getRecordBatch(), metadata, tryZeroCopy, option)) { + responseObserver.onNext(message); + } catch (Exception e) { + // This exception comes from ArrowMessage#close, not responseObserver#onNext. + // Generally this should not happen - ArrowMessage's implementation only closes non-throwing things. + // The user can't reasonably do anything about this, but if something does throw, we shouldn't let + // execution continue since other state (e.g. allocators) may be in an odd state. + throw new RuntimeException("Could not free ArrowMessage", e); + } + } + + @Override + public void putMetadata(ArrowBuf metadata) { + waitUntilStreamReady(); + try (final ArrowMessage message = new ArrowMessage(metadata)) { + responseObserver.onNext(message); + } catch (Exception e) { + throw StatusUtils.fromThrowable(e); + } + } + + @Override + public void error(Throwable ex) { + responseObserver.onError(StatusUtils.toGrpcException(ex)); + } + + @Override + public void completed() { + responseObserver.onCompleted(); + } + + @Override + public void setUseZeroCopy(boolean enabled) { + tryZeroCopy = enabled; + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/PutResult.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/PutResult.java new file mode 100644 index 000000000..862401312 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/PutResult.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import org.apache.arrow.flight.impl.Flight; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.ReferenceManager; + +import com.google.protobuf.ByteString; + +/** + * A message from the server during a DoPut operation. + * + * <p>This object owns an {@link ArrowBuf} and should be closed when you are done with it. + */ +public class PutResult implements AutoCloseable { + + private ArrowBuf applicationMetadata; + + private PutResult(ArrowBuf metadata) { + applicationMetadata = metadata; + } + + /** + * Create a PutResult with application-specific metadata. + * + * <p>This method assumes ownership of the {@link ArrowBuf}. + */ + public static PutResult metadata(ArrowBuf metadata) { + if (metadata == null) { + return empty(); + } + return new PutResult(metadata); + } + + /** Create an empty PutResult. */ + public static PutResult empty() { + return new PutResult(null); + } + + /** + * Get the metadata in this message. May be null. + * + * <p>Ownership of the {@link ArrowBuf} is retained by this object. Call {@link ReferenceManager#retain()} to preserve + * a reference. + */ + public ArrowBuf getApplicationMetadata() { + return applicationMetadata; + } + + Flight.PutResult toProtocol() { + if (applicationMetadata == null) { + return Flight.PutResult.getDefaultInstance(); + } + return Flight.PutResult.newBuilder().setAppMetadata(ByteString.copyFrom(applicationMetadata.nioBuffer())).build(); + } + + /** + * Construct a PutResult from a Protobuf message. + * + * @param allocator The allocator to use for allocating application metadata memory. The result object owns the + * allocated buffer, if any. + * @param message The gRPC/Protobuf message. + */ + static PutResult fromProtocol(BufferAllocator allocator, Flight.PutResult message) { + final ArrowBuf buf = allocator.buffer(message.getAppMetadata().size()); + message.getAppMetadata().asReadOnlyByteBufferList().forEach(bb -> { + buf.setBytes(buf.writerIndex(), bb); + buf.writerIndex(buf.writerIndex() + bb.limit()); + }); + return new PutResult(buf); + } + + @Override + public void close() { + if (applicationMetadata != null) { + applicationMetadata.close(); + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/RequestContext.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/RequestContext.java new file mode 100644 index 000000000..5117d05c2 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/RequestContext.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.Set; + +/** + * Tracks variables about the current request. + */ +public interface RequestContext { + /** + * Register a variable and a value. + * @param key the variable name. + * @param value the value. + */ + void put(String key, String value); + + /** + * Retrieve a registered variable. + * @param key the variable name. + * @return the value, or null if not found. + */ + String get(String key); + + /** + * Retrieves the keys that have been registered to this context. + * @return the keys used in this context. + */ + Set<String> keySet(); + + /** + * Deletes a registered variable. + * @return the value associated with the deleted variable, or null if the key doesn't exist. + */ + String remove(String key); +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Result.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Result.java new file mode 100644 index 000000000..5d6ce485d --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Result.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import org.apache.arrow.flight.impl.Flight; + +import com.google.protobuf.ByteString; + +/** + * Opaque result returned after executing an action. + * + * <p>POJO wrapper around the Flight protocol buffer message sharing the same name. + */ +public class Result { + + private final byte[] body; + + public Result(byte[] body) { + this.body = body; + } + + Result(Flight.Result result) { + this.body = result.getBody().toByteArray(); + } + + public byte[] getBody() { + return body; + } + + Flight.Result toProtocol() { + return Flight.Result.newBuilder() + .setBody(ByteString.copyFrom(body)) + .build(); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java new file mode 100644 index 000000000..8a5e7d9a4 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SchemaResult.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; + +import org.apache.arrow.flight.impl.Flight; +import org.apache.arrow.vector.ipc.ReadChannel; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.validate.MetadataV4UnionChecker; + +import com.fasterxml.jackson.databind.util.ByteBufferBackedInputStream; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.ByteString; + +/** + * Opaque result returned after executing a getSchema request. + * + * <p>POJO wrapper around the Flight protocol buffer message sharing the same name. + */ +public class SchemaResult { + + private final Schema schema; + private final IpcOption option; + + public SchemaResult(Schema schema) { + this(schema, IpcOption.DEFAULT); + } + + /** + * Create a schema result with specific IPC options for serialization. + */ + public SchemaResult(Schema schema, IpcOption option) { + MetadataV4UnionChecker.checkForUnion(schema.getFields().iterator(), option.metadataVersion); + this.schema = schema; + this.option = option; + } + + public Schema getSchema() { + return schema; + } + + /** + * Converts to the protocol buffer representation. + */ + Flight.SchemaResult toProtocol() { + // Encode schema in a Message payload + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(baos)), schema, option); + } catch (IOException e) { + throw new RuntimeException(e); + } + return Flight.SchemaResult.newBuilder() + .setSchema(ByteString.copyFrom(baos.toByteArray())) + .build(); + + } + + /** + * Converts from the protocol buffer representation. + */ + static SchemaResult fromProtocol(Flight.SchemaResult pbSchemaResult) { + try { + final ByteBuffer schemaBuf = pbSchemaResult.getSchema().asReadOnlyByteBuffer(); + Schema schema = pbSchemaResult.getSchema().size() > 0 ? + MessageSerializer.deserializeSchema( + new ReadChannel(Channels.newChannel(new ByteBufferBackedInputStream(schemaBuf)))) + : new Schema(ImmutableList.of()); + return new SchemaResult(schema); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ServerHeaderMiddleware.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ServerHeaderMiddleware.java new file mode 100644 index 000000000..527c3128c --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ServerHeaderMiddleware.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +/** + * Middleware that's used to extract and pass headers to the server during requests. + */ +public class ServerHeaderMiddleware implements FlightServerMiddleware { + /** + * Factory for accessing ServerHeaderMiddleware. + */ + public static class Factory implements FlightServerMiddleware.Factory<ServerHeaderMiddleware> { + /** + * Construct a factory for receiving call headers. + */ + public Factory() { + } + + @Override + public ServerHeaderMiddleware onCallStarted(CallInfo callInfo, CallHeaders incomingHeaders, + RequestContext context) { + return new ServerHeaderMiddleware(incomingHeaders); + } + } + + private final CallHeaders headers; + + private ServerHeaderMiddleware(CallHeaders incomingHeaders) { + this.headers = incomingHeaders; + } + + /** + * Retrieve the headers for this call. + */ + public CallHeaders headers() { + return headers; + } + + @Override + public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { + } + + @Override + public void onCallCompleted(CallStatus status) { + } + + @Override + public void onCallErrored(Throwable err) { + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/StreamPipe.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/StreamPipe.java new file mode 100644 index 000000000..d506914d5 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/StreamPipe.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.function.Consumer; +import java.util.function.Function; + +import org.apache.arrow.flight.FlightProducer.StreamListener; +import org.apache.arrow.flight.grpc.StatusUtils; +import org.apache.arrow.util.AutoCloseables; + +import io.grpc.stub.StreamObserver; + +/** + * Shim listener to avoid exposing GRPC internals. + + * @param <FROM> From Type + * @param <TO> To Type + */ +class StreamPipe<FROM, TO> implements StreamListener<FROM> { + + private final StreamObserver<TO> delegate; + private final Function<FROM, TO> mapFunction; + private final Consumer<Throwable> errorHandler; + private AutoCloseable resource; + private boolean closed = false; + + /** + * Wrap the given gRPC StreamObserver with a transformation function. + * + * @param delegate The {@link StreamObserver} to wrap. + * @param func The transformation function. + * @param errorHandler A handler for uncaught exceptions (e.g. if something tries to double-close this stream). + * @param <FROM> The source type. + * @param <TO> The output type. + * @return A wrapped listener. + */ + public static <FROM, TO> StreamPipe<FROM, TO> wrap(StreamObserver<TO> delegate, Function<FROM, TO> func, + Consumer<Throwable> errorHandler) { + return new StreamPipe<>(delegate, func, errorHandler); + } + + public StreamPipe(StreamObserver<TO> delegate, Function<FROM, TO> func, Consumer<Throwable> errorHandler) { + super(); + this.delegate = delegate; + this.mapFunction = func; + this.errorHandler = errorHandler; + this.resource = null; + } + + /** Set an AutoCloseable resource to be cleaned up when the gRPC observer is to be completed. */ + void setAutoCloseable(AutoCloseable ac) { + resource = ac; + } + + @Override + public void onNext(FROM val) { + delegate.onNext(mapFunction.apply(val)); + } + + @Override + public void onError(Throwable t) { + if (closed) { + errorHandler.accept(t); + return; + } + try { + AutoCloseables.close(resource); + } catch (Exception e) { + errorHandler.accept(e); + } finally { + // Set closed to true in case onError throws, so that we don't try to close again + closed = true; + delegate.onError(StatusUtils.toGrpcException(t)); + } + } + + @Override + public void onCompleted() { + if (closed) { + errorHandler.accept(new IllegalStateException("Tried to complete already-completed call")); + return; + } + try { + AutoCloseables.close(resource); + } catch (Exception e) { + errorHandler.accept(e); + } finally { + // Set closed to true in case onCompleted throws, so that we don't try to close again + closed = true; + delegate.onCompleted(); + } + } + + /** + * Ensure this stream has been completed. + */ + void ensureCompleted() { + if (!closed) { + onCompleted(); + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SyncPutListener.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SyncPutListener.java new file mode 100644 index 000000000..730cf4924 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/SyncPutListener.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import org.apache.arrow.flight.grpc.StatusUtils; +import org.apache.arrow.memory.ArrowBuf; + +/** + * A listener for server-sent application metadata messages during a Flight DoPut. This class wraps the messages in a + * synchronous interface. + */ +public final class SyncPutListener implements FlightClient.PutListener, AutoCloseable { + + private final LinkedBlockingQueue<Object> queue; + private final CompletableFuture<Void> completed; + private static final Object DONE = new Object(); + private static final Object DONE_WITH_EXCEPTION = new Object(); + + public SyncPutListener() { + queue = new LinkedBlockingQueue<>(); + completed = new CompletableFuture<>(); + } + + private PutResult unwrap(Object queueItem) throws InterruptedException, ExecutionException { + if (queueItem == DONE) { + queue.put(queueItem); + return null; + } else if (queueItem == DONE_WITH_EXCEPTION) { + queue.put(queueItem); + completed.get(); + } + return (PutResult) queueItem; + } + + /** + * Get the next message from the server, blocking until it is available. + * + * @return The next message, or null if the server is done sending messages. The caller assumes ownership of the + * metadata and must remember to close it. + * @throws InterruptedException if interrupted while waiting. + * @throws ExecutionException if the server sent an error, or if there was an internal error. + */ + public PutResult read() throws InterruptedException, ExecutionException { + return unwrap(queue.take()); + } + + /** + * Get the next message from the server, blocking for the specified amount of time until it is available. + * + * @return The next message, or null if the server is done sending messages or no message arrived before the timeout. + * The caller assumes ownership of the metadata and must remember to close it. + * @throws InterruptedException if interrupted while waiting. + * @throws ExecutionException if the server sent an error, or if there was an internal error. + */ + public PutResult poll(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException { + return unwrap(queue.poll(timeout, unit)); + } + + @Override + public void getResult() { + try { + completed.get(); + } catch (ExecutionException e) { + throw StatusUtils.fromThrowable(e.getCause()); + } catch (InterruptedException e) { + throw StatusUtils.fromThrowable(e); + } + } + + @Override + public void onNext(PutResult val) { + final ArrowBuf metadata = val.getApplicationMetadata(); + metadata.getReferenceManager().retain(); + queue.add(PutResult.metadata(metadata)); + } + + @Override + public void onError(Throwable t) { + completed.completeExceptionally(StatusUtils.fromThrowable(t)); + queue.add(DONE_WITH_EXCEPTION); + } + + @Override + public void onCompleted() { + completed.complete(null); + queue.add(DONE); + } + + @Override + public void close() { + queue.forEach(o -> { + if (o instanceof PutResult) { + ((PutResult) o).close(); + } + }); + } + + @Override + public boolean isCancelled() { + return completed.isDone(); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Ticket.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Ticket.java new file mode 100644 index 000000000..a93cd0879 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/Ticket.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Arrays; + +import org.apache.arrow.flight.impl.Flight; + +import com.google.protobuf.ByteString; + +/** + * Endpoint for a particular stream. + */ +public class Ticket { + private final byte[] bytes; + + public Ticket(byte[] bytes) { + super(); + this.bytes = bytes; + } + + public byte[] getBytes() { + return bytes; + } + + Ticket(org.apache.arrow.flight.impl.Flight.Ticket ticket) { + this.bytes = ticket.getTicket().toByteArray(); + } + + Flight.Ticket toProtocol() { + return Flight.Ticket.newBuilder() + .setTicket(ByteString.copyFrom(bytes)) + .build(); + } + + /** + * Get the serialized form of this protocol message. + * + * <p>Intended to help interoperability by allowing non-Flight services to still return Flight types. + */ + public ByteBuffer serialize() { + return ByteBuffer.wrap(toProtocol().toByteArray()); + } + + /** + * Parse the serialized form of this protocol message. + * + * <p>Intended to help interoperability by allowing Flight clients to obtain stream info from non-Flight services. + * + * @param serialized The serialized form of the Ticket, as returned by {@link #serialize()}. + * @return The deserialized Ticket. + * @throws IOException if the serialized form is invalid. + */ + public static Ticket deserialize(ByteBuffer serialized) throws IOException { + return new Ticket(Flight.Ticket.parseFrom(serialized)); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + Arrays.hashCode(bytes); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + Ticket other = (Ticket) obj; + if (!Arrays.equals(bytes, other.bytes)) { + return false; + } + return true; + } + + +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/AuthConstants.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/AuthConstants.java new file mode 100644 index 000000000..ac55872e5 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/AuthConstants.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth; + +import org.apache.arrow.flight.FlightConstants; + +import io.grpc.Context; +import io.grpc.Metadata.BinaryMarshaller; +import io.grpc.Metadata.Key; +import io.grpc.MethodDescriptor; + +/** + * Constants used in authorization of flight connections. + */ +public final class AuthConstants { + + public static final String HANDSHAKE_DESCRIPTOR_NAME = MethodDescriptor + .generateFullMethodName(FlightConstants.SERVICE, "Handshake"); + public static final String TOKEN_NAME = "Auth-Token-bin"; + public static final Key<byte[]> TOKEN_KEY = Key.of(TOKEN_NAME, new BinaryMarshaller<byte[]>() { + + @Override + public byte[] toBytes(byte[] value) { + return value; + } + + @Override + public byte[] parseBytes(byte[] serialized) { + return serialized; + } + }); + + public static final Context.Key<String> PEER_IDENTITY_KEY = Context.keyWithDefault("arrow-flight-peer-identity", ""); + + private AuthConstants() {} +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/BasicClientAuthHandler.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/BasicClientAuthHandler.java new file mode 100644 index 000000000..c6dca97fb --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/BasicClientAuthHandler.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth; + +import java.util.Iterator; + +import org.apache.arrow.flight.impl.Flight.BasicAuth; + +/** + * A client auth handler that supports username and password. + */ +public class BasicClientAuthHandler implements ClientAuthHandler { + + private final String name; + private final String password; + private byte[] token = null; + + public BasicClientAuthHandler(String name, String password) { + this.name = name; + this.password = password; + } + + @Override + public void authenticate(ClientAuthSender outgoing, Iterator<byte[]> incoming) { + BasicAuth.Builder builder = BasicAuth.newBuilder(); + if (name != null) { + builder.setUsername(name); + } + + if (password != null) { + builder.setPassword(password); + } + + outgoing.send(builder.build().toByteArray()); + this.token = incoming.next(); + } + + @Override + public byte[] getCallToken() { + return token; + } + +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/BasicServerAuthHandler.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/BasicServerAuthHandler.java new file mode 100644 index 000000000..34e3efc0d --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/BasicServerAuthHandler.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth; + +import java.util.Iterator; +import java.util.Optional; + +import org.apache.arrow.flight.impl.Flight.BasicAuth; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.protobuf.InvalidProtocolBufferException; + +/** + * A ServerAuthHandler for username/password authentication. + */ +public class BasicServerAuthHandler implements ServerAuthHandler { + + private static final Logger logger = LoggerFactory.getLogger(BasicServerAuthHandler.class); + private final BasicAuthValidator authValidator; + + public BasicServerAuthHandler(BasicAuthValidator authValidator) { + super(); + this.authValidator = authValidator; + } + + /** + * Interface that this handler delegates for determining if credentials are valid. + */ + public interface BasicAuthValidator { + + byte[] getToken(String username, String password) throws Exception; + + Optional<String> isValid(byte[] token); + + } + + @Override + public boolean authenticate(ServerAuthSender outgoing, Iterator<byte[]> incoming) { + byte[] bytes = incoming.next(); + try { + BasicAuth auth = BasicAuth.parseFrom(bytes); + byte[] token = authValidator.getToken(auth.getUsername(), auth.getPassword()); + outgoing.send(token); + return true; + } catch (InvalidProtocolBufferException e) { + logger.debug("Failure parsing auth message.", e); + } catch (Exception e) { + logger.debug("Unknown error during authorization.", e); + } + + return false; + } + + @Override + public Optional<String> isValid(byte[] token) { + return authValidator.isValid(token); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthHandler.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthHandler.java new file mode 100644 index 000000000..985e10aa4 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthHandler.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth; + +import java.util.Iterator; + +/** + * Implement authentication for Flight on the client side. + */ +public interface ClientAuthHandler { + /** + * Handle the initial handshake with the server. + * @param outgoing A channel to send data to the server. + * @param incoming An iterator of incoming data from the server. + */ + void authenticate(ClientAuthSender outgoing, Iterator<byte[]> incoming); + + /** + * Get the per-call authentication token. + */ + byte[] getCallToken(); + + /** + * A communication channel to the server during initial connection. + */ + interface ClientAuthSender { + + /** + * Send the server a message. + */ + void send(byte[] payload); + + /** + * Signal an error to the server and abort the authentication attempt. + */ + void onError(Throwable cause); + + } + +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthInterceptor.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthInterceptor.java new file mode 100644 index 000000000..3d28b7ba7 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthInterceptor.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth; + +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; + +/** + * GRPC client intercepter that handles authentication with the server. + */ +public class ClientAuthInterceptor implements ClientInterceptor { + private volatile ClientAuthHandler authHandler = null; + + public void setAuthHandler(ClientAuthHandler authHandler) { + this.authHandler = authHandler; + } + + public ClientAuthInterceptor() { + } + + public boolean hasAuthHandler() { + return authHandler != null; + } + + @Override + public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> methodDescriptor, + CallOptions callOptions, Channel next) { + ClientCall<ReqT, RespT> call = next.newCall(methodDescriptor, callOptions); + + // once we have an auth header, add that to the calls. + if (authHandler != null) { + call = new HeaderAttachingClientCall<>(call); + } + + return call; + } + + private final class HeaderAttachingClientCall<ReqT, RespT> extends SimpleForwardingClientCall<ReqT, RespT> { + + private HeaderAttachingClientCall(ClientCall<ReqT, RespT> call) { + super(call); + } + + @Override + public void start(Listener<RespT> responseListener, Metadata headers) { + final Metadata authHeaders = new Metadata(); + authHeaders.put(AuthConstants.TOKEN_KEY, authHandler.getCallToken()); + headers.merge(authHeaders); + super.start(responseListener, headers); + } + } + +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java new file mode 100644 index 000000000..e86dc163c --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ClientAuthWrapper.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth; + +import java.util.Iterator; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.LinkedBlockingQueue; + +import org.apache.arrow.flight.auth.ClientAuthHandler.ClientAuthSender; +import org.apache.arrow.flight.grpc.StatusUtils; +import org.apache.arrow.flight.impl.Flight.HandshakeRequest; +import org.apache.arrow.flight.impl.Flight.HandshakeResponse; +import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceStub; + +import com.google.protobuf.ByteString; + +import io.grpc.StatusRuntimeException; +import io.grpc.stub.StreamObserver; + +/** + * Utility class for performing authorization over using a GRPC stub. + */ +public class ClientAuthWrapper { + + /** + * Do client auth for a client. The stub will be authenticated after this method returns. + * + * @param authHandler The handler to use. + * @param stub The service stub. + */ + public static void doClientAuth(ClientAuthHandler authHandler, FlightServiceStub stub) { + AuthObserver observer = new AuthObserver(); + try { + observer.responseObserver = stub.handshake(observer); + authHandler.authenticate(observer.sender, observer.iter); + if (!observer.sender.errored) { + observer.responseObserver.onCompleted(); + } + } catch (StatusRuntimeException sre) { + throw StatusUtils.fromGrpcRuntimeException(sre); + } + try { + if (!observer.completed.get()) { + // TODO: ARROW-5681 + throw new RuntimeException("Unauthenticated"); + } + } catch (InterruptedException e) { + throw new RuntimeException(e); + } catch (ExecutionException e) { + throw StatusUtils.fromThrowable(e.getCause()); + } + } + + private static class AuthObserver implements StreamObserver<HandshakeResponse> { + + private volatile StreamObserver<HandshakeRequest> responseObserver; + private final LinkedBlockingQueue<byte[]> messages = new LinkedBlockingQueue<>(); + private final AuthSender sender = new AuthSender(); + private CompletableFuture<Boolean> completed; + + public AuthObserver() { + super(); + completed = new CompletableFuture<>(); + } + + @Override + public void onNext(HandshakeResponse value) { + ByteString payload = value.getPayload(); + if (payload != null) { + messages.add(payload.toByteArray()); + } + } + + private Iterator<byte[]> iter = new Iterator<byte[]>() { + + @Override + public byte[] next() { + while (!completed.isDone() || !messages.isEmpty()) { + byte[] bytes = messages.poll(); + if (bytes == null) { + // busy wait. + continue; + } else { + return bytes; + } + } + + if (completed.isCompletedExceptionally()) { + // Preserve prior exception behavior + // TODO: with ARROW-5681, throw an appropriate Flight exception if gRPC raised an exception + try { + completed.get(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } catch (ExecutionException e) { + if (e.getCause() instanceof StatusRuntimeException) { + throw (StatusRuntimeException) e.getCause(); + } + throw new RuntimeException(e); + } + } + + throw new IllegalStateException("You attempted to retrieve messages after there were none."); + } + + @Override + public boolean hasNext() { + return !messages.isEmpty(); + } + }; + + @Override + public void onError(Throwable t) { + completed.completeExceptionally(t); + } + + private class AuthSender implements ClientAuthSender { + + private boolean errored = false; + + @Override + public void send(byte[] payload) { + try { + responseObserver.onNext(HandshakeRequest.newBuilder() + .setPayload(ByteString.copyFrom(payload)) + .build()); + } catch (StatusRuntimeException sre) { + throw StatusUtils.fromGrpcRuntimeException(sre); + } + } + + @Override + public void onError(Throwable cause) { + this.errored = true; + responseObserver.onError(StatusUtils.toGrpcException(cause)); + } + + } + + @Override + public void onCompleted() { + completed.complete(true); + } + } + +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java new file mode 100644 index 000000000..3a978b131 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthHandler.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth; + +import java.util.Iterator; +import java.util.Optional; + +/** + * Interface for Server side authentication handlers. + */ +public interface ServerAuthHandler { + + /** + * Validate the client token provided on each call. + * + * @return An empty optional if the client is not authenticated; the peer identity otherwise (may be the empty + * string). + */ + Optional<String> isValid(byte[] token); + + /** + * Handle the initial handshake with the client. + * + * @param outgoing A writer to send messages to the client. + * @param incoming An iterator of messages from the client. + * @return true if client is authenticated, false otherwise. + */ + boolean authenticate(ServerAuthSender outgoing, Iterator<byte[]> incoming); + + /** + * Interface for a server implementations to send back authentication messages + * back to the client. + */ + interface ServerAuthSender { + + void send(byte[] payload); + + void onError(Throwable cause); + + } + + /** + * An auth handler that does nothing. + */ + ServerAuthHandler NO_OP = new ServerAuthHandler() { + + @Override + public Optional<String> isValid(byte[] token) { + return Optional.of(""); + } + + @Override + public boolean authenticate(ServerAuthSender outgoing, Iterator<byte[]> incoming) { + return true; + } + }; +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java new file mode 100644 index 000000000..5bff3784e --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthInterceptor.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth; + +import java.util.Optional; + +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.grpc.StatusUtils; + +import io.grpc.Context; +import io.grpc.Contexts; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCall.Listener; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; + +/** + * GRPC Interceptor for performing authentication. + */ +public class ServerAuthInterceptor implements ServerInterceptor { + + private final ServerAuthHandler authHandler; + + public ServerAuthInterceptor(ServerAuthHandler authHandler) { + this.authHandler = authHandler; + } + + @Override + public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers, + ServerCallHandler<ReqT, RespT> next) { + if (!call.getMethodDescriptor().getFullMethodName().equals(AuthConstants.HANDSHAKE_DESCRIPTOR_NAME)) { + final Optional<String> peerIdentity; + + // Allow customizing the response code by throwing FlightRuntimeException + try { + peerIdentity = isValid(headers); + } catch (FlightRuntimeException e) { + final Status grpcStatus = StatusUtils.toGrpcStatus(e.status()); + call.close(grpcStatus, new Metadata()); + return new NoopServerCallListener<>(); + } catch (StatusRuntimeException e) { + Metadata trailers = e.getTrailers(); + call.close(e.getStatus(), trailers == null ? new Metadata() : trailers); + return new NoopServerCallListener<>(); + } + + if (!peerIdentity.isPresent()) { + // Send back a description along with the status code + call.close(Status.UNAUTHENTICATED + .withDescription("Unauthenticated (invalid or missing auth token)"), new Metadata()); + return new NoopServerCallListener<>(); + } + return Contexts.interceptCall(Context.current().withValue(AuthConstants.PEER_IDENTITY_KEY, peerIdentity.get()), + call, headers, next); + } + + return next.startCall(call, headers); + } + + private Optional<String> isValid(Metadata headers) { + byte[] token = headers.get(AuthConstants.TOKEN_KEY); + return authHandler.isValid(token); + } + + private static class NoopServerCallListener<T> extends ServerCall.Listener<T> { + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java new file mode 100644 index 000000000..ad1a36a93 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth/ServerAuthWrapper.java @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth; + +import java.util.Iterator; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.LinkedBlockingQueue; + +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.auth.ServerAuthHandler.ServerAuthSender; +import org.apache.arrow.flight.grpc.StatusUtils; +import org.apache.arrow.flight.impl.Flight.HandshakeRequest; +import org.apache.arrow.flight.impl.Flight.HandshakeResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.protobuf.ByteString; + +import io.grpc.stub.StreamObserver; + +/** + * Contains utility methods for integrating authorization into a GRPC stream. + */ +public class ServerAuthWrapper { + private static final Logger LOGGER = LoggerFactory.getLogger(ServerAuthWrapper.class); + + /** + * Wrap the auth handler for handshake purposes. + * + * @param authHandler Authentication handler + * @param responseObserver Observer for handshake response + * @param executors ExecutorService + * @return AuthObserver + */ + public static StreamObserver<HandshakeRequest> wrapHandshake(ServerAuthHandler authHandler, + StreamObserver<HandshakeResponse> responseObserver, ExecutorService executors) { + + // stream started. + AuthObserver observer = new AuthObserver(responseObserver); + final Runnable r = () -> { + try { + if (authHandler.authenticate(observer.sender, observer.iter)) { + responseObserver.onCompleted(); + return; + } + + responseObserver.onError(StatusUtils.toGrpcException(CallStatus.UNAUTHENTICATED.toRuntimeException())); + } catch (Exception ex) { + LOGGER.error("Error during authentication", ex); + responseObserver.onError(StatusUtils.toGrpcException(ex)); + } + }; + observer.future = executors.submit(r); + return observer; + } + + private static class AuthObserver implements StreamObserver<HandshakeRequest> { + + private final StreamObserver<HandshakeResponse> responseObserver; + private volatile Future<?> future; + private volatile boolean completed = false; + private final LinkedBlockingQueue<byte[]> messages = new LinkedBlockingQueue<>(); + private final AuthSender sender = new AuthSender(); + + public AuthObserver(StreamObserver<HandshakeResponse> responseObserver) { + super(); + this.responseObserver = responseObserver; + } + + @Override + public void onNext(HandshakeRequest value) { + ByteString payload = value.getPayload(); + if (payload != null) { + messages.add(payload.toByteArray()); + } + } + + private Iterator<byte[]> iter = new Iterator<byte[]>() { + + @Override + public byte[] next() { + while (!completed || !messages.isEmpty()) { + byte[] bytes = messages.poll(); + if (bytes == null) { + //busy wait. + continue; + } + return bytes; + } + throw new IllegalStateException("Requesting more messages than client sent."); + } + + @Override + public boolean hasNext() { + return !messages.isEmpty(); + } + }; + + @Override + public void onError(Throwable t) { + completed = true; + while (future == null) {/* busy wait */} + future.cancel(true); + } + + @Override + public void onCompleted() { + completed = true; + } + + private class AuthSender implements ServerAuthSender { + + @Override + public void send(byte[] payload) { + responseObserver.onNext(HandshakeResponse.newBuilder() + .setPayload(ByteString.copyFrom(payload)) + .build()); + } + + @Override + public void onError(Throwable cause) { + responseObserver.onError(StatusUtils.toGrpcException(cause)); + } + + } + } + +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/Auth2Constants.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/Auth2Constants.java new file mode 100644 index 000000000..624d7d5ff --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/Auth2Constants.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth2; + +/** + * Constants used in authorization of flight connections. + */ +public final class Auth2Constants { + public static final String PEER_IDENTITY_KEY = "arrow-flight-peer-identity"; + public static final String BEARER_PREFIX = "Bearer "; + public static final String BASIC_PREFIX = "Basic "; + public static final String AUTHORIZATION_HEADER = "Authorization"; + + private Auth2Constants() { + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/AuthUtilities.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/AuthUtilities.java new file mode 100644 index 000000000..c73b7cf1a --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/AuthUtilities.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth2; + +import org.apache.arrow.flight.CallHeaders; + +/** + * Utility class for completing the auth process. + */ +public final class AuthUtilities { + private AuthUtilities() { + + } + + /** + * Helper method for retrieving a value from the Authorization header. + * + * @param headers The headers to inspect. + * @param valuePrefix The prefix within the value portion of the header to extract away. + * @return The header value. + */ + public static String getValueFromAuthHeader(CallHeaders headers, String valuePrefix) { + final String authHeaderValue = headers.get(Auth2Constants.AUTHORIZATION_HEADER); + if (authHeaderValue != null) { + if (authHeaderValue.regionMatches(true, 0, valuePrefix, 0, valuePrefix.length())) { + return authHeaderValue.substring(valuePrefix.length()); + } + } + return null; + } + +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BasicAuthCredentialWriter.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BasicAuthCredentialWriter.java new file mode 100644 index 000000000..698287e88 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BasicAuthCredentialWriter.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth2; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.function.Consumer; + +import org.apache.arrow.flight.CallHeaders; + +/** + * Client credentials that use a username and password. + */ +public final class BasicAuthCredentialWriter implements Consumer<CallHeaders> { + + private final String name; + private final String password; + + public BasicAuthCredentialWriter(String name, String password) { + this.name = name; + this.password = password; + } + + @Override + public void accept(CallHeaders outputHeaders) { + outputHeaders.insert(Auth2Constants.AUTHORIZATION_HEADER, Auth2Constants.BASIC_PREFIX + + Base64.getEncoder().encodeToString(String.format("%s:%s", name, password).getBytes(StandardCharsets.UTF_8))); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BasicCallHeaderAuthenticator.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BasicCallHeaderAuthenticator.java new file mode 100644 index 000000000..fff7b4690 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BasicCallHeaderAuthenticator.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth2; + +import java.io.UnsupportedEncodingException; +import java.nio.charset.StandardCharsets; +import java.util.Base64; + +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightRuntimeException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A ServerAuthHandler for username/password authentication. + */ +public class BasicCallHeaderAuthenticator implements CallHeaderAuthenticator { + + private static final Logger logger = LoggerFactory.getLogger(BasicCallHeaderAuthenticator.class); + + private final CredentialValidator authValidator; + + public BasicCallHeaderAuthenticator(CredentialValidator authValidator) { + this.authValidator = authValidator; + } + + @Override + public AuthResult authenticate(CallHeaders incomingHeaders) { + try { + final String authEncoded = AuthUtilities.getValueFromAuthHeader( + incomingHeaders, Auth2Constants.BASIC_PREFIX); + if (authEncoded == null) { + throw CallStatus.UNAUTHENTICATED.toRuntimeException(); + } + // The value has the format Base64(<username>:<password>) + final String authDecoded = new String(Base64.getDecoder().decode(authEncoded), StandardCharsets.UTF_8); + final int colonPos = authDecoded.indexOf(':'); + if (colonPos == -1) { + throw CallStatus.UNAUTHENTICATED.toRuntimeException(); + } + + final String user = authDecoded.substring(0, colonPos); + final String password = authDecoded.substring(colonPos + 1); + return authValidator.validate(user, password); + } catch (UnsupportedEncodingException ex) { + // Note: Intentionally discarding the exception cause when reporting back to the client for security purposes. + logger.error("Authentication failed due to missing encoding.", ex); + throw CallStatus.INTERNAL.toRuntimeException(); + } catch (FlightRuntimeException ex) { + throw ex; + } catch (Exception ex) { + // Note: Intentionally discarding the exception cause when reporting back to the client for security purposes. + logger.error("Authentication failed.", ex); + throw CallStatus.UNAUTHENTICATED.toRuntimeException(); + } + } + + /** + * Interface that this handler delegates to for validating the incoming headers. + */ + public interface CredentialValidator { + /** + * Validate the supplied credentials (username/password) and return the peer identity. + * + * @param username The username to validate. + * @param password The password to validate. + * @return The peer identity if the supplied credentials are valid. + * @throws Exception If the supplied credentials are not valid. + */ + AuthResult validate(String username, String password) throws Exception; + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BearerCredentialWriter.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BearerCredentialWriter.java new file mode 100644 index 000000000..715ee502b --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BearerCredentialWriter.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth2; + +import java.util.function.Consumer; + +import org.apache.arrow.flight.CallHeaders; + +/** + * Client credentials that use a bearer token. + */ +public final class BearerCredentialWriter implements Consumer<CallHeaders> { + + private final String bearer; + + public BearerCredentialWriter(String bearer) { + this.bearer = bearer; + } + + @Override + public void accept(CallHeaders outputHeaders) { + outputHeaders.insert(Auth2Constants.AUTHORIZATION_HEADER, Auth2Constants.BEARER_PREFIX + bearer); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BearerTokenAuthenticator.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BearerTokenAuthenticator.java new file mode 100644 index 000000000..2006e0a2b --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/BearerTokenAuthenticator.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth2; + +import org.apache.arrow.flight.CallHeaders; + +/** + * Partial implementation of {@link CallHeaderAuthenticator} for bearer-token based authentication. + */ +public abstract class BearerTokenAuthenticator implements CallHeaderAuthenticator { + + final CallHeaderAuthenticator initialAuthenticator; + + public BearerTokenAuthenticator(CallHeaderAuthenticator initialAuthenticator) { + this.initialAuthenticator = initialAuthenticator; + } + + @Override + public AuthResult authenticate(CallHeaders incomingHeaders) { + // Check if headers contain a bearer token and if so, validate the token. + final String bearerToken = + AuthUtilities.getValueFromAuthHeader(incomingHeaders, Auth2Constants.BEARER_PREFIX); + if (bearerToken != null) { + return validateBearer(bearerToken); + } + + // Delegate to the basic auth handler to do the validation. + final CallHeaderAuthenticator.AuthResult result = initialAuthenticator.authenticate(incomingHeaders); + return getAuthResultWithBearerToken(result); + } + + /** + * Callback to run when the initial authenticator succeeds. + * @param authResult A successful initial authentication result. + * @return an alternate AuthResult based on the original AuthResult that will write a bearer token to output headers. + */ + protected abstract AuthResult getAuthResultWithBearerToken(AuthResult authResult); + + /** + * Validate the bearer token. + * @param bearerToken The bearer token to validate. + * @return A successful AuthResult if validation succeeded. + * @throws Exception If the token validation fails. + */ + protected abstract AuthResult validateBearer(String bearerToken); + +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/CallHeaderAuthenticator.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/CallHeaderAuthenticator.java new file mode 100644 index 000000000..87e60f1fa --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/CallHeaderAuthenticator.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth2; + +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.FlightRuntimeException; + +/** + * Interface for Server side authentication handlers. + * + * A CallHeaderAuthenticator is used by {@link ServerCallHeaderAuthMiddleware} to validate headers sent by a Flight + * client for authentication purposes. The headers validated do not necessarily have to be Authorization headers. + * + * The workflow is that the FlightServer will intercept headers on a request, validate the headers, and + * either send back an UNAUTHENTICATED error, or succeed and potentially send back additional headers to the client. + * + * Implementations of CallHeaderAuthenticator should take care not to provide leak confidential details (such as + * indicating if usernames are valid or not) for security reasons when reporting errors back to clients. + * + * Example CallHeaderAuthenticators provided include: + * The {@link BasicCallHeaderAuthenticator} will authenticate basic HTTP credentials. + * + * The {@link BearerTokenAuthenticator} will authenticate basic HTTP credentials initially, then also send back a + * bearer token that the client can use for subsequent requests. The {@link GeneratedBearerTokenAuthenticator} will + * provide internally generated bearer tokens and maintain a cache of them. + */ +public interface CallHeaderAuthenticator { + + /** + * Encapsulates the result of the {@link CallHeaderAuthenticator} analysis of headers. + * + * This includes the identity of the incoming user and any outbound headers to send as a response to the client. + */ + interface AuthResult { + /** + * The peer identity that was determined by the handshake process based on the + * authentication credentials supplied by the client. + * + * @return The peer identity. + */ + String getPeerIdentity(); + + /** + * Appends a header to the outgoing call headers. + * @param outgoingHeaders The outgoing headers. + */ + default void appendToOutgoingHeaders(CallHeaders outgoingHeaders) { + + } + } + + /** + * Validate the auth headers sent by the client. + * + * @param incomingHeaders The incoming headers to authenticate. + * @return an auth result containing a peer identity and optionally a bearer token. + * @throws FlightRuntimeException with CallStatus.UNAUTHENTICATED if credentials were not supplied + * or if credentials were supplied but were not valid. + */ + AuthResult authenticate(CallHeaders incomingHeaders); + + /** + * An auth handler that does nothing. + */ + CallHeaderAuthenticator NO_OP = new CallHeaderAuthenticator() { + @Override + public AuthResult authenticate(CallHeaders incomingHeaders) { + return () -> ""; + } + }; +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientBearerHeaderHandler.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientBearerHeaderHandler.java new file mode 100644 index 000000000..45bdb6d95 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientBearerHeaderHandler.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth2; + +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.grpc.CredentialCallOption; + +/** + * A client header handler that parses the incoming headers for a bearer token. + */ +public class ClientBearerHeaderHandler implements ClientHeaderHandler { + + @Override + public CredentialCallOption getCredentialCallOptionFromIncomingHeaders(CallHeaders incomingHeaders) { + final String bearerValue = AuthUtilities.getValueFromAuthHeader(incomingHeaders, Auth2Constants.BEARER_PREFIX); + if (bearerValue != null) { + return new CredentialCallOption(new BearerCredentialWriter(bearerValue)); + } + return null; + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientHandshakeWrapper.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientHandshakeWrapper.java new file mode 100644 index 000000000..16a514250 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientHandshakeWrapper.java @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth2; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.grpc.StatusUtils; +import org.apache.arrow.flight.impl.Flight.HandshakeRequest; +import org.apache.arrow.flight.impl.Flight.HandshakeResponse; +import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceStub; + +import io.grpc.StatusRuntimeException; +import io.grpc.stub.StreamObserver; + +/** + * Utility class for executing a handshake with a FlightServer. + */ +public class ClientHandshakeWrapper { + private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(ClientHandshakeWrapper.class); + + /** + * Do handshake for a client. The stub will be authenticated after this method returns. + * + * @param stub The service stub. + */ + public static void doClientHandshake(FlightServiceStub stub) { + final HandshakeObserver observer = new HandshakeObserver(); + try { + observer.requestObserver = stub.handshake(observer); + observer.requestObserver.onNext(HandshakeRequest.newBuilder().build()); + observer.requestObserver.onCompleted(); + try { + if (!observer.completed.get()) { + // TODO: ARROW-5681 + throw CallStatus.UNAUTHENTICATED.toRuntimeException(); + } + } catch (InterruptedException ex) { + Thread.currentThread().interrupt(); + throw ex; + } catch (ExecutionException ex) { + final FlightRuntimeException wrappedException = StatusUtils.fromThrowable(ex.getCause()); + logger.error("Failed on completing future", wrappedException); + throw wrappedException; + } + } catch (StatusRuntimeException sre) { + logger.error("Failed with SREe", sre); + throw StatusUtils.fromGrpcRuntimeException(sre); + } catch (Throwable ex) { + logger.error("Failed with unknown", ex); + if (ex instanceof FlightRuntimeException) { + throw (FlightRuntimeException) ex; + } + throw StatusUtils.fromThrowable(ex); + } + } + + private static class HandshakeObserver implements StreamObserver<HandshakeResponse> { + + private volatile StreamObserver<HandshakeRequest> requestObserver; + private final CompletableFuture<Boolean> completed; + + public HandshakeObserver() { + super(); + completed = new CompletableFuture<>(); + } + + @Override + public void onNext(HandshakeResponse value) { + } + + @Override + public void onError(Throwable t) { + completed.completeExceptionally(t); + } + + @Override + public void onCompleted() { + completed.complete(true); + } + } + +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientHeaderHandler.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientHeaderHandler.java new file mode 100644 index 000000000..514189f9b --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientHeaderHandler.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth2; + +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.grpc.CredentialCallOption; + +/** + * Interface for client side header parsing and conversion to CredentialCallOption. + */ +public interface ClientHeaderHandler { + /** + * Parses the incoming headers and converts them into a CredentialCallOption. + * @param incomingHeaders Incoming headers to parse. + * @return An instance of CredentialCallOption. + */ + CredentialCallOption getCredentialCallOptionFromIncomingHeaders(CallHeaders incomingHeaders); + + /** + * An client header handler that does nothing. + */ + ClientHeaderHandler NO_OP = new ClientHeaderHandler() { + @Override + public CredentialCallOption getCredentialCallOptionFromIncomingHeaders(CallHeaders incomingHeaders) { + return null; + } + }; +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientIncomingAuthHeaderMiddleware.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientIncomingAuthHeaderMiddleware.java new file mode 100644 index 000000000..be5f3f54d --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ClientIncomingAuthHeaderMiddleware.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth2; + +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallInfo; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightClientMiddleware; +import org.apache.arrow.flight.grpc.CredentialCallOption; + +/** + * Middleware for capturing bearer tokens sent back from the Flight server. + */ +public class ClientIncomingAuthHeaderMiddleware implements FlightClientMiddleware { + private final Factory factory; + + /** + * Factory used within FlightClient. + */ + public static class Factory implements FlightClientMiddleware.Factory { + private final ClientHeaderHandler headerHandler; + private CredentialCallOption credentialCallOption; + + /** + * Construct a factory with the given header handler. + * @param headerHandler The header handler that will be used for handling incoming headers from the flight server. + */ + public Factory(ClientHeaderHandler headerHandler) { + this.headerHandler = headerHandler; + } + + @Override + public FlightClientMiddleware onCallStarted(CallInfo info) { + return new ClientIncomingAuthHeaderMiddleware(this); + } + + void setCredentialCallOption(CredentialCallOption callOption) { + this.credentialCallOption = callOption; + } + + public CredentialCallOption getCredentialCallOption() { + return credentialCallOption; + } + } + + private ClientIncomingAuthHeaderMiddleware(Factory factory) { + this.factory = factory; + } + + @Override + public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { + } + + @Override + public void onHeadersReceived(CallHeaders incomingHeaders) { + factory.setCredentialCallOption( + factory.headerHandler.getCredentialCallOptionFromIncomingHeaders(incomingHeaders)); + } + + @Override + public void onCallCompleted(CallStatus status) { + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/GeneratedBearerTokenAuthenticator.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/GeneratedBearerTokenAuthenticator.java new file mode 100644 index 000000000..8b312b6b7 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/GeneratedBearerTokenAuthenticator.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth2; + +import java.nio.ByteBuffer; +import java.util.Base64; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.grpc.MetadataAdapter; + +import com.google.common.base.Strings; +import com.google.common.cache.Cache; +import com.google.common.cache.CacheBuilder; + +import io.grpc.Metadata; + +/** + * Generates and caches bearer tokens from user credentials. + */ +public class GeneratedBearerTokenAuthenticator extends BearerTokenAuthenticator { + private final Cache<String, String> bearerToIdentityCache; + + /** + * Generate bearer tokens for the given basic call authenticator. + * @param authenticator The authenticator to initial validate inputs with. + */ + public GeneratedBearerTokenAuthenticator(CallHeaderAuthenticator authenticator) { + this(authenticator, CacheBuilder.newBuilder().expireAfterAccess(2, TimeUnit.HOURS)); + } + + /** + * Generate bearer tokens for the given basic call authenticator. + * @param authenticator The authenticator to initial validate inputs with. + * @param timeoutMinutes The time before tokens expire after being accessed. + */ + public GeneratedBearerTokenAuthenticator(CallHeaderAuthenticator authenticator, int timeoutMinutes) { + this(authenticator, CacheBuilder.newBuilder().expireAfterAccess(timeoutMinutes, TimeUnit.MINUTES)); + } + + /** + * Generate bearer tokens for the given basic call authenticator. + * @param authenticator The authenticator to initial validate inputs with. + * @param cacheBuilder The configuration of the cache of bearer tokens. + */ + public GeneratedBearerTokenAuthenticator(CallHeaderAuthenticator authenticator, + CacheBuilder<Object, Object> cacheBuilder) { + super(authenticator); + bearerToIdentityCache = cacheBuilder.build(); + } + + @Override + protected AuthResult validateBearer(String bearerToken) { + final String peerIdentity = bearerToIdentityCache.getIfPresent(bearerToken); + if (peerIdentity == null) { + throw CallStatus.UNAUTHENTICATED.toRuntimeException(); + } + + return new AuthResult() { + @Override + public String getPeerIdentity() { + return peerIdentity; + } + + @Override + public void appendToOutgoingHeaders(CallHeaders outgoingHeaders) { + if (null == AuthUtilities.getValueFromAuthHeader(outgoingHeaders, Auth2Constants.BEARER_PREFIX)) { + outgoingHeaders.insert(Auth2Constants.AUTHORIZATION_HEADER, Auth2Constants.BEARER_PREFIX + bearerToken); + } + } + }; + } + + @Override + protected AuthResult getAuthResultWithBearerToken(AuthResult authResult) { + // We generate a dummy header and call appendToOutgoingHeaders with it. + // We then inspect the dummy header and parse the bearer token if present in the header + // and generate a new bearer token if a bearer token is not present in the header. + final CallHeaders dummyHeaders = new MetadataAdapter(new Metadata()); + authResult.appendToOutgoingHeaders(dummyHeaders); + String bearerToken = + AuthUtilities.getValueFromAuthHeader(dummyHeaders, Auth2Constants.BEARER_PREFIX); + final AuthResult authResultWithBearerToken; + if (Strings.isNullOrEmpty(bearerToken)) { + // Generate a new bearer token and return an AuthResult that can write it. + final UUID uuid = UUID.randomUUID(); + final ByteBuffer byteBuffer = ByteBuffer.wrap(new byte[16]); + byteBuffer.putLong(uuid.getMostSignificantBits()); + byteBuffer.putLong(uuid.getLeastSignificantBits()); + final String newToken = Base64.getEncoder().encodeToString(byteBuffer.array()); + bearerToken = newToken; + authResultWithBearerToken = new AuthResult() { + @Override + public String getPeerIdentity() { + return authResult.getPeerIdentity(); + } + + @Override + public void appendToOutgoingHeaders(CallHeaders outgoingHeaders) { + authResult.appendToOutgoingHeaders(outgoingHeaders); + outgoingHeaders.insert(Auth2Constants.AUTHORIZATION_HEADER, Auth2Constants.BEARER_PREFIX + newToken); + } + }; + } else { + // Use the bearer token supplied by the original auth result. + authResultWithBearerToken = authResult; + } + bearerToIdentityCache.put(bearerToken, authResult.getPeerIdentity()); + return authResultWithBearerToken; + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ServerCallHeaderAuthMiddleware.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ServerCallHeaderAuthMiddleware.java new file mode 100644 index 000000000..9bfa73818 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/auth2/ServerCallHeaderAuthMiddleware.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth2; + +import static org.apache.arrow.flight.auth2.CallHeaderAuthenticator.AuthResult; + +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallInfo; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightServerMiddleware; +import org.apache.arrow.flight.RequestContext; + +/** + * Middleware that's used to validate credentials during the handshake and verify + * the bearer token in subsequent requests. + */ +public class ServerCallHeaderAuthMiddleware implements FlightServerMiddleware { + /** + * Factory for accessing ServerAuthMiddleware. + */ + public static class Factory implements FlightServerMiddleware.Factory<ServerCallHeaderAuthMiddleware> { + private final CallHeaderAuthenticator authHandler; + + /** + * Construct a factory with the given auth handler. + * @param authHandler The auth handler what will be used for authenticating requests. + */ + public Factory(CallHeaderAuthenticator authHandler) { + this.authHandler = authHandler; + } + + @Override + public ServerCallHeaderAuthMiddleware onCallStarted(CallInfo callInfo, CallHeaders incomingHeaders, + RequestContext context) { + final AuthResult result = authHandler.authenticate(incomingHeaders); + context.put(Auth2Constants.PEER_IDENTITY_KEY, result.getPeerIdentity()); + return new ServerCallHeaderAuthMiddleware(result); + } + } + + private final AuthResult authResult; + + public ServerCallHeaderAuthMiddleware(AuthResult authResult) { + this.authResult = authResult; + } + + @Override + public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { + authResult.appendToOutgoingHeaders(outgoingHeaders); + } + + @Override + public void onCallCompleted(CallStatus status) { + } + + @Override + public void onCallErrored(Throwable err) { + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/client/ClientCookieMiddleware.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/client/ClientCookieMiddleware.java new file mode 100644 index 000000000..56f24e101 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/client/ClientCookieMiddleware.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.client; + +import java.net.HttpCookie; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.stream.Collectors; + +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallInfo; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightClientMiddleware; +import org.apache.arrow.util.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A client middleware for receiving and sending cookie information. + * Note that this class will not persist permanent cookies beyond the lifetime + * of this session. + * + * This middleware will automatically remove cookies that have expired. + * <b>Note</b>: Negative max-age values currently do not get marked as expired due to + * a JDK issue. Use max-age=0 to explicitly remove an existing cookie. + */ +public class ClientCookieMiddleware implements FlightClientMiddleware { + private static final Logger LOGGER = LoggerFactory.getLogger(ClientCookieMiddleware.class); + + private static final String SET_COOKIE_HEADER = "Set-Cookie"; + private static final String COOKIE_HEADER = "Cookie"; + + private final Factory factory; + + @VisibleForTesting + ClientCookieMiddleware(Factory factory) { + this.factory = factory; + } + + /** + * Factory used within FlightClient. + */ + public static class Factory implements FlightClientMiddleware.Factory { + // Use a map to track the most recent version of a cookie from the server. + // Note that cookie names are case-sensitive (but header names aren't). + private ConcurrentMap<String, HttpCookie> cookies = new ConcurrentHashMap<>(); + + @Override + public ClientCookieMiddleware onCallStarted(CallInfo info) { + return new ClientCookieMiddleware(this); + } + + private void updateCookies(Iterable<String> newCookieHeaderValues) { + // Note: Intentionally overwrite existing cookie values. + // A cookie defined once will continue to be used in all subsequent + // requests on the client instance. The server can send the same cookie again + // with a different value and the client will use the new value in future requests. + // The server can also update a cookie to have an Expiry in the past or negative age + // to signal that the client should stop using the cookie immediately. + newCookieHeaderValues.forEach(headerValue -> { + try { + final List<HttpCookie> parsedCookies = HttpCookie.parse(headerValue); + parsedCookies.forEach(parsedCookie -> { + final String cookieNameLc = parsedCookie.getName().toLowerCase(Locale.ENGLISH); + if (parsedCookie.hasExpired()) { + cookies.remove(cookieNameLc); + } else { + cookies.put(parsedCookie.getName().toLowerCase(Locale.ENGLISH), parsedCookie); + } + }); + } catch (IllegalArgumentException ex) { + LOGGER.warn("Skipping incorrectly formatted Set-Cookie header with value '{}'.", headerValue); + } + }); + } + } + + @Override + public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { + final String cookieValue = getValidCookiesAsString(); + if (!cookieValue.isEmpty()) { + outgoingHeaders.insert(COOKIE_HEADER, cookieValue); + } + } + + @Override + public void onHeadersReceived(CallHeaders incomingHeaders) { + final Iterable<String> setCookieHeaders = incomingHeaders.getAll(SET_COOKIE_HEADER); + if (setCookieHeaders != null) { + factory.updateCookies(setCookieHeaders); + } + } + + @Override + public void onCallCompleted(CallStatus status) { + + } + + /** + * Discards expired cookies and returns the valid cookies as a String delimited by ';'. + */ + @VisibleForTesting + String getValidCookiesAsString() { + // Discard expired cookies. + factory.cookies.entrySet().removeIf(cookieEntry -> cookieEntry.getValue().hasExpired()); + + // Cookie header value format: + // [<cookie-name1>=<cookie-value1>; <cookie-name2>=<cookie-value2; ...] + return factory.cookies.entrySet().stream() + .map(cookie -> cookie.getValue().toString()) + .collect(Collectors.joining("; ")); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java new file mode 100644 index 000000000..528c227df --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/ExampleFlightServer.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.example; + +import java.io.IOException; + +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; + +/** + * An Example Flight Server that provides access to the InMemoryStore. Used for integration testing. + */ +public class ExampleFlightServer implements AutoCloseable { + + private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(ExampleFlightServer.class); + + private final FlightServer flightServer; + private final Location location; + private final BufferAllocator allocator; + private final InMemoryStore mem; + + /** + * Constructs a new instance using Allocator for allocating buffer storage that binds + * to the given location. + */ + public ExampleFlightServer(BufferAllocator allocator, Location location) { + this.allocator = allocator.newChildAllocator("flight-server", 0, Long.MAX_VALUE); + this.location = location; + this.mem = new InMemoryStore(this.allocator, location); + this.flightServer = FlightServer.builder(allocator, location, mem).build(); + } + + public Location getLocation() { + return location; + } + + public int getPort() { + return this.flightServer.getPort(); + } + + public void start() throws IOException { + flightServer.start(); + } + + public void awaitTermination() throws InterruptedException { + flightServer.awaitTermination(); + } + + public InMemoryStore getStore() { + return mem; + } + + @Override + public void close() throws Exception { + AutoCloseables.close(mem, flightServer, allocator); + } + + /** + * Main method starts the server listening to localhost:12233. + */ + public static void main(String[] args) throws Exception { + final BufferAllocator a = new RootAllocator(Long.MAX_VALUE); + final ExampleFlightServer efs = new ExampleFlightServer(a, Location.forGrpcInsecure("localhost", 12233)); + efs.start(); + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + try { + System.out.println("\nExiting..."); + AutoCloseables.close(efs, a); + } catch (Exception e) { + e.printStackTrace(); + } + })); + efs.awaitTermination(); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/ExampleTicket.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/ExampleTicket.java new file mode 100644 index 000000000..e15ecd034 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/ExampleTicket.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.example; + +import java.io.IOException; +import java.util.List; + +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.util.Preconditions; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.annotation.JsonSerialize; +import com.google.common.base.Throwables; + +/** + * POJO object used to demonstrate how an opaque ticket can be generated. + */ +@JsonSerialize +public class ExampleTicket { + + private static final ObjectMapper MAPPER = new ObjectMapper(); + + private final List<String> path; + private final int ordinal; + + // uuid to ensure that a stream from one node is not recreated on another node and mixed up. + private final String uuid; + + /** + * Constructs a new instance. + * + * @param path Path to data + * @param ordinal A counter for the stream. + * @param uuid A unique identifier for this particular stream. + */ + @JsonCreator + public ExampleTicket(@JsonProperty("path") List<String> path, @JsonProperty("ordinal") int ordinal, + @JsonProperty("uuid") String uuid) { + super(); + Preconditions.checkArgument(ordinal >= 0); + this.path = path; + this.ordinal = ordinal; + this.uuid = uuid; + } + + public List<String> getPath() { + return path; + } + + public int getOrdinal() { + return ordinal; + } + + public String getUuid() { + return uuid; + } + + /** + * Deserializes a new instance from the protocol buffer ticket. + */ + public static ExampleTicket from(Ticket ticket) { + try { + return MAPPER.readValue(ticket.getBytes(), ExampleTicket.class); + } catch (IOException e) { + throw Throwables.propagate(e); + } + } + + /** + * Creates a new protocol buffer Ticket by serializing to JSON. + */ + public Ticket toTicket() { + try { + return new Ticket(MAPPER.writeValueAsBytes(this)); + } catch (JsonProcessingException e) { + throw Throwables.propagate(e); + } + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ordinal; + result = prime * result + ((path == null) ? 0 : path.hashCode()); + result = prime * result + ((uuid == null) ? 0 : uuid.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + ExampleTicket other = (ExampleTicket) obj; + if (ordinal != other.ordinal) { + return false; + } + if (path == null) { + if (other.path != null) { + return false; + } + } else if (!path.equals(other.path)) { + return false; + } + if (uuid == null) { + if (other.uuid != null) { + return false; + } + } else if (!uuid.equals(other.uuid)) { + return false; + } + return true; + } + + +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/FlightHolder.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/FlightHolder.java new file mode 100644 index 000000000..f6295211e --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/FlightHolder.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.example; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.stream.Collectors; + +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.DictionaryUtility; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; + +/** + * A logical collection of streams sharing the same schema. + */ +public class FlightHolder implements AutoCloseable { + + private final BufferAllocator allocator; + private final FlightDescriptor descriptor; + private final Schema schema; + private final List<Stream> streams = new CopyOnWriteArrayList<>(); + private final DictionaryProvider dictionaryProvider; + + /** + * Creates a new instance. + * @param allocator The allocator to use for allocating buffers to store data. + * @param descriptor The descriptor for the streams. + * @param schema The schema for the stream. + * @param dictionaryProvider The dictionary provider for the stream. + */ + public FlightHolder(BufferAllocator allocator, FlightDescriptor descriptor, Schema schema, + DictionaryProvider dictionaryProvider) { + Preconditions.checkArgument(!descriptor.isCommand()); + this.allocator = allocator.newChildAllocator(descriptor.toString(), 0, Long.MAX_VALUE); + this.descriptor = descriptor; + this.schema = schema; + this.dictionaryProvider = dictionaryProvider; + } + + /** + * Returns the stream based on the ordinal of ExampleTicket. + */ + public Stream getStream(ExampleTicket ticket) { + Preconditions.checkArgument(ticket.getOrdinal() < streams.size(), "Unknown stream."); + Stream stream = streams.get(ticket.getOrdinal()); + stream.verify(ticket); + return stream; + } + + /** + * Adds a new streams which clients can populate via the returned object. + */ + public Stream.StreamCreator addStream(Schema schema) { + Preconditions.checkArgument(this.schema.equals(schema), "Stream schema inconsistent with existing schema."); + return new Stream.StreamCreator(schema, dictionaryProvider, allocator, t -> { + synchronized (streams) { + streams.add(t); + } + }); + } + + /** + * List all available streams as being available at <code>l</code>. + */ + public FlightInfo getFlightInfo(final Location l) { + final long bytes = allocator.getAllocatedMemory(); + final long records = streams.stream().collect(Collectors.summingLong(t -> t.getRecordCount())); + + final List<FlightEndpoint> endpoints = new ArrayList<>(); + int i = 0; + for (Stream s : streams) { + endpoints.add( + new FlightEndpoint( + new ExampleTicket(descriptor.getPath(), i, s.getUuid()) + .toTicket(), + l)); + i++; + } + return new FlightInfo(messageFormatSchema(), descriptor, endpoints, bytes, records); + } + + private Schema messageFormatSchema() { + Set<Long> dictionaryIdsUsed = new HashSet<>(); + List<Field> messageFormatFields = schema.getFields() + .stream() + .map(f -> DictionaryUtility.toMessageFormat(f, dictionaryProvider, dictionaryIdsUsed)) + .collect(Collectors.toList()); + return new Schema(messageFormatFields, schema.getCustomMetadata()); + } + + @Override + public void close() throws Exception { + // Close dictionaries + final Set<Long> dictionaryIds = new HashSet<>(); + schema.getFields().forEach(field -> DictionaryUtility.toMessageFormat(field, dictionaryProvider, dictionaryIds)); + + final Iterable<AutoCloseable> dictionaries = dictionaryIds.stream() + .map(id -> (AutoCloseable) dictionaryProvider.lookup(id).getVector())::iterator; + + AutoCloseables.close(Iterables.concat(streams, ImmutableList.of(allocator), dictionaries)); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java new file mode 100644 index 000000000..ff796718d --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.example; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.ActionType; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.example.Stream.StreamCreator; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; + +/** + * A FlightProducer that hosts an in memory store of Arrow buffers. Used for integration testing. + */ +public class InMemoryStore implements FlightProducer, AutoCloseable { + + private final ConcurrentMap<FlightDescriptor, FlightHolder> holders = new ConcurrentHashMap<>(); + private final BufferAllocator allocator; + private Location location; + + /** + * Constructs a new instance. + * + * @param allocator The allocator for creating new Arrow buffers. + * @param location The location of the storage. + */ + public InMemoryStore(BufferAllocator allocator, Location location) { + super(); + this.allocator = allocator; + this.location = location; + } + + /** + * Update the location after server start. + * + * <p>Useful for binding to port 0 to get a free port. + */ + public void setLocation(Location location) { + this.location = location; + } + + @Override + public void getStream(CallContext context, Ticket ticket, + ServerStreamListener listener) { + getStream(ticket).sendTo(allocator, listener); + } + + /** + * Returns the appropriate stream given the ticket (streams are indexed by path and an ordinal). + */ + public Stream getStream(Ticket t) { + ExampleTicket example = ExampleTicket.from(t); + FlightDescriptor d = FlightDescriptor.path(example.getPath()); + FlightHolder h = holders.get(d); + if (h == null) { + throw new IllegalStateException("Unknown ticket."); + } + + return h.getStream(example); + } + + @Override + public void listFlights(CallContext context, Criteria criteria, StreamListener<FlightInfo> listener) { + try { + for (FlightHolder h : holders.values()) { + listener.onNext(h.getFlightInfo(location)); + } + listener.onCompleted(); + } catch (Exception ex) { + listener.onError(ex); + } + } + + @Override + public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { + FlightHolder h = holders.get(descriptor); + if (h == null) { + throw new IllegalStateException("Unknown descriptor."); + } + + return h.getFlightInfo(location); + } + + @Override + public Runnable acceptPut(CallContext context, + final FlightStream flightStream, final StreamListener<PutResult> ackStream) { + return () -> { + StreamCreator creator = null; + boolean success = false; + try (VectorSchemaRoot root = flightStream.getRoot()) { + final FlightHolder h = holders.computeIfAbsent( + flightStream.getDescriptor(), + t -> new FlightHolder(allocator, t, flightStream.getSchema(), flightStream.getDictionaryProvider())); + + creator = h.addStream(flightStream.getSchema()); + + VectorUnloader unloader = new VectorUnloader(root); + while (flightStream.next()) { + ackStream.onNext(PutResult.metadata(flightStream.getLatestMetadata())); + creator.add(unloader.getRecordBatch()); + } + // Closing the stream will release the dictionaries + flightStream.takeDictionaryOwnership(); + creator.complete(); + success = true; + } finally { + if (!success) { + creator.drop(); + } + } + + }; + + } + + @Override + public void doAction(CallContext context, Action action, + StreamListener<Result> listener) { + switch (action.getType()) { + case "drop": { + // not implemented. + listener.onNext(new Result(new byte[0])); + listener.onCompleted(); + break; + } + default: { + listener.onError(CallStatus.UNIMPLEMENTED.toRuntimeException()); + } + } + } + + @Override + public void listActions(CallContext context, + StreamListener<ActionType> listener) { + listener.onNext(new ActionType("get", "pull a stream. Action must be done via standard get mechanism")); + listener.onNext(new ActionType("put", "push a stream. Action must be done via standard put mechanism")); + listener.onNext(new ActionType("drop", "delete a flight. Action body is a JSON encoded path.")); + listener.onCompleted(); + } + + @Override + public void close() throws Exception { + AutoCloseables.close(holders.values()); + holders.clear(); + } + +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/Stream.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/Stream.java new file mode 100644 index 000000000..0bc35798d --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/Stream.java @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.example; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.UUID; +import java.util.function.Consumer; + +import org.apache.arrow.flight.FlightProducer.ServerStreamListener; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Schema; + +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableList; + +/** + * A collection of Arrow record batches. + */ +public class Stream implements AutoCloseable, Iterable<ArrowRecordBatch> { + + private final String uuid = UUID.randomUUID().toString(); + private final DictionaryProvider dictionaryProvider; + private final List<ArrowRecordBatch> batches; + private final Schema schema; + private final long recordCount; + + /** + * Create a new instance. + * + * @param schema The schema for the record batches. + * @param batches The data associated with the stream. + * @param recordCount The total record count across all batches. + */ + public Stream( + final Schema schema, + final DictionaryProvider dictionaryProvider, + List<ArrowRecordBatch> batches, + long recordCount) { + this.schema = schema; + this.dictionaryProvider = dictionaryProvider; + this.batches = ImmutableList.copyOf(batches); + this.recordCount = recordCount; + } + + public Schema getSchema() { + return schema; + } + + @Override + public Iterator<ArrowRecordBatch> iterator() { + return batches.iterator(); + } + + public long getRecordCount() { + return recordCount; + } + + public String getUuid() { + return uuid; + } + + /** + * Sends that data from this object to the given listener. + */ + public void sendTo(BufferAllocator allocator, ServerStreamListener listener) { + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + listener.start(root, dictionaryProvider); + final VectorLoader loader = new VectorLoader(root); + int counter = 0; + for (ArrowRecordBatch batch : batches) { + final byte[] rawMetadata = Integer.toString(counter).getBytes(StandardCharsets.UTF_8); + final ArrowBuf metadata = allocator.buffer(rawMetadata.length); + metadata.writeBytes(rawMetadata); + loader.load(batch); + // Transfers ownership of the buffer - do not free buffer ourselves + listener.putNext(metadata); + counter++; + } + listener.completed(); + } catch (Exception ex) { + listener.error(ex); + } + } + + /** + * Throws an IllegalStateException if the given ticket doesn't correspond to this stream. + */ + public void verify(ExampleTicket ticket) { + if (!uuid.equals(ticket.getUuid())) { + throw new IllegalStateException("Ticket doesn't match."); + } + } + + @Override + public void close() throws Exception { + AutoCloseables.close(batches); + } + + /** + * Provides the functionality to create a new stream by adding batches serially. + */ + public static class StreamCreator { + + private final Schema schema; + private final BufferAllocator allocator; + private final List<ArrowRecordBatch> batches = new ArrayList<>(); + private final Consumer<Stream> committer; + private long recordCount = 0; + private DictionaryProvider dictionaryProvider; + + /** + * Creates a new instance. + * + * @param schema The schema for batches in the stream. + * @param dictionaryProvider The dictionary provider for the stream. + * @param allocator The allocator used to copy data permanently into the stream. + * @param committer A callback for when the stream is ready to be finalized (no more batches). + */ + public StreamCreator(Schema schema, DictionaryProvider dictionaryProvider, + BufferAllocator allocator, Consumer<Stream> committer) { + this.allocator = allocator; + this.committer = committer; + this.schema = schema; + this.dictionaryProvider = dictionaryProvider; + } + + /** + * Abandon creation of the stream. + */ + public void drop() { + try { + AutoCloseables.close(batches); + } catch (Exception ex) { + throw Throwables.propagate(ex); + } + } + + public void add(ArrowRecordBatch batch) { + batches.add(batch.cloneWithTransfer(allocator)); + recordCount += batch.getLength(); + } + + /** + * Complete building the stream (no more batches can be added). + */ + public void complete() { + Stream stream = new Stream(schema, dictionaryProvider, batches, recordCount); + committer.accept(stream); + } + + } + +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java new file mode 100644 index 000000000..3955d7d21 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.example.integration; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Optional; + +import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightStatusCode; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.NoOpFlightProducer; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.auth.BasicClientAuthHandler; +import org.apache.arrow.flight.auth.BasicServerAuthHandler; +import org.apache.arrow.memory.BufferAllocator; + +/** + * A scenario testing the built-in basic authentication Protobuf. + */ +final class AuthBasicProtoScenario implements Scenario { + + static final String USERNAME = "arrow"; + static final String PASSWORD = "flight"; + + @Override + public FlightProducer producer(BufferAllocator allocator, Location location) { + return new NoOpFlightProducer() { + @Override + public void doAction(CallContext context, Action action, StreamListener<Result> listener) { + listener.onNext(new Result(context.peerIdentity().getBytes(StandardCharsets.UTF_8))); + listener.onCompleted(); + } + }; + } + + @Override + public void buildServer(FlightServer.Builder builder) { + builder.authHandler(new BasicServerAuthHandler(new BasicServerAuthHandler.BasicAuthValidator() { + @Override + public byte[] getToken(String username, String password) throws Exception { + if (!USERNAME.equals(username) || !PASSWORD.equals(password)) { + throw CallStatus.UNAUTHENTICATED.withDescription("Username or password is invalid.").toRuntimeException(); + } + return ("valid:" + username).getBytes(StandardCharsets.UTF_8); + } + + @Override + public Optional<String> isValid(byte[] token) { + if (token != null) { + final String credential = new String(token, StandardCharsets.UTF_8); + if (credential.startsWith("valid:")) { + return Optional.of(credential.substring(6)); + } + } + return Optional.empty(); + } + })); + } + + @Override + public void client(BufferAllocator allocator, Location location, FlightClient client) { + final FlightRuntimeException e = IntegrationAssertions.assertThrows(FlightRuntimeException.class, () -> { + client.listActions().forEach(act -> { + }); + }); + if (!FlightStatusCode.UNAUTHENTICATED.equals(e.status().code())) { + throw new AssertionError("Expected UNAUTHENTICATED but found " + e.status().code(), e); + } + + client.authenticate(new BasicClientAuthHandler(USERNAME, PASSWORD)); + final Result result = client.doAction(new Action("")).next(); + if (!USERNAME.equals(new String(result.getBody(), StandardCharsets.UTF_8))) { + throw new AssertionError("Expected " + USERNAME + " but got " + Arrays.toString(result.getBody())); + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java new file mode 100644 index 000000000..576d1887f --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.example.integration; + +import java.util.Objects; + +/** + * Utility methods to implement integration tests without using JUnit assertions. + */ +final class IntegrationAssertions { + + /** + * Assert that the given code throws the given exception or subclass thereof. + * + * @param clazz The exception type. + * @param body The code to run. + * @param <T> The exception type. + * @return The thrown exception. + */ + @SuppressWarnings("unchecked") + static <T extends Throwable> T assertThrows(Class<T> clazz, AssertThrows body) { + try { + body.run(); + } catch (Throwable t) { + if (clazz.isInstance(t)) { + return (T) t; + } + throw new AssertionError("Expected exception of class " + clazz + " but got " + t.getClass(), t); + } + throw new AssertionError("Expected exception of class " + clazz + " but did not throw."); + } + + /** + * Assert that the two (non-array) objects are equal. + */ + static void assertEquals(Object expected, Object actual) { + if (!Objects.equals(expected, actual)) { + throw new AssertionError("Expected:\n" + expected + "\nbut got:\n" + actual); + } + } + + /** + * Assert that the value is false, using the given message as an error otherwise. + */ + static void assertFalse(String message, boolean value) { + if (value) { + throw new AssertionError("Expected false: " + message); + } + } + + /** + * An interface used with {@link #assertThrows(Class, AssertThrows)}. + */ + @FunctionalInterface + interface AssertThrows { + + void run() throws Throwable; + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java new file mode 100644 index 000000000..27a545f84 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.example.integration; + +import static org.apache.arrow.memory.util.LargeMemoryUtil.checkedCastToInt; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import org.apache.arrow.flight.AsyncPutListener; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.JsonFileReader; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Validator; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.CommandLineParser; +import org.apache.commons.cli.DefaultParser; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.ParseException; + +/** + * A Flight client for integration testing. + */ +class IntegrationTestClient { + private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(IntegrationTestClient.class); + private final Options options; + + private IntegrationTestClient() { + options = new Options(); + options.addOption("j", "json", true, "json file"); + options.addOption("scenario", true, "The integration test scenario."); + options.addOption("host", true, "The host to connect to."); + options.addOption("port", true, "The port to connect to."); + } + + public static void main(String[] args) { + try { + new IntegrationTestClient().run(args); + } catch (ParseException e) { + fatalError("Invalid parameters", e); + } catch (IOException e) { + fatalError("Error accessing files", e); + } catch (Exception e) { + fatalError("Unknown error", e); + } + } + + private static void fatalError(String message, Throwable e) { + System.err.println(message); + System.err.println(e.getMessage()); + LOGGER.error(message, e); + System.exit(1); + } + + private void run(String[] args) throws Exception { + final CommandLineParser parser = new DefaultParser(); + final CommandLine cmd = parser.parse(options, args, false); + + final String host = cmd.getOptionValue("host", "localhost"); + final int port = Integer.parseInt(cmd.getOptionValue("port", "31337")); + + final Location defaultLocation = Location.forGrpcInsecure(host, port); + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + final FlightClient client = FlightClient.builder(allocator, defaultLocation).build()) { + + if (cmd.hasOption("scenario")) { + Scenarios.getScenario(cmd.getOptionValue("scenario")).client(allocator, defaultLocation, client); + } else { + final String inputPath = cmd.getOptionValue("j"); + testStream(allocator, defaultLocation, client, inputPath); + } + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + private static void testStream(BufferAllocator allocator, Location server, FlightClient client, String inputPath) + throws IOException { + // 1. Read data from JSON and upload to server. + FlightDescriptor descriptor = FlightDescriptor.path(inputPath); + try (JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator); + VectorSchemaRoot root = VectorSchemaRoot.create(reader.start(), allocator)) { + FlightClient.ClientStreamListener stream = client.startPut(descriptor, root, reader, + new AsyncPutListener() { + int counter = 0; + + @Override + public void onNext(PutResult val) { + final byte[] metadataRaw = new byte[checkedCastToInt(val.getApplicationMetadata().readableBytes())]; + val.getApplicationMetadata().readBytes(metadataRaw); + final String metadata = new String(metadataRaw, StandardCharsets.UTF_8); + if (!Integer.toString(counter).equals(metadata)) { + throw new RuntimeException( + String.format("Invalid ACK from server. Expected '%d' but got '%s'.", counter, metadata)); + } + counter++; + } + }); + int counter = 0; + while (reader.read(root)) { + final byte[] rawMetadata = Integer.toString(counter).getBytes(StandardCharsets.UTF_8); + final ArrowBuf metadata = allocator.buffer(rawMetadata.length); + metadata.writeBytes(rawMetadata); + // Transfers ownership of the buffer, so do not release it ourselves + stream.putNext(metadata); + root.clear(); + counter++; + } + stream.completed(); + // Need to call this, or exceptions from the server get swallowed + stream.getResult(); + } + + // 2. Get the ticket for the data. + FlightInfo info = client.getInfo(descriptor); + List<FlightEndpoint> endpoints = info.getEndpoints(); + if (endpoints.isEmpty()) { + throw new RuntimeException("No endpoints returned from Flight server."); + } + + for (FlightEndpoint endpoint : info.getEndpoints()) { + // 3. Download the data from the server. + List<Location> locations = endpoint.getLocations(); + if (locations.isEmpty()) { + throw new RuntimeException("No locations returned from Flight server."); + } + for (Location location : locations) { + System.out.println("Verifying location " + location.getUri()); + try (FlightClient readClient = FlightClient.builder(allocator, location).build(); + FlightStream stream = readClient.getStream(endpoint.getTicket()); + VectorSchemaRoot root = stream.getRoot(); + VectorSchemaRoot downloadedRoot = VectorSchemaRoot.create(root.getSchema(), allocator); + JsonFileReader reader = new JsonFileReader(new File(inputPath), allocator)) { + VectorLoader loader = new VectorLoader(downloadedRoot); + VectorUnloader unloader = new VectorUnloader(root); + + Schema jsonSchema = reader.start(); + Validator.compareSchemas(root.getSchema(), jsonSchema); + try (VectorSchemaRoot jsonRoot = VectorSchemaRoot.create(jsonSchema, allocator)) { + + while (stream.next()) { + try (final ArrowRecordBatch arb = unloader.getRecordBatch()) { + loader.load(arb); + if (reader.read(jsonRoot)) { + + // 4. Validate the data. + Validator.compareVectorSchemaRoot(jsonRoot, downloadedRoot); + jsonRoot.clear(); + } else { + throw new RuntimeException("Flight stream has more batches than JSON"); + } + } + } + + // Verify no more batches with data in JSON + // NOTE: Currently the C++ Flight server skips empty batches at end of the stream + if (reader.read(jsonRoot) && jsonRoot.getRowCount() > 0) { + throw new RuntimeException("JSON has more batches with than Flight stream"); + } + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java new file mode 100644 index 000000000..da336c502 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.example.integration; + +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.example.InMemoryStore; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.CommandLineParser; +import org.apache.commons.cli.DefaultParser; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.ParseException; + +/** + * Flight server for integration testing. + */ +class IntegrationTestServer { + private static final org.slf4j.Logger LOGGER = org.slf4j.LoggerFactory.getLogger(IntegrationTestServer.class); + private final Options options; + + private IntegrationTestServer() { + options = new Options(); + options.addOption("port", true, "The port to serve on."); + options.addOption("scenario", true, "The integration test scenario."); + } + + private void run(String[] args) throws Exception { + CommandLineParser parser = new DefaultParser(); + CommandLine cmd = parser.parse(options, args, false); + final int port = Integer.parseInt(cmd.getOptionValue("port", "31337")); + final Location location = Location.forGrpcInsecure("localhost", port); + + final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final FlightServer.Builder builder = FlightServer.builder().allocator(allocator).location(location); + + final FlightServer server; + if (cmd.hasOption("scenario")) { + final Scenario scenario = Scenarios.getScenario(cmd.getOptionValue("scenario")); + scenario.buildServer(builder); + server = builder.producer(scenario.producer(allocator, location)).build(); + server.start(); + } else { + final InMemoryStore store = new InMemoryStore(allocator, location); + server = FlightServer.builder(allocator, location, store).build().start(); + store.setLocation(Location.forGrpcInsecure("localhost", server.getPort())); + } + // Print out message for integration test script + System.out.println("Server listening on localhost:" + server.getPort()); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + try { + System.out.println("\nExiting..."); + AutoCloseables.close(server, allocator); + } catch (Exception e) { + e.printStackTrace(); + } + })); + + server.awaitTermination(); + } + + public static void main(String[] args) { + try { + new IntegrationTestServer().run(args); + } catch (ParseException e) { + fatalError("Error parsing arguments", e); + } catch (Exception e) { + fatalError("Runtime error", e); + } + } + + private static void fatalError(String message, Throwable e) { + System.err.println(message); + System.err.println(e.getMessage()); + LOGGER.error(message, e); + System.exit(1); + } + +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/MiddlewareScenario.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/MiddlewareScenario.java new file mode 100644 index 000000000..c710ce98b --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/MiddlewareScenario.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.example.integration; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; + +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallInfo; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightClientMiddleware; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightServerMiddleware; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.NoOpFlightProducer; +import org.apache.arrow.flight.RequestContext; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.types.pojo.Schema; + +/** + * Test an edge case in middleware: gRPC-Java consolidates headers and trailers if a call fails immediately. On the + * gRPC implementation side, we need to watch for this, or else we'll have a call with "no headers" if we only look + * for headers. + */ +final class MiddlewareScenario implements Scenario { + + private static final String HEADER = "x-middleware"; + private static final String EXPECTED_HEADER_VALUE = "expected value"; + private static final byte[] COMMAND_SUCCESS = "success".getBytes(StandardCharsets.UTF_8); + + @Override + public FlightProducer producer(BufferAllocator allocator, Location location) { + return new NoOpFlightProducer() { + @Override + public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { + if (descriptor.isCommand()) { + if (Arrays.equals(COMMAND_SUCCESS, descriptor.getCommand())) { + return new FlightInfo(new Schema(Collections.emptyList()), descriptor, Collections.emptyList(), -1, -1); + } + } + throw CallStatus.UNIMPLEMENTED.toRuntimeException(); + } + }; + } + + @Override + public void buildServer(FlightServer.Builder builder) { + builder.middleware(FlightServerMiddleware.Key.of("test"), new InjectingServerMiddleware.Factory()); + } + + @Override + public void client(BufferAllocator allocator, Location location, FlightClient ignored) throws Exception { + final ExtractingClientMiddleware.Factory factory = new ExtractingClientMiddleware.Factory(); + try (final FlightClient client = FlightClient.builder(allocator, location).intercept(factory).build()) { + // Should fail immediately + IntegrationAssertions.assertThrows(FlightRuntimeException.class, + () -> client.getInfo(FlightDescriptor.command(new byte[0]))); + if (!EXPECTED_HEADER_VALUE.equals(factory.extractedHeader)) { + throw new AssertionError( + "Expected to extract the header value '" + + EXPECTED_HEADER_VALUE + + "', but found: " + + factory.extractedHeader); + } + + // Should not fail + factory.extractedHeader = ""; + client.getInfo(FlightDescriptor.command(COMMAND_SUCCESS)); + if (!EXPECTED_HEADER_VALUE.equals(factory.extractedHeader)) { + throw new AssertionError( + "Expected to extract the header value '" + + EXPECTED_HEADER_VALUE + + "', but found: " + + factory.extractedHeader); + } + } + } + + /** Middleware that inserts a constant value in outgoing requests. */ + static class InjectingServerMiddleware implements FlightServerMiddleware { + + private final String headerValue; + + InjectingServerMiddleware(String incoming) { + this.headerValue = incoming; + } + + @Override + public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { + outgoingHeaders.insert("x-middleware", headerValue); + } + + @Override + public void onCallCompleted(CallStatus status) { + } + + @Override + public void onCallErrored(Throwable err) { + } + + /** The factory for the server middleware. */ + static class Factory implements FlightServerMiddleware.Factory<InjectingServerMiddleware> { + + @Override + public InjectingServerMiddleware onCallStarted(CallInfo info, CallHeaders incomingHeaders, + RequestContext context) { + String incoming = incomingHeaders.get(HEADER); + return new InjectingServerMiddleware(incoming == null ? "" : incoming); + } + } + } + + /** Middleware that pulls a value out of incoming responses. */ + static class ExtractingClientMiddleware implements FlightClientMiddleware { + + private final ExtractingClientMiddleware.Factory factory; + + public ExtractingClientMiddleware(ExtractingClientMiddleware.Factory factory) { + this.factory = factory; + } + + @Override + public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { + outgoingHeaders.insert(HEADER, EXPECTED_HEADER_VALUE); + } + + @Override + public void onHeadersReceived(CallHeaders incomingHeaders) { + this.factory.extractedHeader = incomingHeaders.get(HEADER); + } + + @Override + public void onCallCompleted(CallStatus status) { + } + + /** The factory for the client middleware. */ + static class Factory implements FlightClientMiddleware.Factory { + + String extractedHeader = null; + + @Override + public FlightClientMiddleware onCallStarted(CallInfo info) { + return new ExtractingClientMiddleware(this); + } + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java new file mode 100644 index 000000000..b3b962d2e --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.example.integration; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.BufferAllocator; + +/** + * A particular scenario in integration testing. + */ +interface Scenario { + + /** + * Construct the FlightProducer for a server in this scenario. + */ + FlightProducer producer(BufferAllocator allocator, Location location) throws Exception; + + /** + * Set any other server options. + */ + void buildServer(FlightServer.Builder builder) throws Exception; + + /** + * Run as the client in the scenario. + */ + void client(BufferAllocator allocator, Location location, FlightClient client) throws Exception; +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java new file mode 100644 index 000000000..cd9859b4f --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.example.integration; + +import java.util.Map; +import java.util.TreeMap; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; + +/** + * Scenarios for integration testing. + */ +final class Scenarios { + + private static Scenarios INSTANCE; + + private final Map<String, Supplier<Scenario>> scenarios; + + private Scenarios() { + scenarios = new TreeMap<>(); + scenarios.put("auth:basic_proto", AuthBasicProtoScenario::new); + scenarios.put("middleware", MiddlewareScenario::new); + } + + private static Scenarios getInstance() { + if (INSTANCE == null) { + INSTANCE = new Scenarios(); + } + return INSTANCE; + } + + static Scenario getScenario(String scenario) { + final Supplier<Scenario> ctor = getInstance().scenarios.get(scenario); + if (ctor == null) { + throw new IllegalArgumentException("Unknown integration test scenario: " + scenario); + } + return ctor.get(); + } + + // Utility methods for implementing tests. + + public static void main(String[] args) { + // Run scenarios one after the other + final Location location = Location.forGrpcInsecure("localhost", 31337); + for (final Map.Entry<String, Supplier<Scenario>> entry : getInstance().scenarios.entrySet()) { + System.out.println("Running test scenario: " + entry.getKey()); + final Scenario scenario = entry.getValue().get(); + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) { + final FlightServer.Builder builder = FlightServer + .builder(allocator, location, scenario.producer(allocator, location)); + scenario.buildServer(builder); + try (final FlightServer server = builder.build()) { + server.start(); + + try (final FlightClient client = FlightClient.builder(allocator, location).build()) { + scenario.client(allocator, location, client); + } + + server.shutdown(); + server.awaitTermination(1, TimeUnit.SECONDS); + System.out.println("Ran scenario " + entry.getKey()); + } + } catch (Exception e) { + System.out.println("Exception while running scenario " + entry.getKey()); + e.printStackTrace(); + } + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/AddWritableBuffer.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/AddWritableBuffer.java new file mode 100644 index 000000000..26e0274fa --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/AddWritableBuffer.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.grpc; + +import java.io.IOException; +import java.io.OutputStream; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.List; + +import io.netty.buffer.ByteBuf; + +/** + * Allow a user to add a ByteBuf based InputStream directly into GRPC WritableBuffer to avoid an + * extra copy. This could be solved in GRPC by adding a ByteBufListable interface on InputStream and + * letting BufferChainOutputStream take advantage of it. + */ +public class AddWritableBuffer { + + private static final Constructor<?> bufConstruct; + private static final Field bufferList; + private static final Field current; + private static final Method listAdd; + private static final Class<?> bufChainOut; + + static { + + Constructor<?> tmpConstruct = null; + Field tmpBufferList = null; + Field tmpCurrent = null; + Class<?> tmpBufChainOut = null; + Method tmpListAdd = null; + + try { + Class<?> nwb = Class.forName("io.grpc.netty.NettyWritableBuffer"); + + Constructor<?> tmpConstruct2 = nwb.getDeclaredConstructor(ByteBuf.class); + tmpConstruct2.setAccessible(true); + + Class<?> tmpBufChainOut2 = Class.forName("io.grpc.internal.MessageFramer$BufferChainOutputStream"); + + Field tmpBufferList2 = tmpBufChainOut2.getDeclaredField("bufferList"); + tmpBufferList2.setAccessible(true); + + Field tmpCurrent2 = tmpBufChainOut2.getDeclaredField("current"); + tmpCurrent2.setAccessible(true); + + Method tmpListAdd2 = List.class.getDeclaredMethod("add", Object.class); + + // output fields last. + tmpConstruct = tmpConstruct2; + tmpBufferList = tmpBufferList2; + tmpCurrent = tmpCurrent2; + tmpListAdd = tmpListAdd2; + tmpBufChainOut = tmpBufChainOut2; + + } catch (Exception ex) { + ex.printStackTrace(); + } + + bufConstruct = tmpConstruct; + bufferList = tmpBufferList; + current = tmpCurrent; + listAdd = tmpListAdd; + bufChainOut = tmpBufChainOut; + + } + + /** + * Add the provided ByteBuf to the gRPC BufferChainOutputStream if possible, else copy the buffer to the stream. + * @param buf The buffer to add. + * @param stream The Candidate OutputStream to add to. + * @param tryZeroCopy If true, try to zero-copy append the buffer to the stream. This may not succeed. + * @return True if buffer was zero-copy added to the stream. False if the buffer was copied. + * @throws IOException if the fast path is not enabled and there was an error copying the buffer to the stream. + */ + public static boolean add(ByteBuf buf, OutputStream stream, boolean tryZeroCopy) throws IOException { + if (!tryZeroCopy || !tryAddBuffer(buf, stream)) { + buf.getBytes(0, stream, buf.readableBytes()); + return false; + } + return true; + } + + private static boolean tryAddBuffer(ByteBuf buf, OutputStream stream) throws IOException { + + if (bufChainOut == null) { + return false; + } + + if (!stream.getClass().equals(bufChainOut)) { + return false; + } + + try { + if (current.get(stream) != null) { + return false; + } + + buf.retain(); + Object obj = bufConstruct.newInstance(buf); + Object list = bufferList.get(stream); + listAdd.invoke(list, obj); + current.set(stream, obj); + return true; + } catch (IllegalAccessException | IllegalArgumentException | InvocationTargetException | InstantiationException e) { + e.printStackTrace(); + return false; + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/CallCredentialAdapter.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/CallCredentialAdapter.java new file mode 100644 index 000000000..285ddb9ba --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/CallCredentialAdapter.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.grpc; + +import java.util.concurrent.Executor; +import java.util.function.Consumer; + +import org.apache.arrow.flight.CallHeaders; + +import io.grpc.CallCredentials; +import io.grpc.Metadata; + +/** + * Adapter class to utilize a CredentialWriter to implement Grpc CallCredentials. + */ +public class CallCredentialAdapter extends CallCredentials { + + private final Consumer<CallHeaders> credentialWriter; + + public CallCredentialAdapter(Consumer<CallHeaders> credentialWriter) { + this.credentialWriter = credentialWriter; + } + + @Override + public void applyRequestMetadata(RequestInfo requestInfo, Executor executor, MetadataApplier metadataApplier) { + executor.execute(() -> + { + final Metadata headers = new Metadata(); + credentialWriter.accept(new MetadataAdapter(headers)); + metadataApplier.apply(headers); + }); + } + + @Override + public void thisUsesUnstableApi() { + // Mandatory to override this to acknowledge that CallCredentials is Experimental. + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ClientInterceptorAdapter.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ClientInterceptorAdapter.java new file mode 100644 index 000000000..ae11e5260 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ClientInterceptorAdapter.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.grpc; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.arrow.flight.CallInfo; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightClientMiddleware; +import org.apache.arrow.flight.FlightClientMiddleware.Factory; +import org.apache.arrow.flight.FlightMethod; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightStatusCode; + +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; +import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; + +/** + * An adapter between Flight client middleware and gRPC interceptors. + * + * <p>This is implemented as a single gRPC interceptor that runs all Flight client middleware sequentially. + */ +public class ClientInterceptorAdapter implements ClientInterceptor { + + private final List<Factory> factories; + + public ClientInterceptorAdapter(List<Factory> factories) { + this.factories = factories; + } + + @Override + public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(MethodDescriptor<ReqT, RespT> method, + CallOptions callOptions, Channel next) { + final List<FlightClientMiddleware> middleware = new ArrayList<>(); + final CallInfo info = new CallInfo(FlightMethod.fromProtocol(method.getFullMethodName())); + + try { + for (final Factory factory : factories) { + middleware.add(factory.onCallStarted(info)); + } + } catch (FlightRuntimeException e) { + // Explicitly propagate + throw e; + } catch (StatusRuntimeException e) { + throw StatusUtils.fromGrpcRuntimeException(e); + } catch (RuntimeException e) { + throw StatusUtils.fromThrowable(e); + } + return new FlightClientCall<>(next.newCall(method, callOptions), middleware); + } + + /** + * The ClientCallListener which hooks into the gRPC request cycle and actually runs middleware at certain points. + */ + private static class FlightClientCallListener<RespT> extends SimpleForwardingClientCallListener<RespT> { + + private final List<FlightClientMiddleware> middleware; + boolean receivedHeaders; + + public FlightClientCallListener(ClientCall.Listener<RespT> responseListener, + List<FlightClientMiddleware> middleware) { + super(responseListener); + this.middleware = middleware; + receivedHeaders = false; + } + + @Override + public void onHeaders(Metadata headers) { + receivedHeaders = true; + final MetadataAdapter adapter = new MetadataAdapter(headers); + try { + middleware.forEach(m -> m.onHeadersReceived(adapter)); + } finally { + // Make sure to always call the gRPC callback to avoid interrupting the gRPC request cycle + super.onHeaders(headers); + } + } + + @Override + public void onClose(Status status, Metadata trailers) { + try { + if (!receivedHeaders) { + // gRPC doesn't always send response headers if the call errors or completes immediately, but instead + // consolidates them with the trailers. If we never got headers, assume this happened and run the header + // callback with the trailers. + final MetadataAdapter adapter = new MetadataAdapter(trailers); + middleware.forEach(m -> m.onHeadersReceived(adapter)); + } + final CallStatus flightStatus = StatusUtils.fromGrpcStatusAndTrailers(status, trailers); + middleware.forEach(m -> m.onCallCompleted(flightStatus)); + } finally { + // Make sure to always call the gRPC callback to avoid interrupting the gRPC request cycle + super.onClose(status, trailers); + } + } + } + + /** + * The gRPC ClientCall which hooks into the gRPC request cycle and injects our ClientCallListener. + */ + private static class FlightClientCall<ReqT, RespT> extends SimpleForwardingClientCall<ReqT, RespT> { + + private final List<FlightClientMiddleware> middleware; + + public FlightClientCall(ClientCall<ReqT, RespT> clientCall, List<FlightClientMiddleware> middleware) { + super(clientCall); + this.middleware = middleware; + } + + @Override + public void start(Listener<RespT> responseListener, Metadata headers) { + final MetadataAdapter metadataAdapter = new MetadataAdapter(headers); + middleware.forEach(m -> m.onBeforeSendingHeaders(metadataAdapter)); + + super.start(new FlightClientCallListener<>(responseListener, middleware), headers); + } + + @Override + public void cancel(String message, Throwable cause) { + final CallStatus flightStatus = new CallStatus(FlightStatusCode.CANCELLED, cause, message, null); + middleware.forEach(m -> m.onCallCompleted(flightStatus)); + super.cancel(message, cause); + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ContextPropagatingExecutorService.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ContextPropagatingExecutorService.java new file mode 100644 index 000000000..8f6bb6db2 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ContextPropagatingExecutorService.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.grpc; + +import java.util.Collection; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.stream.Collectors; + +import io.grpc.Context; + +/** + * An {@link ExecutorService} that propagates the {@link Context}. + * + * <p>Context is used to propagate per-call state, like the authenticated user, between threads (as gRPC makes no + * guarantees about what thread things execute on). This wrapper makes it easy to preserve this when using an Executor. + * The Context itself is immutable, so it is thread-safe. + */ +public class ContextPropagatingExecutorService implements ExecutorService { + + private final ExecutorService delegate; + + public ContextPropagatingExecutorService(ExecutorService delegate) { + this.delegate = delegate; + } + + // These are just delegate methods. + + @Override + public void shutdown() { + delegate.shutdown(); + } + + @Override + public List<Runnable> shutdownNow() { + return delegate.shutdownNow(); + } + + @Override + public boolean isShutdown() { + return delegate.isShutdown(); + } + + @Override + public boolean isTerminated() { + return delegate.isTerminated(); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return delegate.awaitTermination(timeout, unit); + } + + // These are delegate methods that wrap the submitted task in the current gRPC Context. + + @Override + public <T> Future<T> submit(Callable<T> task) { + return delegate.submit(Context.current().wrap(task)); + } + + @Override + public <T> Future<T> submit(Runnable task, T result) { + return delegate.submit(Context.current().wrap(task), result); + } + + @Override + public Future<?> submit(Runnable task) { + return delegate.submit(Context.current().wrap(task)); + } + + @Override + public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks) throws InterruptedException { + return delegate.invokeAll(tasks.stream().map(Context.current()::wrap).collect(Collectors.toList())); + } + + @Override + public <T> List<Future<T>> invokeAll(Collection<? extends Callable<T>> tasks, long timeout, + TimeUnit unit) throws InterruptedException { + return delegate.invokeAll(tasks.stream().map(Context.current()::wrap).collect(Collectors.toList()), timeout, unit); + } + + @Override + public <T> T invokeAny(Collection<? extends Callable<T>> tasks) throws InterruptedException, ExecutionException { + return delegate.invokeAny(tasks.stream().map(Context.current()::wrap).collect(Collectors.toList())); + } + + @Override + public <T> T invokeAny(Collection<? extends Callable<T>> tasks, long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + return delegate.invokeAny(tasks.stream().map(Context.current()::wrap).collect(Collectors.toList()), timeout, unit); + } + + @Override + public void execute(Runnable command) { + delegate.execute(Context.current().wrap(command)); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/CredentialCallOption.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/CredentialCallOption.java new file mode 100644 index 000000000..3bde7a835 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/CredentialCallOption.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.grpc; + +import java.util.function.Consumer; + +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallOptions; + +import io.grpc.stub.AbstractStub; + +/** + * Method option for supplying credentials to method calls. + */ +public class CredentialCallOption implements CallOptions.GrpcCallOption { + private final Consumer<CallHeaders> credentialWriter; + + public CredentialCallOption(Consumer<CallHeaders> credentialWriter) { + this.credentialWriter = credentialWriter; + } + + @Override + public <T extends AbstractStub<T>> T wrapStub(T stub) { + return stub.withCallCredentials(new CallCredentialAdapter(credentialWriter)); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java new file mode 100644 index 000000000..5f8a71576 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/GetReadableBuffer.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.grpc; + +import java.io.IOException; +import java.io.InputStream; +import java.lang.reflect.Field; + +import org.apache.arrow.memory.ArrowBuf; + +import com.google.common.base.Throwables; +import com.google.common.io.ByteStreams; + +import io.grpc.internal.ReadableBuffer; + +/** + * Enable access to ReadableBuffer directly to copy data from a BufferInputStream into a target + * ByteBuffer/ByteBuf. + * + * <p>This could be solved by BufferInputStream exposing Drainable. + */ +public class GetReadableBuffer { + + private static final Field READABLE_BUFFER; + private static final Class<?> BUFFER_INPUT_STREAM; + + static { + Field tmpField = null; + Class<?> tmpClazz = null; + try { + Class<?> clazz = Class.forName("io.grpc.internal.ReadableBuffers$BufferInputStream"); + + Field f = clazz.getDeclaredField("buffer"); + f.setAccessible(true); + // don't set until we've gotten past all exception cases. + tmpField = f; + tmpClazz = clazz; + } catch (Exception e) { + e.printStackTrace(); + } + READABLE_BUFFER = tmpField; + BUFFER_INPUT_STREAM = tmpClazz; + } + + /** + * Extracts the ReadableBuffer for the given input stream. + * + * @param is Must be an instance of io.grpc.internal.ReadableBuffers$BufferInputStream or + * null will be returned. + */ + public static ReadableBuffer getReadableBuffer(InputStream is) { + + if (BUFFER_INPUT_STREAM == null || !is.getClass().equals(BUFFER_INPUT_STREAM)) { + return null; + } + + try { + return (ReadableBuffer) READABLE_BUFFER.get(is); + } catch (Exception ex) { + throw Throwables.propagate(ex); + } + } + + /** + * Helper method to read a gRPC-provided InputStream into an ArrowBuf. + * @param stream The stream to read from. Should be an instance of {@link #BUFFER_INPUT_STREAM}. + * @param buf The buffer to read into. + * @param size The number of bytes to read. + * @param fastPath Whether to enable the fast path (i.e. detect whether the stream is a {@link #BUFFER_INPUT_STREAM}). + * @throws IOException if there is an error reading form the stream + */ + public static void readIntoBuffer(final InputStream stream, final ArrowBuf buf, final int size, + final boolean fastPath) throws IOException { + ReadableBuffer readableBuffer = fastPath ? getReadableBuffer(stream) : null; + if (readableBuffer != null) { + readableBuffer.readBytes(buf.nioBuffer(0, size)); + } else { + byte[] heapBytes = new byte[size]; + ByteStreams.readFully(stream, heapBytes); + buf.writeBytes(heapBytes); + } + buf.writerIndex(size); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/MetadataAdapter.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/MetadataAdapter.java new file mode 100644 index 000000000..4327f0ca8 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/MetadataAdapter.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.grpc; + +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import org.apache.arrow.flight.CallHeaders; + +import io.grpc.Metadata; +import io.grpc.Metadata.Key; + +/** + * A mutable adapter between the gRPC Metadata object and the Flight headers interface. + * + * <p>This allows us to present the headers (metadata) from gRPC without copying to/from our own object. + */ +public class MetadataAdapter implements CallHeaders { + + private final Metadata metadata; + + public MetadataAdapter(Metadata metadata) { + this.metadata = metadata; + } + + @Override + public String get(String key) { + return this.metadata.get(Key.of(key, Metadata.ASCII_STRING_MARSHALLER)); + } + + @Override + public byte[] getByte(String key) { + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + return this.metadata.get(Key.of(key, Metadata.BINARY_BYTE_MARSHALLER)); + } + return get(key).getBytes(); + } + + @Override + public Iterable<String> getAll(String key) { + return this.metadata.getAll(Key.of(key, Metadata.ASCII_STRING_MARSHALLER)); + } + + @Override + public Iterable<byte[]> getAllByte(String key) { + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + return this.metadata.getAll(Key.of(key, Metadata.BINARY_BYTE_MARSHALLER)); + } + return StreamSupport.stream(getAll(key).spliterator(), false) + .map(String::getBytes).collect(Collectors.toList()); + } + + @Override + public void insert(String key, String value) { + this.metadata.put(Key.of(key, Metadata.ASCII_STRING_MARSHALLER), value); + } + + @Override + public void insert(String key, byte[] value) { + this.metadata.put(Key.of(key, Metadata.BINARY_BYTE_MARSHALLER), value); + } + + @Override + public Set<String> keys() { + return new HashSet<>(this.metadata.keys()); + } + + @Override + public boolean containsKey(String key) { + if (key.endsWith("-bin")) { + final Key<?> grpcKey = Key.of(key, Metadata.BINARY_BYTE_MARSHALLER); + return this.metadata.containsKey(grpcKey); + } + final Key<?> grpcKey = Key.of(key, Metadata.ASCII_STRING_MARSHALLER); + return this.metadata.containsKey(grpcKey); + } + + public String toString() { + return this.metadata.toString(); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/RequestContextAdapter.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/RequestContextAdapter.java new file mode 100644 index 000000000..9be4d12b9 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/RequestContextAdapter.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.grpc; + +import java.util.HashMap; +import java.util.Set; + +import org.apache.arrow.flight.RequestContext; + +import io.grpc.Context; + + +/** + * Adapter for holding key value pairs. + */ +public class RequestContextAdapter implements RequestContext { + public static final Context.Key<RequestContext> REQUEST_CONTEXT_KEY = + Context.key("arrow-flight-request-context"); + private final HashMap<String, String> map = new HashMap<>(); + + @Override + public void put(String key, String value) { + if (map.putIfAbsent(key, value) != null) { + throw new IllegalArgumentException("Duplicate write to a RequestContext at key " + key + " not allowed."); + } + } + + @Override + public String get(String key) { + return map.get(key); + } + + @Override + public Set<String> keySet() { + return map.keySet(); + } + + @Override + public String remove(String key) { + return map.remove(key); + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ServerInterceptorAdapter.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ServerInterceptorAdapter.java new file mode 100644 index 000000000..ddf43ff84 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/ServerInterceptorAdapter.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.grpc; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import org.apache.arrow.flight.CallInfo; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightMethod; +import org.apache.arrow.flight.FlightProducer.CallContext; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightServerMiddleware; +import org.apache.arrow.flight.FlightServerMiddleware.Factory; +import org.apache.arrow.flight.FlightServerMiddleware.Key; + +import io.grpc.Context; +import io.grpc.Contexts; +import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCall.Listener; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; + +/** + * An adapter between Flight middleware and a gRPC interceptor. + * + * <p>This is implemented as a single gRPC interceptor that runs all Flight server middleware sequentially. Flight + * middleware instances are stored in the gRPC Context so their state is accessible later. + */ +public class ServerInterceptorAdapter implements ServerInterceptor { + + /** + * A combination of a middleware Key and factory. + * + * @param <T> The middleware type. + */ + public static class KeyFactory<T extends FlightServerMiddleware> { + + private final FlightServerMiddleware.Key<T> key; + private final FlightServerMiddleware.Factory<T> factory; + + public KeyFactory(Key<T> key, Factory<T> factory) { + this.key = key; + this.factory = factory; + } + } + + /** + * The {@link Context.Key} that stores the Flight middleware active for a particular call. + * + * <p>Applications should not use this directly. Instead, see {@link CallContext#getMiddleware(Key)}. + */ + public static final Context.Key<Map<FlightServerMiddleware.Key<?>, FlightServerMiddleware>> SERVER_MIDDLEWARE_KEY = + Context.key("arrow.flight.server_middleware"); + private final List<KeyFactory<?>> factories; + + public ServerInterceptorAdapter(List<KeyFactory<?>> factories) { + this.factories = factories; + } + + @Override + public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call, Metadata headers, + ServerCallHandler<ReqT, RespT> next) { + final CallInfo info = new CallInfo(FlightMethod.fromProtocol(call.getMethodDescriptor().getFullMethodName())); + final List<FlightServerMiddleware> middleware = new ArrayList<>(); + // Use LinkedHashMap to preserve insertion order + final Map<FlightServerMiddleware.Key<?>, FlightServerMiddleware> middlewareMap = new LinkedHashMap<>(); + final MetadataAdapter headerAdapter = new MetadataAdapter(headers); + final RequestContextAdapter requestContextAdapter = new RequestContextAdapter(); + for (final KeyFactory<?> factory : factories) { + final FlightServerMiddleware m; + try { + m = factory.factory.onCallStarted(info, headerAdapter, requestContextAdapter); + } catch (FlightRuntimeException e) { + // Cancel call + call.close(StatusUtils.toGrpcStatus(e.status()), new Metadata()); + return new Listener<ReqT>() {}; + } + middleware.add(m); + middlewareMap.put(factory.key, m); + } + + // Inject the middleware into the context so RPC method implementations can communicate with middleware instances + final Context contextWithMiddlewareAndRequestsOptions = Context.current() + .withValue(SERVER_MIDDLEWARE_KEY, Collections.unmodifiableMap(middlewareMap)) + .withValue(RequestContextAdapter.REQUEST_CONTEXT_KEY, requestContextAdapter); + + final SimpleForwardingServerCall<ReqT, RespT> forwardingServerCall = new SimpleForwardingServerCall<ReqT, RespT>( + call) { + boolean sentHeaders = false; + + @Override + public void sendHeaders(Metadata headers) { + sentHeaders = true; + try { + final MetadataAdapter headerAdapter = new MetadataAdapter(headers); + middleware.forEach(m -> m.onBeforeSendingHeaders(headerAdapter)); + } finally { + // Make sure to always call the gRPC callback to avoid interrupting the gRPC request cycle + super.sendHeaders(headers); + } + } + + @Override + public void close(Status status, Metadata trailers) { + try { + if (!sentHeaders) { + // gRPC doesn't always send response headers if the call errors or completes immediately + final MetadataAdapter headerAdapter = new MetadataAdapter(trailers); + middleware.forEach(m -> m.onBeforeSendingHeaders(headerAdapter)); + } + } finally { + // Make sure to always call the gRPC callback to avoid interrupting the gRPC request cycle + super.close(status, trailers); + } + + final CallStatus flightStatus = StatusUtils.fromGrpcStatus(status); + middleware.forEach(m -> m.onCallCompleted(flightStatus)); + } + }; + return Contexts.interceptCall(contextWithMiddlewareAndRequestsOptions, forwardingServerCall, headers, next); + + } +} diff --git a/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java new file mode 100644 index 000000000..55e841864 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/main/java/org/apache/arrow/flight/grpc/StatusUtils.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.grpc; + +import java.util.Iterator; +import java.util.Objects; +import java.util.function.Function; + +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.ErrorFlightMetadata; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightStatusCode; + +import io.grpc.InternalMetadata; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.Status.Code; +import io.grpc.StatusException; +import io.grpc.StatusRuntimeException; + +/** + * Utilities to adapt gRPC and Flight status objects. + * + * <p>NOT A PUBLIC CLASS, interface is not guaranteed to remain stable. + */ +public class StatusUtils { + + private StatusUtils() { + throw new AssertionError("Do not instantiate this class."); + } + + /** + * Convert from a Flight status code to a gRPC status code. + */ + public static Status.Code toGrpcStatusCode(FlightStatusCode code) { + switch (code) { + case OK: + return Code.OK; + case UNKNOWN: + return Code.UNKNOWN; + case INTERNAL: + return Code.INTERNAL; + case INVALID_ARGUMENT: + return Code.INVALID_ARGUMENT; + case TIMED_OUT: + return Code.DEADLINE_EXCEEDED; + case NOT_FOUND: + return Code.NOT_FOUND; + case ALREADY_EXISTS: + return Code.ALREADY_EXISTS; + case CANCELLED: + return Code.CANCELLED; + case UNAUTHENTICATED: + return Code.UNAUTHENTICATED; + case UNAUTHORIZED: + return Code.PERMISSION_DENIED; + case UNIMPLEMENTED: + return Code.UNIMPLEMENTED; + case UNAVAILABLE: + return Code.UNAVAILABLE; + default: + return Code.UNKNOWN; + } + } + + /** + * Convert from a gRPC status code to a Flight status code. + */ + public static FlightStatusCode fromGrpcStatusCode(Status.Code code) { + switch (code) { + case OK: + return FlightStatusCode.OK; + case CANCELLED: + return FlightStatusCode.CANCELLED; + case UNKNOWN: + return FlightStatusCode.UNKNOWN; + case INVALID_ARGUMENT: + return FlightStatusCode.INVALID_ARGUMENT; + case DEADLINE_EXCEEDED: + return FlightStatusCode.TIMED_OUT; + case NOT_FOUND: + return FlightStatusCode.NOT_FOUND; + case ALREADY_EXISTS: + return FlightStatusCode.ALREADY_EXISTS; + case PERMISSION_DENIED: + return FlightStatusCode.UNAUTHORIZED; + case RESOURCE_EXHAUSTED: + return FlightStatusCode.INVALID_ARGUMENT; + case FAILED_PRECONDITION: + return FlightStatusCode.INVALID_ARGUMENT; + case ABORTED: + return FlightStatusCode.INTERNAL; + case OUT_OF_RANGE: + return FlightStatusCode.INVALID_ARGUMENT; + case UNIMPLEMENTED: + return FlightStatusCode.UNIMPLEMENTED; + case INTERNAL: + return FlightStatusCode.INTERNAL; + case UNAVAILABLE: + return FlightStatusCode.UNAVAILABLE; + case DATA_LOSS: + return FlightStatusCode.INTERNAL; + case UNAUTHENTICATED: + return FlightStatusCode.UNAUTHENTICATED; + default: + return FlightStatusCode.UNKNOWN; + } + } + + /** Create Metadata Key for binary metadata. */ + static Metadata.Key<byte[]> keyOfBinary(String name) { + return Metadata.Key.of(name, Metadata.BINARY_BYTE_MARSHALLER); + } + + /** Create Metadata Key for ascii metadata. */ + static Metadata.Key<String> keyOfAscii(String name) { + // Use InternalMetadata for keys that start with ":", e.g. ":status". See ARROW-14014. + return InternalMetadata.keyOf(name, Metadata.ASCII_STRING_MARSHALLER); + } + + /** Convert from a gRPC Status & trailers to a Flight status. */ + public static CallStatus fromGrpcStatusAndTrailers(Status status, Metadata trailers) { + // gRPC may not always have trailers - this happens when the server internally generates an error, which is rare, + // but can happen. + final ErrorFlightMetadata errorMetadata = trailers == null ? null : parseTrailers(trailers); + return new CallStatus( + fromGrpcStatusCode(status.getCode()), + status.getCause(), + status.getDescription(), + errorMetadata); + } + + /** Convert from a gRPC status to a Flight status. */ + public static CallStatus fromGrpcStatus(Status status) { + return new CallStatus( + fromGrpcStatusCode(status.getCode()), + status.getCause(), + status.getDescription(), + null); + } + + /** Convert from a Flight status to a gRPC status. */ + public static Status toGrpcStatus(CallStatus status) { + return toGrpcStatusCode(status.code()).toStatus().withDescription(status.description()).withCause(status.cause()); + } + + /** Convert from a gRPC exception to a Flight exception. */ + public static FlightRuntimeException fromGrpcRuntimeException(StatusRuntimeException sre) { + return fromGrpcStatusAndTrailers(sre.getStatus(), sre.getTrailers()).toRuntimeException(); + } + + /** Convert gRPC trailers into Flight error metadata. */ + private static ErrorFlightMetadata parseTrailers(Metadata trailers) { + ErrorFlightMetadata metadata = new ErrorFlightMetadata(); + for (String key : trailers.keys()) { + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + metadata.insert(key, trailers.get(keyOfBinary(key))); + } else { + metadata.insert(key, Objects.requireNonNull(trailers.get(keyOfAscii(key))).getBytes()); + } + } + return metadata; + } + + /** + * Convert arbitrary exceptions to a {@link FlightRuntimeException}. + */ + public static FlightRuntimeException fromThrowable(Throwable t) { + if (t instanceof StatusRuntimeException) { + return fromGrpcRuntimeException((StatusRuntimeException) t); + } else if (t instanceof FlightRuntimeException) { + return (FlightRuntimeException) t; + } + return CallStatus.UNKNOWN.withCause(t).withDescription(t.getMessage()).toRuntimeException(); + } + + /** + * Convert arbitrary exceptions to a {@link StatusRuntimeException} or {@link StatusException}. + * + * <p>Such exceptions can be passed to {@link io.grpc.stub.StreamObserver#onError(Throwable)} and will give the client + * a reasonable error message. + */ + public static Throwable toGrpcException(Throwable ex) { + if (ex instanceof StatusRuntimeException) { + return ex; + } else if (ex instanceof StatusException) { + return ex; + } else if (ex instanceof FlightRuntimeException) { + final FlightRuntimeException fre = (FlightRuntimeException) ex; + if (fre.status().metadata() != null) { + Metadata trailers = toGrpcMetadata(fre.status().metadata()); + return new StatusRuntimeException(toGrpcStatus(fre.status()), trailers); + } + return toGrpcStatus(fre.status()).asRuntimeException(); + } + return Status.INTERNAL.withCause(ex).withDescription("There was an error servicing your request.") + .asRuntimeException(); + } + + private static Metadata toGrpcMetadata(ErrorFlightMetadata metadata) { + final Metadata trailers = new Metadata(); + for (final String key : metadata.keys()) { + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + trailers.put(keyOfBinary(key), metadata.getByte(key)); + } else { + trailers.put(keyOfAscii(key), metadata.get(key)); + } + } + return trailers; + } + + /** + * Maps a transformation function to the elements of an iterator, while wrapping exceptions in {@link + * FlightRuntimeException}. + */ + public static <FROM, TO> Iterator<TO> wrapIterator(Iterator<FROM> fromIterator, + Function<? super FROM, ? extends TO> transformer) { + Objects.requireNonNull(fromIterator); + Objects.requireNonNull(transformer); + return new Iterator<TO>() { + @Override + public boolean hasNext() { + try { + return fromIterator.hasNext(); + } catch (StatusRuntimeException e) { + throw fromGrpcRuntimeException(e); + } + } + + @Override + public TO next() { + try { + return transformer.apply(fromIterator.next()); + } catch (StatusRuntimeException e) { + throw fromGrpcRuntimeException(e); + } + } + }; + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/FlightTestUtil.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/FlightTestUtil.java new file mode 100644 index 000000000..cd043b639 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/FlightTestUtil.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.File; +import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Random; +import java.util.function.Function; + +import org.junit.Assert; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.function.Executable; + +/** + * Utility methods and constants for testing flight servers. + */ +public class FlightTestUtil { + + private static final Random RANDOM = new Random(); + + public static final String LOCALHOST = "localhost"; + public static final String TEST_DATA_ENV_VAR = "ARROW_TEST_DATA"; + public static final String TEST_DATA_PROPERTY = "arrow.test.dataRoot"; + + /** + * Returns a a FlightServer (actually anything that is startable) + * that has been started bound to a random port. + */ + public static <T> T getStartedServer(Function<Location, T> newServerFromLocation) throws IOException { + IOException lastThrown = null; + T server = null; + for (int x = 0; x < 3; x++) { + final int port = 49152 + RANDOM.nextInt(5000); + final Location location = Location.forGrpcInsecure(LOCALHOST, port); + lastThrown = null; + try { + server = newServerFromLocation.apply(location); + try { + server.getClass().getMethod("start").invoke(server); + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new IllegalArgumentException("Couldn't call start method on object.", e); + } + break; + } catch (InvocationTargetException e) { + if (e.getTargetException() instanceof IOException) { + lastThrown = (IOException) e.getTargetException(); + } else { + throw (RuntimeException) e.getTargetException(); + } + } + } + if (lastThrown != null) { + throw lastThrown; + } + return server; + } + + static Path getTestDataRoot() { + String path = System.getenv(TEST_DATA_ENV_VAR); + if (path == null) { + path = System.getProperty(TEST_DATA_PROPERTY); + } + return Paths.get(Objects.requireNonNull(path, + String.format("Could not find test data path. Set the environment variable %s or the JVM property %s.", + TEST_DATA_ENV_VAR, TEST_DATA_PROPERTY))); + } + + static Path getFlightTestDataRoot() { + return getTestDataRoot().resolve("flight"); + } + + static Path exampleTlsRootCert() { + return getFlightTestDataRoot().resolve("root-ca.pem"); + } + + static List<CertKeyPair> exampleTlsCerts() { + final Path root = getFlightTestDataRoot(); + return Arrays.asList(new CertKeyPair(root.resolve("cert0.pem").toFile(), root.resolve("cert0.pkcs1").toFile()), + new CertKeyPair(root.resolve("cert1.pem").toFile(), root.resolve("cert1.pkcs1").toFile())); + } + + static boolean isEpollAvailable() { + try { + Class<?> epoll = Class.forName("io.netty.channel.epoll.Epoll"); + return (Boolean) epoll.getMethod("isAvailable").invoke(null); + } catch (ClassNotFoundException | NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + return false; + } + } + + static boolean isKqueueAvailable() { + try { + Class<?> kqueue = Class.forName("io.netty.channel.kqueue.KQueue"); + return (Boolean) kqueue.getMethod("isAvailable").invoke(null); + } catch (ClassNotFoundException | NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + return false; + } + } + + static boolean isNativeTransportAvailable() { + return isEpollAvailable() || isKqueueAvailable(); + } + + /** + * Assert that the given runnable fails with a Flight exception of the given code. + * @param code The expected Flight status code. + * @param r The code to run. + * @return The thrown status. + */ + public static CallStatus assertCode(FlightStatusCode code, Executable r) { + final FlightRuntimeException ex = Assertions.assertThrows(FlightRuntimeException.class, r); + Assert.assertEquals(code, ex.status().code()); + return ex.status(); + } + + public static class CertKeyPair { + + public final File cert; + public final File key; + + public CertKeyPair(File cert, File key) { + this.cert = cert; + this.key = key; + } + } + + private FlightTestUtil() { + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java new file mode 100644 index 000000000..c7b3321af --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java @@ -0,0 +1,329 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.Arrays; +import java.util.Collections; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; + +import org.apache.arrow.flight.FlightClient.PutListener; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Test; + +/** + * Tests for application-specific metadata support in Flight. + */ +public class TestApplicationMetadata { + + // The command used to trigger the test for ARROW-6136. + private static final byte[] COMMAND_ARROW_6136 = "ARROW-6136".getBytes(); + // The expected error message. + private static final String MESSAGE_ARROW_6136 = "The stream should not be double-closed."; + + /** + * Ensure that a client can read the metadata sent from the server. + */ + @Test + // This test is consistently flaky on CI, unfortunately. + @Ignore + public void retrieveMetadata() { + test((allocator, client) -> { + try (final FlightStream stream = client.getStream(new Ticket(new byte[0]))) { + byte i = 0; + while (stream.next()) { + final IntVector vector = (IntVector) stream.getRoot().getVector("a"); + Assert.assertEquals(1, vector.getValueCount()); + Assert.assertEquals(10, vector.get(0)); + Assert.assertEquals(i, stream.getLatestMetadata().getByte(0)); + i++; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + /** ARROW-6136: make sure that the Flight implementation doesn't double-close the server-to-client stream. */ + @Test + public void arrow6136() { + final Schema schema = new Schema(Collections.emptyList()); + test((allocator, client) -> { + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + final FlightDescriptor descriptor = FlightDescriptor.command(COMMAND_ARROW_6136); + + final PutListener listener = new SyncPutListener(); + final FlightClient.ClientStreamListener writer = client.startPut(descriptor, root, listener); + // Must attempt to retrieve the result to get any server-side errors. + final CallStatus status = FlightTestUtil.assertCode(FlightStatusCode.INTERNAL, writer::getResult); + Assert.assertEquals(MESSAGE_ARROW_6136, status.description()); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + /** + * Ensure that a client can send metadata to the server. + */ + @Test + @Ignore + public void uploadMetadataAsync() { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); + test((allocator, client) -> { + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + final FlightDescriptor descriptor = FlightDescriptor.path("test"); + + final PutListener listener = new AsyncPutListener() { + int counter = 0; + + @Override + public void onNext(PutResult val) { + Assert.assertNotNull(val); + Assert.assertEquals(counter, val.getApplicationMetadata().getByte(0)); + counter++; + } + }; + final FlightClient.ClientStreamListener writer = client.startPut(descriptor, root, listener); + + root.allocateNew(); + for (byte i = 0; i < 10; i++) { + final IntVector vector = (IntVector) root.getVector("a"); + final ArrowBuf metadata = allocator.buffer(1); + metadata.writeByte(i); + vector.set(0, 10); + vector.setValueCount(1); + root.setRowCount(1); + writer.putNext(metadata); + } + writer.completed(); + // Must attempt to retrieve the result to get any server-side errors. + writer.getResult(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + /** + * Ensure that a client can send metadata to the server. Uses the synchronous API. + */ + @Test + @Ignore + public void uploadMetadataSync() { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); + test((allocator, client) -> { + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + final SyncPutListener listener = new SyncPutListener()) { + final FlightDescriptor descriptor = FlightDescriptor.path("test"); + final FlightClient.ClientStreamListener writer = client.startPut(descriptor, root, listener); + + root.allocateNew(); + for (byte i = 0; i < 10; i++) { + final IntVector vector = (IntVector) root.getVector("a"); + final ArrowBuf metadata = allocator.buffer(1); + metadata.writeByte(i); + vector.set(0, 10); + vector.setValueCount(1); + root.setRowCount(1); + writer.putNext(metadata); + try (final PutResult message = listener.poll(5000, TimeUnit.SECONDS)) { + Assert.assertNotNull(message); + Assert.assertEquals(i, message.getApplicationMetadata().getByte(0)); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } + writer.completed(); + // Must attempt to retrieve the result to get any server-side errors. + writer.getResult(); + } + }); + } + + /** + * Make sure that a {@link SyncPutListener} properly reclaims memory if ignored. + */ + @Test + @Ignore + public void syncMemoryReclaimed() { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); + test((allocator, client) -> { + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + final SyncPutListener listener = new SyncPutListener()) { + final FlightDescriptor descriptor = FlightDescriptor.path("test"); + final FlightClient.ClientStreamListener writer = client.startPut(descriptor, root, listener); + + root.allocateNew(); + for (byte i = 0; i < 10; i++) { + final IntVector vector = (IntVector) root.getVector("a"); + final ArrowBuf metadata = allocator.buffer(1); + metadata.writeByte(i); + vector.set(0, 10); + vector.setValueCount(1); + root.setRowCount(1); + writer.putNext(metadata); + } + writer.completed(); + // Must attempt to retrieve the result to get any server-side errors. + writer.getResult(); + } + }); + } + + /** + * ARROW-9221: Flight copies metadata from the byte buffer of a Protobuf ByteString, + * which is in big-endian by default, thus mangling metadata. + */ + @Test + public void testMetadataEndianness() throws Exception { + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final BufferAllocator serverAllocator = allocator.newChildAllocator("flight-server", 0, Long.MAX_VALUE); + final FlightServer server = FlightTestUtil.getStartedServer( + (location) -> FlightServer + .builder(serverAllocator, location, new EndianFlightProducer(serverAllocator)) + .build()); + final FlightClient client = FlightClient.builder(allocator, server.getLocation()).build()) { + final Schema schema = new Schema(Collections.emptyList()); + final FlightDescriptor descriptor = FlightDescriptor.command(new byte[0]); + try (final SyncPutListener reader = new SyncPutListener(); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + final FlightClient.ClientStreamListener writer = client.startPut(descriptor, root, reader); + writer.completed(); + try (final PutResult metadata = reader.read()) { + Assert.assertEquals(16, metadata.getApplicationMetadata().readableBytes()); + byte[] bytes = new byte[16]; + metadata.getApplicationMetadata().readBytes(bytes); + Assert.assertArrayEquals(EndianFlightProducer.EXPECTED_BYTES, bytes); + } + writer.getResult(); + } + } + } + + private void test(BiConsumer<BufferAllocator, FlightClient> fun) { + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final FlightServer s = + FlightTestUtil.getStartedServer( + (location) -> FlightServer.builder(allocator, location, new MetadataFlightProducer(allocator)).build()); + final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { + fun.accept(allocator, client); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + /** + * A FlightProducer that always produces a fixed data stream with metadata on the side. + */ + private static class MetadataFlightProducer extends NoOpFlightProducer { + + private final BufferAllocator allocator; + + public MetadataFlightProducer(BufferAllocator allocator) { + this.allocator = allocator; + } + + @Override + public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); + try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + root.allocateNew(); + listener.start(root); + for (byte i = 0; i < 10; i++) { + final IntVector vector = (IntVector) root.getVector("a"); + vector.set(0, 10); + vector.setValueCount(1); + root.setRowCount(1); + final ArrowBuf metadata = allocator.buffer(1); + metadata.writeByte(i); + listener.putNext(metadata); + } + listener.completed(); + } + } + + @Override + public Runnable acceptPut(CallContext context, FlightStream stream, StreamListener<PutResult> ackStream) { + return () -> { + // Wait for the descriptor to be sent + stream.getRoot(); + if (stream.getDescriptor().isCommand() && + Arrays.equals(stream.getDescriptor().getCommand(), COMMAND_ARROW_6136)) { + // ARROW-6136: Try closing the stream + ackStream.onError( + CallStatus.INTERNAL.withDescription(MESSAGE_ARROW_6136).toRuntimeException()); + return; + } + try { + byte current = 0; + while (stream.next()) { + final ArrowBuf metadata = stream.getLatestMetadata(); + if (current != metadata.getByte(0)) { + ackStream.onError(CallStatus.INVALID_ARGUMENT.withDescription(String + .format("Metadata does not match expected value; got %d but expected %d.", metadata.getByte(0), + current)).toRuntimeException()); + return; + } + ackStream.onNext(PutResult.metadata(metadata)); + current++; + } + if (current != 10) { + throw CallStatus.INVALID_ARGUMENT.withDescription("Wrong number of messages sent.").toRuntimeException(); + } + } catch (Exception e) { + throw CallStatus.INTERNAL.withCause(e).withDescription(e.toString()).toRuntimeException(); + } + }; + } + } + + private static class EndianFlightProducer extends NoOpFlightProducer { + static final byte[] EXPECTED_BYTES = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + private final BufferAllocator allocator; + + private EndianFlightProducer(BufferAllocator allocator) { + this.allocator = allocator; + } + + @Override + public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream) { + return () -> { + while (flightStream.next()) { + // Ignore any data + } + + try (final ArrowBuf buf = allocator.buffer(16)) { + buf.writeBytes(EXPECTED_BYTES); + ackStream.onNext(PutResult.metadata(buf)); + } + ackStream.onCompleted(); + }; + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestAuth.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestAuth.java new file mode 100644 index 000000000..6f0ec9f02 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestAuth.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.Iterator; +import java.util.Optional; + +import org.apache.arrow.flight.auth.ClientAuthHandler; +import org.apache.arrow.flight.auth.ServerAuthHandler; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.junit.Test; + +public class TestAuth { + + /** An auth handler that does not send messages should not block the server forever. */ + @Test(expected = RuntimeException.class) + public void noMessages() throws Exception { + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + final FlightServer s = FlightTestUtil + .getStartedServer( + location -> FlightServer.builder(allocator, location, new NoOpFlightProducer()).authHandler( + new OneshotAuthHandler()).build()); + final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { + client.authenticate(new ClientAuthHandler() { + @Override + public void authenticate(ClientAuthSender outgoing, Iterator<byte[]> incoming) { + } + + @Override + public byte[] getCallToken() { + return new byte[0]; + } + }); + } + } + + /** An auth handler that sends an error should not block the server forever. */ + @Test(expected = RuntimeException.class) + public void clientError() throws Exception { + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + final FlightServer s = FlightTestUtil + .getStartedServer( + location -> FlightServer.builder(allocator, location, new NoOpFlightProducer()).authHandler( + new OneshotAuthHandler()).build()); + final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { + client.authenticate(new ClientAuthHandler() { + @Override + public void authenticate(ClientAuthSender outgoing, Iterator<byte[]> incoming) { + outgoing.send(new byte[0]); + // Ensure the server-side runs + incoming.next(); + outgoing.onError(new RuntimeException("test")); + } + + @Override + public byte[] getCallToken() { + return new byte[0]; + } + }); + } + } + + private static class OneshotAuthHandler implements ServerAuthHandler { + + @Override + public Optional<String> isValid(byte[] token) { + return Optional.of("test"); + } + + @Override + public boolean authenticate(ServerAuthSender outgoing, Iterator<byte[]> incoming) { + incoming.next(); + outgoing.send(new byte[0]); + return false; + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBackPressure.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBackPressure.java new file mode 100644 index 000000000..1a71c363e --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBackPressure.java @@ -0,0 +1,262 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Function; + +import org.apache.arrow.flight.perf.PerformanceTestServer; +import org.apache.arrow.flight.perf.TestPerf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Test; + +import com.google.common.collect.ImmutableList; + +public class TestBackPressure { + + private static final int BATCH_SIZE = 4095; + + /** + * Make sure that failing to consume one stream doesn't block other streams. + */ + @Ignore + @Test + public void ensureIndependentSteams() throws Exception { + ensureIndependentSteams((b) -> (location -> new PerformanceTestServer(b, location))); + } + + /** + * Make sure that failing to consume one stream doesn't block other streams. + */ + @Ignore + @Test + public void ensureIndependentSteamsWithCallbacks() throws Exception { + ensureIndependentSteams((b) -> (location -> new PerformanceTestServer(b, location, + new BackpressureStrategy.CallbackBackpressureStrategy(), true))); + } + + /** + * Test to make sure stream doesn't go faster than the consumer is consuming. + */ + @Ignore + @Test + public void ensureWaitUntilProceed() throws Exception { + ensureWaitUntilProceed(new PollingBackpressureStrategy(), false); + } + + /** + * Test to make sure stream doesn't go faster than the consumer is consuming using a callback-based + * backpressure strategy. + */ + @Ignore + @Test + public void ensureWaitUntilProceedWithCallbacks() throws Exception { + ensureWaitUntilProceed(new RecordingCallbackBackpressureStrategy(), true); + } + + /** + * Make sure that failing to consume one stream doesn't block other streams. + */ + private static void ensureIndependentSteams(Function<BufferAllocator, Function<Location, PerformanceTestServer>> + serverConstructor) throws Exception { + try ( + final BufferAllocator a = new RootAllocator(Long.MAX_VALUE); + final PerformanceTestServer server = FlightTestUtil.getStartedServer( + (location) -> (serverConstructor.apply(a).apply(location))); + final FlightClient client = FlightClient.builder(a, server.getLocation()).build() + ) { + try (FlightStream fs1 = client.getStream(client.getInfo( + TestPerf.getPerfFlightDescriptor(110L * BATCH_SIZE, BATCH_SIZE, 1)) + .getEndpoints().get(0).getTicket())) { + consume(fs1, 10); + + // stop consuming fs1 but make sure we can consume a large amount of fs2. + try (FlightStream fs2 = client.getStream(client.getInfo( + TestPerf.getPerfFlightDescriptor(200L * BATCH_SIZE, BATCH_SIZE, 1)) + .getEndpoints().get(0).getTicket())) { + consume(fs2, 100); + + consume(fs1, 100); + consume(fs2, 100); + + consume(fs1); + consume(fs2); + } + } + } + } + + /** + * Make sure that a stream doesn't go faster than the consumer is consuming. + */ + private static void ensureWaitUntilProceed(SleepTimeRecordingBackpressureStrategy bpStrategy, boolean isNonBlocking) + throws Exception { + // request some values. + final long wait = 3000; + final long epsilon = 1000; + + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + + final FlightProducer producer = new NoOpFlightProducer() { + + @Override + public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + bpStrategy.register(listener); + final Runnable loadData = () -> { + int batches = 0; + final Schema pojoSchema = new Schema(ImmutableList.of(Field.nullable("a", MinorType.BIGINT.getType()))); + try (VectorSchemaRoot root = VectorSchemaRoot.create(pojoSchema, allocator)) { + listener.start(root); + while (true) { + bpStrategy.waitForListener(0); + if (batches > 100) { + root.clear(); + listener.completed(); + return; + } + + root.allocateNew(); + root.setRowCount(4095); + listener.putNext(); + batches++; + } + } + }; + + if (!isNonBlocking) { + loadData.run(); + } else { + final ExecutorService service = Executors.newSingleThreadExecutor(); + service.submit(loadData); + service.shutdown(); + } + } + }; + + + try ( + BufferAllocator serverAllocator = allocator.newChildAllocator("server", 0, Long.MAX_VALUE); + FlightServer server = + FlightTestUtil.getStartedServer((location) -> FlightServer.builder(serverAllocator, location, producer) + .build()); + BufferAllocator clientAllocator = allocator.newChildAllocator("client", 0, Long.MAX_VALUE); + FlightClient client = + FlightClient + .builder(clientAllocator, server.getLocation()) + .build(); + FlightStream stream = client.getStream(new Ticket(new byte[1])) + ) { + VectorSchemaRoot root = stream.getRoot(); + root.clear(); + Thread.sleep(wait); + while (stream.next()) { + root.clear(); + } + long expected = wait - epsilon; + Assert.assertTrue( + String.format("Expected a sleep of at least %dms but only slept for %d", expected, + bpStrategy.getSleepTime()), bpStrategy.getSleepTime() > expected); + + } + } + } + + private static void consume(FlightStream stream) { + VectorSchemaRoot root = stream.getRoot(); + while (stream.next()) { + root.clear(); + } + } + + private static void consume(FlightStream stream, int batches) { + VectorSchemaRoot root = stream.getRoot(); + while (batches > 0 && stream.next()) { + root.clear(); + batches--; + } + } + + private interface SleepTimeRecordingBackpressureStrategy extends BackpressureStrategy { + /** + * Returns the total time spent waiting on the listener to be ready. + * @return the total time spent waiting on the listener to be ready. + */ + long getSleepTime(); + } + + /** + * Implementation of a backpressure strategy that polls on isReady and records amount of time spent in Thread.sleep(). + */ + private static class PollingBackpressureStrategy implements SleepTimeRecordingBackpressureStrategy { + private final AtomicLong sleepTime = new AtomicLong(0); + private FlightProducer.ServerStreamListener listener; + + @Override + public long getSleepTime() { + return sleepTime.get(); + } + + @Override + public void register(FlightProducer.ServerStreamListener listener) { + this.listener = listener; + } + + @Override + public WaitResult waitForListener(long timeout) { + while (!listener.isReady()) { + try { + Thread.sleep(1); + sleepTime.addAndGet(1L); + } catch (InterruptedException ignore) { + } + } + return WaitResult.READY; + } + } + + /** + * Implementation of a backpressure strategy that uses callbacks to detect changes in client readiness state + * and records spent time waiting. + */ + private static class RecordingCallbackBackpressureStrategy extends BackpressureStrategy.CallbackBackpressureStrategy + implements SleepTimeRecordingBackpressureStrategy { + private final AtomicLong sleepTime = new AtomicLong(0); + + @Override + public long getSleepTime() { + return sleepTime.get(); + } + + @Override + public WaitResult waitForListener(long timeout) { + final long startTime = System.currentTimeMillis(); + final WaitResult result = super.waitForListener(timeout); + sleepTime.addAndGet(System.currentTimeMillis() - startTime); + return result; + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java new file mode 100644 index 000000000..e29cd07ce --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java @@ -0,0 +1,567 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.Consumer; + +import org.apache.arrow.flight.FlightClient.ClientStreamListener; +import org.apache.arrow.flight.impl.Flight; +import org.apache.arrow.flight.impl.Flight.FlightDescriptor.DescriptorType; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Assert; +import org.junit.Test; + +import com.google.common.base.Charsets; +import com.google.protobuf.ByteString; + +import io.grpc.MethodDescriptor; + +/** + * Test the operations of a basic flight service. + */ +public class TestBasicOperation { + + @Test + public void fastPathDefaults() { + Assert.assertTrue(ArrowMessage.ENABLE_ZERO_COPY_READ); + Assert.assertFalse(ArrowMessage.ENABLE_ZERO_COPY_WRITE); + } + + /** + * ARROW-6017: we should be able to construct locations for unknown schemes. + */ + @Test + public void unknownScheme() throws URISyntaxException { + final Location location = new Location("s3://unknown"); + Assert.assertEquals("s3", location.getUri().getScheme()); + } + + @Test + public void unknownSchemeRemote() throws Exception { + test(c -> { + try { + final FlightInfo info = c.getInfo(FlightDescriptor.path("test")); + Assert.assertEquals(new URI("https://example.com"), info.getEndpoints().get(0).getLocations().get(0).getUri()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + }); + } + + @Test + public void roundTripTicket() throws Exception { + final Ticket ticket = new Ticket(new byte[]{0, 1, 2, 3, 4, 5}); + Assert.assertEquals(ticket, Ticket.deserialize(ticket.serialize())); + } + + @Test + public void roundTripInfo() throws Exception { + final Map<String, String> metadata = new HashMap<>(); + metadata.put("foo", "bar"); + final Schema schema = new Schema(Arrays.asList( + Field.nullable("a", new ArrowType.Int(32, true)), + Field.nullable("b", new ArrowType.FixedSizeBinary(32)) + ), metadata); + final FlightInfo info1 = new FlightInfo(schema, FlightDescriptor.path(), Collections.emptyList(), -1, -1); + final FlightInfo info2 = new FlightInfo(schema, FlightDescriptor.command(new byte[2]), + Collections.singletonList(new FlightEndpoint( + new Ticket(new byte[10]), Location.forGrpcDomainSocket("/tmp/test.sock"))), 200, 500); + final FlightInfo info3 = new FlightInfo(schema, FlightDescriptor.path("a", "b"), + Arrays.asList(new FlightEndpoint( + new Ticket(new byte[10]), Location.forGrpcDomainSocket("/tmp/test.sock")), + new FlightEndpoint( + new Ticket(new byte[10]), Location.forGrpcDomainSocket("/tmp/test.sock"), + Location.forGrpcInsecure("localhost", 50051)) + ), 200, 500); + + Assert.assertEquals(info1, FlightInfo.deserialize(info1.serialize())); + Assert.assertEquals(info2, FlightInfo.deserialize(info2.serialize())); + Assert.assertEquals(info3, FlightInfo.deserialize(info3.serialize())); + } + + @Test + public void roundTripDescriptor() throws Exception { + final FlightDescriptor cmd = FlightDescriptor.command("test command".getBytes(StandardCharsets.UTF_8)); + Assert.assertEquals(cmd, FlightDescriptor.deserialize(cmd.serialize())); + final FlightDescriptor path = FlightDescriptor.path("foo", "bar", "test.arrow"); + Assert.assertEquals(path, FlightDescriptor.deserialize(path.serialize())); + } + + @Test + public void getDescriptors() throws Exception { + test(c -> { + int count = 0; + for (FlightInfo i : c.listFlights(Criteria.ALL)) { + count += 1; + } + Assert.assertEquals(1, count); + }); + } + + @Test + public void getDescriptorsWithCriteria() throws Exception { + test(c -> { + int count = 0; + for (FlightInfo i : c.listFlights(new Criteria(new byte[]{1}))) { + count += 1; + } + Assert.assertEquals(0, count); + }); + } + + @Test + public void getDescriptor() throws Exception { + test(c -> { + System.out.println(c.getInfo(FlightDescriptor.path("hello")).getDescriptor()); + }); + } + + @Test + public void getSchema() throws Exception { + test(c -> { + System.out.println(c.getSchema(FlightDescriptor.path("hello")).getSchema()); + }); + } + + + @Test + public void listActions() throws Exception { + test(c -> { + for (ActionType at : c.listActions()) { + System.out.println(at.getType()); + } + }); + } + + @Test + public void doAction() throws Exception { + test(c -> { + Iterator<Result> stream = c.doAction(new Action("hello")); + + Assert.assertTrue(stream.hasNext()); + Result r = stream.next(); + Assert.assertArrayEquals("world".getBytes(Charsets.UTF_8), r.getBody()); + }); + test(c -> { + Iterator<Result> stream = c.doAction(new Action("hellooo")); + + Assert.assertTrue(stream.hasNext()); + Result r = stream.next(); + Assert.assertArrayEquals("world".getBytes(Charsets.UTF_8), r.getBody()); + + Assert.assertTrue(stream.hasNext()); + r = stream.next(); + Assert.assertArrayEquals("!".getBytes(Charsets.UTF_8), r.getBody()); + Assert.assertFalse(stream.hasNext()); + }); + } + + @Test + public void putStream() throws Exception { + test((c, a) -> { + final int size = 10; + + IntVector iv = new IntVector("c1", a); + + try (VectorSchemaRoot root = VectorSchemaRoot.of(iv)) { + ClientStreamListener listener = c + .startPut(FlightDescriptor.path("hello"), root, new AsyncPutListener()); + + //batch 1 + root.allocateNew(); + for (int i = 0; i < size; i++) { + iv.set(i, i); + } + iv.setValueCount(size); + root.setRowCount(size); + listener.putNext(); + + // batch 2 + + root.allocateNew(); + for (int i = 0; i < size; i++) { + iv.set(i, i + size); + } + iv.setValueCount(size); + root.setRowCount(size); + listener.putNext(); + root.clear(); + listener.completed(); + + // wait for ack to avoid memory leaks. + listener.getResult(); + } + }); + } + + @Test + public void propagateErrors() throws Exception { + test(client -> { + FlightTestUtil.assertCode(FlightStatusCode.UNIMPLEMENTED, () -> { + client.doAction(new Action("invalid-action")).forEachRemaining(action -> Assert.fail()); + }); + }); + } + + @Test + public void getStream() throws Exception { + test(c -> { + try (final FlightStream stream = c.getStream(new Ticket(new byte[0]))) { + VectorSchemaRoot root = stream.getRoot(); + IntVector iv = (IntVector) root.getVector("c1"); + int value = 0; + while (stream.next()) { + for (int i = 0; i < root.getRowCount(); i++) { + Assert.assertEquals(value, iv.get(i)); + value++; + } + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + /** Ensure the client is configured to accept large messages. */ + @Test + public void getStreamLargeBatch() throws Exception { + test(c -> { + try (final FlightStream stream = c.getStream(new Ticket(Producer.TICKET_LARGE_BATCH))) { + Assert.assertEquals(128, stream.getRoot().getFieldVectors().size()); + Assert.assertTrue(stream.next()); + Assert.assertEquals(65536, stream.getRoot().getRowCount()); + Assert.assertTrue(stream.next()); + Assert.assertEquals(65536, stream.getRoot().getRowCount()); + Assert.assertFalse(stream.next()); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + /** Ensure the server is configured to accept large messages. */ + @Test + public void startPutLargeBatch() throws Exception { + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) { + final List<FieldVector> vectors = new ArrayList<>(); + for (int col = 0; col < 128; col++) { + final BigIntVector vector = new BigIntVector("f" + col, allocator); + for (int row = 0; row < 65536; row++) { + vector.setSafe(row, row); + } + vectors.add(vector); + } + test(c -> { + try (final VectorSchemaRoot root = new VectorSchemaRoot(vectors)) { + root.setRowCount(65536); + final ClientStreamListener stream = c.startPut(FlightDescriptor.path(""), root, new SyncPutListener()); + stream.putNext(); + stream.putNext(); + stream.completed(); + stream.getResult(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + } + + private void test(Consumer<FlightClient> consumer) throws Exception { + test((c, a) -> { + consumer.accept(c); + }); + } + + private void test(BiConsumer<FlightClient, BufferAllocator> consumer) throws Exception { + try ( + BufferAllocator a = new RootAllocator(Long.MAX_VALUE); + Producer producer = new Producer(a); + FlightServer s = + FlightTestUtil.getStartedServer( + (location) -> FlightServer.builder(a, location, producer).build() + )) { + + try ( + FlightClient c = FlightClient.builder(a, s.getLocation()).build() + ) { + try (BufferAllocator testAllocator = a.newChildAllocator("testcase", 0, Long.MAX_VALUE)) { + consumer.accept(c, testAllocator); + } + } + } + } + + /** Helper method to convert an ArrowMessage into a Protobuf message. */ + private Flight.FlightData arrowMessageToProtobuf( + MethodDescriptor.Marshaller<ArrowMessage> marshaller, ArrowMessage message) throws IOException { + final ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try (final InputStream serialized = marshaller.stream(message)) { + final byte[] buf = new byte[1024]; + while (true) { + int read = serialized.read(buf); + if (read < 0) { + break; + } + baos.write(buf, 0, read); + } + } + final byte[] serializedMessage = baos.toByteArray(); + return Flight.FlightData.parseFrom(serializedMessage); + } + + /** ARROW-10962: accept FlightData messages generated by Protobuf (which can omit empty fields). */ + @Test + public void testProtobufRecordBatchCompatibility() throws Exception { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("foo", new ArrowType.Int(32, true)))); + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + final VectorUnloader unloader = new VectorUnloader(root); + root.setRowCount(0); + final MethodDescriptor.Marshaller<ArrowMessage> marshaller = ArrowMessage.createMarshaller(allocator); + try (final ArrowMessage message = new ArrowMessage( + unloader.getRecordBatch(), /* appMetadata */ null, /* tryZeroCopy */ false, IpcOption.DEFAULT)) { + Assert.assertEquals(ArrowMessage.HeaderType.RECORD_BATCH, message.getMessageType()); + // Should have at least one empty body buffer (there may be multiple for e.g. data and validity) + Iterator<ArrowBuf> iterator = message.getBufs().iterator(); + Assert.assertTrue(iterator.hasNext()); + while (iterator.hasNext()) { + Assert.assertEquals(0, iterator.next().capacity()); + } + final Flight.FlightData protobufData = arrowMessageToProtobuf(marshaller, message) + .toBuilder() + .clearDataBody() + .build(); + Assert.assertEquals(0, protobufData.getDataBody().size()); + ArrowMessage parsedMessage = marshaller.parse(new ByteArrayInputStream(protobufData.toByteArray())); + // Should have an empty body buffer + Iterator<ArrowBuf> parsedIterator = parsedMessage.getBufs().iterator(); + Assert.assertTrue(parsedIterator.hasNext()); + Assert.assertEquals(0, parsedIterator.next().capacity()); + // Should have only one (the parser synthesizes exactly one); in the case of empty buffers, this is equivalent + Assert.assertFalse(parsedIterator.hasNext()); + // Should not throw + final ArrowRecordBatch rb = parsedMessage.asRecordBatch(); + Assert.assertEquals(rb.computeBodyLength(), 0); + } + } + } + + /** ARROW-10962: accept FlightData messages generated by Protobuf (which can omit empty fields). */ + @Test + public void testProtobufSchemaCompatibility() throws Exception { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("foo", new ArrowType.Int(32, true)))); + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) { + final MethodDescriptor.Marshaller<ArrowMessage> marshaller = ArrowMessage.createMarshaller(allocator); + Flight.FlightDescriptor descriptor = FlightDescriptor.command(new byte[0]).toProtocol(); + try (final ArrowMessage message = new ArrowMessage(descriptor, schema, IpcOption.DEFAULT)) { + Assert.assertEquals(ArrowMessage.HeaderType.SCHEMA, message.getMessageType()); + // Should have no body buffers + Assert.assertFalse(message.getBufs().iterator().hasNext()); + final Flight.FlightData protobufData = arrowMessageToProtobuf(marshaller, message) + .toBuilder() + .setDataBody(ByteString.EMPTY) + .build(); + Assert.assertEquals(0, protobufData.getDataBody().size()); + final ArrowMessage parsedMessage = marshaller.parse(new ByteArrayInputStream(protobufData.toByteArray())); + // Should have no body buffers + Assert.assertFalse(parsedMessage.getBufs().iterator().hasNext()); + // Should not throw + parsedMessage.asSchema(); + } + } + } + + /** + * An example FlightProducer for test purposes. + */ + public static class Producer implements FlightProducer, AutoCloseable { + static final byte[] TICKET_LARGE_BATCH = "large-batch".getBytes(StandardCharsets.UTF_8); + + private final BufferAllocator allocator; + + public Producer(BufferAllocator allocator) { + super(); + this.allocator = allocator; + } + + @Override + public void listFlights(CallContext context, Criteria criteria, + StreamListener<FlightInfo> listener) { + if (criteria.getExpression().length > 0) { + // Don't send anything if criteria are set + listener.onCompleted(); + } + + Flight.FlightInfo getInfo = Flight.FlightInfo.newBuilder() + .setFlightDescriptor(Flight.FlightDescriptor.newBuilder() + .setType(DescriptorType.CMD) + .setCmd(ByteString.copyFrom("cool thing", Charsets.UTF_8))) + .build(); + try { + listener.onNext(new FlightInfo(getInfo)); + } catch (URISyntaxException e) { + listener.onError(e); + return; + } + listener.onCompleted(); + } + + @Override + public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream) { + return () -> { + while (flightStream.next()) { + // Drain the stream + } + }; + } + + @Override + public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + if (Arrays.equals(TICKET_LARGE_BATCH, ticket.getBytes())) { + getLargeBatch(listener); + return; + } + final int size = 10; + + IntVector iv = new IntVector("c1", allocator); + VectorSchemaRoot root = VectorSchemaRoot.of(iv); + listener.start(root); + + //batch 1 + root.allocateNew(); + for (int i = 0; i < size; i++) { + iv.set(i, i); + } + iv.setValueCount(size); + root.setRowCount(size); + listener.putNext(); + + // batch 2 + + root.allocateNew(); + for (int i = 0; i < size; i++) { + iv.set(i, i + size); + } + iv.setValueCount(size); + root.setRowCount(size); + listener.putNext(); + root.clear(); + listener.completed(); + } + + private void getLargeBatch(ServerStreamListener listener) { + final List<FieldVector> vectors = new ArrayList<>(); + for (int col = 0; col < 128; col++) { + final BigIntVector vector = new BigIntVector("f" + col, allocator); + for (int row = 0; row < 65536; row++) { + vector.setSafe(row, row); + } + vectors.add(vector); + } + try (final VectorSchemaRoot root = new VectorSchemaRoot(vectors)) { + root.setRowCount(65536); + listener.start(root); + listener.putNext(); + listener.putNext(); + listener.completed(); + } + } + + @Override + public void close() throws Exception { + allocator.close(); + } + + @Override + public FlightInfo getFlightInfo(CallContext context, + FlightDescriptor descriptor) { + try { + Flight.FlightInfo getInfo = Flight.FlightInfo.newBuilder() + .setFlightDescriptor(Flight.FlightDescriptor.newBuilder() + .setType(DescriptorType.CMD) + .setCmd(ByteString.copyFrom("cool thing", Charsets.UTF_8))) + .addEndpoint( + Flight.FlightEndpoint.newBuilder().addLocation(new Location("https://example.com").toProtocol())) + .build(); + return new FlightInfo(getInfo); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + @Override + public void doAction(CallContext context, Action action, + StreamListener<Result> listener) { + switch (action.getType()) { + case "hello": { + listener.onNext(new Result("world".getBytes(Charsets.UTF_8))); + listener.onCompleted(); + break; + } + case "hellooo": { + listener.onNext(new Result("world".getBytes(Charsets.UTF_8))); + listener.onNext(new Result("!".getBytes(Charsets.UTF_8))); + listener.onCompleted(); + break; + } + default: + listener.onError(CallStatus.UNIMPLEMENTED.withDescription("Action not implemented: " + action.getType()) + .toRuntimeException()); + } + } + + @Override + public void listActions(CallContext context, + StreamListener<ActionType> listener) { + listener.onNext(new ActionType("get", "")); + listener.onNext(new ActionType("put", "")); + listener.onNext(new ActionType("hello", "")); + listener.onCompleted(); + } + + } + + +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestCallOptions.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestCallOptions.java new file mode 100644 index 000000000..45e3e4960 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestCallOptions.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.Iterator; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Test; + +import io.grpc.Metadata; + +public class TestCallOptions { + + @Test + @Ignore + public void timeoutFires() { + // Ignored due to CI flakiness + test((client) -> { + Instant start = Instant.now(); + Iterator<Result> results = client.doAction(new Action("hang"), CallOptions.timeout(1, TimeUnit.SECONDS)); + try { + results.next(); + Assert.fail("Call should have failed"); + } catch (RuntimeException e) { + Assert.assertTrue(e.getMessage(), e.getMessage().contains("deadline exceeded")); + } + Instant end = Instant.now(); + Assert.assertTrue("Call took over 1500 ms despite timeout", Duration.between(start, end).toMillis() < 1500); + }); + } + + @Test + @Ignore + public void underTimeout() { + // Ignored due to CI flakiness + test((client) -> { + Instant start = Instant.now(); + // This shouldn't fail and it should complete within the timeout + Iterator<Result> results = client.doAction(new Action("fast"), CallOptions.timeout(2, TimeUnit.SECONDS)); + Assert.assertArrayEquals(new byte[]{42, 42}, results.next().getBody()); + Instant end = Instant.now(); + Assert.assertTrue("Call took over 2500 ms despite timeout", Duration.between(start, end).toMillis() < 2500); + }); + } + + @Test + public void singleProperty() { + final FlightCallHeaders headers = new FlightCallHeaders(); + headers.insert("key", "value"); + testHeaders(headers); + } + + @Test + public void multipleProperties() { + final FlightCallHeaders headers = new FlightCallHeaders(); + headers.insert("key", "value"); + headers.insert("key2", "value2"); + testHeaders(headers); + } + + @Test + public void binaryProperties() { + final FlightCallHeaders headers = new FlightCallHeaders(); + headers.insert("key-bin", "value".getBytes()); + headers.insert("key3-bin", "ëfßæ".getBytes()); + testHeaders(headers); + } + + @Test + public void mixedProperties() { + final FlightCallHeaders headers = new FlightCallHeaders(); + headers.insert("key", "value"); + headers.insert("key3-bin", "ëfßæ".getBytes()); + testHeaders(headers); + } + + private void testHeaders(CallHeaders headers) { + try ( + BufferAllocator a = new RootAllocator(Long.MAX_VALUE); + HeaderProducer producer = new HeaderProducer(); + FlightServer s = + FlightTestUtil.getStartedServer((location) -> FlightServer.builder(a, location, producer).build()); + FlightClient client = FlightClient.builder(a, s.getLocation()).build()) { + client.doAction(new Action(""), new HeaderCallOption(headers)).hasNext(); + + final CallHeaders incomingHeaders = producer.headers(); + for (String key : headers.keys()) { + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + Assert.assertArrayEquals(headers.getByte(key), incomingHeaders.getByte(key)); + } else { + Assert.assertEquals(headers.get(key), incomingHeaders.get(key)); + } + } + } catch (InterruptedException | IOException e) { + throw new RuntimeException(e); + } + } + + void test(Consumer<FlightClient> testFn) { + try ( + BufferAllocator a = new RootAllocator(Long.MAX_VALUE); + Producer producer = new Producer(); + FlightServer s = + FlightTestUtil.getStartedServer((location) -> FlightServer.builder(a, location, producer).build()); + FlightClient client = FlightClient.builder(a, s.getLocation()).build()) { + testFn.accept(client); + } catch (InterruptedException | IOException e) { + throw new RuntimeException(e); + } + } + + static class HeaderProducer extends NoOpFlightProducer implements AutoCloseable { + CallHeaders headers; + + @Override + public void close() { + } + + public CallHeaders headers() { + return headers; + } + + @Override + public void doAction(CallContext context, Action action, StreamListener<Result> listener) { + this.headers = context.getMiddleware(FlightConstants.HEADER_KEY).headers(); + listener.onCompleted(); + } + } + + static class Producer extends NoOpFlightProducer implements AutoCloseable { + + Producer() { + } + + @Override + public void close() { + } + + @Override + public void doAction(CallContext context, Action action, StreamListener<Result> listener) { + switch (action.getType()) { + case "hang": { + try { + Thread.sleep(25000); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + listener.onNext(new Result(new byte[]{})); + listener.onCompleted(); + return; + } + case "fast": { + try { + Thread.sleep(500); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + listener.onNext(new Result(new byte[]{42, 42})); + listener.onCompleted(); + return; + } + default: { + throw new UnsupportedOperationException(action.getType()); + } + } + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestClientMiddleware.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestClientMiddleware.java new file mode 100644 index 000000000..ccfc9f2d1 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestClientMiddleware.java @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** + * A basic test of client middleware using a simplified OpenTracing-like example. + */ +@RunWith(JUnit4.class) +public class TestClientMiddleware { + + /** + * Test that a client middleware can fail a call before it starts by throwing a {@link FlightRuntimeException}. + */ + @Test + public void clientMiddleware_failCallBeforeSending() { + test(new NoOpFlightProducer(), null, Collections.singletonList(new CallRejector.Factory()), + (allocator, client) -> { + FlightTestUtil.assertCode(FlightStatusCode.UNAVAILABLE, client::listActions); + }); + } + + /** + * Test an OpenTracing-like scenario where client and server middleware work together to propagate a request ID + * without explicit intervention from the service implementation. + */ + @Test + public void middleware_propagateHeader() { + final Context context = new Context("span id"); + test(new NoOpFlightProducer(), + new TestServerMiddleware.ServerMiddlewarePair<>( + FlightServerMiddleware.Key.of("test"), new ServerSpanInjector.Factory()), + Collections.singletonList(new ClientSpanInjector.Factory(context)), + (allocator, client) -> { + FlightTestUtil.assertCode(FlightStatusCode.UNIMPLEMENTED, () -> client.listActions().forEach(actionType -> { + })); + }); + Assert.assertEquals(context.outgoingSpanId, context.incomingSpanId); + Assert.assertNotNull(context.finalStatus); + Assert.assertEquals(FlightStatusCode.UNIMPLEMENTED, context.finalStatus.code()); + } + + /** Ensure both server and client can send and receive multi-valued headers (both binary and text values). */ + @Test + public void testMultiValuedHeaders() { + final MultiHeaderClientMiddlewareFactory clientFactory = new MultiHeaderClientMiddlewareFactory(); + test(new NoOpFlightProducer(), + new TestServerMiddleware.ServerMiddlewarePair<>( + FlightServerMiddleware.Key.of("test"), new MultiHeaderServerMiddlewareFactory()), + Collections.singletonList(clientFactory), + (allocator, client) -> { + FlightTestUtil.assertCode(FlightStatusCode.UNIMPLEMENTED, () -> client.listActions().forEach(actionType -> { + })); + }); + // The server echoes the headers we send back to us, so ensure all the ones we sent are present with the correct + // values in the correct order. + for (final Map.Entry<String, List<byte[]>> entry : EXPECTED_BINARY_HEADERS.entrySet()) { + // Compare header values entry-by-entry because byte arrays don't compare via equals + final List<byte[]> receivedValues = clientFactory.lastBinaryHeaders.get(entry.getKey()); + Assert.assertNotNull("Missing for header: " + entry.getKey(), receivedValues); + Assert.assertEquals( + "Missing or wrong value for header: " + entry.getKey(), + entry.getValue().size(), receivedValues.size()); + for (int i = 0; i < entry.getValue().size(); i++) { + Assert.assertArrayEquals(entry.getValue().get(i), receivedValues.get(i)); + } + } + for (final Map.Entry<String, List<String>> entry : EXPECTED_TEXT_HEADERS.entrySet()) { + Assert.assertEquals( + "Missing or wrong value for header: " + entry.getKey(), + entry.getValue(), clientFactory.lastTextHeaders.get(entry.getKey())); + } + } + + private static <T extends FlightServerMiddleware> void test(FlightProducer producer, + TestServerMiddleware.ServerMiddlewarePair<T> serverMiddleware, + List<FlightClientMiddleware.Factory> clientMiddleware, + BiConsumer<BufferAllocator, FlightClient> body) { + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) { + final FlightServer server = FlightTestUtil + .getStartedServer(location -> { + final FlightServer.Builder builder = FlightServer.builder(allocator, location, producer); + if (serverMiddleware != null) { + builder.middleware(serverMiddleware.key, serverMiddleware.factory); + } + return builder.build(); + }); + FlightClient.Builder builder = FlightClient.builder(allocator, server.getLocation()); + clientMiddleware.forEach(builder::intercept); + try (final FlightServer ignored = server; + final FlightClient client = builder.build() + ) { + body.accept(allocator, client); + } + } catch (InterruptedException | IOException e) { + throw new RuntimeException(e); + } + } + + /** + * A server middleware component that reads a request ID from incoming headers and sends the request ID back on + * outgoing headers. + */ + static class ServerSpanInjector implements FlightServerMiddleware { + + private final String spanId; + + public ServerSpanInjector(String spanId) { + this.spanId = spanId; + } + + @Override + public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { + outgoingHeaders.insert("x-span", spanId); + } + + @Override + public void onCallCompleted(CallStatus status) { + + } + + @Override + public void onCallErrored(Throwable err) { + + } + + static class Factory implements FlightServerMiddleware.Factory<ServerSpanInjector> { + + @Override + public ServerSpanInjector onCallStarted(CallInfo info, CallHeaders incomingHeaders, RequestContext context) { + return new ServerSpanInjector(incomingHeaders.get("x-span")); + } + } + } + + /** + * A client middleware component that, given a mock OpenTracing-like "request context", sends the request ID in the + * context on outgoing headers and reads it from incoming headers. + */ + static class ClientSpanInjector implements FlightClientMiddleware { + + private final Context context; + + public ClientSpanInjector(Context context) { + this.context = context; + } + + @Override + public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { + outgoingHeaders.insert("x-span", context.outgoingSpanId); + } + + @Override + public void onHeadersReceived(CallHeaders incomingHeaders) { + context.incomingSpanId = incomingHeaders.get("x-span"); + } + + @Override + public void onCallCompleted(CallStatus status) { + context.finalStatus = status; + } + + static class Factory implements FlightClientMiddleware.Factory { + + private final Context context; + + Factory(Context context) { + this.context = context; + } + + @Override + public FlightClientMiddleware onCallStarted(CallInfo info) { + return new ClientSpanInjector(context); + } + } + } + + /** + * A mock OpenTracing-like "request context". + */ + static class Context { + + final String outgoingSpanId; + String incomingSpanId; + CallStatus finalStatus; + + Context(String spanId) { + this.outgoingSpanId = spanId; + } + } + + /** + * A client middleware that fails outgoing calls. + */ + static class CallRejector implements FlightClientMiddleware { + + @Override + public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { + } + + @Override + public void onHeadersReceived(CallHeaders incomingHeaders) { + } + + @Override + public void onCallCompleted(CallStatus status) { + } + + static class Factory implements FlightClientMiddleware.Factory { + + @Override + public FlightClientMiddleware onCallStarted(CallInfo info) { + throw CallStatus.UNAVAILABLE.withDescription("Rejecting call.").toRuntimeException(); + } + } + } + + // Used to test that middleware can send and receive multi-valued text and binary headers. + static final Map<String, List<byte[]>> EXPECTED_BINARY_HEADERS = new HashMap<String, List<byte[]>>() {{ + put("x-binary-bin", Arrays.asList(new byte[] {0}, new byte[]{1})); + }}; + static final Map<String, List<String>> EXPECTED_TEXT_HEADERS = new HashMap<String, List<String>>() {{ + put("x-text", Arrays.asList("foo", "bar")); + }}; + + static class MultiHeaderServerMiddlewareFactory implements + FlightServerMiddleware.Factory<MultiHeaderServerMiddleware> { + @Override + public MultiHeaderServerMiddleware onCallStarted(CallInfo info, CallHeaders incomingHeaders, + RequestContext context) { + // Echo the headers back to the client. Copy values out of CallHeaders since the underlying gRPC metadata + // object isn't safe to use after this function returns. + Map<String, List<byte[]>> binaryHeaders = new HashMap<>(); + Map<String, List<String>> textHeaders = new HashMap<>(); + for (final String key : incomingHeaders.keys()) { + if (key.endsWith("-bin")) { + binaryHeaders.compute(key, (ignored, values) -> { + if (values == null) { + values = new ArrayList<>(); + } + incomingHeaders.getAllByte(key).forEach(values::add); + return values; + }); + } else { + textHeaders.compute(key, (ignored, values) -> { + if (values == null) { + values = new ArrayList<>(); + } + incomingHeaders.getAll(key).forEach(values::add); + return values; + }); + } + } + return new MultiHeaderServerMiddleware(binaryHeaders, textHeaders); + } + } + + static class MultiHeaderServerMiddleware implements FlightServerMiddleware { + private final Map<String, List<byte[]>> binaryHeaders; + private final Map<String, List<String>> textHeaders; + + MultiHeaderServerMiddleware(Map<String, List<byte[]>> binaryHeaders, Map<String, List<String>> textHeaders) { + this.binaryHeaders = binaryHeaders; + this.textHeaders = textHeaders; + } + + @Override + public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { + binaryHeaders.forEach((key, values) -> values.forEach(value -> outgoingHeaders.insert(key, value))); + textHeaders.forEach((key, values) -> values.forEach(value -> outgoingHeaders.insert(key, value))); + } + + @Override + public void onCallCompleted(CallStatus status) {} + + @Override + public void onCallErrored(Throwable err) {} + } + + static class MultiHeaderClientMiddlewareFactory implements FlightClientMiddleware.Factory { + Map<String, List<byte[]>> lastBinaryHeaders = null; + Map<String, List<String>> lastTextHeaders = null; + + @Override + public FlightClientMiddleware onCallStarted(CallInfo info) { + return new MultiHeaderClientMiddleware(this); + } + } + + static class MultiHeaderClientMiddleware implements FlightClientMiddleware { + private final MultiHeaderClientMiddlewareFactory factory; + + public MultiHeaderClientMiddleware(MultiHeaderClientMiddlewareFactory factory) { + this.factory = factory; + } + + @Override + public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { + for (final Map.Entry<String, List<byte[]>> entry : EXPECTED_BINARY_HEADERS.entrySet()) { + entry.getValue().forEach((value) -> outgoingHeaders.insert(entry.getKey(), value)); + Assert.assertTrue(outgoingHeaders.containsKey(entry.getKey())); + } + for (final Map.Entry<String, List<String>> entry : EXPECTED_TEXT_HEADERS.entrySet()) { + entry.getValue().forEach((value) -> outgoingHeaders.insert(entry.getKey(), value)); + Assert.assertTrue(outgoingHeaders.containsKey(entry.getKey())); + } + } + + @Override + public void onHeadersReceived(CallHeaders incomingHeaders) { + factory.lastBinaryHeaders = new HashMap<>(); + factory.lastTextHeaders = new HashMap<>(); + incomingHeaders.keys().forEach(header -> { + if (header.endsWith("-bin")) { + final List<byte[]> values = new ArrayList<>(); + incomingHeaders.getAllByte(header).forEach(values::add); + factory.lastBinaryHeaders.put(header, values); + } else { + final List<String> values = new ArrayList<>(); + incomingHeaders.getAll(header).forEach(values::add); + factory.lastTextHeaders.put(header, values); + } + }); + } + + @Override + public void onCallCompleted(CallStatus status) {} + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDictionaryUtils.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDictionaryUtils.java new file mode 100644 index 000000000..b5bf117c6 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDictionaryUtils.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.TreeSet; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Test; + +import com.google.common.collect.ImmutableList; + +/** + * Test cases for {@link DictionaryUtils}. + */ +public class TestDictionaryUtils { + + @Test + public void testReuseSchema() { + FieldType varcharType = new FieldType(true, new ArrowType.Utf8(), null); + FieldType intType = new FieldType(true, new ArrowType.Int(32, true), null); + + ImmutableList<Field> build = ImmutableList.of( + new Field("stringCol", varcharType, null), + new Field("intCol", intType, null)); + + Schema schema = new Schema(build); + Schema newSchema = DictionaryUtils.generateSchema(schema, null, new TreeSet<>()); + + // assert that no new schema is created. + assertTrue(schema == newSchema); + } + + @Test + public void testCreateSchema() { + try (BufferAllocator allocator = new RootAllocator(1024)) { + DictionaryEncoding dictionaryEncoding = + new DictionaryEncoding(0, true, new ArrowType.Int(8, true)); + VarCharVector dictVec = new VarCharVector("dict vector", allocator); + Dictionary dictionary = new Dictionary(dictVec, dictionaryEncoding); + DictionaryProvider dictProvider = new DictionaryProvider.MapDictionaryProvider(dictionary); + TreeSet<Long> dictionaryUsed = new TreeSet<>(); + + FieldType encodedVarcharType = new FieldType(true, new ArrowType.Int(8, true), dictionaryEncoding); + FieldType intType = new FieldType(true, new ArrowType.Int(32, true), null); + + ImmutableList<Field> build = ImmutableList.of( + new Field("stringCol", encodedVarcharType, null), + new Field("intCol", intType, null)); + + Schema schema = new Schema(build); + Schema newSchema = DictionaryUtils.generateSchema(schema, dictProvider, dictionaryUsed); + + // assert that a new schema is created. + assertTrue(schema != newSchema); + + // assert the column is converted as expected + ArrowType newColType = newSchema.getFields().get(0).getType(); + assertEquals(new ArrowType.Utf8(), newColType); + + assertEquals(1, dictionaryUsed.size()); + assertEquals(0, dictionaryUsed.first()); + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java new file mode 100644 index 000000000..70394e11e --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java @@ -0,0 +1,536 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; +import java.util.stream.IntStream; + +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.testing.ValueVectorDataPopulator; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class TestDoExchange { + static byte[] EXCHANGE_DO_GET = "do-get".getBytes(StandardCharsets.UTF_8); + static byte[] EXCHANGE_DO_PUT = "do-put".getBytes(StandardCharsets.UTF_8); + static byte[] EXCHANGE_ECHO = "echo".getBytes(StandardCharsets.UTF_8); + static byte[] EXCHANGE_METADATA_ONLY = "only-metadata".getBytes(StandardCharsets.UTF_8); + static byte[] EXCHANGE_TRANSFORM = "transform".getBytes(StandardCharsets.UTF_8); + static byte[] EXCHANGE_CANCEL = "cancel".getBytes(StandardCharsets.UTF_8); + + private BufferAllocator allocator; + private FlightServer server; + private FlightClient client; + + @Before + public void setUp() throws Exception { + allocator = new RootAllocator(Integer.MAX_VALUE); + final Location serverLocation = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, 0); + server = FlightServer.builder(allocator, serverLocation, new Producer(allocator)).build(); + server.start(); + final Location clientLocation = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort()); + client = FlightClient.builder(allocator, clientLocation).build(); + } + + @After + public void tearDown() throws Exception { + AutoCloseables.close(client, server, allocator); + } + + /** Test a pure-metadata flow. */ + @Test + public void testDoExchangeOnlyMetadata() throws Exception { + // Send a particular descriptor to the server and check for a particular response pattern. + try (final FlightClient.ExchangeReaderWriter stream = + client.doExchange(FlightDescriptor.command(EXCHANGE_METADATA_ONLY))) { + final FlightStream reader = stream.getReader(); + + // Server starts by sending a message without data (hence no VectorSchemaRoot should be present) + assertTrue(reader.next()); + assertFalse(reader.hasRoot()); + assertEquals(42, reader.getLatestMetadata().getInt(0)); + + // Write a metadata message to the server (without sending any data) + ArrowBuf buf = allocator.buffer(4); + buf.writeInt(84); + stream.getWriter().putMetadata(buf); + + // Check that the server echoed the metadata back to us + assertTrue(reader.next()); + assertFalse(reader.hasRoot()); + assertEquals(84, reader.getLatestMetadata().getInt(0)); + + // Close our write channel and ensure the server also closes theirs + stream.getWriter().completed(); + assertFalse(reader.next()); + } + } + + /** Emulate a DoGet with a DoExchange. */ + @Test + public void testDoExchangeDoGet() throws Exception { + try (final FlightClient.ExchangeReaderWriter stream = + client.doExchange(FlightDescriptor.command(EXCHANGE_DO_GET))) { + final FlightStream reader = stream.getReader(); + VectorSchemaRoot root = reader.getRoot(); + IntVector iv = (IntVector) root.getVector("a"); + int value = 0; + while (reader.next()) { + for (int i = 0; i < root.getRowCount(); i++) { + assertFalse(String.format("Row %d should not be null", value), iv.isNull(i)); + assertEquals(value, iv.get(i)); + value++; + } + } + assertEquals(100, value); + } + } + + /** Emulate a DoPut with a DoExchange. */ + @Test + public void testDoExchangeDoPut() throws Exception { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); + try (final FlightClient.ExchangeReaderWriter stream = + client.doExchange(FlightDescriptor.command(EXCHANGE_DO_PUT)); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + IntVector iv = (IntVector) root.getVector("a"); + iv.allocateNew(); + + stream.getWriter().start(root); + int counter = 0; + for (int i = 0; i < 10; i++) { + ValueVectorDataPopulator.setVector(iv, IntStream.range(0, i).boxed().toArray(Integer[]::new)); + root.setRowCount(i); + counter += i; + stream.getWriter().putNext(); + + assertTrue(stream.getReader().next()); + assertFalse(stream.getReader().hasRoot()); + // For each write, the server sends back a metadata message containing the index of the last written batch + final ArrowBuf metadata = stream.getReader().getLatestMetadata(); + assertEquals(counter, metadata.getInt(0)); + } + stream.getWriter().completed(); + + while (stream.getReader().next()) { + // Drain the stream. Otherwise closing the stream sends a CANCEL which seriously screws with the server. + // CANCEL -> runs onCancel handler -> closes the FlightStream early + } + } + } + + /** Test a DoExchange that echoes the client message. */ + @Test + public void testDoExchangeEcho() throws Exception { + final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); + try (final FlightClient.ExchangeReaderWriter stream = client.doExchange(FlightDescriptor.command(EXCHANGE_ECHO)); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + final FlightStream reader = stream.getReader(); + + // First try writing metadata without starting the Arrow data stream + ArrowBuf buf = allocator.buffer(4); + buf.writeInt(42); + stream.getWriter().putMetadata(buf); + buf = allocator.buffer(4); + buf.writeInt(84); + stream.getWriter().putMetadata(buf); + + // Ensure that the server echoes the metadata back, also without starting its data stream + assertTrue(reader.next()); + assertFalse(reader.hasRoot()); + assertEquals(42, reader.getLatestMetadata().getInt(0)); + assertTrue(reader.next()); + assertFalse(reader.hasRoot()); + assertEquals(84, reader.getLatestMetadata().getInt(0)); + + // Write data and check that it gets echoed back. + IntVector iv = (IntVector) root.getVector("a"); + iv.allocateNew(); + stream.getWriter().start(root); + for (int i = 0; i < 10; i++) { + iv.setSafe(0, i); + root.setRowCount(1); + stream.getWriter().putNext(); + + assertTrue(reader.next()); + assertNull(reader.getLatestMetadata()); + assertEquals(root.getSchema(), reader.getSchema()); + assertEquals(i, ((IntVector) reader.getRoot().getVector("a")).get(0)); + } + + // Complete the stream so that the server knows not to expect any more messages from us. + stream.getWriter().completed(); + // The server will end its side of the call, so this shouldn't block or indicate that + // there is more data. + assertFalse("We should not be waiting for any messages", reader.next()); + } + } + + /** Write some data, have it transformed, then read it back. */ + @Test + public void testTransform() throws Exception { + final Schema schema = new Schema(Arrays.asList( + Field.nullable("a", new ArrowType.Int(32, true)), + Field.nullable("b", new ArrowType.Int(32, true)))); + try (final FlightClient.ExchangeReaderWriter stream = + client.doExchange(FlightDescriptor.command(EXCHANGE_TRANSFORM))) { + // Write ten batches of data to the stream, where batch N contains N rows of data (N in [0, 10)) + final FlightStream reader = stream.getReader(); + final FlightClient.ClientStreamListener writer = stream.getWriter(); + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + writer.start(root); + for (int batchIndex = 0; batchIndex < 10; batchIndex++) { + for (final FieldVector rawVec : root.getFieldVectors()) { + final IntVector vec = (IntVector) rawVec; + ValueVectorDataPopulator.setVector(vec, IntStream.range(0, batchIndex).boxed().toArray(Integer[]::new)); + } + root.setRowCount(batchIndex); + writer.putNext(); + } + } + // Indicate that we're done writing so that the server does not expect more data. + writer.completed(); + + // Read back data. We expect the server to double each value in each row of each batch. + assertEquals(schema, reader.getSchema()); + final VectorSchemaRoot root = reader.getRoot(); + for (int batchIndex = 0; batchIndex < 10; batchIndex++) { + assertTrue("Didn't receive batch #" + batchIndex, reader.next()); + assertEquals(batchIndex, root.getRowCount()); + for (final FieldVector rawVec : root.getFieldVectors()) { + final IntVector vec = (IntVector) rawVec; + for (int row = 0; row < batchIndex; row++) { + assertEquals(2 * row, vec.get(row)); + } + } + } + + // The server also sends back a metadata-only message containing the message count + assertTrue("There should be one extra message", reader.next()); + assertEquals(10, reader.getLatestMetadata().getInt(0)); + assertFalse("There should be no more data", reader.next()); + } + } + + /** Write some data, have it transformed, then read it back. Use the zero-copy optimization. */ + @Test + public void testTransformZeroCopy() throws Exception { + final int rowsPerBatch = 4096; + final Schema schema = new Schema(Arrays.asList( + Field.nullable("a", new ArrowType.Int(32, true)), + Field.nullable("b", new ArrowType.Int(32, true)))); + try (final FlightClient.ExchangeReaderWriter stream = + client.doExchange(FlightDescriptor.command(EXCHANGE_TRANSFORM))) { + // Write ten batches of data to the stream, where batch N contains 1024 rows of data (N in [0, 10)) + final FlightStream reader = stream.getReader(); + final FlightClient.ClientStreamListener writer = stream.getWriter(); + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + writer.start(root); + // Enable the zero-copy optimization + writer.setUseZeroCopy(true); + for (int batchIndex = 0; batchIndex < 100; batchIndex++) { + for (final FieldVector rawVec : root.getFieldVectors()) { + final IntVector vec = (IntVector) rawVec; + for (int row = 0; row < rowsPerBatch; row++) { + // Use a value that'll be different per batch, so we can detect if we accidentally + // reuse a buffer (and overwrite a buffer that hasn't yet been sent over the network) + vec.setSafe(row, batchIndex + row); + } + } + root.setRowCount(rowsPerBatch); + writer.putNext(); + // Allocate new buffers every time since we don't know if gRPC has written the buffer + // to the network yet + root.allocateNew(); + } + } + // Indicate that we're done writing so that the server does not expect more data. + writer.completed(); + + // Read back data. We expect the server to double each value in each row of each batch. + assertEquals(schema, reader.getSchema()); + final VectorSchemaRoot root = reader.getRoot(); + for (int batchIndex = 0; batchIndex < 100; batchIndex++) { + assertTrue("Didn't receive batch #" + batchIndex, reader.next()); + assertEquals(rowsPerBatch, root.getRowCount()); + for (final FieldVector rawVec : root.getFieldVectors()) { + final IntVector vec = (IntVector) rawVec; + for (int row = 0; row < rowsPerBatch; row++) { + assertEquals(2 * (batchIndex + row), vec.get(row)); + } + } + } + + // The server also sends back a metadata-only message containing the message count + assertTrue("There should be one extra message", reader.next()); + assertEquals(100, reader.getLatestMetadata().getInt(0)); + assertFalse("There should be no more data", reader.next()); + } + } + + /** Have the server immediately cancel; ensure the client doesn't hang. */ + @Test + public void testServerCancel() throws Exception { + try (final FlightClient.ExchangeReaderWriter stream = + client.doExchange(FlightDescriptor.command(EXCHANGE_CANCEL))) { + final FlightStream reader = stream.getReader(); + final FlightClient.ClientStreamListener writer = stream.getWriter(); + + final FlightRuntimeException fre = assertThrows(FlightRuntimeException.class, reader::next); + assertEquals(FlightStatusCode.CANCELLED, fre.status().code()); + assertEquals("expected", fre.status().description()); + + // Before, this would hang forever, because the writer checks if the stream is ready and not cancelled. + // However, the cancellation flag (was) only updated by reading, and the stream is never ready once the call ends. + // The test looks weird since normally, an application shouldn't try to write after the read fails. However, + // an application that isn't reading data wouldn't notice, and would instead get stuck on the write. + // Here, we read first to avoid a race condition in the test itself. + writer.putMetadata(allocator.getEmpty()); + } + } + + /** Have the server immediately cancel; ensure the server cleans up the FlightStream. */ + @Test + public void testServerCancelLeak() throws Exception { + try (final FlightClient.ExchangeReaderWriter stream = + client.doExchange(FlightDescriptor.command(EXCHANGE_CANCEL))) { + final FlightStream reader = stream.getReader(); + final FlightClient.ClientStreamListener writer = stream.getWriter(); + try (final VectorSchemaRoot root = VectorSchemaRoot.create(Producer.SCHEMA, allocator)) { + writer.start(root); + final IntVector ints = (IntVector) root.getVector("a"); + for (int i = 0; i < 128; i++) { + for (int row = 0; row < 1024; row++) { + ints.setSafe(row, row); + } + root.setRowCount(1024); + writer.putNext(); + } + } + + final FlightRuntimeException fre = assertThrows(FlightRuntimeException.class, reader::next); + assertEquals(FlightStatusCode.CANCELLED, fre.status().code()); + assertEquals("expected", fre.status().description()); + } + } + + /** Have the client cancel without reading; ensure memory is not leaked. */ + @Test + public void testClientCancel() throws Exception { + try (final FlightClient.ExchangeReaderWriter stream = + client.doExchange(FlightDescriptor.command(EXCHANGE_DO_GET))) { + final FlightStream reader = stream.getReader(); + reader.cancel("", null); + // Cancel should be idempotent + reader.cancel("", null); + } + } + + /** Have the client close the stream without reading; ensure memory is not leaked. */ + @Test + public void testClientClose() throws Exception { + try (final FlightClient.ExchangeReaderWriter stream = + client.doExchange(FlightDescriptor.command(EXCHANGE_DO_GET))) { + assertEquals(Producer.SCHEMA, stream.getReader().getSchema()); + } + // Intentionally leak the allocator in this test. gRPC has a bug where it does not wait for all calls to complete + // when shutting down the server, so this test will fail otherwise because it closes the allocator while the + // server-side call still has memory allocated. + // TODO(ARROW-9586): fix this once we track outstanding RPCs outside of gRPC. + // https://stackoverflow.com/questions/46716024/ + allocator = null; + client = null; + } + + static class Producer extends NoOpFlightProducer { + static final Schema SCHEMA = new Schema( + Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); + private final BufferAllocator allocator; + + Producer(BufferAllocator allocator) { + this.allocator = allocator; + } + + @Override + public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) { + if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_METADATA_ONLY)) { + metadataOnly(context, reader, writer); + } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_DO_GET)) { + doGet(context, reader, writer); + } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_DO_PUT)) { + doPut(context, reader, writer); + } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_ECHO)) { + echo(context, reader, writer); + } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_TRANSFORM)) { + transform(context, reader, writer); + } else if (Arrays.equals(reader.getDescriptor().getCommand(), EXCHANGE_CANCEL)) { + cancel(context, reader, writer); + } else { + writer.error(CallStatus.UNIMPLEMENTED.withDescription("Command not implemented").toRuntimeException()); + } + } + + /** Emulate DoGet. */ + private void doGet(CallContext context, FlightStream reader, ServerStreamListener writer) { + try (VectorSchemaRoot root = VectorSchemaRoot.create(SCHEMA, allocator)) { + writer.start(root); + root.allocateNew(); + IntVector iv = (IntVector) root.getVector("a"); + + for (int i = 0; i < 100; i += 2) { + iv.set(0, i); + iv.set(1, i + 1); + root.setRowCount(2); + writer.putNext(); + } + } + writer.completed(); + } + + /** Emulate DoPut. */ + private void doPut(CallContext context, FlightStream reader, ServerStreamListener writer) { + int counter = 0; + while (reader.next()) { + if (!reader.hasRoot()) { + writer.error(CallStatus.INVALID_ARGUMENT.withDescription("Message has no data").toRuntimeException()); + return; + } + counter += reader.getRoot().getRowCount(); + + final ArrowBuf pong = allocator.buffer(4); + pong.writeInt(counter); + writer.putMetadata(pong); + } + writer.completed(); + } + + /** Exchange metadata without ever exchanging data. */ + private void metadataOnly(CallContext context, FlightStream reader, ServerStreamListener writer) { + final ArrowBuf buf = allocator.buffer(4); + buf.writeInt(42); + writer.putMetadata(buf); + assertTrue(reader.next()); + assertNotNull(reader.getLatestMetadata()); + reader.getLatestMetadata().getReferenceManager().retain(); + writer.putMetadata(reader.getLatestMetadata()); + writer.completed(); + } + + /** Echo the client's response back to it. */ + private void echo(CallContext context, FlightStream reader, ServerStreamListener writer) { + VectorSchemaRoot root = null; + VectorLoader loader = null; + while (reader.next()) { + if (reader.hasRoot()) { + if (root == null) { + root = VectorSchemaRoot.create(reader.getSchema(), allocator); + loader = new VectorLoader(root); + writer.start(root); + } + VectorUnloader unloader = new VectorUnloader(reader.getRoot()); + try (final ArrowRecordBatch arb = unloader.getRecordBatch()) { + loader.load(arb); + } + if (reader.getLatestMetadata() != null) { + reader.getLatestMetadata().getReferenceManager().retain(); + writer.putNext(reader.getLatestMetadata()); + } else { + writer.putNext(); + } + } else { + // Pure metadata + reader.getLatestMetadata().getReferenceManager().retain(); + writer.putMetadata(reader.getLatestMetadata()); + } + } + if (root != null) { + root.close(); + } + writer.completed(); + } + + /** Accept a set of messages, then return some result. */ + private void transform(CallContext context, FlightStream reader, ServerStreamListener writer) { + final Schema schema = reader.getSchema(); + for (final Field field : schema.getFields()) { + if (!(field.getType() instanceof ArrowType.Int)) { + writer.error(CallStatus.INVALID_ARGUMENT.withDescription("Invalid type: " + field).toRuntimeException()); + return; + } + final ArrowType.Int intType = (ArrowType.Int) field.getType(); + if (!intType.getIsSigned() || intType.getBitWidth() != 32) { + writer.error(CallStatus.INVALID_ARGUMENT.withDescription("Must be i32: " + field).toRuntimeException()); + return; + } + } + int batches = 0; + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + writer.start(root); + writer.setUseZeroCopy(true); + final VectorLoader loader = new VectorLoader(root); + final VectorUnloader unloader = new VectorUnloader(reader.getRoot()); + while (reader.next()) { + try (final ArrowRecordBatch batch = unloader.getRecordBatch()) { + loader.load(batch); + } + batches++; + for (final FieldVector rawVec : root.getFieldVectors()) { + final IntVector vec = (IntVector) rawVec; + for (int i = 0; i < root.getRowCount(); i++) { + if (!vec.isNull(i)) { + vec.set(i, vec.get(i) * 2); + } + } + } + writer.putNext(); + } + } + final ArrowBuf count = allocator.buffer(4); + count.writeInt(batches); + writer.putMetadata(count); + writer.completed(); + } + + /** Immediately cancel the call. */ + private void cancel(CallContext context, FlightStream reader, ServerStreamListener writer) { + writer.error(CallStatus.CANCELLED.withDescription("expected").toRuntimeException()); + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestErrorMetadata.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestErrorMetadata.java new file mode 100644 index 000000000..2c62bc7fa --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestErrorMetadata.java @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import org.apache.arrow.flight.perf.impl.PerfOuterClass; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.junit.Assert; +import org.junit.Test; + +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.rpc.Status; + +import io.grpc.Metadata; +import io.grpc.StatusRuntimeException; +import io.grpc.protobuf.ProtoUtils; +import io.grpc.protobuf.StatusProto; + +public class TestErrorMetadata { + private static final Metadata.BinaryMarshaller<Status> marshaller = + ProtoUtils.metadataMarshaller(Status.getDefaultInstance()); + + /** Ensure metadata attached to a gRPC error is propagated. */ + @Test + public void testGrpcMetadata() throws Exception { + PerfOuterClass.Perf perf = PerfOuterClass.Perf.newBuilder() + .setStreamCount(12) + .setRecordsPerBatch(1000) + .setRecordsPerStream(1000000L) + .build(); + StatusRuntimeExceptionProducer producer = new StatusRuntimeExceptionProducer(perf); + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final FlightServer s = + FlightTestUtil.getStartedServer( + (location) -> { + return FlightServer.builder(allocator, location, producer).build(); + }); + final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { + final CallStatus flightStatus = FlightTestUtil.assertCode(FlightStatusCode.CANCELLED, () -> { + FlightStream stream = client.getStream(new Ticket("abs".getBytes())); + stream.next(); + }); + PerfOuterClass.Perf newPerf = null; + ErrorFlightMetadata metadata = flightStatus.metadata(); + Assert.assertNotNull(metadata); + Assert.assertEquals(2, metadata.keys().size()); + Assert.assertTrue(metadata.containsKey("grpc-status-details-bin")); + Status status = marshaller.parseBytes(metadata.getByte("grpc-status-details-bin")); + for (Any details : status.getDetailsList()) { + if (details.is(PerfOuterClass.Perf.class)) { + try { + newPerf = details.unpack(PerfOuterClass.Perf.class); + } catch (InvalidProtocolBufferException e) { + Assert.fail(); + } + } + } + Assert.assertNotNull(newPerf); + Assert.assertEquals(perf, newPerf); + } + } + + /** Ensure metadata attached to a Flight error is propagated. */ + @Test + public void testFlightMetadata() throws Exception { + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final FlightServer s = + FlightTestUtil.getStartedServer( + (location) -> FlightServer.builder(allocator, location, new CallStatusProducer()).build()); + final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { + CallStatus flightStatus = FlightTestUtil.assertCode(FlightStatusCode.INVALID_ARGUMENT, () -> { + FlightStream stream = client.getStream(new Ticket(new byte[0])); + stream.next(); + }); + ErrorFlightMetadata metadata = flightStatus.metadata(); + Assert.assertNotNull(metadata); + Assert.assertEquals("foo", metadata.get("x-foo")); + Assert.assertArrayEquals(new byte[]{1}, metadata.getByte("x-bar-bin")); + + flightStatus = FlightTestUtil.assertCode(FlightStatusCode.INVALID_ARGUMENT, () -> { + client.getInfo(FlightDescriptor.command(new byte[0])); + }); + metadata = flightStatus.metadata(); + Assert.assertNotNull(metadata); + Assert.assertEquals("foo", metadata.get("x-foo")); + Assert.assertArrayEquals(new byte[]{1}, metadata.getByte("x-bar-bin")); + } + } + + private static class StatusRuntimeExceptionProducer extends NoOpFlightProducer { + private final PerfOuterClass.Perf perf; + + private StatusRuntimeExceptionProducer(PerfOuterClass.Perf perf) { + this.perf = perf; + } + + @Override + public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + StatusRuntimeException sre = StatusProto.toStatusRuntimeException(Status.newBuilder() + .setCode(1) + .setMessage("Testing 1 2 3") + .addDetails(Any.pack(perf, "arrow/meta/types")) + .build()); + listener.error(sre); + } + } + + private static class CallStatusProducer extends NoOpFlightProducer { + ErrorFlightMetadata metadata; + + CallStatusProducer() { + this.metadata = new ErrorFlightMetadata(); + metadata.insert("x-foo", "foo"); + metadata.insert("x-bar-bin", new byte[]{1}); + } + + @Override + public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + listener.error(CallStatus.INVALID_ARGUMENT.withDescription("Failed").withMetadata(metadata).toRuntimeException()); + } + + @Override + public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { + throw CallStatus.INVALID_ARGUMENT.withDescription("Failed").withMetadata(metadata).toRuntimeException(); + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightClient.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightClient.java new file mode 100644 index 000000000..30e351e94 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightClient.java @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.List; + +import org.apache.arrow.flight.FlightClient.ClientStreamListener; +import org.apache.arrow.flight.TestBasicOperation.Producer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.jupiter.api.Assertions; + +public class TestFlightClient { + /** + * ARROW-5063: make sure two clients to the same location can be closed independently. + */ + @Test + public void independentShutdown() throws Exception { + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + final FlightServer server = FlightTestUtil.getStartedServer( + location -> FlightServer.builder(allocator, location, + new Producer(allocator)).build())) { + final Location location = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort()); + final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); + try (final FlightClient client1 = FlightClient.builder(allocator, location).build(); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + // Use startPut as this ensures the RPC won't finish until we want it to + final ClientStreamListener listener = client1.startPut(FlightDescriptor.path("test"), root, + new AsyncPutListener()); + try (final FlightClient client2 = FlightClient.builder(allocator, location).build()) { + client2.listActions().forEach(actionType -> Assert.assertNotNull(actionType.getType())); + } + listener.completed(); + listener.getResult(); + } + } + } + + /** + * ARROW-5978: make sure that we can properly close a client/stream after requesting dictionaries. + */ + @Ignore // Unfortunately this test is flaky in CI. + @Test + public void freeDictionaries() throws Exception { + final Schema expectedSchema = new Schema(Collections + .singletonList(new Field("encoded", + new FieldType(true, new ArrowType.Int(32, true), new DictionaryEncoding(1L, false, null)), null))); + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + final BufferAllocator serverAllocator = allocator.newChildAllocator("flight-server", 0, Integer.MAX_VALUE); + final FlightServer server = FlightTestUtil.getStartedServer( + location -> FlightServer.builder(serverAllocator, location, + new DictionaryProducer(serverAllocator)).build())) { + final Location location = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort()); + try (final FlightClient client = FlightClient.builder(allocator, location).build()) { + try (final FlightStream stream = client.getStream(new Ticket(new byte[0]))) { + Assert.assertTrue(stream.next()); + Assert.assertNotNull(stream.getDictionaryProvider().lookup(1)); + final VectorSchemaRoot root = stream.getRoot(); + Assert.assertEquals(expectedSchema, root.getSchema()); + Assert.assertEquals(6, root.getVector("encoded").getValueCount()); + try (final ValueVector decoded = DictionaryEncoder + .decode(root.getVector("encoded"), stream.getDictionaryProvider().lookup(1))) { + Assert.assertFalse(decoded.isNull(1)); + Assert.assertTrue(decoded instanceof VarCharVector); + Assert.assertArrayEquals("one".getBytes(StandardCharsets.UTF_8), ((VarCharVector) decoded).get(1)); + } + Assert.assertFalse(stream.next()); + } + // Closing stream fails if it doesn't free dictionaries; closing dictionaries fails (refcount goes negative) + // if reference isn't retained in ArrowMessage + } + } + } + + /** + * ARROW-5978: make sure that dictionary ownership can't be claimed twice. + */ + @Ignore // Unfortunately this test is flaky in CI. + @Test + public void ownDictionaries() throws Exception { + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + final BufferAllocator serverAllocator = allocator.newChildAllocator("flight-server", 0, Integer.MAX_VALUE); + final FlightServer server = FlightTestUtil.getStartedServer( + location -> FlightServer.builder(serverAllocator, location, + new DictionaryProducer(serverAllocator)).build())) { + final Location location = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort()); + try (final FlightClient client = FlightClient.builder(allocator, location).build()) { + try (final FlightStream stream = client.getStream(new Ticket(new byte[0]))) { + Assert.assertTrue(stream.next()); + Assert.assertFalse(stream.next()); + final DictionaryProvider provider = stream.takeDictionaryOwnership(); + Assertions.assertThrows(IllegalStateException.class, stream::takeDictionaryOwnership); + Assertions.assertThrows(IllegalStateException.class, stream::getDictionaryProvider); + DictionaryUtils.closeDictionaries(stream.getSchema(), provider); + } + } + } + } + + /** + * ARROW-5978: make sure that dictionaries can be used after closing the stream. + */ + @Ignore // Unfortunately this test is flaky in CI. + @Test + public void useDictionariesAfterClose() throws Exception { + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + final BufferAllocator serverAllocator = allocator.newChildAllocator("flight-server", 0, Integer.MAX_VALUE); + final FlightServer server = FlightTestUtil.getStartedServer( + location -> FlightServer.builder(serverAllocator, location, new DictionaryProducer(serverAllocator)) + .build())) { + final Location location = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort()); + try (final FlightClient client = FlightClient.builder(allocator, location).build()) { + final VectorSchemaRoot root; + final DictionaryProvider provider; + try (final FlightStream stream = client.getStream(new Ticket(new byte[0]))) { + final VectorUnloader unloader = new VectorUnloader(stream.getRoot()); + root = VectorSchemaRoot.create(stream.getSchema(), allocator); + final VectorLoader loader = new VectorLoader(root); + while (stream.next()) { + try (final ArrowRecordBatch arb = unloader.getRecordBatch()) { + loader.load(arb); + } + } + provider = stream.takeDictionaryOwnership(); + } + try (final ValueVector decoded = DictionaryEncoder + .decode(root.getVector("encoded"), provider.lookup(1))) { + Assert.assertFalse(decoded.isNull(1)); + Assert.assertTrue(decoded instanceof VarCharVector); + Assert.assertArrayEquals("one".getBytes(StandardCharsets.UTF_8), ((VarCharVector) decoded).get(1)); + } + root.close(); + DictionaryUtils.closeDictionaries(root.getSchema(), provider); + } + } + } + + static class DictionaryProducer extends NoOpFlightProducer { + + private final BufferAllocator allocator; + + public DictionaryProducer(BufferAllocator allocator) { + this.allocator = allocator; + } + + @Override + public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + final byte[] zero = "zero".getBytes(StandardCharsets.UTF_8); + final byte[] one = "one".getBytes(StandardCharsets.UTF_8); + final byte[] two = "two".getBytes(StandardCharsets.UTF_8); + try (final VarCharVector dictionaryVector = newVarCharVector("dictionary", allocator)) { + final DictionaryProvider.MapDictionaryProvider provider = new DictionaryProvider.MapDictionaryProvider(); + + dictionaryVector.allocateNew(512, 3); + dictionaryVector.setSafe(0, zero, 0, zero.length); + dictionaryVector.setSafe(1, one, 0, one.length); + dictionaryVector.setSafe(2, two, 0, two.length); + dictionaryVector.setValueCount(3); + + final Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + provider.put(dictionary); + + final FieldVector encodedVector; + try (final VarCharVector unencoded = newVarCharVector("encoded", allocator)) { + unencoded.allocateNewSafe(); + unencoded.set(1, one); + unencoded.set(2, two); + unencoded.set(3, zero); + unencoded.set(4, two); + unencoded.setValueCount(6); + encodedVector = (FieldVector) DictionaryEncoder.encode(unencoded, dictionary); + } + + final List<Field> fields = Collections.singletonList(encodedVector.getField()); + final List<FieldVector> vectors = Collections.singletonList(encodedVector); + + try (final VectorSchemaRoot root = new VectorSchemaRoot(fields, vectors, encodedVector.getValueCount())) { + listener.start(root, provider); + listener.putNext(); + listener.completed(); + } + } + } + + private static VarCharVector newVarCharVector(String name, BufferAllocator allocator) { + return (VarCharVector) + FieldType.nullable(new ArrowType.Utf8()).createNewSingleVector(name, allocator, null); + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java new file mode 100644 index 000000000..65ef12a8a --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import static org.junit.jupiter.api.Assertions.fail; + +import org.apache.arrow.flight.impl.Flight; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import io.grpc.stub.ServerCallStreamObserver; + +public class TestFlightService { + + private BufferAllocator allocator; + + @Before + public void setup() { + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @After + public void cleanup() throws Exception { + AutoCloseables.close(allocator); + } + + @Test + public void testFlightServiceWithNoAuthHandlerOrInterceptors() { + // This test is for ARROW-10491. There was a bug where FlightService would try to access the RequestContext, + // but the RequestContext was getting set to null because no interceptors were active to initialize it + // when using FlightService directly rather than starting up a FlightServer. + + // Arrange + final FlightProducer producer = new NoOpFlightProducer() { + @Override + public void getStream(CallContext context, Ticket ticket, + ServerStreamListener listener) { + listener.completed(); + } + }; + + // This response observer notifies that the test failed if onError() is called. + final ServerCallStreamObserver<ArrowMessage> observer = new ServerCallStreamObserver<ArrowMessage>() { + @Override + public boolean isCancelled() { + return false; + } + + @Override + public void setOnCancelHandler(Runnable runnable) { + + } + + @Override + public void setCompression(String s) { + + } + + @Override + public boolean isReady() { + return false; + } + + @Override + public void setOnReadyHandler(Runnable runnable) { + + } + + @Override + public void disableAutoInboundFlowControl() { + + } + + @Override + public void request(int i) { + + } + + @Override + public void setMessageCompression(boolean b) { + + } + + @Override + public void onNext(ArrowMessage arrowMessage) { + + } + + @Override + public void onError(Throwable throwable) { + fail(throwable); + } + + @Override + public void onCompleted() { + + } + }; + final FlightService flightService = new FlightService(allocator, producer, null, null); + + // Act + flightService.doGetCustom(Flight.Ticket.newBuilder().build(), observer); + + // fail() would have been called if an error happened during doGetCustom(), so this test passed. + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLargeMessage.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLargeMessage.java new file mode 100644 index 000000000..629b6f5eb --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLargeMessage.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Assert; +import org.junit.Test; + +public class TestLargeMessage { + /** + * Make sure a Flight client accepts large message payloads by default. + */ + @Test + public void getLargeMessage() throws Exception { + try (final BufferAllocator a = new RootAllocator(Long.MAX_VALUE); + final Producer producer = new Producer(a); + final FlightServer s = + FlightTestUtil.getStartedServer((location) -> FlightServer.builder(a, location, producer).build())) { + + try (FlightClient client = FlightClient.builder(a, s.getLocation()).build()) { + try (FlightStream stream = client.getStream(new Ticket(new byte[]{})); + VectorSchemaRoot root = stream.getRoot()) { + while (stream.next()) { + for (final Field field : root.getSchema().getFields()) { + int value = 0; + final IntVector iv = (IntVector) root.getVector(field.getName()); + for (int i = 0; i < root.getRowCount(); i++) { + Assert.assertEquals(value, iv.get(i)); + value++; + } + } + } + } + } + } + } + + /** + * Make sure a Flight server accepts large message payloads by default. + */ + @Test + public void putLargeMessage() throws Exception { + try (final BufferAllocator a = new RootAllocator(Long.MAX_VALUE); + final Producer producer = new Producer(a); + final FlightServer s = + FlightTestUtil.getStartedServer((location) -> FlightServer.builder(a, location, producer).build() + )) { + + try (FlightClient client = FlightClient.builder(a, s.getLocation()).build(); + BufferAllocator testAllocator = a.newChildAllocator("testcase", 0, Long.MAX_VALUE); + VectorSchemaRoot root = generateData(testAllocator)) { + final FlightClient.ClientStreamListener listener = client.startPut(FlightDescriptor.path("hello"), root, + new AsyncPutListener()); + listener.putNext(); + listener.completed(); + listener.getResult(); + } + } + } + + private static VectorSchemaRoot generateData(BufferAllocator allocator) { + final int size = 128 * 1024; + final List<String> fieldNames = Arrays.asList("c1", "c2", "c3", "c4", "c5", "c6", "c7", "c8", "c9", "c10"); + final Stream<Field> fields = fieldNames + .stream() + .map(fieldName -> new Field(fieldName, FieldType.nullable(new ArrowType.Int(32, true)), null)); + final Schema schema = new Schema(fields::iterator, null); + + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + root.allocateNew(); + for (final String fieldName : fieldNames) { + final IntVector iv = (IntVector) root.getVector(fieldName); + iv.setValueCount(size); + for (int i = 0; i < size; i++) { + iv.set(i, i); + } + } + root.setRowCount(size); + return root; + } + + private static class Producer implements FlightProducer, AutoCloseable { + private final BufferAllocator allocator; + + Producer(BufferAllocator allocator) { + this.allocator = allocator; + } + + @Override + public void getStream(CallContext context, Ticket ticket, + ServerStreamListener listener) { + try (VectorSchemaRoot root = generateData(allocator)) { + listener.start(root); + listener.putNext(); + listener.completed(); + } + } + + @Override + public void listFlights(CallContext context, Criteria criteria, + StreamListener<FlightInfo> listener) { + + } + + @Override + public FlightInfo getFlightInfo(CallContext context, + FlightDescriptor descriptor) { + return null; + } + + @Override + public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream) { + return () -> { + try (VectorSchemaRoot root = flightStream.getRoot()) { + while (flightStream.next()) { + ; + } + } + }; + } + + @Override + public void doAction(CallContext context, Action action, + StreamListener<Result> listener) { + listener.onCompleted(); + } + + @Override + public void listActions(CallContext context, + StreamListener<ActionType> listener) { + + } + + @Override + public void close() throws Exception { + allocator.close(); + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLeak.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLeak.java new file mode 100644 index 000000000..6e2870499 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLeak.java @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.Arrays; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Test; + +/** + * Tests for scenarios where Flight could leak memory. + */ +public class TestLeak { + + private static final int ROWS = 2048; + + private static Schema getSchema() { + return new Schema(Arrays.asList( + Field.nullable("0", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), + Field.nullable("1", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), + Field.nullable("2", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), + Field.nullable("3", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), + Field.nullable("4", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), + Field.nullable("5", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), + Field.nullable("6", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), + Field.nullable("7", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), + Field.nullable("8", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), + Field.nullable("9", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), + Field.nullable("10", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)) + )); + } + + /** + * Ensure that if the client cancels, the server does not leak memory. + * + * <p>In gRPC, canceling the stream from the client sends an event to the server. Once processed, gRPC will start + * silently rejecting messages sent by the server. However, Flight depends on gRPC processing these messages in order + * to free the associated memory. + */ + @Test + public void testCancelingDoGetDoesNotLeak() throws Exception { + final CountDownLatch callFinished = new CountDownLatch(1); + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final FlightServer s = + FlightTestUtil.getStartedServer( + (location) -> FlightServer.builder(allocator, location, new LeakFlightProducer(allocator, callFinished)) + .build()); + final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { + + final FlightStream stream = client.getStream(new Ticket(new byte[0])); + stream.getRoot(); + stream.cancel("Cancel", null); + + // Wait for the call to finish. (Closing the allocator while a call is ongoing is a guaranteed leak.) + callFinished.await(60, TimeUnit.SECONDS); + + s.shutdown(); + s.awaitTermination(); + } + } + + @Test + public void testCancelingDoPutDoesNotBlock() throws Exception { + final CountDownLatch callFinished = new CountDownLatch(1); + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final FlightServer s = + FlightTestUtil.getStartedServer( + (location) -> FlightServer.builder(allocator, location, new LeakFlightProducer(allocator, callFinished)) + .build()); + final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { + + try (final VectorSchemaRoot root = VectorSchemaRoot.create(getSchema(), allocator)) { + final FlightDescriptor descriptor = FlightDescriptor.command(new byte[0]); + final SyncPutListener listener = new SyncPutListener(); + final FlightClient.ClientStreamListener stream = client.startPut(descriptor, root, listener); + // Wait for the server to cancel + callFinished.await(60, TimeUnit.SECONDS); + + for (int col = 0; col < 11; col++) { + final Float8Vector vector = (Float8Vector) root.getVector(Integer.toString(col)); + vector.allocateNew(); + for (int row = 0; row < ROWS; row++) { + vector.setSafe(row, 10.); + } + } + root.setRowCount(ROWS); + // Unlike DoGet, this method fairly reliably will write the message to the stream, so even without the fix + // for ARROW-7343, this won't leak memory. + // However, it will block if FlightClient doesn't check for cancellation. + stream.putNext(); + stream.completed(); + } + + s.shutdown(); + s.awaitTermination(); + } + } + + /** + * A FlightProducer that always produces a fixed data stream with metadata on the side. + */ + private static class LeakFlightProducer extends NoOpFlightProducer { + + private final BufferAllocator allocator; + private final CountDownLatch callFinished; + + public LeakFlightProducer(BufferAllocator allocator, CountDownLatch callFinished) { + this.allocator = allocator; + this.callFinished = callFinished; + } + + @Override + public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + BufferAllocator childAllocator = allocator.newChildAllocator("foo", 0, Long.MAX_VALUE); + VectorSchemaRoot root = VectorSchemaRoot.create(TestLeak.getSchema(), childAllocator); + root.allocateNew(); + listener.start(root); + + // We can't poll listener#isCancelled since gRPC has two distinct "is cancelled" flags. + // TODO: should we continue leaking gRPC semantics? Can we even avoid this? + listener.setOnCancelHandler(() -> { + try { + for (int col = 0; col < 11; col++) { + final Float8Vector vector = (Float8Vector) root.getVector(Integer.toString(col)); + vector.allocateNew(); + for (int row = 0; row < ROWS; row++) { + vector.setSafe(row, 10.); + } + } + root.setRowCount(ROWS); + // Once the call is "really cancelled" (setOnCancelListener has run/is running), this call is actually a + // no-op on the gRPC side and will leak the ArrowMessage unless Flight checks for this. + listener.putNext(); + listener.completed(); + } finally { + try { + root.close(); + childAllocator.close(); + } finally { + // Don't let the test hang if we throw above + callFinished.countDown(); + } + } + }); + } + + @Override + public Runnable acceptPut(CallContext context, + FlightStream flightStream, StreamListener<PutResult> ackStream) { + return () -> { + flightStream.getRoot(); + ackStream.onError(CallStatus.CANCELLED.withDescription("CANCELLED").toRuntimeException()); + callFinished.countDown(); + ackStream.onCompleted(); + }; + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java new file mode 100644 index 000000000..83a694bf3 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java @@ -0,0 +1,319 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Collections; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.types.MetadataVersion; +import org.apache.arrow.vector.types.UnionMode; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +/** + * Test clients/servers with different metadata versions. + */ +public class TestMetadataVersion { + private static BufferAllocator allocator; + private static Schema schema; + private static IpcOption optionV4; + private static IpcOption optionV5; + private static Schema unionSchema; + + @BeforeClass + public static void setUpClass() { + allocator = new RootAllocator(Integer.MAX_VALUE); + schema = new Schema(Collections.singletonList(Field.nullable("foo", new ArrowType.Int(32, true)))); + unionSchema = new Schema( + Collections.singletonList(Field.nullable("union", new ArrowType.Union(UnionMode.Dense, new int[]{0})))); + + // avoid writing legacy ipc format by default + optionV4 = new IpcOption(false, MetadataVersion.V4); + optionV5 = IpcOption.DEFAULT; + } + + @AfterClass + public static void tearDownClass() { + allocator.close(); + } + + @Test + public void testGetFlightInfoV4() throws Exception { + try (final FlightServer server = startServer(optionV4); + final FlightClient client = connect(server)) { + final FlightInfo result = client.getInfo(FlightDescriptor.command(new byte[0])); + assertEquals(schema, result.getSchema()); + } + } + + @Test + public void testGetSchemaV4() throws Exception { + try (final FlightServer server = startServer(optionV4); + final FlightClient client = connect(server)) { + final SchemaResult result = client.getSchema(FlightDescriptor.command(new byte[0])); + assertEquals(schema, result.getSchema()); + } + } + + @Test + public void testUnionCheck() throws Exception { + assertThrows(IllegalArgumentException.class, () -> new SchemaResult(unionSchema, optionV4)); + assertThrows(IllegalArgumentException.class, () -> + new FlightInfo(unionSchema, FlightDescriptor.command(new byte[0]), Collections.emptyList(), -1, -1, optionV4)); + try (final FlightServer server = startServer(optionV4); + final FlightClient client = connect(server); + final FlightStream stream = client.getStream(new Ticket("union".getBytes(StandardCharsets.UTF_8)))) { + final FlightRuntimeException err = assertThrows(FlightRuntimeException.class, stream::next); + assertTrue(err.getMessage(), err.getMessage().contains("Cannot write union with V4 metadata")); + } + + try (final FlightServer server = startServer(optionV4); + final FlightClient client = connect(server); + final VectorSchemaRoot root = VectorSchemaRoot.create(unionSchema, allocator)) { + final FlightDescriptor descriptor = FlightDescriptor.command(new byte[0]); + final SyncPutListener reader = new SyncPutListener(); + final FlightClient.ClientStreamListener listener = client.startPut(descriptor, reader); + final IllegalArgumentException err = assertThrows(IllegalArgumentException.class, + () -> listener.start(root, null, optionV4)); + assertTrue(err.getMessage(), err.getMessage().contains("Cannot write union with V4 metadata")); + } + } + + @Test + public void testPutV4() throws Exception { + try (final FlightServer server = startServer(optionV4); + final FlightClient client = connect(server); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + generateData(root); + final FlightDescriptor descriptor = FlightDescriptor.command(new byte[0]); + final SyncPutListener reader = new SyncPutListener(); + final FlightClient.ClientStreamListener listener = client.startPut(descriptor, reader); + listener.start(root, null, optionV4); + listener.putNext(); + listener.completed(); + listener.getResult(); + } + } + + @Test + public void testGetV4() throws Exception { + try (final FlightServer server = startServer(optionV4); + final FlightClient client = connect(server); + final FlightStream stream = client.getStream(new Ticket(new byte[0]))) { + assertTrue(stream.next()); + assertEquals(optionV4.metadataVersion, stream.metadataVersion); + validateRoot(stream.getRoot()); + assertFalse(stream.next()); + } + } + + @Test + public void testExchangeV4ToV5() throws Exception { + try (final FlightServer server = startServer(optionV5); + final FlightClient client = connect(server); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + final FlightClient.ExchangeReaderWriter stream = client.doExchange(FlightDescriptor.command(new byte[0]))) { + stream.getWriter().start(root, null, optionV4); + generateData(root); + stream.getWriter().putNext(); + stream.getWriter().completed(); + assertTrue(stream.getReader().next()); + assertEquals(optionV5.metadataVersion, stream.getReader().metadataVersion); + validateRoot(stream.getReader().getRoot()); + assertFalse(stream.getReader().next()); + } + } + + @Test + public void testExchangeV5ToV4() throws Exception { + try (final FlightServer server = startServer(optionV4); + final FlightClient client = connect(server); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + final FlightClient.ExchangeReaderWriter stream = client.doExchange(FlightDescriptor.command(new byte[0]))) { + stream.getWriter().start(root, null, optionV5); + generateData(root); + stream.getWriter().putNext(); + stream.getWriter().completed(); + assertTrue(stream.getReader().next()); + assertEquals(optionV4.metadataVersion, stream.getReader().metadataVersion); + validateRoot(stream.getReader().getRoot()); + assertFalse(stream.getReader().next()); + } + } + + @Test + public void testExchangeV4ToV4() throws Exception { + try (final FlightServer server = startServer(optionV4); + final FlightClient client = connect(server); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + final FlightClient.ExchangeReaderWriter stream = client.doExchange(FlightDescriptor.command(new byte[0]))) { + stream.getWriter().start(root, null, optionV4); + generateData(root); + stream.getWriter().putNext(); + stream.getWriter().completed(); + assertTrue(stream.getReader().next()); + assertEquals(optionV4.metadataVersion, stream.getReader().metadataVersion); + validateRoot(stream.getReader().getRoot()); + assertFalse(stream.getReader().next()); + } + } + + private static void generateData(VectorSchemaRoot root) { + assertEquals(schema, root.getSchema()); + final IntVector vector = (IntVector) root.getVector("foo"); + vector.setSafe(0, 0); + vector.setSafe(1, 1); + vector.setSafe(2, 4); + root.setRowCount(3); + } + + private static void validateRoot(VectorSchemaRoot root) { + assertEquals(schema, root.getSchema()); + assertEquals(3, root.getRowCount()); + final IntVector vector = (IntVector) root.getVector("foo"); + assertEquals(0, vector.get(0)); + assertEquals(1, vector.get(1)); + assertEquals(4, vector.get(2)); + } + + FlightServer startServer(IpcOption option) throws Exception { + Location location = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, 0); + VersionFlightProducer producer = new VersionFlightProducer(allocator, option); + final FlightServer server = FlightServer.builder(allocator, location, producer).build(); + server.start(); + return server; + } + + FlightClient connect(FlightServer server) { + Location location = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort()); + return FlightClient.builder(allocator, location).build(); + } + + static final class VersionFlightProducer extends NoOpFlightProducer { + private final BufferAllocator allocator; + private final IpcOption option; + + VersionFlightProducer(BufferAllocator allocator, IpcOption option) { + this.allocator = allocator; + this.option = option; + } + + @Override + public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { + return new FlightInfo(schema, descriptor, Collections.emptyList(), -1, -1, option); + } + + @Override + public SchemaResult getSchema(CallContext context, FlightDescriptor descriptor) { + return new SchemaResult(schema, option); + } + + @Override + public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + if (Arrays.equals("union".getBytes(StandardCharsets.UTF_8), ticket.getBytes())) { + try (final VectorSchemaRoot root = VectorSchemaRoot.create(unionSchema, allocator)) { + listener.start(root, null, option); + } catch (IllegalArgumentException e) { + listener.error(CallStatus.INTERNAL.withCause(e).withDescription(e.getMessage()).toRuntimeException()); + return; + } + listener.error(CallStatus.INTERNAL.withDescription("Expected exception not raised").toRuntimeException()); + return; + } + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + listener.start(root, null, option); + generateData(root); + listener.putNext(); + listener.completed(); + } + } + + @Override + public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream) { + return () -> { + try { + assertTrue(flightStream.next()); + assertEquals(option.metadataVersion, flightStream.metadataVersion); + validateRoot(flightStream.getRoot()); + } catch (AssertionError err) { + // gRPC doesn't propagate stack traces across the wire. + err.printStackTrace(); + ackStream.onError(CallStatus.INVALID_ARGUMENT + .withCause(err) + .withDescription("Server assertion failed: " + err) + .toRuntimeException()); + return; + } catch (RuntimeException err) { + err.printStackTrace(); + ackStream.onError(CallStatus.INTERNAL + .withCause(err) + .withDescription("Server assertion failed: " + err) + .toRuntimeException()); + return; + } + ackStream.onCompleted(); + }; + } + + @Override + public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) { + try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + try { + assertTrue(reader.next()); + validateRoot(reader.getRoot()); + assertFalse(reader.next()); + } catch (AssertionError err) { + // gRPC doesn't propagate stack traces across the wire. + err.printStackTrace(); + writer.error(CallStatus.INVALID_ARGUMENT + .withCause(err) + .withDescription("Server assertion failed: " + err) + .toRuntimeException()); + return; + } catch (RuntimeException err) { + err.printStackTrace(); + writer.error(CallStatus.INTERNAL + .withCause(err) + .withDescription("Server assertion failed: " + err) + .toRuntimeException()); + return; + } + + writer.start(root, null, option); + generateData(root); + writer.putNext(); + writer.completed(); + } + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerMiddleware.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerMiddleware.java new file mode 100644 index 000000000..1f3e35ca3 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerMiddleware.java @@ -0,0 +1,360 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.BiConsumer; + +import org.apache.arrow.flight.FlightClient.ClientStreamListener; +import org.apache.arrow.flight.FlightServerMiddleware.Factory; +import org.apache.arrow.flight.FlightServerMiddleware.Key; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class TestServerMiddleware { + + private static final RuntimeException EXPECTED_EXCEPTION = new RuntimeException("test"); + + /** + * Make sure errors in DoPut are intercepted. + */ + @Test + public void doPutErrors() { + test( + new ErrorProducer(EXPECTED_EXCEPTION), + (allocator, client) -> { + final FlightDescriptor descriptor = FlightDescriptor.path("test"); + try (final VectorSchemaRoot root = VectorSchemaRoot.create(new Schema(Collections.emptyList()), allocator)) { + final ClientStreamListener listener = client.startPut(descriptor, root, new SyncPutListener()); + listener.completed(); + FlightTestUtil.assertCode(FlightStatusCode.INTERNAL, listener::getResult); + } + }, (recorder) -> { + final CallStatus status = recorder.statusFuture.get(); + Assert.assertNotNull(status); + Assert.assertNotNull(status.cause()); + Assert.assertEquals(FlightStatusCode.INTERNAL, status.code()); + }); + // Check the status after server shutdown (to make sure gRPC finishes pending calls on the server side) + } + + /** + * Make sure custom error codes in DoPut are intercepted. + */ + @Test + public void doPutCustomCode() { + test( + new ErrorProducer(CallStatus.UNAVAILABLE.withDescription("description").toRuntimeException()), + (allocator, client) -> { + final FlightDescriptor descriptor = FlightDescriptor.path("test"); + try (final VectorSchemaRoot root = VectorSchemaRoot.create(new Schema(Collections.emptyList()), allocator)) { + final ClientStreamListener listener = client.startPut(descriptor, root, new SyncPutListener()); + listener.completed(); + FlightTestUtil.assertCode(FlightStatusCode.UNAVAILABLE, listener::getResult); + } + }, (recorder) -> { + final CallStatus status = recorder.statusFuture.get(); + Assert.assertNotNull(status); + Assert.assertNull(status.cause()); + Assert.assertEquals(FlightStatusCode.UNAVAILABLE, status.code()); + Assert.assertEquals("description", status.description()); + }); + } + + /** + * Make sure uncaught exceptions in DoPut are intercepted. + */ + @Test + public void doPutUncaught() { + test(new ServerErrorProducer(EXPECTED_EXCEPTION), + (allocator, client) -> { + final FlightDescriptor descriptor = FlightDescriptor.path("test"); + try (final VectorSchemaRoot root = VectorSchemaRoot.create(new Schema(Collections.emptyList()), allocator)) { + final ClientStreamListener listener = client.startPut(descriptor, root, new SyncPutListener()); + listener.completed(); + listener.getResult(); + } + }, (recorder) -> { + final CallStatus status = recorder.statusFuture.get(); + final Throwable err = recorder.errFuture.get(); + Assert.assertNotNull(status); + Assert.assertEquals(FlightStatusCode.OK, status.code()); + Assert.assertNull(status.cause()); + Assert.assertNotNull(err); + Assert.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage()); + }); + } + + @Test + public void listFlightsUncaught() { + test(new ServerErrorProducer(EXPECTED_EXCEPTION), + (allocator, client) -> client.listFlights(new Criteria(new byte[0])).forEach((action) -> { + }), (recorder) -> { + final CallStatus status = recorder.statusFuture.get(); + final Throwable err = recorder.errFuture.get(); + Assert.assertNotNull(status); + Assert.assertEquals(FlightStatusCode.OK, status.code()); + Assert.assertNull(status.cause()); + Assert.assertNotNull(err); + Assert.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage()); + }); + } + + @Test + public void doActionUncaught() { + test(new ServerErrorProducer(EXPECTED_EXCEPTION), + (allocator, client) -> client.doAction(new Action("test")).forEachRemaining(result -> { + }), (recorder) -> { + final CallStatus status = recorder.statusFuture.get(); + final Throwable err = recorder.errFuture.get(); + Assert.assertNotNull(status); + Assert.assertEquals(FlightStatusCode.OK, status.code()); + Assert.assertNull(status.cause()); + Assert.assertNotNull(err); + Assert.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage()); + }); + } + + @Test + public void listActionsUncaught() { + test(new ServerErrorProducer(EXPECTED_EXCEPTION), + (allocator, client) -> client.listActions().forEach(result -> { + }), (recorder) -> { + final CallStatus status = recorder.statusFuture.get(); + final Throwable err = recorder.errFuture.get(); + Assert.assertNotNull(status); + Assert.assertEquals(FlightStatusCode.OK, status.code()); + Assert.assertNull(status.cause()); + Assert.assertNotNull(err); + Assert.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage()); + }); + } + + @Test + public void getFlightInfoUncaught() { + test(new ServerErrorProducer(EXPECTED_EXCEPTION), + (allocator, client) -> { + FlightTestUtil.assertCode(FlightStatusCode.INTERNAL, () -> client.getInfo(FlightDescriptor.path("test"))); + }, (recorder) -> { + final CallStatus status = recorder.statusFuture.get(); + Assert.assertNotNull(status); + Assert.assertEquals(FlightStatusCode.INTERNAL, status.code()); + Assert.assertNotNull(status.cause()); + Assert.assertEquals(EXPECTED_EXCEPTION.getMessage(), status.cause().getMessage()); + }); + } + + @Test + public void doGetUncaught() { + test(new ServerErrorProducer(EXPECTED_EXCEPTION), + (allocator, client) -> { + try (final FlightStream stream = client.getStream(new Ticket(new byte[0]))) { + while (stream.next()) { + } + } catch (Exception e) { + Assert.fail(e.toString()); + } + }, (recorder) -> { + final CallStatus status = recorder.statusFuture.get(); + final Throwable err = recorder.errFuture.get(); + Assert.assertNotNull(status); + Assert.assertEquals(FlightStatusCode.OK, status.code()); + Assert.assertNull(status.cause()); + Assert.assertNotNull(err); + Assert.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage()); + }); + } + + /** + * A middleware that records the last error on any call. + */ + static class ErrorRecorder implements FlightServerMiddleware { + + CompletableFuture<CallStatus> statusFuture = new CompletableFuture<>(); + CompletableFuture<Throwable> errFuture = new CompletableFuture<>(); + + @Override + public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { + } + + @Override + public void onCallCompleted(CallStatus status) { + statusFuture.complete(status); + } + + @Override + public void onCallErrored(Throwable err) { + errFuture.complete(err); + } + + static class Factory implements FlightServerMiddleware.Factory<ErrorRecorder> { + + ErrorRecorder instance = new ErrorRecorder(); + + @Override + public ErrorRecorder onCallStarted(CallInfo info, CallHeaders incomingHeaders, RequestContext context) { + return instance; + } + } + } + + /** + * A producer that throws the given exception on a call. + */ + static class ErrorProducer extends NoOpFlightProducer { + + final RuntimeException error; + + ErrorProducer(RuntimeException t) { + error = t; + } + + @Override + public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream) { + return () -> { + // Drain queue to avoid FlightStream#close cancelling the call + while (flightStream.next()) { + } + throw error; + }; + } + } + + /** + * A producer that throws the given exception on a call, but only after sending a success to the client. + */ + static class ServerErrorProducer extends NoOpFlightProducer { + + final RuntimeException error; + + ServerErrorProducer(RuntimeException t) { + error = t; + } + + @Override + public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(new Schema(Collections.emptyList()), allocator)) { + listener.start(root); + listener.completed(); + } + throw error; + } + + @Override + public void listFlights(CallContext context, Criteria criteria, StreamListener<FlightInfo> listener) { + listener.onCompleted(); + throw error; + } + + @Override + public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { + throw error; + } + + @Override + public Runnable acceptPut(CallContext context, FlightStream flightStream, StreamListener<PutResult> ackStream) { + return () -> { + while (flightStream.next()) { + } + ackStream.onCompleted(); + throw error; + }; + } + + @Override + public void doAction(CallContext context, Action action, StreamListener<Result> listener) { + listener.onCompleted(); + throw error; + } + + @Override + public void listActions(CallContext context, StreamListener<ActionType> listener) { + listener.onCompleted(); + throw error; + } + } + + static class ServerMiddlewarePair<T extends FlightServerMiddleware> { + + final FlightServerMiddleware.Key<T> key; + final FlightServerMiddleware.Factory<T> factory; + + ServerMiddlewarePair(Key<T> key, Factory<T> factory) { + this.key = key; + this.factory = factory; + } + } + + /** + * Spin up a service with the given middleware and producer. + * + * @param producer The Flight producer to use. + * @param middleware A list of middleware to register. + * @param body A function to run as the body of the test. + * @param <T> The middleware type. + */ + static <T extends FlightServerMiddleware> void test(FlightProducer producer, List<ServerMiddlewarePair<T>> middleware, + BiConsumer<BufferAllocator, FlightClient> body) { + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) { + final FlightServer server = FlightTestUtil + .getStartedServer(location -> { + final FlightServer.Builder builder = FlightServer.builder(allocator, location, producer); + middleware.forEach(pair -> builder.middleware(pair.key, pair.factory)); + return builder.build(); + }); + try (final FlightServer ignored = server; + final FlightClient client = FlightClient.builder(allocator, server.getLocation()).build() + ) { + body.accept(allocator, client); + } + } catch (InterruptedException | IOException e) { + throw new RuntimeException(e); + } + } + + static void test(FlightProducer producer, BiConsumer<BufferAllocator, FlightClient> body, + ErrorConsumer<ErrorRecorder> verify) { + final ErrorRecorder.Factory factory = new ErrorRecorder.Factory(); + final List<ServerMiddlewarePair<ErrorRecorder>> middleware = Collections + .singletonList(new ServerMiddlewarePair<>(Key.of("m"), factory)); + test(producer, middleware, (allocator, client) -> { + body.accept(allocator, client); + try { + verify.accept(factory.instance); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + } + + @FunctionalInterface + interface ErrorConsumer<T> { + void accept(T obj) throws Exception; + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerOptions.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerOptions.java new file mode 100644 index 000000000..363ad443e --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerOptions.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +import java.io.File; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; + +import org.apache.arrow.flight.TestBasicOperation.Producer; +import org.apache.arrow.flight.auth.ServerAuthHandler; +import org.apache.arrow.flight.impl.FlightServiceGrpc; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import io.grpc.MethodDescriptor; +import io.grpc.ServerServiceDefinition; +import io.grpc.netty.NettyServerBuilder; + +@RunWith(JUnit4.class) +public class TestServerOptions { + + @Test + public void builderConsumer() throws Exception { + final AtomicBoolean consumerCalled = new AtomicBoolean(); + final Consumer<NettyServerBuilder> consumer = (builder) -> consumerCalled.set(true); + + try ( + BufferAllocator a = new RootAllocator(Long.MAX_VALUE); + Producer producer = new Producer(a); + FlightServer s = + FlightTestUtil.getStartedServer( + (location) -> FlightServer.builder(a, location, producer) + .transportHint("grpc.builderConsumer", consumer).build() + )) { + Assert.assertTrue(consumerCalled.get()); + } + } + + /** + * Make sure that if Flight supplies a default executor to gRPC, then it is closed along with the server. + */ + @Test + public void defaultExecutorClosed() throws Exception { + final ExecutorService executor; + try ( + BufferAllocator a = new RootAllocator(Long.MAX_VALUE); + FlightServer server = + FlightTestUtil.getStartedServer( + (location) -> FlightServer.builder(a, location, new NoOpFlightProducer()) + .build() + )) { + assertNotNull(server.grpcExecutor); + executor = server.grpcExecutor; + } + Assert.assertTrue(executor.isShutdown()); + } + + /** + * Make sure that if the user provides an executor to gRPC, then Flight does not close it. + */ + @Test + public void suppliedExecutorNotClosed() throws Exception { + final ExecutorService executor = Executors.newSingleThreadExecutor(); + try { + try ( + BufferAllocator a = new RootAllocator(Long.MAX_VALUE); + FlightServer server = + FlightTestUtil.getStartedServer( + (location) -> FlightServer.builder(a, location, new NoOpFlightProducer()) + .executor(executor) + .build() + )) { + Assert.assertNull(server.grpcExecutor); + } + Assert.assertFalse(executor.isShutdown()); + } finally { + executor.shutdown(); + } + } + + @Test + public void domainSocket() throws Exception { + Assume.assumeTrue("We have a native transport available", FlightTestUtil.isNativeTransportAvailable()); + final File domainSocket = File.createTempFile("flight-unit-test-", ".sock"); + Assert.assertTrue(domainSocket.delete()); + // Domain socket paths have a platform-dependent limit. Set a conservative limit and skip the test if the temporary + // file name is too long. (We do not assume a particular platform-dependent temporary directory path.) + Assume.assumeTrue("The domain socket path is not too long", domainSocket.getAbsolutePath().length() < 100); + final Location location = Location.forGrpcDomainSocket(domainSocket.getAbsolutePath()); + try ( + BufferAllocator a = new RootAllocator(Long.MAX_VALUE); + Producer producer = new Producer(a); + FlightServer s = + FlightTestUtil.getStartedServer( + (port) -> FlightServer.builder(a, location, producer).build() + )) { + try (FlightClient c = FlightClient.builder(a, location).build()) { + try (FlightStream stream = c.getStream(new Ticket(new byte[0]))) { + VectorSchemaRoot root = stream.getRoot(); + IntVector iv = (IntVector) root.getVector("c1"); + int value = 0; + while (stream.next()) { + for (int i = 0; i < root.getRowCount(); i++) { + Assert.assertEquals(value, iv.get(i)); + value++; + } + } + } + } + } + } + + @Test + public void checkReflectionMetadata() { + // This metadata is needed for gRPC reflection to work. + final ExecutorService executorService = Executors.newSingleThreadExecutor(); + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) { + final FlightBindingService service = new FlightBindingService(allocator, new NoOpFlightProducer(), + ServerAuthHandler.NO_OP, executorService); + final ServerServiceDefinition definition = service.bindService(); + assertEquals(FlightServiceGrpc.getServiceDescriptor().getSchemaDescriptor(), + definition.getServiceDescriptor().getSchemaDescriptor()); + + final Map<String, MethodDescriptor<?, ?>> definedMethods = new HashMap<>(); + final Map<String, MethodDescriptor<?, ?>> serviceMethods = new HashMap<>(); + + // Make sure that the reflection metadata object is identical across all the places where it's accessible + definition.getMethods().forEach( + method -> definedMethods.put(method.getMethodDescriptor().getFullMethodName(), method.getMethodDescriptor())); + definition.getServiceDescriptor().getMethods().forEach( + method -> serviceMethods.put(method.getFullMethodName(), method)); + + for (final MethodDescriptor<?, ?> descriptor : FlightServiceGrpc.getServiceDescriptor().getMethods()) { + final String methodName = descriptor.getFullMethodName(); + Assert.assertTrue("Method is missing from ServerServiceDefinition: " + methodName, + definedMethods.containsKey(methodName)); + Assert.assertTrue("Method is missing from ServiceDescriptor: " + methodName, + definedMethods.containsKey(methodName)); + + assertEquals(descriptor.getSchemaDescriptor(), definedMethods.get(methodName).getSchemaDescriptor()); + assertEquals(descriptor.getSchemaDescriptor(), serviceMethods.get(methodName).getSchemaDescriptor()); + } + } finally { + executorService.shutdown(); + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestTls.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestTls.java new file mode 100644 index 000000000..c5cd871e2 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestTls.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.Iterator; +import java.util.function.Consumer; + +import org.apache.arrow.flight.FlightClient.Builder; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.junit.Assert; +import org.junit.Test; + +/** + * Tests for TLS in Flight. + */ +public class TestTls { + + /** + * Test a basic request over TLS. + */ + @Test + public void connectTls() { + test((builder) -> { + try (final InputStream roots = new FileInputStream(FlightTestUtil.exampleTlsRootCert().toFile()); + final FlightClient client = builder.trustedCertificates(roots).build()) { + final Iterator<Result> responses = client.doAction(new Action("hello-world")); + final byte[] response = responses.next().getBody(); + Assert.assertEquals("Hello, world!", new String(response, StandardCharsets.UTF_8)); + Assert.assertFalse(responses.hasNext()); + } catch (InterruptedException | IOException e) { + throw new RuntimeException(e); + } + }); + } + + /** + * Make sure that connections are rejected when the root certificate isn't trusted. + */ + @Test + public void rejectInvalidCert() { + test((builder) -> { + try (final FlightClient client = builder.build()) { + final Iterator<Result> responses = client.doAction(new Action("hello-world")); + FlightTestUtil.assertCode(FlightStatusCode.UNAVAILABLE, () -> responses.next().getBody()); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + } + + /** + * Make sure that connections are rejected when the hostname doesn't match. + */ + @Test + public void rejectHostname() { + test((builder) -> { + try (final InputStream roots = new FileInputStream(FlightTestUtil.exampleTlsRootCert().toFile()); + final FlightClient client = builder.trustedCertificates(roots).overrideHostname("fakehostname") + .build()) { + final Iterator<Result> responses = client.doAction(new Action("hello-world")); + FlightTestUtil.assertCode(FlightStatusCode.UNAVAILABLE, () -> responses.next().getBody()); + } catch (InterruptedException | IOException e) { + throw new RuntimeException(e); + } + }); + } + + /** + * Test a basic request over TLS. + */ + @Test + public void connectTlsDisableServerVerification() { + test((builder) -> { + try (final FlightClient client = builder.verifyServer(false).build()) { + final Iterator<Result> responses = client.doAction(new Action("hello-world")); + final byte[] response = responses.next().getBody(); + Assert.assertEquals("Hello, world!", new String(response, StandardCharsets.UTF_8)); + Assert.assertFalse(responses.hasNext()); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + } + + void test(Consumer<Builder> testFn) { + final FlightTestUtil.CertKeyPair certKey = FlightTestUtil.exampleTlsCerts().get(0); + try ( + BufferAllocator a = new RootAllocator(Long.MAX_VALUE); + Producer producer = new Producer(); + FlightServer s = + FlightTestUtil.getStartedServer( + (location) -> { + try { + return FlightServer.builder(a, location, producer) + .useTls(certKey.cert, certKey.key) + .build(); + } catch (IOException e) { + throw new RuntimeException(e); + } + })) { + final Builder builder = FlightClient.builder(a, Location.forGrpcTls(FlightTestUtil.LOCALHOST, s.getPort())); + testFn.accept(builder); + } catch (InterruptedException | IOException e) { + throw new RuntimeException(e); + } + } + + static class Producer extends NoOpFlightProducer implements AutoCloseable { + + @Override + public void doAction(CallContext context, Action action, StreamListener<Result> listener) { + if (action.getType().equals("hello-world")) { + listener.onNext(new Result("Hello, world!".getBytes(StandardCharsets.UTF_8))); + listener.onCompleted(); + return; + } + listener + .onError(CallStatus.UNIMPLEMENTED.withDescription("Invalid action " + action.getType()).toRuntimeException()); + } + + @Override + public void close() { + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java new file mode 100644 index 000000000..c18f5709b --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Optional; + +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightStatusCode; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.FlightTestUtil; +import org.apache.arrow.flight.NoOpFlightProducer; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; + +import com.google.common.collect.ImmutableList; + +public class TestBasicAuth { + + private static final String USERNAME = "flight"; + private static final String PASSWORD = "woohoo"; + private static final byte[] VALID_TOKEN = "my_token".getBytes(StandardCharsets.UTF_8); + + private FlightClient client; + private FlightServer server; + private BufferAllocator allocator; + + @Test + public void validAuth() { + client.authenticateBasic(USERNAME, PASSWORD); + Assert.assertTrue(ImmutableList.copyOf(client.listFlights(Criteria.ALL)).size() == 0); + } + + // ARROW-7722: this test occasionally leaks memory + @Ignore + @Test + public void asyncCall() throws Exception { + client.authenticateBasic(USERNAME, PASSWORD); + client.listFlights(Criteria.ALL); + try (final FlightStream s = client.getStream(new Ticket(new byte[1]))) { + while (s.next()) { + Assert.assertEquals(4095, s.getRoot().getRowCount()); + } + } + } + + @Test + public void invalidAuth() { + FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () -> { + client.authenticateBasic(USERNAME, "WRONG"); + }); + + FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () -> { + client.listFlights(Criteria.ALL).forEach(action -> Assert.fail()); + }); + } + + @Test + public void didntAuth() { + FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () -> { + client.listFlights(Criteria.ALL).forEach(action -> Assert.fail()); + }); + } + + @Before + public void setup() throws IOException { + allocator = new RootAllocator(Long.MAX_VALUE); + final BasicServerAuthHandler.BasicAuthValidator validator = new BasicServerAuthHandler.BasicAuthValidator() { + + @Override + public Optional<String> isValid(byte[] token) { + if (Arrays.equals(token, VALID_TOKEN)) { + return Optional.of(USERNAME); + } + return Optional.empty(); + } + + @Override + public byte[] getToken(String username, String password) { + if (USERNAME.equals(username) && PASSWORD.equals(password)) { + return VALID_TOKEN; + } else { + throw new IllegalArgumentException("invalid credentials"); + } + } + }; + + server = FlightTestUtil.getStartedServer((location) -> FlightServer.builder( + allocator, + location, + new NoOpFlightProducer() { + @Override + public void listFlights(CallContext context, Criteria criteria, + StreamListener<FlightInfo> listener) { + if (!context.peerIdentity().equals(USERNAME)) { + listener.onError(new IllegalArgumentException("Invalid username")); + return; + } + listener.onCompleted(); + } + + @Override + public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + if (!context.peerIdentity().equals(USERNAME)) { + listener.error(new IllegalArgumentException("Invalid username")); + return; + } + final Schema pojoSchema = new Schema(ImmutableList.of(Field.nullable("a", + Types.MinorType.BIGINT.getType()))); + try (VectorSchemaRoot root = VectorSchemaRoot.create(pojoSchema, allocator)) { + listener.start(root); + root.allocateNew(); + root.setRowCount(4095); + listener.putNext(); + listener.completed(); + } + } + }).authHandler(new BasicServerAuthHandler(validator)).build()); + client = FlightClient.builder(allocator, server.getLocation()).build(); + } + + @After + public void shutdown() throws Exception { + AutoCloseables.close(client, server, allocator); + } + +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth2/TestBasicAuth2.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth2/TestBasicAuth2.java new file mode 100644 index 000000000..9bec32f1b --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth2/TestBasicAuth2.java @@ -0,0 +1,232 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.auth2; + +import java.io.IOException; + +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightStatusCode; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.FlightTestUtil; +import org.apache.arrow.flight.NoOpFlightProducer; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.grpc.CredentialCallOption; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; + +import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; + +public class TestBasicAuth2 { + + private static final String USERNAME_1 = "flight1"; + private static final String USERNAME_2 = "flight2"; + private static final String NO_USERNAME = ""; + private static final String PASSWORD_1 = "woohoo1"; + private static final String PASSWORD_2 = "woohoo2"; + private BufferAllocator allocator; + private FlightServer server; + private FlightClient client; + private FlightClient client2; + + @Before + public void setup() throws Exception { + allocator = new RootAllocator(Long.MAX_VALUE); + startServerAndClient(); + } + + private FlightProducer getFlightProducer() { + return new NoOpFlightProducer() { + @Override + public void listFlights(CallContext context, Criteria criteria, + StreamListener<FlightInfo> listener) { + if (!context.peerIdentity().equals(USERNAME_1) && !context.peerIdentity().equals(USERNAME_2)) { + listener.onError(new IllegalArgumentException("Invalid username")); + return; + } + listener.onCompleted(); + } + + @Override + public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + if (!context.peerIdentity().equals(USERNAME_1) && !context.peerIdentity().equals(USERNAME_2)) { + listener.error(new IllegalArgumentException("Invalid username")); + return; + } + final Schema pojoSchema = new Schema(ImmutableList.of(Field.nullable("a", + Types.MinorType.BIGINT.getType()))); + try (VectorSchemaRoot root = VectorSchemaRoot.create(pojoSchema, allocator)) { + listener.start(root); + root.allocateNew(); + root.setRowCount(4095); + listener.putNext(); + listener.completed(); + } + } + }; + } + + private void startServerAndClient() throws IOException { + final FlightProducer flightProducer = getFlightProducer(); + this.server = FlightTestUtil.getStartedServer((location) -> FlightServer + .builder(allocator, location, flightProducer) + .headerAuthenticator(new GeneratedBearerTokenAuthenticator( + new BasicCallHeaderAuthenticator(this::validate))) + .build()); + + this.client = FlightClient.builder(allocator, server.getLocation()) + .build(); + } + + @After + public void shutdown() throws Exception { + AutoCloseables.close(client, client2, server, allocator); + client = null; + client2 = null; + server = null; + allocator = null; + } + + private void startClient2() throws IOException { + client2 = FlightClient.builder(allocator, server.getLocation()) + .build(); + } + + private CallHeaderAuthenticator.AuthResult validate(String username, String password) { + if (Strings.isNullOrEmpty(username)) { + throw CallStatus.UNAUTHENTICATED.withDescription("Credentials not supplied.").toRuntimeException(); + } + final String identity; + if (USERNAME_1.equals(username) && PASSWORD_1.equals(password)) { + identity = USERNAME_1; + } else if (USERNAME_2.equals(username) && PASSWORD_2.equals(password)) { + identity = USERNAME_2; + } else { + throw CallStatus.UNAUTHENTICATED.withDescription("Username or password is invalid.").toRuntimeException(); + } + return () -> identity; + } + + @Test + public void validAuthWithBearerAuthServer() throws IOException { + testValidAuth(client); + } + + @Test + public void validAuthWithMultipleClientsWithSameCredentialsWithBearerAuthServer() throws IOException { + startClient2(); + testValidAuthWithMultipleClientsWithSameCredentials(client, client2); + } + + @Test + public void validAuthWithMultipleClientsWithDifferentCredentialsWithBearerAuthServer() throws IOException { + startClient2(); + testValidAuthWithMultipleClientsWithDifferentCredentials(client, client2); + } + + // ARROW-7722: this test occasionally leaks memory + @Ignore + @Test + public void asyncCall() throws Exception { + final CredentialCallOption bearerToken = client + .authenticateBasicToken(USERNAME_1, PASSWORD_1).get(); + client.listFlights(Criteria.ALL, bearerToken); + try (final FlightStream s = client.getStream(new Ticket(new byte[1]))) { + while (s.next()) { + Assert.assertEquals(4095, s.getRoot().getRowCount()); + } + } + } + + @Test + public void invalidAuthWithBearerAuthServer() throws IOException { + testInvalidAuth(client); + } + + @Test + public void didntAuthWithBearerAuthServer() throws IOException { + didntAuth(client); + } + + private void testValidAuth(FlightClient client) { + final CredentialCallOption bearerToken = client + .authenticateBasicToken(USERNAME_1, PASSWORD_1).get(); + Assert.assertTrue(ImmutableList.copyOf(client + .listFlights(Criteria.ALL, bearerToken)) + .isEmpty()); + } + + private void testValidAuthWithMultipleClientsWithSameCredentials( + FlightClient client1, FlightClient client2) { + final CredentialCallOption bearerToken1 = client1 + .authenticateBasicToken(USERNAME_1, PASSWORD_1).get(); + final CredentialCallOption bearerToken2 = client2 + .authenticateBasicToken(USERNAME_1, PASSWORD_1).get(); + Assert.assertTrue(ImmutableList.copyOf(client1 + .listFlights(Criteria.ALL, bearerToken1)) + .isEmpty()); + Assert.assertTrue(ImmutableList.copyOf(client2 + .listFlights(Criteria.ALL, bearerToken2)) + .isEmpty()); + } + + private void testValidAuthWithMultipleClientsWithDifferentCredentials( + FlightClient client1, FlightClient client2) { + final CredentialCallOption bearerToken1 = client1 + .authenticateBasicToken(USERNAME_1, PASSWORD_1).get(); + final CredentialCallOption bearerToken2 = client2 + .authenticateBasicToken(USERNAME_2, PASSWORD_2).get(); + Assert.assertTrue(ImmutableList.copyOf(client1 + .listFlights(Criteria.ALL, bearerToken1)) + .isEmpty()); + Assert.assertTrue(ImmutableList.copyOf(client2 + .listFlights(Criteria.ALL, bearerToken2)) + .isEmpty()); + } + + private void testInvalidAuth(FlightClient client) { + FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () -> + client.authenticateBasicToken(USERNAME_1, "WRONG")); + + FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () -> + client.authenticateBasicToken(NO_USERNAME, PASSWORD_1)); + + FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () -> + client.listFlights(Criteria.ALL).forEach(action -> Assert.fail())); + } + + private void didntAuth(FlightClient client) { + FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () -> + client.listFlights(Criteria.ALL).forEach(action -> Assert.fail())); + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/TestCookieHandling.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/TestCookieHandling.java new file mode 100644 index 000000000..f205f9a3b --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/TestCookieHandling.java @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.client; + +import java.io.IOException; + +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallInfo; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.ErrorFlightMetadata; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightMethod; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightServerMiddleware; +import org.apache.arrow.flight.FlightTestUtil; +import org.apache.arrow.flight.NoOpFlightProducer; +import org.apache.arrow.flight.RequestContext; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; + +/** + * Tests for correct handling of cookies from the FlightClient using {@link ClientCookieMiddleware}. + */ +public class TestCookieHandling { + private static final String SET_COOKIE_HEADER = "Set-Cookie"; + private static final String COOKIE_HEADER = "Cookie"; + private BufferAllocator allocator; + private FlightServer server; + private FlightClient client; + + private ClientCookieMiddlewareTestFactory testFactory = new ClientCookieMiddlewareTestFactory(); + private ClientCookieMiddleware cookieMiddleware = new ClientCookieMiddleware(testFactory); + + @Before + public void setup() throws Exception { + allocator = new RootAllocator(Long.MAX_VALUE); + startServerAndClient(); + } + + @After + public void cleanup() throws Exception { + testFactory = new ClientCookieMiddlewareTestFactory(); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + AutoCloseables.close(client, server, allocator); + client = null; + server = null; + allocator = null; + } + + @Test + public void basicCookie() { + CallHeaders headersToSend = new ErrorFlightMetadata(); + headersToSend.insert(SET_COOKIE_HEADER, "k=v"); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + cookieMiddleware.onHeadersReceived(headersToSend); + Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString()); + } + + @Test + public void cookieStaysAfterMultipleRequests() { + CallHeaders headersToSend = new ErrorFlightMetadata(); + headersToSend.insert(SET_COOKIE_HEADER, "k=v"); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + cookieMiddleware.onHeadersReceived(headersToSend); + Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString()); + + headersToSend = new ErrorFlightMetadata(); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + cookieMiddleware.onHeadersReceived(headersToSend); + Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString()); + + headersToSend = new ErrorFlightMetadata(); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + cookieMiddleware.onHeadersReceived(headersToSend); + Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString()); + } + + @Ignore + @Test + public void cookieAutoExpires() { + CallHeaders headersToSend = new ErrorFlightMetadata(); + headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=2"); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + cookieMiddleware.onHeadersReceived(headersToSend); + // Note: using max-age changes cookie version from 0->1, which quotes values. + Assert.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString()); + + headersToSend = new ErrorFlightMetadata(); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + cookieMiddleware.onHeadersReceived(headersToSend); + Assert.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString()); + + try { + Thread.sleep(5000); + } catch (InterruptedException ignored) { + } + + // Verify that the k cookie was discarded because it expired. + Assert.assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty()); + } + + @Test + public void cookieExplicitlyExpires() { + CallHeaders headersToSend = new ErrorFlightMetadata(); + headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=2"); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + cookieMiddleware.onHeadersReceived(headersToSend); + // Note: using max-age changes cookie version from 0->1, which quotes values. + Assert.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString()); + + // Note: The JDK treats Max-Age < 0 as not expired and treats 0 as expired. + // This violates the RFC, which states that less than zero and zero should both be expired. + headersToSend = new ErrorFlightMetadata(); + headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=0"); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + cookieMiddleware.onHeadersReceived(headersToSend); + + // Verify that the k cookie was discarded because the server told the client it is expired. + Assert.assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty()); + } + + @Ignore + @Test + public void cookieExplicitlyExpiresWithMaxAgeMinusOne() { + CallHeaders headersToSend = new ErrorFlightMetadata(); + headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=2"); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + cookieMiddleware.onHeadersReceived(headersToSend); + // Note: using max-age changes cookie version from 0->1, which quotes values. + Assert.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString()); + + headersToSend = new ErrorFlightMetadata(); + + // The Java HttpCookie class has a bug where it uses a -1 maxAge to indicate + // a persistent cookie, when the RFC spec says this should mean the cookie expires immediately. + headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=-1"); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + cookieMiddleware.onHeadersReceived(headersToSend); + + // Verify that the k cookie was discarded because the server told the client it is expired. + Assert.assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty()); + } + + @Test + public void changeCookieValue() { + CallHeaders headersToSend = new ErrorFlightMetadata(); + headersToSend.insert(SET_COOKIE_HEADER, "k=v"); + cookieMiddleware.onHeadersReceived(headersToSend); + Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString()); + + headersToSend = new ErrorFlightMetadata(); + headersToSend.insert(SET_COOKIE_HEADER, "k=v2"); + cookieMiddleware.onHeadersReceived(headersToSend); + Assert.assertEquals("k=v2", cookieMiddleware.getValidCookiesAsString()); + } + + @Test + public void multipleCookiesWithSetCookie() { + CallHeaders headersToSend = new ErrorFlightMetadata(); + headersToSend.insert(SET_COOKIE_HEADER, "firstKey=firstVal"); + headersToSend.insert(SET_COOKIE_HEADER, "secondKey=secondVal"); + cookieMiddleware.onHeadersReceived(headersToSend); + Assert.assertEquals("firstKey=firstVal; secondKey=secondVal", cookieMiddleware.getValidCookiesAsString()); + } + + @Test + public void cookieStaysAfterMultipleRequestsEndToEnd() { + client.handshake(); + Assert.assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString()); + client.handshake(); + Assert.assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString()); + client.listFlights(Criteria.ALL); + Assert.assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString()); + } + + /** + * A server middleware component that injects SET_COOKIE_HEADER into the outgoing headers. + */ + static class SetCookieHeaderInjector implements FlightServerMiddleware { + private final Factory factory; + + public SetCookieHeaderInjector(Factory factory) { + this.factory = factory; + } + + @Override + public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { + if (!factory.receivedCookieHeader) { + outgoingHeaders.insert(SET_COOKIE_HEADER, "k=v"); + } + } + + @Override + public void onCallCompleted(CallStatus status) { + + } + + @Override + public void onCallErrored(Throwable err) { + + } + + static class Factory implements FlightServerMiddleware.Factory<SetCookieHeaderInjector> { + private boolean receivedCookieHeader = false; + + @Override + public SetCookieHeaderInjector onCallStarted(CallInfo info, CallHeaders incomingHeaders, + RequestContext context) { + receivedCookieHeader = null != incomingHeaders.get(COOKIE_HEADER); + return new SetCookieHeaderInjector(this); + } + } + } + + public static class ClientCookieMiddlewareTestFactory extends ClientCookieMiddleware.Factory { + + private ClientCookieMiddleware clientCookieMiddleware; + + @Override + public ClientCookieMiddleware onCallStarted(CallInfo info) { + this.clientCookieMiddleware = new ClientCookieMiddleware(this); + return this.clientCookieMiddleware; + } + } + + private void startServerAndClient() throws IOException { + final FlightProducer flightProducer = new NoOpFlightProducer() { + public void listFlights(CallContext context, Criteria criteria, + StreamListener<FlightInfo> listener) { + listener.onCompleted(); + } + }; + + this.server = FlightTestUtil.getStartedServer((location) -> FlightServer + .builder(allocator, location, flightProducer) + .middleware(FlightServerMiddleware.Key.of("test"), new SetCookieHeaderInjector.Factory()) + .build()); + + this.client = FlightClient.builder(allocator, server.getLocation()) + .intercept(testFactory) + .build(); + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java new file mode 100644 index 000000000..fb157f45e --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/example/TestExampleServer.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.example; + +import java.io.IOException; + +import org.apache.arrow.flight.AsyncPutListener; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightClient.ClientStreamListener; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.FlightTestUtil; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.junit.After; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; + +/** + * Ensure that example server supports get and put. + */ +public class TestExampleServer { + + private BufferAllocator allocator; + private BufferAllocator caseAllocator; + private ExampleFlightServer server; + private FlightClient client; + + @Before + public void start() throws IOException { + allocator = new RootAllocator(Long.MAX_VALUE); + + Location l = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, 12233); + if (!Boolean.getBoolean("disableServer")) { + System.out.println("Starting server."); + server = new ExampleFlightServer(allocator, l); + server.start(); + } else { + System.out.println("Skipping server startup."); + } + client = FlightClient.builder(allocator, l).build(); + caseAllocator = allocator.newChildAllocator("test-case", 0, Long.MAX_VALUE); + } + + @After + public void after() throws Exception { + AutoCloseables.close(server, client, caseAllocator, allocator); + } + + @Test + @Ignore + public void putStream() { + BufferAllocator a = caseAllocator; + final int size = 10; + + IntVector iv = new IntVector("c1", a); + + VectorSchemaRoot root = VectorSchemaRoot.of(iv); + ClientStreamListener listener = client.startPut(FlightDescriptor.path("hello"), root, + new AsyncPutListener()); + + //batch 1 + root.allocateNew(); + for (int i = 0; i < size; i++) { + iv.set(i, i); + } + iv.setValueCount(size); + root.setRowCount(size); + listener.putNext(); + + // batch 2 + + root.allocateNew(); + for (int i = 0; i < size; i++) { + iv.set(i, i + size); + } + iv.setValueCount(size); + root.setRowCount(size); + listener.putNext(); + root.clear(); + listener.completed(); + + // wait for ack to avoid memory leaks. + listener.getResult(); + + FlightInfo info = client.getInfo(FlightDescriptor.path("hello")); + try (final FlightStream stream = client.getStream(info.getEndpoints().get(0).getTicket())) { + VectorSchemaRoot newRoot = stream.getRoot(); + while (stream.next()) { + newRoot.clear(); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/grpc/TestStatusUtils.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/grpc/TestStatusUtils.java new file mode 100644 index 000000000..5d76e8ae1 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/grpc/TestStatusUtils.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.grpc; + +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightStatusCode; +import org.junit.Assert; +import org.junit.Test; + +import io.grpc.Metadata; +import io.grpc.Status; + +public class TestStatusUtils { + + @Test + public void testParseTrailers() { + Status status = Status.CANCELLED; + Metadata trailers = new Metadata(); + + // gRPC can have trailers with certain metadata keys beginning with ":", such as ":status". + // See https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md + trailers.put(StatusUtils.keyOfAscii(":status"), "502"); + trailers.put(StatusUtils.keyOfAscii("date"), "Fri, 13 Sep 2015 11:23:58 GMT"); + trailers.put(StatusUtils.keyOfAscii("content-type"), "text/html"); + + CallStatus callStatus = StatusUtils.fromGrpcStatusAndTrailers(status, trailers); + + Assert.assertEquals(FlightStatusCode.CANCELLED, callStatus.code()); + Assert.assertTrue(callStatus.metadata().containsKey(":status")); + Assert.assertEquals("502", callStatus.metadata().get(":status")); + Assert.assertTrue(callStatus.metadata().containsKey("date")); + Assert.assertEquals("Fri, 13 Sep 2015 11:23:58 GMT", callStatus.metadata().get("date")); + Assert.assertTrue(callStatus.metadata().containsKey("content-type")); + Assert.assertEquals("text/html", callStatus.metadata().get("content-type")); + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java new file mode 100644 index 000000000..7794ed748 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/PerformanceTestServer.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.perf; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import org.apache.arrow.flight.BackpressureStrategy; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.NoOpFlightProducer; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.perf.impl.PerfOuterClass.Perf; +import org.apache.arrow.flight.perf.impl.PerfOuterClass.Token; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.InvalidProtocolBufferException; + +public class PerformanceTestServer implements AutoCloseable { + + private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(PerformanceTestServer.class); + + private final FlightServer flightServer; + private final Location location; + private final BufferAllocator allocator; + private final PerfProducer producer; + private final boolean isNonBlocking; + + public PerformanceTestServer(BufferAllocator incomingAllocator, Location location) { + this(incomingAllocator, location, new BackpressureStrategy() { + private FlightProducer.ServerStreamListener listener; + + @Override + public void register(FlightProducer.ServerStreamListener listener) { + this.listener = listener; + } + + @Override + public WaitResult waitForListener(long timeout) { + while (!listener.isReady() && !listener.isCancelled()) { + // busy wait + } + return WaitResult.READY; + } + }, false); + } + + public PerformanceTestServer(BufferAllocator incomingAllocator, Location location, BackpressureStrategy bpStrategy, + boolean isNonBlocking) { + this.allocator = incomingAllocator.newChildAllocator("perf-server", 0, Long.MAX_VALUE); + this.location = location; + this.producer = new PerfProducer(bpStrategy); + this.flightServer = FlightServer.builder(this.allocator, location, producer).build(); + this.isNonBlocking = isNonBlocking; + } + + public Location getLocation() { + return location; + } + + public void start() throws IOException { + flightServer.start(); + } + + @Override + public void close() throws Exception { + AutoCloseables.close(flightServer, allocator); + } + + private final class PerfProducer extends NoOpFlightProducer { + private final BackpressureStrategy bpStrategy; + + private PerfProducer(BackpressureStrategy bpStrategy) { + this.bpStrategy = bpStrategy; + } + + @Override + public void getStream(CallContext context, Ticket ticket, + ServerStreamListener listener) { + bpStrategy.register(listener); + final Runnable loadData = () -> { + VectorSchemaRoot root = null; + try { + Token token = Token.parseFrom(ticket.getBytes()); + Perf perf = token.getDefinition(); + Schema schema = Schema.deserialize(ByteBuffer.wrap(perf.getSchema().toByteArray())); + root = VectorSchemaRoot.create(schema, allocator); + BigIntVector a = (BigIntVector) root.getVector("a"); + BigIntVector b = (BigIntVector) root.getVector("b"); + BigIntVector c = (BigIntVector) root.getVector("c"); + BigIntVector d = (BigIntVector) root.getVector("d"); + listener.setUseZeroCopy(true); + listener.start(root); + root.allocateNew(); + + int current = 0; + long i = token.getStart(); + while (i < token.getEnd()) { + if (listener.isCancelled()) { + root.clear(); + return; + } + + if (TestPerf.VALIDATE) { + a.setSafe(current, i); + } + + i++; + current++; + if (i % perf.getRecordsPerBatch() == 0) { + root.setRowCount(current); + + bpStrategy.waitForListener(0); + if (listener.isCancelled()) { + root.clear(); + return; + } + listener.putNext(); + current = 0; + root.allocateNew(); + } + } + + // send last partial batch. + if (current != 0) { + root.setRowCount(current); + listener.putNext(); + } + listener.completed(); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } finally { + try { + AutoCloseables.close(root); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + }; + + if (!isNonBlocking) { + loadData.run(); + } else { + final ExecutorService service = Executors.newSingleThreadExecutor(); + service.submit(loadData); + service.shutdown(); + } + } + + @Override + public FlightInfo getFlightInfo(CallContext context, + FlightDescriptor descriptor) { + try { + Preconditions.checkArgument(descriptor.isCommand()); + Perf exec = Perf.parseFrom(descriptor.getCommand()); + + final Schema pojoSchema = new Schema(ImmutableList.of( + Field.nullable("a", MinorType.BIGINT.getType()), + Field.nullable("b", MinorType.BIGINT.getType()), + Field.nullable("c", MinorType.BIGINT.getType()), + Field.nullable("d", MinorType.BIGINT.getType()) + )); + + Token token = Token.newBuilder().setDefinition(exec) + .setStart(0) + .setEnd(exec.getRecordsPerStream()) + .build(); + final Ticket ticket = new Ticket(token.toByteArray()); + + List<FlightEndpoint> endpoints = new ArrayList<>(); + for (int i = 0; i < exec.getStreamCount(); i++) { + endpoints.add(new FlightEndpoint(ticket, getLocation())); + } + + return new FlightInfo(pojoSchema, descriptor, endpoints, -1, + exec.getRecordsPerStream() * exec.getStreamCount()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + } + } +} + + + diff --git a/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/TestPerf.java b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/TestPerf.java new file mode 100644 index 000000000..9e2d7cc54 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/TestPerf.java @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.perf; + +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.FlightTestUtil; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.perf.impl.PerfOuterClass.Perf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Test; + +import com.google.common.base.MoreObjects; +import com.google.common.base.Stopwatch; +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.MoreExecutors; +import com.google.protobuf.ByteString; + +@org.junit.Ignore +public class TestPerf { + + public static final boolean VALIDATE = false; + + public static FlightDescriptor getPerfFlightDescriptor(long recordCount, int recordsPerBatch, int streamCount) { + final Schema pojoSchema = new Schema(ImmutableList.of( + Field.nullable("a", MinorType.BIGINT.getType()), + Field.nullable("b", MinorType.BIGINT.getType()), + Field.nullable("c", MinorType.BIGINT.getType()), + Field.nullable("d", MinorType.BIGINT.getType()) + )); + + ByteString serializedSchema = ByteString.copyFrom(pojoSchema.toByteArray()); + + return FlightDescriptor.command(Perf.newBuilder() + .setRecordsPerStream(recordCount) + .setRecordsPerBatch(recordsPerBatch) + .setSchema(serializedSchema) + .setStreamCount(streamCount) + .build() + .toByteArray()); + } + + public static void main(String[] args) throws Exception { + new TestPerf().throughput(); + } + + @Test + public void throughput() throws Exception { + final int numRuns = 10; + ListeningExecutorService pool = MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(4)); + double [] throughPuts = new double[numRuns]; + + for (int i = 0; i < numRuns; i++) { + try ( + final BufferAllocator a = new RootAllocator(Long.MAX_VALUE); + final PerformanceTestServer server = + FlightTestUtil.getStartedServer((location) -> new PerformanceTestServer(a, location)); + final FlightClient client = FlightClient.builder(a, server.getLocation()).build(); + ) { + final FlightInfo info = client.getInfo(getPerfFlightDescriptor(50_000_000L, 4095, 2)); + List<ListenableFuture<Result>> results = info.getEndpoints() + .stream() + .map(t -> new Consumer(client, t.getTicket())) + .map(t -> pool.submit(t)) + .collect(Collectors.toList()); + + final Result r = Futures.whenAllSucceed(results).call(() -> { + Result res = new Result(); + for (ListenableFuture<Result> f : results) { + res.add(f.get()); + } + return res; + }, pool).get(); + + double seconds = r.nanos * 1.0d / 1000 / 1000 / 1000; + throughPuts[i] = (r.bytes * 1.0d / 1024 / 1024) / seconds; + System.out.println(String.format( + "Transferred %d records totaling %s bytes at %f MiB/s. %f record/s. %f batch/s.", + r.rows, + r.bytes, + throughPuts[i], + (r.rows * 1.0d) / seconds, + (r.batches * 1.0d) / seconds + )); + } + } + pool.shutdown(); + + System.out.println("Summary: "); + double average = Arrays.stream(throughPuts).sum() / numRuns; + double sqrSum = Arrays.stream(throughPuts).map(val -> val - average).map(val -> val * val).sum(); + double stddev = Math.sqrt(sqrSum / numRuns); + System.out.println(String.format("Average throughput: %f MiB/s, standard deviation: %f MiB/s", + average, stddev)); + } + + private final class Consumer implements Callable<Result> { + + private final FlightClient client; + private final Ticket ticket; + + public Consumer(FlightClient client, Ticket ticket) { + super(); + this.client = client; + this.ticket = ticket; + } + + @Override + public Result call() throws Exception { + final Result r = new Result(); + Stopwatch watch = Stopwatch.createStarted(); + try (final FlightStream stream = client.getStream(ticket)) { + final VectorSchemaRoot root = stream.getRoot(); + try { + BigIntVector a = (BigIntVector) root.getVector("a"); + while (stream.next()) { + int rows = root.getRowCount(); + long aSum = r.aSum; + for (int i = 0; i < rows; i++) { + if (VALIDATE) { + aSum += a.get(i); + } + } + r.bytes += rows * 32; + r.rows += rows; + r.aSum = aSum; + r.batches++; + } + + r.nanos = watch.elapsed(TimeUnit.NANOSECONDS); + return r; + } finally { + root.clear(); + } + } + } + + } + + private final class Result { + private long rows; + private long aSum; + private long bytes; + private long nanos; + private long batches; + + public void add(Result r) { + rows += r.rows; + aSum += r.aSum; + bytes += r.bytes; + batches += r.batches; + nanos = Math.max(nanos, r.nanos); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("rows", rows) + .add("aSum", aSum) + .add("batches", batches) + .add("bytes", bytes) + .add("nanos", nanos) + .toString(); + } + } +} diff --git a/src/arrow/java/flight/flight-core/src/test/protobuf/perf.proto b/src/arrow/java/flight/flight-core/src/test/protobuf/perf.proto new file mode 100644 index 000000000..99f35a9e6 --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/protobuf/perf.proto @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * <p> + * http://www.apache.org/licenses/LICENSE-2.0 + * <p> + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +option java_package = "org.apache.arrow.flight.perf.impl"; + +message Perf { + bytes schema = 1; + int32 stream_count = 2; + int64 records_per_stream = 3; + int32 records_per_batch = 4; +} + +/* + * Payload of ticket + */ +message Token { + + // definition of entire flight. + Perf definition = 1; + + // inclusive start + int64 start = 2; + + // exclusive end + int64 end = 3; + +} + diff --git a/src/arrow/java/flight/flight-core/src/test/resources/logback.xml b/src/arrow/java/flight/flight-core/src/test/resources/logback.xml new file mode 100644 index 000000000..444b2ed6d --- /dev/null +++ b/src/arrow/java/flight/flight-core/src/test/resources/logback.xml @@ -0,0 +1,28 @@ +<?xml version="1.0" encoding="UTF-8" ?> +<!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor + license agreements. See the NOTICE file distributed with this work for additional + information regarding copyright ownership. The ASF licenses this file to + You under the Apache License, Version 2.0 (the "License"); you may not use + this file except in compliance with the License. You may obtain a copy of + the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required + by applicable law or agreed to in writing, software distributed under the + License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS + OF ANY KIND, either express or implied. See the License for the specific + language governing permissions and limitations under the License. --> + +<configuration> + <statusListener class="ch.qos.logback.core.status.NopStatusListener"/> + <appender name="SOCKET" + class="de.huxhorn.lilith.logback.appender.ClassicMultiplexSocketAppender"> + <Compressing>true</Compressing> + <ReconnectionDelay>10000</ReconnectionDelay> + <IncludeCallerData>true</IncludeCallerData> + <RemoteHosts>${LILITH_HOSTNAME:-localhost}</RemoteHosts> + </appender> + + <logger name="org.apache.arrow" additivity="false"> + <level value="info" /> + <appender-ref ref="FILE" /> + </logger> + +</configuration> diff --git a/src/arrow/java/flight/flight-grpc/pom.xml b/src/arrow/java/flight/flight-grpc/pom.xml new file mode 100644 index 000000000..1968484a1 --- /dev/null +++ b/src/arrow/java/flight/flight-grpc/pom.xml @@ -0,0 +1,132 @@ +<?xml version="1.0" encoding="UTF-8"?> +<!-- Licensed to the Apache Software Foundation (ASF) under one or more contributor + license agreements. See the NOTICE file distributed with this work for additional + information regarding copyright ownership. The ASF licenses this file to + You under the Apache License, Version 2.0 (the "License"); you may not use + this file except in compliance with the License. You may obtain a copy of + the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required + by applicable law or agreed to in writing, software distributed under the + License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS + OF ANY KIND, either express or implied. See the License for the specific + language governing permissions and limitations under the License. --> +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + <parent> + <artifactId>arrow-java-root</artifactId> + <groupId>org.apache.arrow</groupId> + <version>6.0.1</version> + <relativePath>../../pom.xml</relativePath> + </parent> + <modelVersion>4.0.0</modelVersion> + + <artifactId>flight-grpc</artifactId> + <name>Arrow Flight GRPC</name> + <description>(Experimental)Contains utility class to expose Flight gRPC service and client</description> + <packaging>jar</packaging> + + <properties> + <dep.grpc.version>1.41.0</dep.grpc.version> + <dep.protobuf.version>3.7.1</dep.protobuf.version> + <forkCount>1</forkCount> + </properties> + + <dependencies> + <dependency> + <groupId>org.apache.arrow</groupId> + <artifactId>flight-core</artifactId> + <version>${project.version}</version> + <exclusions> + <exclusion> + <groupId>io.netty</groupId> + <artifactId>netty-transport-native-unix-common</artifactId> + </exclusion> + <exclusion> + <groupId>io.netty</groupId> + <artifactId>netty-transport-native-kqueue</artifactId> + </exclusion> + <exclusion> + <groupId>io.netty</groupId> + <artifactId>netty-transport-native-epoll</artifactId> + </exclusion> + </exclusions> + </dependency> + <dependency> + <groupId>io.grpc</groupId> + <artifactId>grpc-core</artifactId> + <version>${dep.grpc.version}</version> + </dependency> + <dependency> + <groupId>io.grpc</groupId> + <artifactId>grpc-stub</artifactId> + <version>${dep.grpc.version}</version> + </dependency> + <dependency> + <groupId>org.apache.arrow</groupId> + <artifactId>arrow-memory-core</artifactId> + <version>${project.version}</version> + <scope>compile</scope> + </dependency> + <dependency> + <groupId>org.apache.arrow</groupId> + <artifactId>arrow-memory-netty</artifactId> + <version>${project.version}</version> + <scope>runtime</scope> + </dependency> + <dependency> + <groupId>io.grpc</groupId> + <artifactId>grpc-protobuf</artifactId> + <version>${dep.grpc.version}</version> + </dependency> + <dependency> + <groupId>com.google.guava</groupId> + <artifactId>guava</artifactId> + </dependency> + <dependency> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + <version>${dep.protobuf.version}</version> + </dependency> + <dependency> + <groupId>io.grpc</groupId> + <artifactId>grpc-api</artifactId> + <version>${dep.grpc.version}</version> + </dependency> + </dependencies> + + <build> + <extensions> + <!-- provides os.detected.classifier (i.e. linux-x86_64, osx-x86_64) property --> + <extension> + <groupId>kr.motd.maven</groupId> + <artifactId>os-maven-plugin</artifactId> + <version>1.5.0.Final</version> + </extension> + </extensions> + <plugins> + <plugin> + <groupId>org.xolstice.maven.plugins</groupId> + <artifactId>protobuf-maven-plugin</artifactId> + <version>0.5.0</version> + <configuration> + <protocArtifact>com.google.protobuf:protoc:${dep.protobuf.version}:exe:${os.detected.classifier}</protocArtifact> + <clearOutputDirectory>false</clearOutputDirectory> + <pluginId>grpc-java</pluginId> + <pluginArtifact>io.grpc:protoc-gen-grpc-java:${dep.grpc.version}:exe:${os.detected.classifier}</pluginArtifact> + </configuration> + <executions> + <execution> + <id>test</id> + <configuration> + <protoSourceRoot>${basedir}/src/test/protobuf</protoSourceRoot> + <outputDirectory>${project.build.directory}/generated-test-sources//protobuf</outputDirectory> + </configuration> + <goals> + <goal>compile</goal> + <goal>compile-custom</goal> + </goals> + </execution> + </executions> + </plugin> + </plugins> + </build> + +</project> diff --git a/src/arrow/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java b/src/arrow/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java new file mode 100644 index 000000000..eb5e492b4 --- /dev/null +++ b/src/arrow/java/flight/flight-grpc/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.Collections; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.TimeUnit; + +import org.apache.arrow.flight.auth.ServerAuthHandler; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.VisibleForTesting; + +import io.grpc.BindableService; +import io.grpc.CallOptions; +import io.grpc.ClientCall; +import io.grpc.ConnectivityState; +import io.grpc.ManagedChannel; +import io.grpc.MethodDescriptor; + +/** + * Exposes Flight GRPC service & client. + */ +public class FlightGrpcUtils { + /** + * Proxy class for ManagedChannel that makes closure a no-op. + */ + @VisibleForTesting + static class NonClosingProxyManagedChannel extends ManagedChannel { + private final ManagedChannel channel; + private boolean isShutdown; + + NonClosingProxyManagedChannel(ManagedChannel channel) { + this.channel = channel; + this.isShutdown = channel.isShutdown(); + } + + @Override + public ManagedChannel shutdown() { + isShutdown = true; + return this; + } + + @Override + public boolean isShutdown() { + if (this.channel.isShutdown()) { + // If the underlying channel is shut down, ensure we're updated to match. + shutdown(); + } + return isShutdown; + } + + @Override + public boolean isTerminated() { + return this.isShutdown(); + } + + @Override + public ManagedChannel shutdownNow() { + return shutdown(); + } + + @Override + public boolean awaitTermination(long l, TimeUnit timeUnit) { + // Don't actually await termination, since it'll be a no-op, so simply return whether or not + // the channel has been shut down already. + return this.isShutdown(); + } + + @Override + public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall( + MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions) { + if (this.isShutdown()) { + throw new IllegalStateException("Channel has been shut down."); + } + + return this.channel.newCall(methodDescriptor, callOptions); + } + + @Override + public String authority() { + return this.channel.authority(); + } + + @Override + public ConnectivityState getState(boolean requestConnection) { + if (this.isShutdown()) { + return ConnectivityState.SHUTDOWN; + } + + return this.channel.getState(requestConnection); + } + + @Override + public void notifyWhenStateChanged(ConnectivityState source, Runnable callback) { + // The proxy has no insight into the underlying channel state changes, so we'll have to leak the abstraction + // a bit here and simply pass to the underlying channel, even though it will never transition to shutdown via + // the proxy. This should be fine, since it's mainly targeted at the FlightClient and there's no getter for + // the channel. + this.channel.notifyWhenStateChanged(source, callback); + } + + @Override + public void resetConnectBackoff() { + this.channel.resetConnectBackoff(); + } + + @Override + public void enterIdle() { + this.channel.enterIdle(); + } + } + + private FlightGrpcUtils() {} + + /** + * Creates a Flight service. + * @param allocator Memory allocator + * @param producer Specifies the service api + * @param authHandler Authentication handler + * @param executor Executor service + * @return FlightBindingService + */ + public static BindableService createFlightService(BufferAllocator allocator, FlightProducer producer, + ServerAuthHandler authHandler, ExecutorService executor) { + return new FlightBindingService(allocator, producer, authHandler, executor); + } + + /** + * Creates a Flight client. + * @param incomingAllocator Memory allocator + * @param channel provides a connection to a gRPC server. + */ + public static FlightClient createFlightClient(BufferAllocator incomingAllocator, ManagedChannel channel) { + return new FlightClient(incomingAllocator, channel, Collections.emptyList()); + } + + /** + * Creates a Flight client. + * @param incomingAllocator Memory allocator + * @param channel provides a connection to a gRPC server. Will not be closed on closure of the returned FlightClient. + */ + public static FlightClient createFlightClientWithSharedChannel( + BufferAllocator incomingAllocator, ManagedChannel channel) { + return new FlightClient(incomingAllocator, new NonClosingProxyManagedChannel(channel), Collections.emptyList()); + } +} diff --git a/src/arrow/java/flight/flight-grpc/src/test/java/org/apache/arrow/flight/TestFlightGrpcUtils.java b/src/arrow/java/flight/flight-grpc/src/test/java/org/apache/arrow/flight/TestFlightGrpcUtils.java new file mode 100644 index 000000000..142a0f937 --- /dev/null +++ b/src/arrow/java/flight/flight-grpc/src/test/java/org/apache/arrow/flight/TestFlightGrpcUtils.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import org.apache.arrow.flight.auth.ServerAuthHandler; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import com.google.protobuf.Empty; + +import io.grpc.BindableService; +import io.grpc.ConnectivityState; +import io.grpc.ManagedChannel; +import io.grpc.Server; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.StreamObserver; + +/** + * Unit test which adds 2 services to same server end point. + */ +public class TestFlightGrpcUtils { + private Server server; + private BufferAllocator allocator; + private String serverName; + + @Before + public void setup() throws IOException { + //Defines flight service + allocator = new RootAllocator(Integer.MAX_VALUE); + final NoOpFlightProducer producer = new NoOpFlightProducer(); + final ServerAuthHandler authHandler = ServerAuthHandler.NO_OP; + final ExecutorService exec = Executors.newCachedThreadPool(); + final BindableService flightBindingService = FlightGrpcUtils.createFlightService(allocator, producer, + authHandler, exec); + + //initializes server with 2 services - FlightBindingService & TestService + serverName = InProcessServerBuilder.generateName(); + server = InProcessServerBuilder.forName(serverName) + .directExecutor() + .addService(flightBindingService) + .addService(new TestServiceAdapter()) + .build(); + server.start(); + } + + @After + public void cleanup() { + server.shutdownNow(); + } + + /** + * This test checks if multiple gRPC services can be added to the same + * server endpoint and if they can be used by different clients via the same channel. + * @throws IOException If server fails to start. + */ + @Test + public void testMultipleGrpcServices() throws IOException { + //Initializes channel so that multiple clients can communicate with server + final ManagedChannel managedChannel = InProcessChannelBuilder.forName(serverName) + .directExecutor() + .build(); + + //Defines flight client and calls service method. Since we use a NoOpFlightProducer we expect the service + //to throw a RunTimeException + final FlightClient flightClient = FlightGrpcUtils.createFlightClient(allocator, managedChannel); + final Iterable<ActionType> actionTypes = flightClient.listActions(); + assertThrows(FlightRuntimeException.class, () -> actionTypes.forEach( + actionType -> System.out.println(actionType.toString()))); + + //Define Test client as a blocking stub and call test method which correctly returns an empty protobuf object + final TestServiceGrpc.TestServiceBlockingStub blockingStub = TestServiceGrpc.newBlockingStub(managedChannel); + Assert.assertEquals(Empty.newBuilder().build(), blockingStub.test(Empty.newBuilder().build())); + } + + @Test + public void testShutdown() throws IOException, InterruptedException { + //Initializes channel so that multiple clients can communicate with server + final ManagedChannel managedChannel = InProcessChannelBuilder.forName(serverName) + .directExecutor() + .build(); + + //Defines flight client and calls service method. Since we use a NoOpFlightProducer we expect the service + //to throw a RunTimeException + final FlightClient flightClient = FlightGrpcUtils.createFlightClientWithSharedChannel(allocator, managedChannel); + + // Should be a no-op. + flightClient.close(); + Assert.assertFalse(managedChannel.isShutdown()); + Assert.assertFalse(managedChannel.isTerminated()); + Assert.assertEquals(ConnectivityState.IDLE, managedChannel.getState(false)); + managedChannel.shutdownNow(); + } + + @Test + public void testProxyChannel() throws IOException, InterruptedException { + //Initializes channel so that multiple clients can communicate with server + final ManagedChannel managedChannel = InProcessChannelBuilder.forName(serverName) + .directExecutor() + .build(); + + final FlightGrpcUtils.NonClosingProxyManagedChannel proxyChannel = + new FlightGrpcUtils.NonClosingProxyManagedChannel(managedChannel); + Assert.assertFalse(proxyChannel.isShutdown()); + Assert.assertFalse(proxyChannel.isTerminated()); + proxyChannel.shutdown(); + Assert.assertTrue(proxyChannel.isShutdown()); + Assert.assertTrue(proxyChannel.isTerminated()); + Assert.assertEquals(ConnectivityState.SHUTDOWN, proxyChannel.getState(false)); + try { + proxyChannel.newCall(null, null); + Assert.fail(); + } catch (IllegalStateException e) { + // This is expected, since the proxy channel is shut down. + } + + Assert.assertFalse(managedChannel.isShutdown()); + Assert.assertFalse(managedChannel.isTerminated()); + Assert.assertEquals(ConnectivityState.IDLE, managedChannel.getState(false)); + + managedChannel.shutdownNow(); + } + + @Test + public void testProxyChannelWithClosedChannel() throws IOException, InterruptedException { + //Initializes channel so that multiple clients can communicate with server + final ManagedChannel managedChannel = InProcessChannelBuilder.forName(serverName) + .directExecutor() + .build(); + + final FlightGrpcUtils.NonClosingProxyManagedChannel proxyChannel = + new FlightGrpcUtils.NonClosingProxyManagedChannel(managedChannel); + Assert.assertFalse(proxyChannel.isShutdown()); + Assert.assertFalse(proxyChannel.isTerminated()); + managedChannel.shutdownNow(); + Assert.assertTrue(proxyChannel.isShutdown()); + Assert.assertTrue(proxyChannel.isTerminated()); + Assert.assertEquals(ConnectivityState.SHUTDOWN, proxyChannel.getState(false)); + try { + proxyChannel.newCall(null, null); + Assert.fail(); + } catch (IllegalStateException e) { + // This is expected, since the proxy channel is shut down. + } + + Assert.assertTrue(managedChannel.isShutdown()); + Assert.assertTrue(managedChannel.isTerminated()); + Assert.assertEquals(ConnectivityState.SHUTDOWN, managedChannel.getState(false)); + } + + /** + * Private class used for testing purposes that overrides service behavior. + */ + private class TestServiceAdapter extends TestServiceGrpc.TestServiceImplBase { + + /** + * gRPC service that receives an empty object & returns and empty protobuf object. + * @param request google.protobuf.Empty + * @param responseObserver google.protobuf.Empty + */ + @Override + public void test(Empty request, StreamObserver<Empty> responseObserver) { + responseObserver.onNext(Empty.newBuilder().build()); + responseObserver.onCompleted(); + } + } +} + diff --git a/src/arrow/java/flight/flight-grpc/src/test/protobuf/test.proto b/src/arrow/java/flight/flight-grpc/src/test/protobuf/test.proto new file mode 100644 index 000000000..6fa1890b2 --- /dev/null +++ b/src/arrow/java/flight/flight-grpc/src/test/protobuf/test.proto @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; + +option java_package = "org.apache.arrow.flight"; + +import "google/protobuf/empty.proto"; + +service TestService { + rpc Test(google.protobuf.Empty) returns (google.protobuf.Empty) {} +} |