summaryrefslogtreecommitdiffstats
path: root/libfreerdp/core/rdstls.c
diff options
context:
space:
mode:
Diffstat (limited to 'libfreerdp/core/rdstls.c')
-rw-r--r--libfreerdp/core/rdstls.c969
1 files changed, 969 insertions, 0 deletions
diff --git a/libfreerdp/core/rdstls.c b/libfreerdp/core/rdstls.c
new file mode 100644
index 0000000..94e0967
--- /dev/null
+++ b/libfreerdp/core/rdstls.c
@@ -0,0 +1,969 @@
+/**
+ * FreeRDP: A Remote Desktop Protocol Implementation
+ * RDSTLS Security protocol
+ *
+ * Copyright 2023 Joan Torres <joan.torres@suse.com>
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <freerdp/config.h>
+
+#include "settings.h"
+
+#include <freerdp/log.h>
+#include <freerdp/error.h>
+#include <freerdp/settings.h>
+
+#include <winpr/assert.h>
+#include <winpr/stream.h>
+#include <winpr/wlog.h>
+
+#include "rdstls.h"
+#include "transport.h"
+#include "utils.h"
+
+#define RDSTLS_VERSION_1 0x01
+
+#define RDSTLS_TYPE_CAPABILITIES 0x01
+#define RDSTLS_TYPE_AUTHREQ 0x02
+#define RDSTLS_TYPE_AUTHRSP 0x04
+
+#define RDSTLS_DATA_CAPABILITIES 0x01
+#define RDSTLS_DATA_PASSWORD_CREDS 0x01
+#define RDSTLS_DATA_AUTORECONNECT_COOKIE 0x02
+#define RDSTLS_DATA_RESULT_CODE 0x01
+
+typedef enum
+{
+ RDSTLS_STATE_INITIAL,
+ RDSTLS_STATE_CAPABILITIES,
+ RDSTLS_STATE_AUTH_REQ,
+ RDSTLS_STATE_AUTH_RSP,
+ RDSTLS_STATE_FINAL,
+} RDSTLS_STATE;
+
+struct rdp_rdstls
+{
+ BOOL server;
+ RDSTLS_STATE state;
+ rdpContext* context;
+ rdpTransport* transport;
+
+ UINT32 resultCode;
+ wLog* log;
+};
+
+/**
+ * Create new RDSTLS state machine.
+ *
+ * @param context A pointer to the rdp context to use
+ *
+ * @return new RDSTLS state machine.
+ */
+
+rdpRdstls* rdstls_new(rdpContext* context, rdpTransport* transport)
+{
+ WINPR_ASSERT(context);
+ WINPR_ASSERT(transport);
+
+ rdpSettings* settings = context->settings;
+ WINPR_ASSERT(settings);
+
+ rdpRdstls* rdstls = (rdpRdstls*)calloc(1, sizeof(rdpRdstls));
+
+ if (!rdstls)
+ return NULL;
+ rdstls->log = WLog_Get(FREERDP_TAG("core.rdstls"));
+ rdstls->context = context;
+ rdstls->transport = transport;
+ rdstls->server = settings->ServerMode;
+
+ rdstls->state = RDSTLS_STATE_INITIAL;
+
+ return rdstls;
+}
+
+/**
+ * Free RDSTLS state machine.
+ * @param rdstls The RDSTLS instance to free
+ */
+
+void rdstls_free(rdpRdstls* rdstls)
+{
+ free(rdstls);
+}
+
+static const char* rdstls_get_state_str(RDSTLS_STATE state)
+{
+ switch (state)
+ {
+ case RDSTLS_STATE_INITIAL:
+ return "RDSTLS_STATE_INITIAL";
+ case RDSTLS_STATE_CAPABILITIES:
+ return "RDSTLS_STATE_CAPABILITIES";
+ case RDSTLS_STATE_AUTH_REQ:
+ return "RDSTLS_STATE_AUTH_REQ";
+ case RDSTLS_STATE_AUTH_RSP:
+ return "RDSTLS_STATE_AUTH_RSP";
+ case RDSTLS_STATE_FINAL:
+ return "RDSTLS_STATE_FINAL";
+ default:
+ return "UNKNOWN";
+ }
+}
+
+static RDSTLS_STATE rdstls_get_state(rdpRdstls* rdstls)
+{
+ WINPR_ASSERT(rdstls);
+ return rdstls->state;
+}
+
+static BOOL check_transition(wLog* log, RDSTLS_STATE current, RDSTLS_STATE expected,
+ RDSTLS_STATE requested)
+{
+ if (requested != expected)
+ {
+ WLog_Print(log, WLOG_ERROR,
+ "Unexpected rdstls state transition from %s [%d] to %s [%d], expected %s [%d]",
+ rdstls_get_state_str(current), current, rdstls_get_state_str(requested),
+ requested, rdstls_get_state_str(expected), expected);
+ return FALSE;
+ }
+ return TRUE;
+}
+
+static BOOL rdstls_set_state(rdpRdstls* rdstls, RDSTLS_STATE state)
+{
+ BOOL rc = FALSE;
+ WINPR_ASSERT(rdstls);
+
+ WLog_Print(rdstls->log, WLOG_DEBUG, "-- %s\t--> %s", rdstls_get_state_str(rdstls->state),
+ rdstls_get_state_str(state));
+
+ switch (rdstls->state)
+ {
+ case RDSTLS_STATE_INITIAL:
+ rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_CAPABILITIES, state);
+ break;
+ case RDSTLS_STATE_CAPABILITIES:
+ rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_REQ, state);
+ break;
+ case RDSTLS_STATE_AUTH_REQ:
+ rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_AUTH_RSP, state);
+ break;
+ case RDSTLS_STATE_AUTH_RSP:
+ rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_FINAL, state);
+ break;
+ case RDSTLS_STATE_FINAL:
+ rc = check_transition(rdstls->log, rdstls->state, RDSTLS_STATE_CAPABILITIES, state);
+ break;
+ default:
+ WLog_Print(rdstls->log, WLOG_ERROR,
+ "Invalid rdstls state %s [%d], requested transition to %s [%d]",
+ rdstls_get_state_str(rdstls->state), rdstls->state,
+ rdstls_get_state_str(state), state);
+ break;
+ }
+ if (rc)
+ rdstls->state = state;
+
+ return rc;
+}
+
+static BOOL rdstls_write_capabilities(rdpRdstls* rdstls, wStream* s)
+{
+ if (!Stream_EnsureRemainingCapacity(s, 6))
+ return FALSE;
+
+ Stream_Write_UINT16(s, RDSTLS_TYPE_CAPABILITIES);
+ Stream_Write_UINT16(s, RDSTLS_DATA_CAPABILITIES);
+ Stream_Write_UINT16(s, RDSTLS_VERSION_1);
+
+ return TRUE;
+}
+
+static SSIZE_T rdstls_write_string(wStream* s, const char* str)
+{
+ const size_t pos = Stream_GetPosition(s);
+
+ if (!Stream_EnsureRemainingCapacity(s, 2))
+ return -1;
+
+ if (!str)
+ {
+ /* Write unicode null */
+ Stream_Write_UINT16(s, 2);
+ if (!Stream_EnsureRemainingCapacity(s, 2))
+ return -1;
+
+ Stream_Write_UINT16(s, 0);
+ return (SSIZE_T)(Stream_GetPosition(s) - pos);
+ }
+
+ const size_t length = (strlen(str) + 1);
+
+ Stream_Write_UINT16(s, (UINT16)length * sizeof(WCHAR));
+
+ if (!Stream_EnsureRemainingCapacity(s, length * sizeof(WCHAR)))
+ return -1;
+
+ if (Stream_Write_UTF16_String_From_UTF8(s, length, str, length, TRUE) < 0)
+ return -1;
+
+ return (SSIZE_T)(Stream_GetPosition(s) - pos);
+}
+
+static BOOL rdstls_write_data(wStream* s, UINT32 length, const BYTE* data)
+{
+ WINPR_ASSERT(data || (length == 0));
+
+ if (!Stream_EnsureRemainingCapacity(s, 2))
+ return FALSE;
+
+ Stream_Write_UINT16(s, length);
+
+ if (!Stream_EnsureRemainingCapacity(s, length))
+ return FALSE;
+
+ Stream_Write(s, data, length);
+
+ return TRUE;
+}
+
+static BOOL rdstls_write_authentication_request_with_password(rdpRdstls* rdstls, wStream* s)
+{
+ rdpSettings* settings = rdstls->context->settings;
+ WINPR_ASSERT(settings);
+
+ if (!Stream_EnsureRemainingCapacity(s, 4))
+ return FALSE;
+
+ Stream_Write_UINT16(s, RDSTLS_TYPE_AUTHREQ);
+ Stream_Write_UINT16(s, RDSTLS_DATA_PASSWORD_CREDS);
+
+ if (!rdstls_write_data(s, settings->RedirectionGuidLength, settings->RedirectionGuid))
+ return FALSE;
+
+ if (rdstls_write_string(s, settings->Username) < 0)
+ return FALSE;
+
+ if (rdstls_write_string(s, settings->Domain) < 0)
+ return FALSE;
+
+ if (!rdstls_write_data(s, settings->RedirectionPasswordLength, settings->RedirectionPassword))
+ return FALSE;
+
+ return TRUE;
+}
+
+static BOOL rdstls_write_authentication_request_with_cookie(rdpRdstls* rdstls, wStream* s)
+{
+ // TODO
+ return FALSE;
+}
+
+static BOOL rdstls_write_authentication_response(rdpRdstls* rdstls, wStream* s)
+{
+ if (!Stream_EnsureRemainingCapacity(s, 8))
+ return FALSE;
+
+ Stream_Write_UINT16(s, RDSTLS_TYPE_AUTHRSP);
+ Stream_Write_UINT16(s, RDSTLS_DATA_RESULT_CODE);
+ Stream_Write_UINT32(s, rdstls->resultCode);
+
+ return TRUE;
+}
+
+static BOOL rdstls_process_capabilities(rdpRdstls* rdstls, wStream* s)
+{
+ UINT16 dataType = 0;
+ UINT16 supportedVersions = 0;
+
+ if (!Stream_CheckAndLogRequiredLengthWLog(rdstls->log, s, 4))
+ return FALSE;
+
+ Stream_Read_UINT16(s, dataType);
+ if (dataType != RDSTLS_DATA_CAPABILITIES)
+ {
+ WLog_Print(rdstls->log, WLOG_ERROR,
+ "received invalid DataType=0x%04" PRIX16 ", expected 0x%04" PRIX16, dataType,
+ RDSTLS_DATA_CAPABILITIES);
+ return FALSE;
+ }
+
+ Stream_Read_UINT16(s, supportedVersions);
+ if ((supportedVersions & RDSTLS_VERSION_1) == 0)
+ {
+ WLog_Print(rdstls->log, WLOG_ERROR,
+ "received invalid supportedVersions=0x%04" PRIX16 ", expected 0x%04" PRIX16,
+ supportedVersions, RDSTLS_VERSION_1);
+ return FALSE;
+ }
+
+ return TRUE;
+}
+
+static BOOL rdstls_read_unicode_string(wLog* log, wStream* s, char** str)
+{
+ UINT16 length = 0;
+
+ WINPR_ASSERT(str);
+
+ if (!Stream_CheckAndLogRequiredLengthWLog(log, s, 2))
+ return FALSE;
+
+ Stream_Read_UINT16(s, length);
+
+ if (!Stream_CheckAndLogRequiredLengthWLog(log, s, length))
+ return FALSE;
+
+ if (length <= 2)
+ {
+ Stream_Seek(s, length);
+ return TRUE;
+ }
+
+ *str = Stream_Read_UTF16_String_As_UTF8(s, length / sizeof(WCHAR), NULL);
+ if (!*str)
+ return FALSE;
+
+ return TRUE;
+}
+
+static BOOL rdstls_read_data(wLog* log, wStream* s, UINT16* pLength, const BYTE** pData)
+{
+ UINT16 length = 0;
+
+ WINPR_ASSERT(pLength);
+ WINPR_ASSERT(pData);
+
+ *pData = NULL;
+ *pLength = 0;
+ if (!Stream_CheckAndLogRequiredLengthWLog(log, s, 2))
+ return FALSE;
+
+ Stream_Read_UINT16(s, length);
+
+ if (!Stream_CheckAndLogRequiredLengthWLog(log, s, length))
+ return FALSE;
+
+ if (length <= 2)
+ {
+ Stream_Seek(s, length);
+ return TRUE;
+ }
+
+ *pData = Stream_ConstPointer(s);
+ *pLength = length;
+ return Stream_SafeSeek(s, length);
+}
+
+static BOOL rdstls_cmp_data(wLog* log, const char* field, const BYTE* serverData,
+ const UINT32 serverDataLength, const BYTE* clientData,
+ const UINT16 clientDataLength)
+{
+ if (serverDataLength > 0)
+ {
+ if (clientDataLength == 0)
+ {
+ WLog_Print(log, WLOG_ERROR, "expected %s", field);
+ return FALSE;
+ }
+
+ if (serverDataLength > UINT16_MAX || serverDataLength != clientDataLength ||
+ memcmp(serverData, clientData, serverDataLength) != 0)
+ {
+ WLog_Print(log, WLOG_ERROR, "%s verification failed", field);
+ return FALSE;
+ }
+ }
+
+ return TRUE;
+}
+
+static BOOL rdstls_cmp_str(wLog* log, const char* field, const char* serverStr,
+ const char* clientStr)
+{
+ if (!utils_str_is_empty(serverStr))
+ {
+ if (utils_str_is_empty(clientStr))
+ {
+ WLog_Print(log, WLOG_ERROR, "expected %s", field);
+ return FALSE;
+ }
+
+ if (strcmp(serverStr, clientStr) != 0)
+ {
+ WLog_Print(log, WLOG_ERROR, "%s verification failed", field);
+ return FALSE;
+ }
+ }
+
+ return TRUE;
+}
+
+static BOOL rdstls_process_authentication_request_with_password(rdpRdstls* rdstls, wStream* s)
+{
+ BOOL rc = FALSE;
+
+ const BYTE* clientRedirectionGuid = NULL;
+ UINT16 clientRedirectionGuidLength = 0;
+ char* clientPassword = NULL;
+ char* clientUsername = NULL;
+ char* clientDomain = NULL;
+
+ const BYTE* serverRedirectionGuid = NULL;
+ UINT16 serverRedirectionGuidLength = 0;
+ const char* serverPassword = NULL;
+ const char* serverUsername = NULL;
+ const char* serverDomain = NULL;
+
+ rdpSettings* settings = rdstls->context->settings;
+ WINPR_ASSERT(settings);
+
+ if (!rdstls_read_data(rdstls->log, s, &clientRedirectionGuidLength, &clientRedirectionGuid))
+ goto fail;
+
+ if (!rdstls_read_unicode_string(rdstls->log, s, &clientUsername))
+ goto fail;
+
+ if (!rdstls_read_unicode_string(rdstls->log, s, &clientDomain))
+ goto fail;
+
+ if (!rdstls_read_unicode_string(rdstls->log, s, &clientPassword))
+ goto fail;
+
+ serverRedirectionGuid = freerdp_settings_get_pointer(settings, FreeRDP_RedirectionGuid);
+ serverRedirectionGuidLength =
+ freerdp_settings_get_uint32(settings, FreeRDP_RedirectionGuidLength);
+ serverUsername = freerdp_settings_get_string(settings, FreeRDP_Username);
+ serverDomain = freerdp_settings_get_string(settings, FreeRDP_Domain);
+ serverPassword = freerdp_settings_get_string(settings, FreeRDP_Password);
+
+ rdstls->resultCode = ERROR_SUCCESS;
+
+ if (!rdstls_cmp_data(rdstls->log, "RedirectionGuid", serverRedirectionGuid,
+ serverRedirectionGuidLength, clientRedirectionGuid,
+ clientRedirectionGuidLength))
+ rdstls->resultCode = ERROR_LOGON_FAILURE;
+
+ if (!rdstls_cmp_str(rdstls->log, "UserName", serverUsername, clientUsername))
+ rdstls->resultCode = ERROR_LOGON_FAILURE;
+
+ if (!rdstls_cmp_str(rdstls->log, "Domain", serverDomain, clientDomain))
+ rdstls->resultCode = ERROR_LOGON_FAILURE;
+
+ if (!rdstls_cmp_str(rdstls->log, "Password", serverPassword, clientPassword))
+ rdstls->resultCode = ERROR_LOGON_FAILURE;
+
+ rc = TRUE;
+fail:
+ return rc;
+}
+
+static BOOL rdstls_process_authentication_request_with_cookie(rdpRdstls* rdstls, wStream* s)
+{
+ // TODO
+ return FALSE;
+}
+
+static BOOL rdstls_process_authentication_request(rdpRdstls* rdstls, wStream* s)
+{
+ UINT16 dataType = 0;
+
+ if (!Stream_CheckAndLogRequiredLengthWLog(rdstls->log, s, 2))
+ return FALSE;
+
+ Stream_Read_UINT16(s, dataType);
+ switch (dataType)
+ {
+ case RDSTLS_DATA_PASSWORD_CREDS:
+ if (!rdstls_process_authentication_request_with_password(rdstls, s))
+ return FALSE;
+ break;
+ case RDSTLS_DATA_AUTORECONNECT_COOKIE:
+ if (!rdstls_process_authentication_request_with_cookie(rdstls, s))
+ return FALSE;
+ break;
+ default:
+ WLog_Print(rdstls->log, WLOG_ERROR,
+ "received invalid DataType=0x%04" PRIX16 ", expected 0x%04" PRIX16
+ " or 0x%04" PRIX16,
+ dataType, RDSTLS_DATA_PASSWORD_CREDS, RDSTLS_DATA_AUTORECONNECT_COOKIE);
+ return FALSE;
+ }
+
+ return TRUE;
+}
+
+static BOOL rdstls_process_authentication_response(rdpRdstls* rdstls, wStream* s)
+{
+ UINT16 dataType = 0;
+ UINT32 resultCode = 0;
+
+ if (!Stream_CheckAndLogRequiredLengthWLog(rdstls->log, s, 6))
+ return FALSE;
+
+ Stream_Read_UINT16(s, dataType);
+ if (dataType != RDSTLS_DATA_RESULT_CODE)
+ {
+ WLog_Print(rdstls->log, WLOG_ERROR,
+ "received invalid DataType=0x%04" PRIX16 ", expected 0x%04" PRIX16, dataType,
+ RDSTLS_DATA_RESULT_CODE);
+ return FALSE;
+ }
+
+ Stream_Read_UINT32(s, resultCode);
+ if (resultCode != ERROR_SUCCESS)
+ {
+ WLog_Print(rdstls->log, WLOG_ERROR, "resultCode: %s [0x%08" PRIX32 "] %s",
+ freerdp_get_last_error_name(resultCode), resultCode,
+ freerdp_get_last_error_string(resultCode));
+ return FALSE;
+ }
+
+ return TRUE;
+}
+
+static BOOL rdstls_send(rdpTransport* transport, wStream* s, void* extra)
+{
+ rdpRdstls* rdstls = (rdpRdstls*)extra;
+ rdpSettings* settings = NULL;
+
+ WINPR_ASSERT(transport);
+ WINPR_ASSERT(s);
+ WINPR_ASSERT(rdstls);
+
+ settings = rdstls->context->settings;
+ WINPR_ASSERT(settings);
+
+ if (!Stream_EnsureRemainingCapacity(s, 2))
+ return FALSE;
+
+ Stream_Write_UINT16(s, RDSTLS_VERSION_1);
+
+ const RDSTLS_STATE state = rdstls_get_state(rdstls);
+ switch (state)
+ {
+ case RDSTLS_STATE_CAPABILITIES:
+ if (!rdstls_write_capabilities(rdstls, s))
+ return FALSE;
+ break;
+ case RDSTLS_STATE_AUTH_REQ:
+ if (settings->RedirectionFlags & LB_PASSWORD_IS_PK_ENCRYPTED)
+ {
+ if (!rdstls_write_authentication_request_with_password(rdstls, s))
+ return FALSE;
+ }
+ else if (settings->ServerAutoReconnectCookie != NULL)
+ {
+ if (!rdstls_write_authentication_request_with_cookie(rdstls, s))
+ return FALSE;
+ }
+ else
+ {
+ WLog_Print(rdstls->log, WLOG_ERROR,
+ "cannot authenticate with password or auto-reconnect cookie");
+ return FALSE;
+ }
+ break;
+ case RDSTLS_STATE_AUTH_RSP:
+ if (!rdstls_write_authentication_response(rdstls, s))
+ return FALSE;
+ break;
+ default:
+ WLog_Print(rdstls->log, WLOG_ERROR, "Invalid rdstls state %s [%d]",
+ rdstls_get_state_str(state), state);
+ return FALSE;
+ }
+
+ if (transport_write(rdstls->transport, s) < 0)
+ return FALSE;
+
+ return TRUE;
+}
+
+static int rdstls_recv(rdpTransport* transport, wStream* s, void* extra)
+{
+ UINT16 version = 0;
+ UINT16 pduType = 0;
+ rdpRdstls* rdstls = (rdpRdstls*)extra;
+
+ WINPR_ASSERT(transport);
+ WINPR_ASSERT(s);
+ WINPR_ASSERT(rdstls);
+
+ if (!Stream_CheckAndLogRequiredLengthWLog(rdstls->log, s, 4))
+ return FALSE;
+
+ Stream_Read_UINT16(s, version);
+ if (version != RDSTLS_VERSION_1)
+ {
+ WLog_Print(rdstls->log, WLOG_ERROR,
+ "received invalid RDSTLS Version=0x%04" PRIX16 ", expected 0x%04" PRIX16,
+ version, RDSTLS_VERSION_1);
+ return -1;
+ }
+
+ Stream_Read_UINT16(s, pduType);
+ switch (pduType)
+ {
+ case RDSTLS_TYPE_CAPABILITIES:
+ if (!rdstls_process_capabilities(rdstls, s))
+ return -1;
+ break;
+ case RDSTLS_TYPE_AUTHREQ:
+ if (!rdstls_process_authentication_request(rdstls, s))
+ return -1;
+ break;
+ case RDSTLS_TYPE_AUTHRSP:
+ if (!rdstls_process_authentication_response(rdstls, s))
+ return -1;
+ break;
+ default:
+ WLog_Print(rdstls->log, WLOG_ERROR, "unknown RDSTLS PDU type [0x%04" PRIx16 "]",
+ pduType);
+ return -1;
+ }
+
+ return 1;
+}
+
+#define rdstls_check_state_requirements(rdstls, expected) \
+ rdstls_check_state_requirements_((rdstls), (expected), __FILE__, __func__, __LINE__)
+static BOOL rdstls_check_state_requirements_(rdpRdstls* rdstls, RDSTLS_STATE expected,
+ const char* file, const char* fkt, size_t line)
+{
+ const RDSTLS_STATE current = rdstls_get_state(rdstls);
+ if (current == expected)
+ return TRUE;
+
+ const DWORD log_level = WLOG_ERROR;
+ if (WLog_IsLevelActive(rdstls->log, log_level))
+ WLog_PrintMessage(rdstls->log, WLOG_MESSAGE_TEXT, log_level, line, file, fkt,
+ "Unexpected rdstls state %s [%d], expected %s [%d]",
+ rdstls_get_state_str(current), current, rdstls_get_state_str(expected),
+ expected);
+
+ return FALSE;
+}
+
+static BOOL rdstls_send_capabilities(rdpRdstls* rdstls)
+{
+ BOOL rc = FALSE;
+ wStream* s = NULL;
+
+ if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_CAPABILITIES))
+ goto fail;
+
+ s = Stream_New(NULL, 512);
+ if (!s)
+ goto fail;
+
+ if (!rdstls_send(rdstls->transport, s, rdstls))
+ goto fail;
+
+ rc = rdstls_set_state(rdstls, RDSTLS_STATE_AUTH_REQ);
+fail:
+ Stream_Free(s, TRUE);
+ return rc;
+}
+
+static BOOL rdstls_recv_authentication_request(rdpRdstls* rdstls)
+{
+ BOOL rc = FALSE;
+ int status = 0;
+ wStream* s = NULL;
+
+ if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_REQ))
+ goto fail;
+
+ s = Stream_New(NULL, 4096);
+ if (!s)
+ goto fail;
+
+ status = transport_read_pdu(rdstls->transport, s);
+
+ if (status < 0)
+ goto fail;
+
+ status = rdstls_recv(rdstls->transport, s, rdstls);
+
+ if (status < 0)
+ goto fail;
+
+ rc = rdstls_set_state(rdstls, RDSTLS_STATE_AUTH_RSP);
+fail:
+ Stream_Free(s, TRUE);
+ return rc;
+}
+
+static BOOL rdstls_send_authentication_response(rdpRdstls* rdstls)
+{
+ BOOL rc = FALSE;
+ wStream* s = NULL;
+
+ if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_RSP))
+ goto fail;
+
+ s = Stream_New(NULL, 512);
+ if (!s)
+ goto fail;
+
+ if (!rdstls_send(rdstls->transport, s, rdstls))
+ goto fail;
+
+ rc = rdstls_set_state(rdstls, RDSTLS_STATE_FINAL);
+fail:
+ Stream_Free(s, TRUE);
+ return rc;
+}
+
+static BOOL rdstls_recv_capabilities(rdpRdstls* rdstls)
+{
+ BOOL rc = FALSE;
+ int status = 0;
+ wStream* s = NULL;
+
+ if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_CAPABILITIES))
+ goto fail;
+
+ s = Stream_New(NULL, 512);
+ if (!s)
+ goto fail;
+
+ status = transport_read_pdu(rdstls->transport, s);
+
+ if (status < 0)
+ goto fail;
+
+ status = rdstls_recv(rdstls->transport, s, rdstls);
+
+ if (status < 0)
+ goto fail;
+
+ rc = rdstls_set_state(rdstls, RDSTLS_STATE_AUTH_REQ);
+fail:
+ Stream_Free(s, TRUE);
+ return rc;
+}
+
+static BOOL rdstls_send_authentication_request(rdpRdstls* rdstls)
+{
+ BOOL rc = FALSE;
+ wStream* s = NULL;
+
+ if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_REQ))
+ goto fail;
+
+ s = Stream_New(NULL, 4096);
+ if (!s)
+ goto fail;
+
+ if (!rdstls_send(rdstls->transport, s, rdstls))
+ goto fail;
+
+ rc = rdstls_set_state(rdstls, RDSTLS_STATE_AUTH_RSP);
+fail:
+ Stream_Free(s, TRUE);
+ return rc;
+}
+
+static BOOL rdstls_recv_authentication_response(rdpRdstls* rdstls)
+{
+ BOOL rc = FALSE;
+ int status = 0;
+ wStream* s = NULL;
+
+ WINPR_ASSERT(rdstls);
+
+ if (!rdstls_check_state_requirements(rdstls, RDSTLS_STATE_AUTH_RSP))
+ goto fail;
+
+ s = Stream_New(NULL, 512);
+ if (!s)
+ goto fail;
+
+ status = transport_read_pdu(rdstls->transport, s);
+
+ if (status < 0)
+ goto fail;
+
+ status = rdstls_recv(rdstls->transport, s, rdstls);
+
+ if (status < 0)
+ goto fail;
+
+ rc = rdstls_set_state(rdstls, RDSTLS_STATE_FINAL);
+fail:
+ Stream_Free(s, TRUE);
+ return rc;
+}
+
+static int rdstls_server_authenticate(rdpRdstls* rdstls)
+{
+ if (!rdstls_set_state(rdstls, RDSTLS_STATE_CAPABILITIES))
+ return -1;
+
+ if (!rdstls_send_capabilities(rdstls))
+ return -1;
+
+ if (!rdstls_recv_authentication_request(rdstls))
+ return -1;
+
+ if (!rdstls_send_authentication_response(rdstls))
+ return -1;
+
+ if (rdstls->resultCode != 0)
+ return -1;
+
+ return 1;
+}
+
+static int rdstls_client_authenticate(rdpRdstls* rdstls)
+{
+ if (!rdstls_set_state(rdstls, RDSTLS_STATE_CAPABILITIES))
+ return -1;
+
+ if (!rdstls_recv_capabilities(rdstls))
+ return -1;
+
+ if (!rdstls_send_authentication_request(rdstls))
+ return -1;
+
+ if (!rdstls_recv_authentication_response(rdstls))
+ return -1;
+
+ return 1;
+}
+
+/**
+ * Authenticate using RDSTLS.
+ * @param rdstls The RDSTLS instance to use
+ *
+ * @return 1 if authentication is successful
+ */
+
+int rdstls_authenticate(rdpRdstls* rdstls)
+{
+ WINPR_ASSERT(rdstls);
+
+ if (rdstls->server)
+ return rdstls_server_authenticate(rdstls);
+ else
+ return rdstls_client_authenticate(rdstls);
+}
+
+static SSIZE_T rdstls_parse_pdu_data_type(wLog* log, UINT16 dataType, wStream* s)
+{
+ switch (dataType)
+ {
+ case RDSTLS_DATA_PASSWORD_CREDS:
+ {
+ UINT16 redirGuidLength = 0;
+ if (Stream_GetRemainingLength(s) < 2)
+ return 0;
+ Stream_Read_UINT16(s, redirGuidLength);
+
+ if (Stream_GetRemainingLength(s) < redirGuidLength)
+ return 0;
+ Stream_Seek(s, redirGuidLength);
+
+ UINT16 usernameLength = 0;
+ if (Stream_GetRemainingLength(s) < 2)
+ return 0;
+ Stream_Read_UINT16(s, usernameLength);
+
+ if (Stream_GetRemainingLength(s) < usernameLength)
+ return 0;
+ Stream_Seek(s, usernameLength);
+
+ UINT16 domainLength = 0;
+ if (Stream_GetRemainingLength(s) < 2)
+ return 0;
+ Stream_Read_UINT16(s, domainLength);
+
+ if (Stream_GetRemainingLength(s) < domainLength)
+ return 0;
+ Stream_Seek(s, domainLength);
+
+ UINT16 passwordLength = 0;
+ if (Stream_GetRemainingLength(s) < 2)
+ return 0;
+ Stream_Read_UINT16(s, passwordLength);
+
+ return Stream_GetPosition(s) + passwordLength;
+ }
+ case RDSTLS_DATA_AUTORECONNECT_COOKIE:
+ {
+ if (Stream_GetRemainingLength(s) < 4)
+ return 0;
+ Stream_Seek(s, 4);
+
+ UINT16 cookieLength = 0;
+ if (Stream_GetRemainingLength(s) < 2)
+ return 0;
+ Stream_Read_UINT16(s, cookieLength);
+
+ return 12u + cookieLength;
+ }
+ default:
+ WLog_Print(log, WLOG_ERROR, "invalid RDSLTS dataType");
+ return -1;
+ }
+}
+
+SSIZE_T rdstls_parse_pdu(wLog* log, wStream* stream)
+{
+ SSIZE_T pduLength = -1;
+ wStream sbuffer = { 0 };
+ wStream* s = Stream_StaticConstInit(&sbuffer, Stream_Buffer(stream), Stream_Length(stream));
+
+ UINT16 version = 0;
+ if (Stream_GetRemainingLength(s) < 2)
+ return 0;
+ Stream_Read_UINT16(s, version);
+ if (version != RDSTLS_VERSION_1)
+ {
+ WLog_Print(log, WLOG_ERROR, "invalid RDSTLS version");
+ return -1;
+ }
+
+ UINT16 pduType = 0;
+ if (Stream_GetRemainingLength(s) < 2)
+ return 0;
+ Stream_Read_UINT16(s, pduType);
+ switch (pduType)
+ {
+ case RDSTLS_TYPE_CAPABILITIES:
+ pduLength = 8;
+ break;
+ case RDSTLS_TYPE_AUTHREQ:
+ if (Stream_GetRemainingLength(s) < 2)
+ return 0;
+ UINT16 dataType = 0;
+ Stream_Read_UINT16(s, dataType);
+ pduLength = rdstls_parse_pdu_data_type(log, dataType, s);
+
+ break;
+ case RDSTLS_TYPE_AUTHRSP:
+ pduLength = 10;
+ break;
+ default:
+ WLog_Print(log, WLOG_ERROR, "invalid RDSTLS PDU type");
+ return -1;
+ }
+
+ return pduLength;
+}