summaryrefslogtreecommitdiffstats
path: root/src/jaegertracing/thrift/lib/netstd/Thrift/Transport/Client/TTlsSocketTransport.cs
diff options
context:
space:
mode:
Diffstat (limited to 'src/jaegertracing/thrift/lib/netstd/Thrift/Transport/Client/TTlsSocketTransport.cs')
-rw-r--r--src/jaegertracing/thrift/lib/netstd/Thrift/Transport/Client/TTlsSocketTransport.cs267
1 files changed, 267 insertions, 0 deletions
diff --git a/src/jaegertracing/thrift/lib/netstd/Thrift/Transport/Client/TTlsSocketTransport.cs b/src/jaegertracing/thrift/lib/netstd/Thrift/Transport/Client/TTlsSocketTransport.cs
new file mode 100644
index 000000000..9295bb01b
--- /dev/null
+++ b/src/jaegertracing/thrift/lib/netstd/Thrift/Transport/Client/TTlsSocketTransport.cs
@@ -0,0 +1,267 @@
+// Licensed to the Apache Software Foundation(ASF) under one
+// or more contributor license agreements.See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+using System;
+using System.Net;
+using System.Net.Security;
+using System.Net.Sockets;
+using System.Security.Authentication;
+using System.Security.Cryptography.X509Certificates;
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace Thrift.Transport.Client
+{
+ //TODO: check for correct work
+
+ // ReSharper disable once InconsistentNaming
+ public class TTlsSocketTransport : TStreamTransport
+ {
+ private readonly X509Certificate2 _certificate;
+ private readonly RemoteCertificateValidationCallback _certValidator;
+ private readonly IPAddress _host;
+ private readonly bool _isServer;
+ private readonly LocalCertificateSelectionCallback _localCertificateSelectionCallback;
+ private readonly int _port;
+ private readonly SslProtocols _sslProtocols;
+ private TcpClient _client;
+ private SslStream _secureStream;
+ private int _timeout;
+
+ public TTlsSocketTransport(TcpClient client, X509Certificate2 certificate, bool isServer = false,
+ RemoteCertificateValidationCallback certValidator = null,
+ LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
+ SslProtocols sslProtocols = SslProtocols.Tls12)
+ {
+ _client = client;
+ _certificate = certificate;
+ _certValidator = certValidator;
+ _localCertificateSelectionCallback = localCertificateSelectionCallback;
+ _sslProtocols = sslProtocols;
+ _isServer = isServer;
+
+ if (isServer && certificate == null)
+ {
+ throw new ArgumentException("TTlsSocketTransport needs certificate to be used for server",
+ nameof(certificate));
+ }
+
+ if (IsOpen)
+ {
+ InputStream = client.GetStream();
+ OutputStream = client.GetStream();
+ }
+ }
+
+ public TTlsSocketTransport(IPAddress host, int port, string certificatePath,
+ RemoteCertificateValidationCallback certValidator = null,
+ LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
+ SslProtocols sslProtocols = SslProtocols.Tls12)
+ : this(host, port, 0,
+ new X509Certificate2(certificatePath),
+ certValidator,
+ localCertificateSelectionCallback,
+ sslProtocols)
+ {
+ }
+
+ public TTlsSocketTransport(IPAddress host, int port,
+ X509Certificate2 certificate = null,
+ RemoteCertificateValidationCallback certValidator = null,
+ LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
+ SslProtocols sslProtocols = SslProtocols.Tls12)
+ : this(host, port, 0,
+ certificate,
+ certValidator,
+ localCertificateSelectionCallback,
+ sslProtocols)
+ {
+ }
+
+ public TTlsSocketTransport(IPAddress host, int port, int timeout,
+ X509Certificate2 certificate,
+ RemoteCertificateValidationCallback certValidator = null,
+ LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
+ SslProtocols sslProtocols = SslProtocols.Tls12)
+ {
+ _host = host;
+ _port = port;
+ _timeout = timeout;
+ _certificate = certificate;
+ _certValidator = certValidator;
+ _localCertificateSelectionCallback = localCertificateSelectionCallback;
+ _sslProtocols = sslProtocols;
+
+ InitSocket();
+ }
+
+ public TTlsSocketTransport(string host, int port, int timeout,
+ X509Certificate2 certificate,
+ RemoteCertificateValidationCallback certValidator = null,
+ LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
+ SslProtocols sslProtocols = SslProtocols.Tls12)
+ {
+ try
+ {
+ var entry = Dns.GetHostEntry(host);
+ if (entry.AddressList.Length == 0)
+ throw new TTransportException(TTransportException.ExceptionType.Unknown, "unable to resolve host name");
+
+ var addr = entry.AddressList[0];
+
+ _host = new IPAddress(addr.GetAddressBytes(), addr.ScopeId);
+ _port = port;
+ _timeout = timeout;
+ _certificate = certificate;
+ _certValidator = certValidator;
+ _localCertificateSelectionCallback = localCertificateSelectionCallback;
+ _sslProtocols = sslProtocols;
+
+ InitSocket();
+ }
+ catch (SocketException e)
+ {
+ throw new TTransportException(TTransportException.ExceptionType.Unknown, e.Message, e);
+ }
+ }
+
+ public int Timeout
+ {
+ set { _client.ReceiveTimeout = _client.SendTimeout = _timeout = value; }
+ }
+
+ public TcpClient TcpClient => _client;
+
+ public IPAddress Host => _host;
+
+ public int Port => _port;
+
+ public override bool IsOpen
+ {
+ get
+ {
+ if (_client == null)
+ {
+ return false;
+ }
+
+ return _client.Connected;
+ }
+ }
+
+ private void InitSocket()
+ {
+ _client = new TcpClient();
+ _client.ReceiveTimeout = _client.SendTimeout = _timeout;
+ _client.Client.NoDelay = true;
+ }
+
+ private bool DefaultCertificateValidator(object sender, X509Certificate certificate, X509Chain chain,
+ SslPolicyErrors sslValidationErrors)
+ {
+ return sslValidationErrors == SslPolicyErrors.None;
+ }
+
+ public override async Task OpenAsync(CancellationToken cancellationToken)
+ {
+ if (IsOpen)
+ {
+ throw new TTransportException(TTransportException.ExceptionType.AlreadyOpen, "Socket already connected");
+ }
+
+ if (_host == null)
+ {
+ throw new TTransportException(TTransportException.ExceptionType.NotOpen, "Cannot open null host");
+ }
+
+ if (_port <= 0)
+ {
+ throw new TTransportException(TTransportException.ExceptionType.NotOpen, "Cannot open without port");
+ }
+
+ if (_client == null)
+ {
+ InitSocket();
+ }
+
+ if (_client != null)
+ {
+ await _client.ConnectAsync(_host, _port);
+ await SetupTlsAsync();
+ }
+ }
+
+ public async Task SetupTlsAsync()
+ {
+ var validator = _certValidator ?? DefaultCertificateValidator;
+
+ if (_localCertificateSelectionCallback != null)
+ {
+ _secureStream = new SslStream(_client.GetStream(), false, validator, _localCertificateSelectionCallback);
+ }
+ else
+ {
+ _secureStream = new SslStream(_client.GetStream(), false, validator);
+ }
+
+ try
+ {
+ if (_isServer)
+ {
+ // Server authentication
+ await
+ _secureStream.AuthenticateAsServerAsync(_certificate, _certValidator != null, _sslProtocols,
+ true);
+ }
+ else
+ {
+ // Client authentication
+ var certs = _certificate != null
+ ? new X509CertificateCollection {_certificate}
+ : new X509CertificateCollection();
+
+ var targetHost = _host.ToString();
+ await _secureStream.AuthenticateAsClientAsync(targetHost, certs, _sslProtocols, true);
+ }
+ }
+ catch (Exception)
+ {
+ Close();
+ throw;
+ }
+
+ InputStream = _secureStream;
+ OutputStream = _secureStream;
+ }
+
+ public override void Close()
+ {
+ base.Close();
+ if (_client != null)
+ {
+ _client.Dispose();
+ _client = null;
+ }
+
+ if (_secureStream != null)
+ {
+ _secureStream.Dispose();
+ _secureStream = null;
+ }
+ }
+ }
+}