diff options
Diffstat (limited to 'plugin/win_auth_client/common.cc')
-rw-r--r-- | plugin/win_auth_client/common.cc | 510 |
1 files changed, 510 insertions, 0 deletions
diff --git a/plugin/win_auth_client/common.cc b/plugin/win_auth_client/common.cc new file mode 100644 index 00000000..8b731925 --- /dev/null +++ b/plugin/win_auth_client/common.cc @@ -0,0 +1,510 @@ +/* Copyright (c) 2011, 2019, Oracle and/or its affiliates. All rights reserved. + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; version 2 of the License. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program; if not, write to the Free Software + Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1335 USA */ + +#include "common.h" +#include <sddl.h> // for ConvertSidToStringSid() +#include <secext.h> // for GetUserNameEx() + + +template <> void error_log_print<error_log_level::INFO>(const char *fmt, ...); +template <> void error_log_print<error_log_level::WARNING>(const char *fmt, ...); +template <> void error_log_print<error_log_level::ERROR>(const char *fmt, ...); + +/** + Option indicating desired level of logging. Values: + + 0 - no logging + 1 - log only error messages + 2 - additionally log warnings + 3 - additionally log info notes + 4 - also log debug messages + + Value of this option should be taken into account in the + implementation of error_log_vprint() function (see + log_client.cc). + + Note: No error or debug messages are logged in production code + (see logging macros in common.h). +*/ +int opt_auth_win_log_level= 2; + + +/** Connection class **************************************************/ + +/** + Create connection out of an active MYSQL_PLUGIN_VIO object. + + @param[in] vio pointer to a @c MYSQL_PLUGIN_VIO object used for + connection - it can not be NULL +*/ + +Connection::Connection(MYSQL_PLUGIN_VIO *vio): m_vio(vio), m_error(0) +{ + DBUG_ASSERT(vio); +} + + +/** + Write data to the connection. + + @param[in] blob data to be written + + @return 0 on success, VIO error code on failure. + + @note In case of error, VIO error code is stored in the connection object + and can be obtained with @c error() method. +*/ + +int Connection::write(const Blob &blob) +{ + m_error= m_vio->write_packet(m_vio, blob.ptr(), (int)blob.len()); + +#ifndef DBUG_OFF + if (m_error) + DBUG_PRINT("error", ("vio write error %d", m_error)); +#endif + + return m_error; +} + + +/** + Read data from connection. + + @return A Blob containing read packet or null Blob in case of error. + + @note In case of error, VIO error code is stored in the connection object + and can be obtained with @c error() method. +*/ + +Blob Connection::read() +{ + unsigned char *ptr; + int len= m_vio->read_packet(m_vio, &ptr); + + if (len < 0) + { + m_error= true; + return Blob(); + } + + return Blob(ptr, len); +} + + +/** Sid class *****************************************************/ + + +/** + Create Sid object corresponding to a given account name. + + @param[in] account_name name of a Windows account + + The account name can be in any form accepted by @c LookupAccountName() + function. + + @note In case of errors created object is invalid and its @c is_valid() + method returns @c false. +*/ + +Sid::Sid(const wchar_t *account_name): m_data(NULL) +#ifndef DBUG_OFF +, m_as_string(NULL) +#endif +{ + DWORD sid_size= 0, domain_size= 0; + bool success; + + // Determine required buffer sizes + + success= LookupAccountNameW(NULL, account_name, NULL, &sid_size, + NULL, &domain_size, &m_type); + + if (!success && GetLastError() != ERROR_INSUFFICIENT_BUFFER) + { +#ifndef DBUG_OFF + Error_message_buf error_buf; + DBUG_PRINT("error", ("Could not determine SID buffer size, " + "LookupAccountName() failed with error %X (%s)", + GetLastError(), get_last_error_message(error_buf))); +#endif + return; + } + + // Query for SID (domain is ignored) + + wchar_t *domain= new wchar_t[domain_size]; + m_data= (TOKEN_USER*) new BYTE[sid_size + sizeof(TOKEN_USER)]; + m_data->User.Sid= (BYTE*)m_data + sizeof(TOKEN_USER); + + success= LookupAccountNameW(NULL, account_name, + m_data->User.Sid, &sid_size, + domain, &domain_size, + &m_type); + + if (!success || !is_valid()) + { +#ifndef DBUG_OFF + Error_message_buf error_buf; + DBUG_PRINT("error", ("Could not determine SID of '%S', " + "LookupAccountName() failed with error %X (%s)", + account_name, GetLastError(), + get_last_error_message(error_buf))); +#endif + goto fail; + } + + goto end; + +fail: + if (m_data) + delete [] m_data; + m_data= NULL; + +end: + if (domain) + delete [] domain; +} + + +/** + Create Sid object corresponding to a given security token. + + @param[in] token security token of a Windows account + + @note In case of errors created object is invalid and its @c is_valid() + method returns @c false. +*/ + +Sid::Sid(HANDLE token): m_data(NULL) +#ifndef DBUG_OFF +, m_as_string(NULL) +#endif +{ + DWORD req_size= 0; + bool success; + + // Determine required buffer size + + success= GetTokenInformation(token, TokenUser, NULL, 0, &req_size); + if (!success && GetLastError() != ERROR_INSUFFICIENT_BUFFER) + { +#ifndef DBUG_OFF + Error_message_buf error_buf; + DBUG_PRINT("error", ("Could not determine SID buffer size, " + "GetTokenInformation() failed with error %X (%s)", + GetLastError(), get_last_error_message(error_buf))); +#endif + return; + } + + m_data= (TOKEN_USER*) new BYTE[req_size]; + success= GetTokenInformation(token, TokenUser, m_data, req_size, &req_size); + + if (!success || !is_valid()) + { + delete [] m_data; + m_data= NULL; +#ifndef DBUG_OFF + if (!success) + { + Error_message_buf error_buf; + DBUG_PRINT("error", ("Could not read SID from security token, " + "GetTokenInformation() failed with error %X (%s)", + GetLastError(), get_last_error_message(error_buf))); + } +#endif + } +} + + +Sid::~Sid() +{ + if (m_data) + delete [] m_data; +#ifndef DBUG_OFF + if (m_as_string) + LocalFree(m_as_string); +#endif +} + +/// Check if Sid object is valid. +bool Sid::is_valid(void) const +{ + return m_data && m_data->User.Sid && IsValidSid(m_data->User.Sid); +} + + +#ifndef DBUG_OFF + +/** + Produces string representation of the SID. + + @return String representation of the SID or NULL in case of errors. + + @note Memory allocated for the string is automatically freed in Sid's + destructor. +*/ + +const char* Sid::as_string() +{ + if (!m_data) + return NULL; + + if (!m_as_string) + { + bool success= ConvertSidToStringSid(m_data->User.Sid, &m_as_string); + + if (!success) + { +#ifndef DBUG_OFF + Error_message_buf error_buf; + DBUG_PRINT("error", ("Could not get textual representation of a SID, " + "ConvertSidToStringSid() failed with error %X (%s)", + GetLastError(), get_last_error_message(error_buf))); +#endif + m_as_string= NULL; + return NULL; + } + } + + return m_as_string; +} + +#endif + + +bool Sid::operator ==(const Sid &other) +{ + if (!is_valid() || !other.is_valid()) + return false; + + return EqualSid(m_data->User.Sid, other.m_data->User.Sid); +} + + +/** Generating User Principal Name *************************/ + +/** + Call Windows API functions to get UPN of the current user and store it + in internal buffer. +*/ + +UPN::UPN(): m_buf(NULL) +{ + wchar_t buf1[MAX_SERVICE_NAME_LENGTH]; + + // First we try to use GetUserNameEx. + + m_len= sizeof(buf1)/sizeof(wchar_t); + + if (!GetUserNameExW(NameUserPrincipal, buf1, (PULONG)&m_len)) + { + if (GetLastError()) + { +#ifndef DBUG_OFF + Error_message_buf error_buf; + DBUG_PRINT("note", ("When determining UPN" + ", GetUserNameEx() failed with error %X (%s)", + GetLastError(), get_last_error_message(error_buf))); +#endif + if (ERROR_MORE_DATA == GetLastError()) + ERROR_LOG(INFO, ("Buffer overrun when determining UPN:" + " need %ul characters but have %ul", + m_len, sizeof(buf1)/sizeof(WCHAR))); + } + + m_len= 0; // m_len == 0 indicates invalid UPN + return; + } + + /* + UPN is stored in buf1 in wide-char format - convert it to utf8 + for sending over network. + */ + + m_buf= wchar_to_utf8(buf1, &m_len); + + if(!m_buf) + ERROR_LOG(ERROR, ("Failed to convert UPN to utf8")); + + // Note: possible error would be indicated by the fact that m_buf is NULL. + return; +} + + +UPN::~UPN() +{ + if (m_buf) + free(m_buf); +} + + +/** + Convert a wide-char string to utf8 representation. + + @param[in] string null-terminated wide-char string to be converted + @param[in,out] len length of the string to be converted or 0; on + return length (in bytes, excluding terminating + null character) of the converted string + + If len is 0 then the length of the string will be computed by this function. + + @return Pointer to a buffer containing utf8 representation or NULL in + case of error. + + @note The returned buffer must be freed with @c free() call. +*/ + +char* wchar_to_utf8(const wchar_t *string, size_t *len) +{ + char *buf= NULL; + size_t str_len= len && *len ? *len : wcslen(string); + + /* + A conversion from utf8 to wchar_t will never take more than 3 bytes per + character, so a buffer of length 3 * str_len should be sufficient. + We check that assumption with an assertion later. + */ + + size_t buf_len= 3 * str_len; + + buf= (char*)malloc(buf_len + 1); + if (!buf) + { + DBUG_PRINT("error",("Out of memory when converting string '%S' to utf8", + string)); + return NULL; + } + + int res= WideCharToMultiByte(CP_UTF8, // convert to UTF-8 + 0, // conversion flags + string, // input buffer + (int)str_len, // its length + buf, (int)buf_len, // output buffer and its size + NULL, NULL); // default character (not used) + + if (res) + { + buf[res]= '\0'; + if (len) + *len= res; + return buf; + } + + // res is 0 which indicates error + +#ifndef DBUG_OFF + Error_message_buf error_buf; + DBUG_PRINT("error", ("Could not convert string '%S' to utf8" + ", WideCharToMultiByte() failed with error %X (%s)", + string, GetLastError(), + get_last_error_message(error_buf))); +#endif + + // Let's check our assumption about sufficient buffer size + DBUG_ASSERT(ERROR_INSUFFICIENT_BUFFER != GetLastError()); + + return NULL; +} + + +/** + Convert an utf8 string to a wide-char string. + + @param[in] string null-terminated utf8 string to be converted + @param[in,out] len length of the string to be converted or 0; on + return length (in chars) of the converted string + + If len is 0 then the length of the string will be computed by this function. + + @return Pointer to a buffer containing wide-char representation or NULL in + case of error. + + @note The returned buffer must be freed with @c free() call. +*/ + +wchar_t* utf8_to_wchar(const char *string, size_t *len) +{ + size_t buf_len; + + /* + Note: length (in bytes) of an utf8 string is always bigger than the + number of characters in this string. Hence a buffer of size len will + be sufficient. We add 1 for the terminating null character. + */ + + buf_len= len && *len ? *len : strlen(string); + wchar_t *buf= (wchar_t*)malloc((buf_len+1)*sizeof(wchar_t)); + + if (!buf) + { + DBUG_PRINT("error",("Out of memory when converting utf8 string '%s'" + " to wide-char representation", string)); + return NULL; + } + + size_t res; + res= MultiByteToWideChar(CP_UTF8, // convert from UTF-8 + 0, // conversion flags + string, // input buffer + (int)buf_len, // its size + buf, (int)buf_len); // output buffer and its size + if (res) + { + buf[res]= '\0'; + if (len) + *len= res; + return buf; + } + + // error in MultiByteToWideChar() + +#ifndef DBUG_OFF + Error_message_buf error_buf; + DBUG_PRINT("error", ("Could not convert UPN from UTF-8" + ", MultiByteToWideChar() failed with error %X (%s)", + GetLastError(), get_last_error_message(error_buf))); +#endif + + // Let's check our assumption about sufficient buffer size + DBUG_ASSERT(ERROR_INSUFFICIENT_BUFFER != GetLastError()); + + return NULL; +} + + +/** Error handling ****************************************************/ + + +/** + Returns error message corresponding to the last Windows error given + by GetLastError(). + + @note Error message is overwritten by next call to + @c get_last_error_message(). +*/ + +const char* get_last_error_message(Error_message_buf buf) +{ + int error= GetLastError(); + + buf[0]= '\0'; + FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM, + NULL, error, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPTSTR)buf, ERRMSG_BUFSIZE , NULL ); + + return buf; +} |