diff options
Diffstat (limited to 'src/arrow/cpp/src/arrow/flight/client.cc')
-rw-r--r-- | src/arrow/cpp/src/arrow/flight/client.cc | 1355 |
1 files changed, 1355 insertions, 0 deletions
diff --git a/src/arrow/cpp/src/arrow/flight/client.cc b/src/arrow/cpp/src/arrow/flight/client.cc new file mode 100644 index 000000000..f9728f849 --- /dev/null +++ b/src/arrow/cpp/src/arrow/flight/client.cc @@ -0,0 +1,1355 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/client.h" + +// Platform-specific defines +#include "arrow/flight/platform.h" + +#include <map> +#include <memory> +#include <mutex> +#include <sstream> +#include <string> +#include <unordered_map> +#include <utility> + +#ifdef GRPCPP_PP_INCLUDE +#include <grpcpp/grpcpp.h> +#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) +#include <grpcpp/security/tls_credentials_options.h> +#endif +#else +#include <grpc++/grpc++.h> +#endif + +#include <grpc/grpc_security_constants.h> + +#include "arrow/buffer.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/table.h" +#include "arrow/type.h" +#include "arrow/util/logging.h" +#include "arrow/util/uri.h" + +#include "arrow/flight/client_auth.h" +#include "arrow/flight/client_header_internal.h" +#include "arrow/flight/client_middleware.h" +#include "arrow/flight/internal.h" +#include "arrow/flight/middleware.h" +#include "arrow/flight/middleware_internal.h" +#include "arrow/flight/serialization_internal.h" +#include "arrow/flight/types.h" + +namespace arrow { + +namespace flight { + +namespace pb = arrow::flight::protocol; + +const char* kWriteSizeDetailTypeId = "flight::FlightWriteSizeStatusDetail"; + +FlightCallOptions::FlightCallOptions() + : timeout(-1), + read_options(ipc::IpcReadOptions::Defaults()), + write_options(ipc::IpcWriteOptions::Defaults()) {} + +const char* FlightWriteSizeStatusDetail::type_id() const { + return kWriteSizeDetailTypeId; +} + +std::string FlightWriteSizeStatusDetail::ToString() const { + std::stringstream ss; + ss << "IPC payload size (" << actual_ << " bytes) exceeded soft limit (" << limit_ + << " bytes)"; + return ss.str(); +} + +std::shared_ptr<FlightWriteSizeStatusDetail> FlightWriteSizeStatusDetail::UnwrapStatus( + const arrow::Status& status) { + if (!status.detail() || status.detail()->type_id() != kWriteSizeDetailTypeId) { + return nullptr; + } + return std::dynamic_pointer_cast<FlightWriteSizeStatusDetail>(status.detail()); +} + +FlightClientOptions FlightClientOptions::Defaults() { return FlightClientOptions(); } + +Status FlightStreamReader::ReadAll(std::shared_ptr<Table>* table, + const StopToken& stop_token) { + std::vector<std::shared_ptr<RecordBatch>> batches; + RETURN_NOT_OK(ReadAll(&batches, stop_token)); + ARROW_ASSIGN_OR_RAISE(auto schema, GetSchema()); + return Table::FromRecordBatches(schema, std::move(batches)).Value(table); +} + +struct ClientRpc { + grpc::ClientContext context; + + explicit ClientRpc(const FlightCallOptions& options) { + if (options.timeout.count() >= 0) { + std::chrono::system_clock::time_point deadline = + std::chrono::time_point_cast<std::chrono::system_clock::time_point::duration>( + std::chrono::system_clock::now() + options.timeout); + context.set_deadline(deadline); + } + for (auto header : options.headers) { + context.AddMetadata(header.first, header.second); + } + } + + /// \brief Add an auth token via an auth handler + Status SetToken(ClientAuthHandler* auth_handler) { + if (auth_handler) { + std::string token; + RETURN_NOT_OK(auth_handler->GetToken(&token)); + context.AddMetadata(internal::kGrpcAuthHeader, token); + } + return Status::OK(); + } +}; + +/// Helper that manages Finish() of a gRPC stream. +/// +/// When we encounter an error (e.g. could not decode an IPC message), +/// we want to provide both the client-side error context and any +/// available server-side context. This helper helps wrap up that +/// logic. +/// +/// This class protects the stream with a flag (so that Finish is +/// idempotent), and drains the read side (so that Finish won't hang). +/// +/// The template lets us abstract between DoGet/DoExchange and DoPut, +/// which respectively read internal::FlightData and pb::PutResult. +template <typename Stream, typename ReadT> +class FinishableStream { + public: + FinishableStream(std::shared_ptr<ClientRpc> rpc, std::shared_ptr<Stream> stream) + : rpc_(rpc), stream_(stream), finished_(false), server_status_() {} + virtual ~FinishableStream() = default; + + /// \brief Get the underlying stream. + std::shared_ptr<Stream> stream() const { return stream_; } + + /// \brief Finish the call, adding server context to the given status. + virtual Status Finish(Status st) { + if (finished_) { + return MergeStatus(std::move(st)); + } + + // Drain the read side, as otherwise gRPC Finish() will hang. We + // only call Finish() when the client closes the writer or the + // reader finishes, so it's OK to assume the client no longer + // wants to read and drain the read side. (If the client wants to + // indicate that it is done writing, but not done reading, it + // should use DoneWriting. + ReadT message; + while (internal::ReadPayload(stream_.get(), &message)) { + // Drain the read side to avoid gRPC hanging in Finish() + } + + server_status_ = internal::FromGrpcStatus(stream_->Finish(), &rpc_->context); + finished_ = true; + + return MergeStatus(std::move(st)); + } + + private: + Status MergeStatus(Status&& st) { + if (server_status_.ok()) { + return std::move(st); + } + return Status::FromDetailAndArgs( + server_status_.code(), server_status_.detail(), server_status_.message(), + ". Client context: ", st.ToString(), + ". gRPC client debug context: ", rpc_->context.debug_error_string()); + } + + std::shared_ptr<ClientRpc> rpc_; + std::shared_ptr<Stream> stream_; + bool finished_; + Status server_status_; +}; + +/// Helper that manages \a Finish() of a read-write gRPC stream. +/// +/// This also calls \a WritesDone() and protects itself with a mutex +/// to enable sharing between the reader and writer. +template <typename Stream, typename ReadT> +class FinishableWritableStream : public FinishableStream<Stream, ReadT> { + public: + FinishableWritableStream(std::shared_ptr<ClientRpc> rpc, + std::shared_ptr<std::mutex> read_mutex, + std::shared_ptr<Stream> stream) + : FinishableStream<Stream, ReadT>(rpc, stream), + finish_mutex_(), + read_mutex_(read_mutex), + done_writing_(false) {} + virtual ~FinishableWritableStream() = default; + + /// \brief Indicate to gRPC that the write half of the stream is done. + Status DoneWriting() { + // This is only used by the writer side of a stream, so it need + // not be protected with a lock. + if (done_writing_) { + return Status::OK(); + } + done_writing_ = true; + if (!this->stream()->WritesDone()) { + // Error happened, try to close the stream to get more detailed info + return Finish(MakeFlightError(FlightStatusCode::Internal, + "Could not flush pending record batches")); + } + return Status::OK(); + } + + Status Finish(Status st) override { + // This may be used concurrently by reader/writer side of a + // stream, so it needs to be protected. + std::lock_guard<std::mutex> guard(finish_mutex_); + + // Now that we're shared between a reader and writer, we need to + // protect ourselves from being called while there's an + // outstanding read. + std::unique_lock<std::mutex> read_guard(*read_mutex_, std::try_to_lock); + if (!read_guard.owns_lock()) { + return MakeFlightError( + FlightStatusCode::Internal, + "Cannot close stream with pending read operation. Client context: " + + st.ToString()); + } + + // Try to flush pending writes. Don't use our WritesDone() to + // avoid recursion. + bool finished_writes = done_writing_ || this->stream()->WritesDone(); + done_writing_ = true; + + st = FinishableStream<Stream, ReadT>::Finish(std::move(st)); + + if (!finished_writes) { + return Status::FromDetailAndArgs( + st.code(), st.detail(), st.message(), + ". Additionally, could not finish writing record batches before closing"); + } + return st; + } + + private: + std::mutex finish_mutex_; + std::shared_ptr<std::mutex> read_mutex_; + bool done_writing_; +}; + +class GrpcAddCallHeaders : public AddCallHeaders { + public: + explicit GrpcAddCallHeaders(std::multimap<grpc::string, grpc::string>* metadata) + : metadata_(metadata) {} + ~GrpcAddCallHeaders() override = default; + + void AddHeader(const std::string& key, const std::string& value) override { + metadata_->insert(std::make_pair(key, value)); + } + + private: + std::multimap<grpc::string, grpc::string>* metadata_; +}; + +class GrpcClientInterceptorAdapter : public grpc::experimental::Interceptor { + public: + explicit GrpcClientInterceptorAdapter( + std::vector<std::unique_ptr<ClientMiddleware>> middleware) + : middleware_(std::move(middleware)), received_headers_(false) {} + + void Intercept(grpc::experimental::InterceptorBatchMethods* methods) { + using InterceptionHookPoints = grpc::experimental::InterceptionHookPoints; + if (methods->QueryInterceptionHookPoint( + InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) { + GrpcAddCallHeaders add_headers(methods->GetSendInitialMetadata()); + for (const auto& middleware : middleware_) { + middleware->SendingHeaders(&add_headers); + } + } + + if (methods->QueryInterceptionHookPoint( + InterceptionHookPoints::POST_RECV_INITIAL_METADATA)) { + if (!methods->GetRecvInitialMetadata()->empty()) { + ReceivedHeaders(*methods->GetRecvInitialMetadata()); + } + } + + if (methods->QueryInterceptionHookPoint(InterceptionHookPoints::POST_RECV_STATUS)) { + DCHECK_NE(nullptr, methods->GetRecvStatus()); + DCHECK_NE(nullptr, methods->GetRecvTrailingMetadata()); + ReceivedHeaders(*methods->GetRecvTrailingMetadata()); + const Status status = internal::FromGrpcStatus(*methods->GetRecvStatus()); + for (const auto& middleware : middleware_) { + middleware->CallCompleted(status); + } + } + + methods->Proceed(); + } + + private: + void ReceivedHeaders( + const std::multimap<grpc::string_ref, grpc::string_ref>& metadata) { + if (received_headers_) { + return; + } + received_headers_ = true; + CallHeaders headers; + for (const auto& entry : metadata) { + headers.insert({util::string_view(entry.first.data(), entry.first.length()), + util::string_view(entry.second.data(), entry.second.length())}); + } + for (const auto& middleware : middleware_) { + middleware->ReceivedHeaders(headers); + } + } + + std::vector<std::unique_ptr<ClientMiddleware>> middleware_; + // When communicating with a gRPC-Java server, the server may not + // send back headers if the call fails right away. Instead, the + // headers will be consolidated into the trailers. We don't want to + // call the client middleware callback twice, so instead track + // whether we saw headers - if not, then we need to check trailers. + bool received_headers_; +}; + +class GrpcClientInterceptorAdapterFactory + : public grpc::experimental::ClientInterceptorFactoryInterface { + public: + GrpcClientInterceptorAdapterFactory( + std::vector<std::shared_ptr<ClientMiddlewareFactory>> middleware) + : middleware_(middleware) {} + + grpc::experimental::Interceptor* CreateClientInterceptor( + grpc::experimental::ClientRpcInfo* info) override { + std::vector<std::unique_ptr<ClientMiddleware>> middleware; + + FlightMethod flight_method = FlightMethod::Invalid; + util::string_view method(info->method()); + if (method.ends_with("/Handshake")) { + flight_method = FlightMethod::Handshake; + } else if (method.ends_with("/ListFlights")) { + flight_method = FlightMethod::ListFlights; + } else if (method.ends_with("/GetFlightInfo")) { + flight_method = FlightMethod::GetFlightInfo; + } else if (method.ends_with("/GetSchema")) { + flight_method = FlightMethod::GetSchema; + } else if (method.ends_with("/DoGet")) { + flight_method = FlightMethod::DoGet; + } else if (method.ends_with("/DoPut")) { + flight_method = FlightMethod::DoPut; + } else if (method.ends_with("/DoExchange")) { + flight_method = FlightMethod::DoExchange; + } else if (method.ends_with("/DoAction")) { + flight_method = FlightMethod::DoAction; + } else if (method.ends_with("/ListActions")) { + flight_method = FlightMethod::ListActions; + } else { + DCHECK(false) << "Unknown Flight method: " << info->method(); + } + + const CallInfo flight_info{flight_method}; + for (auto& factory : middleware_) { + std::unique_ptr<ClientMiddleware> instance; + factory->StartCall(flight_info, &instance); + if (instance) { + middleware.push_back(std::move(instance)); + } + } + return new GrpcClientInterceptorAdapter(std::move(middleware)); + } + + private: + std::vector<std::shared_ptr<ClientMiddlewareFactory>> middleware_; +}; + +class GrpcClientAuthSender : public ClientAuthSender { + public: + explicit GrpcClientAuthSender( + std::shared_ptr< + grpc::ClientReaderWriter<pb::HandshakeRequest, pb::HandshakeResponse>> + stream) + : stream_(stream) {} + + Status Write(const std::string& token) override { + pb::HandshakeRequest response; + response.set_payload(token); + if (stream_->Write(response)) { + return Status::OK(); + } + return internal::FromGrpcStatus(stream_->Finish()); + } + + private: + std::shared_ptr<grpc::ClientReaderWriter<pb::HandshakeRequest, pb::HandshakeResponse>> + stream_; +}; + +class GrpcClientAuthReader : public ClientAuthReader { + public: + explicit GrpcClientAuthReader( + std::shared_ptr< + grpc::ClientReaderWriter<pb::HandshakeRequest, pb::HandshakeResponse>> + stream) + : stream_(stream) {} + + Status Read(std::string* token) override { + pb::HandshakeResponse request; + if (stream_->Read(&request)) { + *token = std::move(*request.mutable_payload()); + return Status::OK(); + } + return internal::FromGrpcStatus(stream_->Finish()); + } + + private: + std::shared_ptr<grpc::ClientReaderWriter<pb::HandshakeRequest, pb::HandshakeResponse>> + stream_; +}; + +// An ipc::MessageReader that adapts any readable gRPC stream +// returning FlightData. +template <typename Reader> +class GrpcIpcMessageReader : public ipc::MessageReader { + public: + GrpcIpcMessageReader( + std::shared_ptr<ClientRpc> rpc, std::shared_ptr<std::mutex> read_mutex, + std::shared_ptr<FinishableStream<Reader, internal::FlightData>> stream, + std::shared_ptr<internal::PeekableFlightDataReader<std::shared_ptr<Reader>>> + peekable_reader, + std::shared_ptr<Buffer>* app_metadata) + : rpc_(rpc), + read_mutex_(read_mutex), + stream_(std::move(stream)), + peekable_reader_(peekable_reader), + app_metadata_(app_metadata), + stream_finished_(false) {} + + ::arrow::Result<std::unique_ptr<ipc::Message>> ReadNextMessage() override { + if (stream_finished_) { + return nullptr; + } + internal::FlightData* data; + { + auto guard = read_mutex_ ? std::unique_lock<std::mutex>(*read_mutex_) + : std::unique_lock<std::mutex>(); + peekable_reader_->Next(&data); + } + if (!data) { + stream_finished_ = true; + return stream_->Finish(Status::OK()); + } + // Validate IPC message + auto result = data->OpenMessage(); + if (!result.ok()) { + return stream_->Finish(std::move(result).status()); + } + *app_metadata_ = std::move(data->app_metadata); + return result; + } + + private: + // The RPC context lifetime must be coupled to the ClientReader + std::shared_ptr<ClientRpc> rpc_; + // Guard reads with a mutex to prevent concurrent reads if the write + // side calls Finish(). Nullable as DoGet doesn't need this. + std::shared_ptr<std::mutex> read_mutex_; + std::shared_ptr<FinishableStream<Reader, internal::FlightData>> stream_; + std::shared_ptr<internal::PeekableFlightDataReader<std::shared_ptr<Reader>>> + peekable_reader_; + // A reference to GrpcStreamReader.app_metadata_. That class + // can't access the app metadata because when it Peek()s the stream, + // it may be looking at a dictionary batch, not the record + // batch. Updating it here ensures the reader is always updated with + // the last metadata message read. + std::shared_ptr<Buffer>* app_metadata_; + bool stream_finished_; +}; + +/// The implementation of the public-facing API for reading from a +/// FlightData stream +template <typename Reader> +class GrpcStreamReader : public FlightStreamReader { + public: + GrpcStreamReader(std::shared_ptr<ClientRpc> rpc, std::shared_ptr<std::mutex> read_mutex, + const ipc::IpcReadOptions& options, StopToken stop_token, + std::shared_ptr<FinishableStream<Reader, internal::FlightData>> stream) + : rpc_(rpc), + read_mutex_(read_mutex), + options_(options), + stop_token_(std::move(stop_token)), + stream_(stream), + peekable_reader_(new internal::PeekableFlightDataReader<std::shared_ptr<Reader>>( + stream->stream())), + app_metadata_(nullptr) {} + + Status EnsureDataStarted() { + if (!batch_reader_) { + bool skipped_to_data = false; + { + auto guard = TakeGuard(); + skipped_to_data = peekable_reader_->SkipToData(); + } + // peek() until we find the first data message; discard metadata + if (!skipped_to_data) { + return OverrideWithServerError(MakeFlightError( + FlightStatusCode::Internal, "Server never sent a data message")); + } + + auto message_reader = + std::unique_ptr<ipc::MessageReader>(new GrpcIpcMessageReader<Reader>( + rpc_, read_mutex_, stream_, peekable_reader_, &app_metadata_)); + auto result = + ipc::RecordBatchStreamReader::Open(std::move(message_reader), options_); + RETURN_NOT_OK(OverrideWithServerError(std::move(result).Value(&batch_reader_))); + } + return Status::OK(); + } + arrow::Result<std::shared_ptr<Schema>> GetSchema() override { + RETURN_NOT_OK(EnsureDataStarted()); + return batch_reader_->schema(); + } + Status Next(FlightStreamChunk* out) override { + internal::FlightData* data; + { + auto guard = TakeGuard(); + peekable_reader_->Peek(&data); + } + if (!data) { + out->app_metadata = nullptr; + out->data = nullptr; + return stream_->Finish(Status::OK()); + } + + if (!data->metadata) { + // Metadata-only (data->metadata is the IPC header) + out->app_metadata = data->app_metadata; + out->data = nullptr; + { + auto guard = TakeGuard(); + peekable_reader_->Next(&data); + } + return Status::OK(); + } + + if (!batch_reader_) { + RETURN_NOT_OK(EnsureDataStarted()); + // Re-peek here since EnsureDataStarted() advances the stream + return Next(out); + } + RETURN_NOT_OK(batch_reader_->ReadNext(&out->data)); + out->app_metadata = std::move(app_metadata_); + return Status::OK(); + } + Status ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches) override { + return ReadAll(batches, stop_token_); + } + Status ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches, + const StopToken& stop_token) override { + FlightStreamChunk chunk; + + while (true) { + if (stop_token.IsStopRequested()) { + Cancel(); + return stop_token.Poll(); + } + RETURN_NOT_OK(Next(&chunk)); + if (!chunk.data) break; + batches->emplace_back(std::move(chunk.data)); + } + return Status::OK(); + } + Status ReadAll(std::shared_ptr<Table>* table) override { + return ReadAll(table, stop_token_); + } + using FlightStreamReader::ReadAll; + void Cancel() override { rpc_->context.TryCancel(); } + + private: + std::unique_lock<std::mutex> TakeGuard() { + return read_mutex_ ? std::unique_lock<std::mutex>(*read_mutex_) + : std::unique_lock<std::mutex>(); + } + + Status OverrideWithServerError(Status&& st) { + if (st.ok()) { + return std::move(st); + } + return stream_->Finish(std::move(st)); + } + + friend class GrpcIpcMessageReader<Reader>; + std::shared_ptr<ClientRpc> rpc_; + // Guard reads with a lock to prevent Finish()/Close() from being + // called on the writer while the reader has a pending + // read. Nullable, as DoGet() doesn't need this. + std::shared_ptr<std::mutex> read_mutex_; + ipc::IpcReadOptions options_; + StopToken stop_token_; + std::shared_ptr<FinishableStream<Reader, internal::FlightData>> stream_; + std::shared_ptr<internal::PeekableFlightDataReader<std::shared_ptr<Reader>>> + peekable_reader_; + std::shared_ptr<ipc::RecordBatchReader> batch_reader_; + std::shared_ptr<Buffer> app_metadata_; +}; + +// The next two classes implement writing to a FlightData stream. +// Similarly to the read side, we want to reuse the implementation of +// RecordBatchWriter. As a result, these two classes are intertwined +// in order to pass application metadata "through" RecordBatchWriter. +// In order to get application-specific metadata to the +// IpcPayloadWriter, DoPutPayloadWriter takes a pointer to +// GrpcStreamWriter. GrpcStreamWriter updates a metadata field on +// write; DoPutPayloadWriter reads that metadata field to determine +// what to write. + +template <typename ProtoReadT, typename FlightReadT> +class DoPutPayloadWriter; + +template <typename ProtoReadT, typename FlightReadT> +class GrpcStreamWriter : public FlightStreamWriter { + public: + ~GrpcStreamWriter() override = default; + + using GrpcStream = grpc::ClientReaderWriter<pb::FlightData, ProtoReadT>; + + explicit GrpcStreamWriter( + const FlightDescriptor& descriptor, std::shared_ptr<ClientRpc> rpc, + int64_t write_size_limit_bytes, const ipc::IpcWriteOptions& options, + std::shared_ptr<FinishableWritableStream<GrpcStream, FlightReadT>> writer) + : app_metadata_(nullptr), + batch_writer_(nullptr), + writer_(std::move(writer)), + rpc_(std::move(rpc)), + write_size_limit_bytes_(write_size_limit_bytes), + options_(options), + descriptor_(descriptor), + writer_closed_(false) {} + + static Status Open( + const FlightDescriptor& descriptor, std::shared_ptr<Schema> schema, + const ipc::IpcWriteOptions& options, std::shared_ptr<ClientRpc> rpc, + int64_t write_size_limit_bytes, + std::shared_ptr<FinishableWritableStream<GrpcStream, FlightReadT>> writer, + std::unique_ptr<FlightStreamWriter>* out); + + Status CheckStarted() { + if (!batch_writer_) { + return Status::Invalid("Writer not initialized. Call Begin() with a schema."); + } + return Status::OK(); + } + + Status Begin(const std::shared_ptr<Schema>& schema, + const ipc::IpcWriteOptions& options) override { + if (batch_writer_) { + return Status::Invalid("This writer has already been started."); + } + std::unique_ptr<ipc::internal::IpcPayloadWriter> payload_writer( + new DoPutPayloadWriter<ProtoReadT, FlightReadT>( + descriptor_, std::move(rpc_), write_size_limit_bytes_, writer_, this)); + // XXX: this does not actually write the message to the stream. + // See Close(). + ARROW_ASSIGN_OR_RAISE(batch_writer_, ipc::internal::OpenRecordBatchWriter( + std::move(payload_writer), schema, options)); + return Status::OK(); + } + + Status Begin(const std::shared_ptr<Schema>& schema) override { + return Begin(schema, options_); + } + + Status WriteRecordBatch(const RecordBatch& batch) override { + RETURN_NOT_OK(CheckStarted()); + return WriteWithMetadata(batch, nullptr); + } + + Status WriteMetadata(std::shared_ptr<Buffer> app_metadata) override { + FlightPayload payload{}; + payload.app_metadata = app_metadata; + auto status = internal::WritePayload(payload, writer_->stream().get()); + if (status.IsIOError()) { + return writer_->Finish(MakeFlightError(FlightStatusCode::Internal, + "Could not write metadata to stream")); + } + return status; + } + + Status WriteWithMetadata(const RecordBatch& batch, + std::shared_ptr<Buffer> app_metadata) override { + RETURN_NOT_OK(CheckStarted()); + app_metadata_ = app_metadata; + return batch_writer_->WriteRecordBatch(batch); + } + + Status DoneWriting() override { + // Do not CheckStarted - DoneWriting applies to data and metadata + if (batch_writer_) { + // Close the writer if we have one; this will force it to flush any + // remaining data, before we close the write side of the stream. + writer_closed_ = true; + Status st = batch_writer_->Close(); + if (!st.ok()) { + return writer_->Finish(std::move(st)); + } + } + return writer_->DoneWriting(); + } + + Status Close() override { + // Do not CheckStarted - Close applies to data and metadata + if (batch_writer_ && !writer_closed_) { + // This is important! Close() calls + // IpcPayloadWriter::CheckStarted() which will force the initial + // schema message to be written to the stream. This is required + // to unstick the server, else the client and the server end up + // waiting for each other. This happens if the client never + // wrote anything before calling Close(). + writer_closed_ = true; + return writer_->Finish(batch_writer_->Close()); + } + return writer_->Finish(Status::OK()); + } + + ipc::WriteStats stats() const override { + ARROW_CHECK_NE(batch_writer_, nullptr); + return batch_writer_->stats(); + } + + private: + friend class DoPutPayloadWriter<ProtoReadT, FlightReadT>; + std::shared_ptr<Buffer> app_metadata_; + std::unique_ptr<ipc::RecordBatchWriter> batch_writer_; + std::shared_ptr<FinishableWritableStream<GrpcStream, FlightReadT>> writer_; + + // Fields used to lazy-initialize the IpcPayloadWriter. They're + // invalid once Begin() is called. + std::shared_ptr<ClientRpc> rpc_; + int64_t write_size_limit_bytes_; + ipc::IpcWriteOptions options_; + FlightDescriptor descriptor_; + bool writer_closed_; +}; + +/// A IpcPayloadWriter implementation that writes to a gRPC stream of +/// FlightData messages. +template <typename ProtoReadT, typename FlightReadT> +class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter { + public: + using GrpcStream = grpc::ClientReaderWriter<pb::FlightData, ProtoReadT>; + + DoPutPayloadWriter( + const FlightDescriptor& descriptor, std::shared_ptr<ClientRpc> rpc, + int64_t write_size_limit_bytes, + std::shared_ptr<FinishableWritableStream<GrpcStream, FlightReadT>> writer, + GrpcStreamWriter<ProtoReadT, FlightReadT>* stream_writer) + : descriptor_(descriptor), + rpc_(rpc), + write_size_limit_bytes_(write_size_limit_bytes), + writer_(std::move(writer)), + first_payload_(true), + stream_writer_(stream_writer) {} + + ~DoPutPayloadWriter() override = default; + + Status Start() override { return Status::OK(); } + + Status WritePayload(const ipc::IpcPayload& ipc_payload) override { + FlightPayload payload; + payload.ipc_message = ipc_payload; + + if (first_payload_) { + // First Flight message needs to encore the Flight descriptor + if (ipc_payload.type != ipc::MessageType::SCHEMA) { + return Status::Invalid("First IPC message should be schema"); + } + // Write the descriptor to begin with + RETURN_NOT_OK(internal::ToPayload(descriptor_, &payload.descriptor)); + first_payload_ = false; + } else if (ipc_payload.type == ipc::MessageType::RECORD_BATCH && + stream_writer_->app_metadata_) { + payload.app_metadata = std::move(stream_writer_->app_metadata_); + } + + if (write_size_limit_bytes_ > 0) { + // Check if the total size is greater than the user-configured + // soft-limit. + int64_t size = ipc_payload.body_length + ipc_payload.metadata->size(); + if (payload.descriptor) { + size += payload.descriptor->size(); + } + if (payload.app_metadata) { + size += payload.app_metadata->size(); + } + if (size > write_size_limit_bytes_) { + return arrow::Status( + arrow::StatusCode::Invalid, "IPC payload size exceeded soft limit", + std::make_shared<FlightWriteSizeStatusDetail>(write_size_limit_bytes_, size)); + } + } + + auto status = internal::WritePayload(payload, writer_->stream().get()); + if (status.IsIOError()) { + return writer_->Finish(MakeFlightError(FlightStatusCode::Internal, + "Could not write record batch to stream")); + } + return status; + } + + Status Close() override { + // Closing is handled one layer up in GrpcStreamWriter::Close + return Status::OK(); + } + + protected: + const FlightDescriptor descriptor_; + std::shared_ptr<ClientRpc> rpc_; + int64_t write_size_limit_bytes_; + std::shared_ptr<FinishableWritableStream<GrpcStream, FlightReadT>> writer_; + bool first_payload_; + GrpcStreamWriter<ProtoReadT, FlightReadT>* stream_writer_; +}; + +template <typename ProtoReadT, typename FlightReadT> +Status GrpcStreamWriter<ProtoReadT, FlightReadT>::Open( + const FlightDescriptor& descriptor, + std::shared_ptr<Schema> schema, // this schema is nullable + const ipc::IpcWriteOptions& options, std::shared_ptr<ClientRpc> rpc, + int64_t write_size_limit_bytes, + std::shared_ptr<FinishableWritableStream<GrpcStream, FlightReadT>> writer, + std::unique_ptr<FlightStreamWriter>* out) { + std::unique_ptr<GrpcStreamWriter<ProtoReadT, FlightReadT>> instance( + new GrpcStreamWriter<ProtoReadT, FlightReadT>( + descriptor, std::move(rpc), write_size_limit_bytes, options, writer)); + if (schema) { + // The schema was provided (DoPut). Eagerly write the schema and + // descriptor together as the first message. + RETURN_NOT_OK(instance->Begin(schema, options)); + } else { + // The schema was not provided (DoExchange). Eagerly write just + // the descriptor as the first message. Note that if the client + // calls Begin() to send data, we'll send a redundant descriptor. + FlightPayload payload{}; + RETURN_NOT_OK(internal::ToPayload(descriptor, &payload.descriptor)); + auto status = internal::WritePayload(payload, instance->writer_->stream().get()); + if (status.IsIOError()) { + return writer->Finish(MakeFlightError(FlightStatusCode::Internal, + "Could not write descriptor to stream")); + } + RETURN_NOT_OK(status); + } + *out = std::move(instance); + return Status::OK(); +} + +FlightMetadataReader::~FlightMetadataReader() = default; + +class GrpcMetadataReader : public FlightMetadataReader { + public: + explicit GrpcMetadataReader( + std::shared_ptr<grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>> reader, + std::shared_ptr<std::mutex> read_mutex) + : reader_(reader), read_mutex_(read_mutex) {} + + Status ReadMetadata(std::shared_ptr<Buffer>* out) override { + std::lock_guard<std::mutex> guard(*read_mutex_); + pb::PutResult message; + if (reader_->Read(&message)) { + *out = Buffer::FromString(std::move(*message.mutable_app_metadata())); + } else { + // Stream finished + *out = nullptr; + } + return Status::OK(); + } + + private: + std::shared_ptr<grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>> reader_; + std::shared_ptr<std::mutex> read_mutex_; +}; + +namespace { +// Dummy self-signed certificate to be used because TlsCredentials +// requires root CA certs, even if you are skipping server +// verification. +#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) +constexpr char kDummyRootCert[] = + "-----BEGIN CERTIFICATE-----\n" + "MIICwzCCAaugAwIBAgIJAM12DOkcaqrhMA0GCSqGSIb3DQEBBQUAMBQxEjAQBgNV\n" + "BAMTCWxvY2FsaG9zdDAeFw0yMDEwMDcwODIyNDFaFw0zMDEwMDUwODIyNDFaMBQx\n" + "EjAQBgNVBAMTCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoC\n" + "ggEBALjJ8KPEpF0P4GjMPrJhjIBHUL0AX9E4oWdgJRCSFkPKKEWzQabTQBikMOhI\n" + "W4VvBMaHEBuECE5OEyrzDRiAO354I4F4JbBfxMOY8NIW0uWD6THWm2KkCzoZRIPW\n" + "yZL6dN+mK6cEH+YvbNuy5ZQGNjGG43tyiXdOCAc4AI9POeTtjdMpbbpR2VY4Ad/E\n" + "oTEiS3gNnN7WIAdgMhCJxjzvPwKszV3f7pwuTHzFMsuHLKr6JeaVUYfbi4DxxC8Z\n" + "k6PF6dLlLf3ngTSLBJyaXP1BhKMvz0TaMK3F0y2OGwHM9J8np2zWjTlNVEzffQZx\n" + "SWMOQManlJGs60xYx9KCPJMZZsMCAwEAAaMYMBYwFAYDVR0RBA0wC4IJbG9jYWxo\n" + "b3N0MA0GCSqGSIb3DQEBBQUAA4IBAQC0LrmbcNKgO+D50d/wOc+vhi9K04EZh8bg\n" + "WYAK1kLOT4eShbzqWGV/1EggY4muQ6ypSELCLuSsg88kVtFQIeRilA6bHFqQSj6t\n" + "sqgh2cWsMwyllCtmX6Maf3CLb2ZdoJlqUwdiBdrbIbuyeAZj3QweCtLKGSQzGDyI\n" + "KH7G8nC5d0IoRPiCMB6RnMMKsrhviuCdWbAFHop7Ff36JaOJ8iRa2sSf2OXE8j/5\n" + "obCXCUvYHf4Zw27JcM2AnnQI9VJLnYxis83TysC5s2Z7t0OYNS9kFmtXQbUNlmpS\n" + "doQ/Eu47vWX7S0TXeGziGtbAOKxbHE0BGGPDOAB/jGW/JVbeTiXY\n" + "-----END CERTIFICATE-----\n"; +#endif +} // namespace +class FlightClient::FlightClientImpl { + public: + Status Connect(const Location& location, const FlightClientOptions& options) { + const std::string& scheme = location.scheme(); + + std::stringstream grpc_uri; + std::shared_ptr<grpc::ChannelCredentials> creds; + if (scheme == kSchemeGrpc || scheme == kSchemeGrpcTcp || scheme == kSchemeGrpcTls) { + grpc_uri << arrow::internal::UriEncodeHost(location.uri_->host()) << ':' + << location.uri_->port_text(); + + if (scheme == kSchemeGrpcTls) { + if (options.disable_server_verification) { +#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) + namespace ge = GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS; + + // A callback to supply to TlsCredentialsOptions that accepts any server + // arguments. + struct NoOpTlsAuthorizationCheck + : public ge::TlsServerAuthorizationCheckInterface { + int Schedule(ge::TlsServerAuthorizationCheckArg* arg) override { + arg->set_success(1); + arg->set_status(GRPC_STATUS_OK); + return 0; + } + }; + auto server_authorization_check = std::make_shared<NoOpTlsAuthorizationCheck>(); + noop_auth_check_ = std::make_shared<ge::TlsServerAuthorizationCheckConfig>( + server_authorization_check); +#if defined(GRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS) + auto certificate_provider = + std::make_shared<grpc::experimental::StaticDataCertificateProvider>( + kDummyRootCert); +#if defined(GRPC_USE_TLS_CHANNEL_CREDENTIALS_OPTIONS_ROOT_CERTS) + grpc::experimental::TlsChannelCredentialsOptions tls_options( + certificate_provider); +#else + // While gRPC >= 1.36 does not require a root cert (it has a default) + // in practice the path it hardcodes is broken. See grpc/grpc#21655. + grpc::experimental::TlsChannelCredentialsOptions tls_options; + tls_options.set_certificate_provider(certificate_provider); +#endif + tls_options.watch_root_certs(); + tls_options.set_root_cert_name("dummy"); + tls_options.set_server_verification_option( + grpc_tls_server_verification_option::GRPC_TLS_SKIP_ALL_SERVER_VERIFICATION); + tls_options.set_server_authorization_check_config(noop_auth_check_); +#elif defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) + auto materials_config = std::make_shared<ge::TlsKeyMaterialsConfig>(); + materials_config->set_pem_root_certs(kDummyRootCert); + ge::TlsCredentialsOptions tls_options( + GRPC_SSL_DONT_REQUEST_CLIENT_CERTIFICATE, + GRPC_TLS_SKIP_ALL_SERVER_VERIFICATION, materials_config, + std::shared_ptr<ge::TlsCredentialReloadConfig>(), noop_auth_check_); +#endif + creds = ge::TlsCredentials(tls_options); +#else + return Status::NotImplemented( + "Using encryption with server verification disabled is unsupported. " + "Please use a release of Arrow Flight built with gRPC 1.27 or higher."); +#endif + } else { + grpc::SslCredentialsOptions ssl_options; + if (!options.tls_root_certs.empty()) { + ssl_options.pem_root_certs = options.tls_root_certs; + } + if (!options.cert_chain.empty()) { + ssl_options.pem_cert_chain = options.cert_chain; + } + if (!options.private_key.empty()) { + ssl_options.pem_private_key = options.private_key; + } + creds = grpc::SslCredentials(ssl_options); + } + } else { + creds = grpc::InsecureChannelCredentials(); + } + } else if (scheme == kSchemeGrpcUnix) { + grpc_uri << "unix://" << location.uri_->path(); + creds = grpc::InsecureChannelCredentials(); + } else { + return Status::NotImplemented("Flight scheme " + scheme + " is not supported."); + } + + grpc::ChannelArguments args; + // We can't set the same config value twice, so for values where + // we want to set defaults, keep them in a map and update them; + // then update them all at once + std::unordered_map<std::string, int> default_args; + // Try to reconnect quickly at first, in case the server is still starting up + default_args[GRPC_ARG_INITIAL_RECONNECT_BACKOFF_MS] = 100; + // Receive messages of any size + default_args[GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH] = -1; + // Setting this arg enables each client to open it's own TCP connection to server, + // not sharing one single connection, which becomes bottleneck under high load. + default_args[GRPC_ARG_USE_LOCAL_SUBCHANNEL_POOL] = 1; + + if (options.override_hostname != "") { + args.SetSslTargetNameOverride(options.override_hostname); + } + + // Allow setting generic gRPC options. + for (const auto& arg : options.generic_options) { + if (util::holds_alternative<int>(arg.second)) { + default_args[arg.first] = util::get<int>(arg.second); + } else if (util::holds_alternative<std::string>(arg.second)) { + args.SetString(arg.first, util::get<std::string>(arg.second)); + } + // Otherwise unimplemented + } + for (const auto& pair : default_args) { + args.SetInt(pair.first, pair.second); + } + + std::vector<std::unique_ptr<grpc::experimental::ClientInterceptorFactoryInterface>> + interceptors; + interceptors.emplace_back( + new GrpcClientInterceptorAdapterFactory(std::move(options.middleware))); + + stub_ = pb::FlightService::NewStub( + grpc::experimental::CreateCustomChannelWithInterceptors( + grpc_uri.str(), creds, args, std::move(interceptors))); + + write_size_limit_bytes_ = options.write_size_limit_bytes; + return Status::OK(); + } + + Status Authenticate(const FlightCallOptions& options, + std::unique_ptr<ClientAuthHandler> auth_handler) { + auth_handler_ = std::move(auth_handler); + ClientRpc rpc(options); + std::shared_ptr<grpc::ClientReaderWriter<pb::HandshakeRequest, pb::HandshakeResponse>> + stream = stub_->Handshake(&rpc.context); + GrpcClientAuthSender outgoing{stream}; + GrpcClientAuthReader incoming{stream}; + RETURN_NOT_OK(auth_handler_->Authenticate(&outgoing, &incoming)); + // Explicitly close our side of the connection + bool finished_writes = stream->WritesDone(); + RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish(), &rpc.context)); + if (!finished_writes) { + return MakeFlightError(FlightStatusCode::Internal, + "Could not finish writing before closing"); + } + return Status::OK(); + } + + arrow::Result<std::pair<std::string, std::string>> AuthenticateBasicToken( + const FlightCallOptions& options, const std::string& username, + const std::string& password) { + // Add basic auth headers to outgoing headers. + ClientRpc rpc(options); + internal::AddBasicAuthHeaders(&rpc.context, username, password); + + std::shared_ptr<grpc::ClientReaderWriter<pb::HandshakeRequest, pb::HandshakeResponse>> + stream = stub_->Handshake(&rpc.context); + GrpcClientAuthSender outgoing{stream}; + GrpcClientAuthReader incoming{stream}; + + // Explicitly close our side of the connection. + bool finished_writes = stream->WritesDone(); + RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish(), &rpc.context)); + if (!finished_writes) { + return MakeFlightError(FlightStatusCode::Internal, + "Could not finish writing before closing"); + } + + // Grab bearer token from incoming headers. + return internal::GetBearerTokenHeader(rpc.context); + } + + Status ListFlights(const FlightCallOptions& options, const Criteria& criteria, + std::unique_ptr<FlightListing>* listing) { + pb::Criteria pb_criteria; + RETURN_NOT_OK(internal::ToProto(criteria, &pb_criteria)); + + ClientRpc rpc(options); + RETURN_NOT_OK(rpc.SetToken(auth_handler_.get())); + std::unique_ptr<grpc::ClientReader<pb::FlightInfo>> stream( + stub_->ListFlights(&rpc.context, pb_criteria)); + + std::vector<FlightInfo> flights; + + pb::FlightInfo pb_info; + while (!options.stop_token.IsStopRequested() && stream->Read(&pb_info)) { + FlightInfo::Data info_data; + RETURN_NOT_OK(internal::FromProto(pb_info, &info_data)); + flights.emplace_back(std::move(info_data)); + } + if (options.stop_token.IsStopRequested()) rpc.context.TryCancel(); + RETURN_NOT_OK(options.stop_token.Poll()); + listing->reset(new SimpleFlightListing(std::move(flights))); + return internal::FromGrpcStatus(stream->Finish(), &rpc.context); + } + + Status DoAction(const FlightCallOptions& options, const Action& action, + std::unique_ptr<ResultStream>* results) { + pb::Action pb_action; + RETURN_NOT_OK(internal::ToProto(action, &pb_action)); + + ClientRpc rpc(options); + RETURN_NOT_OK(rpc.SetToken(auth_handler_.get())); + std::unique_ptr<grpc::ClientReader<pb::Result>> stream( + stub_->DoAction(&rpc.context, pb_action)); + + pb::Result pb_result; + + std::vector<Result> materialized_results; + while (!options.stop_token.IsStopRequested() && stream->Read(&pb_result)) { + Result result; + RETURN_NOT_OK(internal::FromProto(pb_result, &result)); + materialized_results.emplace_back(std::move(result)); + } + if (options.stop_token.IsStopRequested()) rpc.context.TryCancel(); + RETURN_NOT_OK(options.stop_token.Poll()); + + *results = std::unique_ptr<ResultStream>( + new SimpleResultStream(std::move(materialized_results))); + return internal::FromGrpcStatus(stream->Finish(), &rpc.context); + } + + Status ListActions(const FlightCallOptions& options, std::vector<ActionType>* types) { + pb::Empty empty; + + ClientRpc rpc(options); + RETURN_NOT_OK(rpc.SetToken(auth_handler_.get())); + std::unique_ptr<grpc::ClientReader<pb::ActionType>> stream( + stub_->ListActions(&rpc.context, empty)); + + pb::ActionType pb_type; + ActionType type; + while (!options.stop_token.IsStopRequested() && stream->Read(&pb_type)) { + RETURN_NOT_OK(internal::FromProto(pb_type, &type)); + types->emplace_back(std::move(type)); + } + if (options.stop_token.IsStopRequested()) rpc.context.TryCancel(); + RETURN_NOT_OK(options.stop_token.Poll()); + return internal::FromGrpcStatus(stream->Finish(), &rpc.context); + } + + Status GetFlightInfo(const FlightCallOptions& options, + const FlightDescriptor& descriptor, + std::unique_ptr<FlightInfo>* info) { + pb::FlightDescriptor pb_descriptor; + pb::FlightInfo pb_response; + + RETURN_NOT_OK(internal::ToProto(descriptor, &pb_descriptor)); + + ClientRpc rpc(options); + RETURN_NOT_OK(rpc.SetToken(auth_handler_.get())); + Status s = internal::FromGrpcStatus( + stub_->GetFlightInfo(&rpc.context, pb_descriptor, &pb_response), &rpc.context); + RETURN_NOT_OK(s); + + FlightInfo::Data info_data; + RETURN_NOT_OK(internal::FromProto(pb_response, &info_data)); + info->reset(new FlightInfo(std::move(info_data))); + return Status::OK(); + } + + Status GetSchema(const FlightCallOptions& options, const FlightDescriptor& descriptor, + std::unique_ptr<SchemaResult>* schema_result) { + pb::FlightDescriptor pb_descriptor; + pb::SchemaResult pb_response; + + RETURN_NOT_OK(internal::ToProto(descriptor, &pb_descriptor)); + + ClientRpc rpc(options); + RETURN_NOT_OK(rpc.SetToken(auth_handler_.get())); + Status s = internal::FromGrpcStatus( + stub_->GetSchema(&rpc.context, pb_descriptor, &pb_response), &rpc.context); + RETURN_NOT_OK(s); + + std::string str; + RETURN_NOT_OK(internal::FromProto(pb_response, &str)); + schema_result->reset(new SchemaResult(str)); + return Status::OK(); + } + + Status DoGet(const FlightCallOptions& options, const Ticket& ticket, + std::unique_ptr<FlightStreamReader>* out) { + using StreamReader = GrpcStreamReader<grpc::ClientReader<pb::FlightData>>; + pb::Ticket pb_ticket; + internal::ToProto(ticket, &pb_ticket); + + auto rpc = std::make_shared<ClientRpc>(options); + RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); + std::shared_ptr<grpc::ClientReader<pb::FlightData>> stream = + stub_->DoGet(&rpc->context, pb_ticket); + auto finishable_stream = std::make_shared< + FinishableStream<grpc::ClientReader<pb::FlightData>, internal::FlightData>>( + rpc, stream); + *out = std::unique_ptr<StreamReader>(new StreamReader( + rpc, nullptr, options.read_options, options.stop_token, finishable_stream)); + // Eagerly read the schema + return static_cast<StreamReader*>(out->get())->EnsureDataStarted(); + } + + Status DoPut(const FlightCallOptions& options, const FlightDescriptor& descriptor, + const std::shared_ptr<Schema>& schema, + std::unique_ptr<FlightStreamWriter>* out, + std::unique_ptr<FlightMetadataReader>* reader) { + using GrpcStream = grpc::ClientReaderWriter<pb::FlightData, pb::PutResult>; + using StreamWriter = GrpcStreamWriter<pb::PutResult, pb::PutResult>; + + auto rpc = std::make_shared<ClientRpc>(options); + RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); + std::shared_ptr<GrpcStream> stream = stub_->DoPut(&rpc->context); + // The writer drains the reader on close to avoid hanging inside + // gRPC. Concurrent reads are unsafe, so a mutex protects this operation. + std::shared_ptr<std::mutex> read_mutex = std::make_shared<std::mutex>(); + auto finishable_stream = + std::make_shared<FinishableWritableStream<GrpcStream, pb::PutResult>>( + rpc, read_mutex, stream); + *reader = + std::unique_ptr<FlightMetadataReader>(new GrpcMetadataReader(stream, read_mutex)); + return StreamWriter::Open(descriptor, schema, options.write_options, rpc, + write_size_limit_bytes_, finishable_stream, out); + } + + Status DoExchange(const FlightCallOptions& options, const FlightDescriptor& descriptor, + std::unique_ptr<FlightStreamWriter>* writer, + std::unique_ptr<FlightStreamReader>* reader) { + using GrpcStream = grpc::ClientReaderWriter<pb::FlightData, pb::FlightData>; + using StreamReader = GrpcStreamReader<GrpcStream>; + using StreamWriter = GrpcStreamWriter<pb::FlightData, internal::FlightData>; + + auto rpc = std::make_shared<ClientRpc>(options); + RETURN_NOT_OK(rpc->SetToken(auth_handler_.get())); + std::shared_ptr<grpc::ClientReaderWriter<pb::FlightData, pb::FlightData>> stream = + stub_->DoExchange(&rpc->context); + // The writer drains the reader on close to avoid hanging inside + // gRPC. Concurrent reads are unsafe, so a mutex protects this operation. + std::shared_ptr<std::mutex> read_mutex = std::make_shared<std::mutex>(); + auto finishable_stream = + std::make_shared<FinishableWritableStream<GrpcStream, internal::FlightData>>( + rpc, read_mutex, stream); + *reader = std::unique_ptr<StreamReader>(new StreamReader( + rpc, read_mutex, options.read_options, options.stop_token, finishable_stream)); + // Do not eagerly read the schema. There may be metadata messages + // before any data is sent, or data may not be sent at all. + return StreamWriter::Open(descriptor, nullptr, options.write_options, rpc, + write_size_limit_bytes_, finishable_stream, writer); + } + + private: + std::unique_ptr<pb::FlightService::Stub> stub_; + std::shared_ptr<ClientAuthHandler> auth_handler_; +#if defined(GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS) + // Scope the TlsServerAuthorizationCheckConfig to be at the class instance level, since + // it gets created during Connect() and needs to persist to DoAction() calls. gRPC does + // not correctly increase the reference count of this object: + // https://github.com/grpc/grpc/issues/22287 + std::shared_ptr< + GRPC_NAMESPACE_FOR_TLS_CREDENTIALS_OPTIONS::TlsServerAuthorizationCheckConfig> + noop_auth_check_; +#endif + int64_t write_size_limit_bytes_; +}; + +FlightClient::FlightClient() { impl_.reset(new FlightClientImpl); } + +FlightClient::~FlightClient() {} + +Status FlightClient::Connect(const Location& location, + std::unique_ptr<FlightClient>* client) { + return Connect(location, FlightClientOptions::Defaults(), client); +} + +Status FlightClient::Connect(const Location& location, const FlightClientOptions& options, + std::unique_ptr<FlightClient>* client) { + client->reset(new FlightClient); + return (*client)->impl_->Connect(location, options); +} + +Status FlightClient::Authenticate(const FlightCallOptions& options, + std::unique_ptr<ClientAuthHandler> auth_handler) { + return impl_->Authenticate(options, std::move(auth_handler)); +} + +arrow::Result<std::pair<std::string, std::string>> FlightClient::AuthenticateBasicToken( + const FlightCallOptions& options, const std::string& username, + const std::string& password) { + return impl_->AuthenticateBasicToken(options, username, password); +} + +Status FlightClient::DoAction(const FlightCallOptions& options, const Action& action, + std::unique_ptr<ResultStream>* results) { + return impl_->DoAction(options, action, results); +} + +Status FlightClient::ListActions(const FlightCallOptions& options, + std::vector<ActionType>* actions) { + return impl_->ListActions(options, actions); +} + +Status FlightClient::GetFlightInfo(const FlightCallOptions& options, + const FlightDescriptor& descriptor, + std::unique_ptr<FlightInfo>* info) { + return impl_->GetFlightInfo(options, descriptor, info); +} + +Status FlightClient::GetSchema(const FlightCallOptions& options, + const FlightDescriptor& descriptor, + std::unique_ptr<SchemaResult>* schema_result) { + return impl_->GetSchema(options, descriptor, schema_result); +} + +Status FlightClient::ListFlights(std::unique_ptr<FlightListing>* listing) { + return ListFlights({}, {}, listing); +} + +Status FlightClient::ListFlights(const FlightCallOptions& options, + const Criteria& criteria, + std::unique_ptr<FlightListing>* listing) { + return impl_->ListFlights(options, criteria, listing); +} + +Status FlightClient::DoGet(const FlightCallOptions& options, const Ticket& ticket, + std::unique_ptr<FlightStreamReader>* stream) { + return impl_->DoGet(options, ticket, stream); +} + +Status FlightClient::DoPut(const FlightCallOptions& options, + const FlightDescriptor& descriptor, + const std::shared_ptr<Schema>& schema, + std::unique_ptr<FlightStreamWriter>* stream, + std::unique_ptr<FlightMetadataReader>* reader) { + return impl_->DoPut(options, descriptor, schema, stream, reader); +} + +Status FlightClient::DoExchange(const FlightCallOptions& options, + const FlightDescriptor& descriptor, + std::unique_ptr<FlightStreamWriter>* writer, + std::unique_ptr<FlightStreamReader>* reader) { + return impl_->DoExchange(options, descriptor, writer, reader); +} + +} // namespace flight +} // namespace arrow |