diff options
Diffstat (limited to 'lib/ansible/module_utils/csharp')
-rw-r--r-- | lib/ansible/module_utils/csharp/Ansible.AccessToken.cs | 460 | ||||
-rw-r--r-- | lib/ansible/module_utils/csharp/Ansible.Basic.cs | 1489 | ||||
-rw-r--r-- | lib/ansible/module_utils/csharp/Ansible.Become.cs | 655 | ||||
-rw-r--r-- | lib/ansible/module_utils/csharp/Ansible.Privilege.cs | 443 | ||||
-rw-r--r-- | lib/ansible/module_utils/csharp/Ansible.Process.cs | 461 | ||||
-rw-r--r-- | lib/ansible/module_utils/csharp/__init__.py | 0 |
6 files changed, 3508 insertions, 0 deletions
diff --git a/lib/ansible/module_utils/csharp/Ansible.AccessToken.cs b/lib/ansible/module_utils/csharp/Ansible.AccessToken.cs new file mode 100644 index 0000000..48c4a19 --- /dev/null +++ b/lib/ansible/module_utils/csharp/Ansible.AccessToken.cs @@ -0,0 +1,460 @@ +using Microsoft.Win32.SafeHandles; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; +using System.Security.Principal; +using System.Text; + +namespace Ansible.AccessToken +{ + internal class NativeHelpers + { + [StructLayout(LayoutKind.Sequential)] + public struct LUID_AND_ATTRIBUTES + { + public Luid Luid; + public UInt32 Attributes; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SID_AND_ATTRIBUTES + { + public IntPtr Sid; + public int Attributes; + } + + [StructLayout(LayoutKind.Sequential)] + public struct TOKEN_PRIVILEGES + { + public UInt32 PrivilegeCount; + [MarshalAs(UnmanagedType.ByValArray, SizeConst = 1)] + public LUID_AND_ATTRIBUTES[] Privileges; + } + + [StructLayout(LayoutKind.Sequential)] + public struct TOKEN_USER + { + public SID_AND_ATTRIBUTES User; + } + + public enum TokenInformationClass : uint + { + TokenUser = 1, + TokenPrivileges = 3, + TokenStatistics = 10, + TokenElevationType = 18, + TokenLinkedToken = 19, + } + } + + internal class NativeMethods + { + [DllImport("kernel32.dll", SetLastError = true)] + public static extern bool CloseHandle( + IntPtr hObject); + + [DllImport("advapi32.dll", SetLastError = true)] + public static extern bool DuplicateTokenEx( + SafeNativeHandle hExistingToken, + TokenAccessLevels dwDesiredAccess, + IntPtr lpTokenAttributes, + SecurityImpersonationLevel ImpersonationLevel, + TokenType TokenType, + out SafeNativeHandle phNewToken); + + [DllImport("kernel32.dll")] + public static extern SafeNativeHandle GetCurrentProcess(); + + [DllImport("advapi32.dll", SetLastError = true)] + public static extern bool GetTokenInformation( + SafeNativeHandle TokenHandle, + NativeHelpers.TokenInformationClass TokenInformationClass, + SafeMemoryBuffer TokenInformation, + UInt32 TokenInformationLength, + out UInt32 ReturnLength); + + [DllImport("advapi32.dll", SetLastError = true)] + public static extern bool ImpersonateLoggedOnUser( + SafeNativeHandle hToken); + + [DllImport("advapi32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + public static extern bool LogonUserW( + string lpszUsername, + string lpszDomain, + string lpszPassword, + LogonType dwLogonType, + LogonProvider dwLogonProvider, + out SafeNativeHandle phToken); + + [DllImport("advapi32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + public static extern bool LookupPrivilegeNameW( + string lpSystemName, + ref Luid lpLuid, + StringBuilder lpName, + ref UInt32 cchName); + + [DllImport("kernel32.dll", SetLastError = true)] + public static extern SafeNativeHandle OpenProcess( + ProcessAccessFlags dwDesiredAccess, + bool bInheritHandle, + UInt32 dwProcessId); + + [DllImport("advapi32.dll", SetLastError = true)] + public static extern bool OpenProcessToken( + SafeNativeHandle ProcessHandle, + TokenAccessLevels DesiredAccess, + out SafeNativeHandle TokenHandle); + + [DllImport("advapi32.dll", SetLastError = true)] + public static extern bool RevertToSelf(); + } + + internal class SafeMemoryBuffer : SafeHandleZeroOrMinusOneIsInvalid + { + public SafeMemoryBuffer() : base(true) { } + public SafeMemoryBuffer(int cb) : base(true) + { + base.SetHandle(Marshal.AllocHGlobal(cb)); + } + public SafeMemoryBuffer(IntPtr handle) : base(true) + { + base.SetHandle(handle); + } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + Marshal.FreeHGlobal(handle); + return true; + } + } + + public enum LogonProvider + { + Default, + WinNT35, + WinNT40, + WinNT50, + } + + public enum LogonType + { + Interactive = 2, + Network = 3, + Batch = 4, + Service = 5, + Unlock = 7, + NetworkCleartext = 8, + NewCredentials = 9, + } + + [Flags] + public enum PrivilegeAttributes : uint + { + Disabled = 0x00000000, + EnabledByDefault = 0x00000001, + Enabled = 0x00000002, + Removed = 0x00000004, + UsedForAccess = 0x80000000, + } + + [Flags] + public enum ProcessAccessFlags : uint + { + Terminate = 0x00000001, + CreateThread = 0x00000002, + VmOperation = 0x00000008, + VmRead = 0x00000010, + VmWrite = 0x00000020, + DupHandle = 0x00000040, + CreateProcess = 0x00000080, + SetQuota = 0x00000100, + SetInformation = 0x00000200, + QueryInformation = 0x00000400, + SuspendResume = 0x00000800, + QueryLimitedInformation = 0x00001000, + Delete = 0x00010000, + ReadControl = 0x00020000, + WriteDac = 0x00040000, + WriteOwner = 0x00080000, + Synchronize = 0x00100000, + } + + public enum SecurityImpersonationLevel + { + Anonymous, + Identification, + Impersonation, + Delegation, + } + + public enum TokenElevationType + { + Default = 1, + Full, + Limited, + } + + public enum TokenType + { + Primary = 1, + Impersonation, + } + + [StructLayout(LayoutKind.Sequential)] + public struct Luid + { + public UInt32 LowPart; + public Int32 HighPart; + + public static explicit operator UInt64(Luid l) + { + return (UInt64)((UInt64)l.HighPart << 32) | (UInt64)l.LowPart; + } + } + + [StructLayout(LayoutKind.Sequential)] + public struct TokenStatistics + { + public Luid TokenId; + public Luid AuthenticationId; + public Int64 ExpirationTime; + public TokenType TokenType; + public SecurityImpersonationLevel ImpersonationLevel; + public UInt32 DynamicCharged; + public UInt32 DynamicAvailable; + public UInt32 GroupCount; + public UInt32 PrivilegeCount; + public Luid ModifiedId; + } + + public class PrivilegeInfo + { + public string Name; + public PrivilegeAttributes Attributes; + + internal PrivilegeInfo(NativeHelpers.LUID_AND_ATTRIBUTES la) + { + Name = TokenUtil.GetPrivilegeName(la.Luid); + Attributes = (PrivilegeAttributes)la.Attributes; + } + } + + public class SafeNativeHandle : SafeHandleZeroOrMinusOneIsInvalid + { + public SafeNativeHandle() : base(true) { } + public SafeNativeHandle(IntPtr handle) : base(true) { this.handle = handle; } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + return NativeMethods.CloseHandle(handle); + } + } + + public class Win32Exception : System.ComponentModel.Win32Exception + { + private string _msg; + + public Win32Exception(string message) : this(Marshal.GetLastWin32Error(), message) { } + public Win32Exception(int errorCode, string message) : base(errorCode) + { + _msg = String.Format("{0} ({1}, Win32ErrorCode {2} - 0x{2:X8})", message, base.Message, errorCode); + } + + public override string Message { get { return _msg; } } + public static explicit operator Win32Exception(string message) { return new Win32Exception(message); } + } + + public class TokenUtil + { + public static SafeNativeHandle DuplicateToken(SafeNativeHandle hToken, TokenAccessLevels access, + SecurityImpersonationLevel impersonationLevel, TokenType tokenType) + { + SafeNativeHandle dupToken; + if (!NativeMethods.DuplicateTokenEx(hToken, access, IntPtr.Zero, impersonationLevel, tokenType, out dupToken)) + throw new Win32Exception("Failed to duplicate token"); + return dupToken; + } + + public static SecurityIdentifier GetTokenUser(SafeNativeHandle hToken) + { + using (SafeMemoryBuffer tokenInfo = GetTokenInformation(hToken, + NativeHelpers.TokenInformationClass.TokenUser)) + { + NativeHelpers.TOKEN_USER tokenUser = (NativeHelpers.TOKEN_USER)Marshal.PtrToStructure( + tokenInfo.DangerousGetHandle(), + typeof(NativeHelpers.TOKEN_USER)); + return new SecurityIdentifier(tokenUser.User.Sid); + } + } + + public static List<PrivilegeInfo> GetTokenPrivileges(SafeNativeHandle hToken) + { + using (SafeMemoryBuffer tokenInfo = GetTokenInformation(hToken, + NativeHelpers.TokenInformationClass.TokenPrivileges)) + { + NativeHelpers.TOKEN_PRIVILEGES tokenPrivs = (NativeHelpers.TOKEN_PRIVILEGES)Marshal.PtrToStructure( + tokenInfo.DangerousGetHandle(), + typeof(NativeHelpers.TOKEN_PRIVILEGES)); + + NativeHelpers.LUID_AND_ATTRIBUTES[] luidAttrs = + new NativeHelpers.LUID_AND_ATTRIBUTES[tokenPrivs.PrivilegeCount]; + PtrToStructureArray(luidAttrs, IntPtr.Add(tokenInfo.DangerousGetHandle(), + Marshal.SizeOf(tokenPrivs.PrivilegeCount))); + + return luidAttrs.Select(la => new PrivilegeInfo(la)).ToList(); + } + } + + public static TokenStatistics GetTokenStatistics(SafeNativeHandle hToken) + { + using (SafeMemoryBuffer tokenInfo = GetTokenInformation(hToken, + NativeHelpers.TokenInformationClass.TokenStatistics)) + { + TokenStatistics tokenStats = (TokenStatistics)Marshal.PtrToStructure( + tokenInfo.DangerousGetHandle(), + typeof(TokenStatistics)); + return tokenStats; + } + } + + public static TokenElevationType GetTokenElevationType(SafeNativeHandle hToken) + { + using (SafeMemoryBuffer tokenInfo = GetTokenInformation(hToken, + NativeHelpers.TokenInformationClass.TokenElevationType)) + { + return (TokenElevationType)Marshal.ReadInt32(tokenInfo.DangerousGetHandle()); + } + } + + public static SafeNativeHandle GetTokenLinkedToken(SafeNativeHandle hToken) + { + using (SafeMemoryBuffer tokenInfo = GetTokenInformation(hToken, + NativeHelpers.TokenInformationClass.TokenLinkedToken)) + { + return new SafeNativeHandle(Marshal.ReadIntPtr(tokenInfo.DangerousGetHandle())); + } + } + + public static IEnumerable<SafeNativeHandle> EnumerateUserTokens(SecurityIdentifier sid, + TokenAccessLevels access = TokenAccessLevels.Query) + { + foreach (System.Diagnostics.Process process in System.Diagnostics.Process.GetProcesses()) + { + // We always need the Query access level so we can query the TokenUser + using (process) + using (SafeNativeHandle hToken = TryOpenAccessToken(process, access | TokenAccessLevels.Query)) + { + if (hToken == null) + continue; + + if (!sid.Equals(GetTokenUser(hToken))) + continue; + + yield return hToken; + } + } + } + + public static void ImpersonateToken(SafeNativeHandle hToken) + { + if (!NativeMethods.ImpersonateLoggedOnUser(hToken)) + throw new Win32Exception("Failed to impersonate token"); + } + + public static SafeNativeHandle LogonUser(string username, string domain, string password, LogonType logonType, + LogonProvider logonProvider) + { + SafeNativeHandle hToken; + if (!NativeMethods.LogonUserW(username, domain, password, logonType, logonProvider, out hToken)) + throw new Win32Exception(String.Format("Failed to logon {0}", + String.IsNullOrEmpty(domain) ? username : domain + "\\" + username)); + + return hToken; + } + + public static SafeNativeHandle OpenProcess() + { + return NativeMethods.GetCurrentProcess(); + } + + public static SafeNativeHandle OpenProcess(Int32 pid, ProcessAccessFlags access, bool inherit) + { + SafeNativeHandle hProcess = NativeMethods.OpenProcess(access, inherit, (UInt32)pid); + if (hProcess.IsInvalid) + throw new Win32Exception(String.Format("Failed to open process {0} with access {1}", + pid, access.ToString())); + + return hProcess; + } + + public static SafeNativeHandle OpenProcessToken(SafeNativeHandle hProcess, TokenAccessLevels access) + { + SafeNativeHandle hToken; + if (!NativeMethods.OpenProcessToken(hProcess, access, out hToken)) + throw new Win32Exception(String.Format("Failed to open process token with access {0}", + access.ToString())); + + return hToken; + } + + public static void RevertToSelf() + { + if (!NativeMethods.RevertToSelf()) + throw new Win32Exception("Failed to revert thread impersonation"); + } + + internal static string GetPrivilegeName(Luid luid) + { + UInt32 nameLen = 0; + NativeMethods.LookupPrivilegeNameW(null, ref luid, null, ref nameLen); + + StringBuilder name = new StringBuilder((int)(nameLen + 1)); + if (!NativeMethods.LookupPrivilegeNameW(null, ref luid, name, ref nameLen)) + throw new Win32Exception("LookupPrivilegeName() failed"); + + return name.ToString(); + } + + private static SafeMemoryBuffer GetTokenInformation(SafeNativeHandle hToken, + NativeHelpers.TokenInformationClass infoClass) + { + UInt32 tokenLength; + bool res = NativeMethods.GetTokenInformation(hToken, infoClass, new SafeMemoryBuffer(IntPtr.Zero), 0, + out tokenLength); + int errCode = Marshal.GetLastWin32Error(); + if (!res && errCode != 24 && errCode != 122) // ERROR_INSUFFICIENT_BUFFER, ERROR_BAD_LENGTH + throw new Win32Exception(errCode, String.Format("GetTokenInformation({0}) failed to get buffer length", + infoClass.ToString())); + + SafeMemoryBuffer tokenInfo = new SafeMemoryBuffer((int)tokenLength); + if (!NativeMethods.GetTokenInformation(hToken, infoClass, tokenInfo, tokenLength, out tokenLength)) + throw new Win32Exception(String.Format("GetTokenInformation({0}) failed", infoClass.ToString())); + + return tokenInfo; + } + + private static void PtrToStructureArray<T>(T[] array, IntPtr ptr) + { + IntPtr ptrOffset = ptr; + for (int i = 0; i < array.Length; i++, ptrOffset = IntPtr.Add(ptrOffset, Marshal.SizeOf(typeof(T)))) + array[i] = (T)Marshal.PtrToStructure(ptrOffset, typeof(T)); + } + + private static SafeNativeHandle TryOpenAccessToken(System.Diagnostics.Process process, TokenAccessLevels access) + { + try + { + using (SafeNativeHandle hProcess = OpenProcess(process.Id, ProcessAccessFlags.QueryInformation, false)) + return OpenProcessToken(hProcess, access); + } + catch (Win32Exception) + { + return null; + } + } + } +} diff --git a/lib/ansible/module_utils/csharp/Ansible.Basic.cs b/lib/ansible/module_utils/csharp/Ansible.Basic.cs new file mode 100644 index 0000000..c68281e --- /dev/null +++ b/lib/ansible/module_utils/csharp/Ansible.Basic.cs @@ -0,0 +1,1489 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Management.Automation; +using System.Management.Automation.Runspaces; +using System.Reflection; +using System.Runtime.InteropServices; +using System.Security.AccessControl; +using System.Security.Principal; +#if CORECLR +using Newtonsoft.Json; +#else +using System.Web.Script.Serialization; +#endif + +// Newtonsoft.Json may reference a different System.Runtime version (6.x) than loaded by PowerShell 7.3 (7.x). +// Ignore CS1701 so the code can be compiled when warnings are reported as errors. +//NoWarn -Name CS1701 -CLR Core + +// System.Diagnostics.EventLog.dll reference different versioned dlls that are +// loaded in PSCore, ignore CS1702 so the code will ignore this warning +//NoWarn -Name CS1702 -CLR Core + +//AssemblyReference -Type Newtonsoft.Json.JsonConvert -CLR Core +//AssemblyReference -Type System.Diagnostics.EventLog -CLR Core +//AssemblyReference -Type System.Security.AccessControl.NativeObjectSecurity -CLR Core +//AssemblyReference -Type System.Security.AccessControl.DirectorySecurity -CLR Core +//AssemblyReference -Type System.Security.Principal.IdentityReference -CLR Core + +//AssemblyReference -Name System.Web.Extensions.dll -CLR Framework + +namespace Ansible.Basic +{ + public class AnsibleModule + { + public delegate void ExitHandler(int rc); + public static ExitHandler Exit = new ExitHandler(ExitModule); + + public delegate void WriteLineHandler(string line); + public static WriteLineHandler WriteLine = new WriteLineHandler(WriteLineModule); + + public static bool _DebugArgSpec = false; + + private static List<string> BOOLEANS_TRUE = new List<string>() { "y", "yes", "on", "1", "true", "t", "1.0" }; + private static List<string> BOOLEANS_FALSE = new List<string>() { "n", "no", "off", "0", "false", "f", "0.0" }; + + private string remoteTmp = Path.GetTempPath(); + private string tmpdir = null; + private HashSet<string> noLogValues = new HashSet<string>(); + private List<string> optionsContext = new List<string>(); + private List<string> warnings = new List<string>(); + private List<Dictionary<string, string>> deprecations = new List<Dictionary<string, string>>(); + private List<string> cleanupFiles = new List<string>(); + + private Dictionary<string, string> passVars = new Dictionary<string, string>() + { + // null values means no mapping, not used in Ansible.Basic.AnsibleModule + { "check_mode", "CheckMode" }, + { "debug", "DebugMode" }, + { "diff", "DiffMode" }, + { "keep_remote_files", "KeepRemoteFiles" }, + { "module_name", "ModuleName" }, + { "no_log", "NoLog" }, + { "remote_tmp", "remoteTmp" }, + { "selinux_special_fs", null }, + { "shell_executable", null }, + { "socket", null }, + { "string_conversion_action", null }, + { "syslog_facility", null }, + { "tmpdir", "tmpdir" }, + { "verbosity", "Verbosity" }, + { "version", "AnsibleVersion" }, + }; + private List<string> passBools = new List<string>() { "check_mode", "debug", "diff", "keep_remote_files", "no_log" }; + private List<string> passInts = new List<string>() { "verbosity" }; + private Dictionary<string, List<object>> specDefaults = new Dictionary<string, List<object>>() + { + // key - (default, type) - null is freeform + { "apply_defaults", new List<object>() { false, typeof(bool) } }, + { "aliases", new List<object>() { typeof(List<string>), typeof(List<string>) } }, + { "choices", new List<object>() { typeof(List<object>), typeof(List<object>) } }, + { "default", new List<object>() { null, null } }, + { "deprecated_aliases", new List<object>() { typeof(List<Hashtable>), typeof(List<Hashtable>) } }, + { "elements", new List<object>() { null, null } }, + { "mutually_exclusive", new List<object>() { typeof(List<List<string>>), typeof(List<object>) } }, + { "no_log", new List<object>() { false, typeof(bool) } }, + { "options", new List<object>() { typeof(Hashtable), typeof(Hashtable) } }, + { "removed_in_version", new List<object>() { null, typeof(string) } }, + { "removed_at_date", new List<object>() { null, typeof(DateTime) } }, + { "removed_from_collection", new List<object>() { null, typeof(string) } }, + { "required", new List<object>() { false, typeof(bool) } }, + { "required_by", new List<object>() { typeof(Hashtable), typeof(Hashtable) } }, + { "required_if", new List<object>() { typeof(List<List<object>>), typeof(List<object>) } }, + { "required_one_of", new List<object>() { typeof(List<List<string>>), typeof(List<object>) } }, + { "required_together", new List<object>() { typeof(List<List<string>>), typeof(List<object>) } }, + { "supports_check_mode", new List<object>() { false, typeof(bool) } }, + { "type", new List<object>() { "str", null } }, + }; + private Dictionary<string, Delegate> optionTypes = new Dictionary<string, Delegate>() + { + { "bool", new Func<object, bool>(ParseBool) }, + { "dict", new Func<object, Dictionary<string, object>>(ParseDict) }, + { "float", new Func<object, float>(ParseFloat) }, + { "int", new Func<object, int>(ParseInt) }, + { "json", new Func<object, string>(ParseJson) }, + { "list", new Func<object, List<object>>(ParseList) }, + { "path", new Func<object, string>(ParsePath) }, + { "raw", new Func<object, object>(ParseRaw) }, + { "sid", new Func<object, SecurityIdentifier>(ParseSid) }, + { "str", new Func<object, string>(ParseStr) }, + }; + + public Dictionary<string, object> Diff = new Dictionary<string, object>(); + public IDictionary Params = null; + public Dictionary<string, object> Result = new Dictionary<string, object>() { { "changed", false } }; + + public bool CheckMode { get; private set; } + public bool DebugMode { get; private set; } + public bool DiffMode { get; private set; } + public bool KeepRemoteFiles { get; private set; } + public string ModuleName { get; private set; } + public bool NoLog { get; private set; } + public int Verbosity { get; private set; } + public string AnsibleVersion { get; private set; } + + public string Tmpdir + { + get + { + if (tmpdir == null) + { +#if WINDOWS + SecurityIdentifier user = WindowsIdentity.GetCurrent().User; + DirectorySecurity dirSecurity = new DirectorySecurity(); + dirSecurity.SetOwner(user); + dirSecurity.SetAccessRuleProtection(true, false); // disable inheritance rules + FileSystemAccessRule ace = new FileSystemAccessRule(user, FileSystemRights.FullControl, + InheritanceFlags.ContainerInherit | InheritanceFlags.ObjectInherit, + PropagationFlags.None, AccessControlType.Allow); + dirSecurity.AddAccessRule(ace); + + string baseDir = Path.GetFullPath(Environment.ExpandEnvironmentVariables(remoteTmp)); + if (!Directory.Exists(baseDir)) + { + string failedMsg = null; + try + { +#if CORECLR + DirectoryInfo createdDir = Directory.CreateDirectory(baseDir); + FileSystemAclExtensions.SetAccessControl(createdDir, dirSecurity); +#else + Directory.CreateDirectory(baseDir, dirSecurity); +#endif + } + catch (Exception e) + { + failedMsg = String.Format("Failed to create base tmpdir '{0}': {1}", baseDir, e.Message); + } + + if (failedMsg != null) + { + string envTmp = Path.GetTempPath(); + Warn(String.Format("Unable to use '{0}' as temporary directory, falling back to system tmp '{1}': {2}", baseDir, envTmp, failedMsg)); + baseDir = envTmp; + } + else + { + NTAccount currentUser = (NTAccount)user.Translate(typeof(NTAccount)); + string warnMsg = String.Format("Module remote_tmp {0} did not exist and was created with FullControl to {1}, ", baseDir, currentUser.ToString()); + warnMsg += "this may cause issues when running as another user. To avoid this, create the remote_tmp dir with the correct permissions manually"; + Warn(warnMsg); + } + } + + string dateTime = DateTime.Now.ToFileTime().ToString(); + string dirName = String.Format("ansible-moduletmp-{0}-{1}", dateTime, new Random().Next(0, int.MaxValue)); + string newTmpdir = Path.Combine(baseDir, dirName); +#if CORECLR + DirectoryInfo tmpdirInfo = Directory.CreateDirectory(newTmpdir); + FileSystemAclExtensions.SetAccessControl(tmpdirInfo, dirSecurity); +#else + Directory.CreateDirectory(newTmpdir, dirSecurity); +#endif + tmpdir = newTmpdir; + + if (!KeepRemoteFiles) + cleanupFiles.Add(tmpdir); +#else + throw new NotImplementedException("Tmpdir is only supported on Windows"); +#endif + } + return tmpdir; + } + } + + public AnsibleModule(string[] args, IDictionary argumentSpec, IDictionary[] fragments = null) + { + // NoLog is not set yet, we cannot rely on FailJson to sanitize the output + // Do the minimum amount to get this running before we actually parse the params + Dictionary<string, string> aliases = new Dictionary<string, string>(); + try + { + ValidateArgumentSpec(argumentSpec); + + // Merge the fragments if present into the main arg spec. + if (fragments != null) + { + foreach (IDictionary fragment in fragments) + { + ValidateArgumentSpec(fragment); + MergeFragmentSpec(argumentSpec, fragment); + } + } + + // Used by ansible-test to retrieve the module argument spec, not designed for public use. + if (_DebugArgSpec) + { + // Cannot call exit here because it will be caught with the catch (Exception e) below. Instead + // just throw a new exception with a specific message and the exception block will handle it. + ScriptBlock.Create("Set-Variable -Name ansibleTestArgSpec -Value $args[0] -Scope Global" + ).Invoke(argumentSpec); + throw new Exception("ansible-test validate-modules check"); + } + + // Now make sure all the metadata keys are set to their defaults, this must be done after we've + // potentially output the arg spec for ansible-test. + SetArgumentSpecDefaults(argumentSpec); + + Params = GetParams(args); + aliases = GetAliases(argumentSpec, Params); + SetNoLogValues(argumentSpec, Params); + } + catch (Exception e) + { + if (e.Message == "ansible-test validate-modules check") + Exit(0); + + Dictionary<string, object> result = new Dictionary<string, object> + { + { "failed", true }, + { "msg", String.Format("internal error: {0}", e.Message) }, + { "exception", e.ToString() } + }; + WriteLine(ToJson(result)); + Exit(1); + } + + // Initialise public properties to the defaults before we parse the actual inputs + CheckMode = false; + DebugMode = false; + DiffMode = false; + KeepRemoteFiles = false; + ModuleName = "undefined win module"; + NoLog = (bool)argumentSpec["no_log"]; + Verbosity = 0; + AppDomain.CurrentDomain.ProcessExit += CleanupFiles; + + List<string> legalInputs = passVars.Keys.Select(v => "_ansible_" + v).ToList(); + legalInputs.AddRange(((IDictionary)argumentSpec["options"]).Keys.Cast<string>().ToList()); + legalInputs.AddRange(aliases.Keys.Cast<string>().ToList()); + CheckArguments(argumentSpec, Params, legalInputs); + + // Set a Ansible friendly invocation value in the result object + Dictionary<string, object> invocation = new Dictionary<string, object>() { { "module_args", Params } }; + Result["invocation"] = RemoveNoLogValues(invocation, noLogValues); + + if (!NoLog) + LogEvent(String.Format("Invoked with:\r\n {0}", FormatLogData(Params, 2)), sanitise: false); + } + + public static AnsibleModule Create(string[] args, IDictionary argumentSpec, IDictionary[] fragments = null) + { + return new AnsibleModule(args, argumentSpec, fragments); + } + + public void Debug(string message) + { + if (DebugMode) + LogEvent(String.Format("[DEBUG] {0}", message)); + } + + public void Deprecate(string message, string version) + { + Deprecate(message, version, null); + } + + public void Deprecate(string message, string version, string collectionName) + { + deprecations.Add(new Dictionary<string, string>() { + { "msg", message }, { "version", version }, { "collection_name", collectionName } }); + LogEvent(String.Format("[DEPRECATION WARNING] {0} {1}", message, version)); + } + + public void Deprecate(string message, DateTime date) + { + Deprecate(message, date, null); + } + + public void Deprecate(string message, DateTime date, string collectionName) + { + string isoDate = date.ToString("yyyy-MM-dd"); + deprecations.Add(new Dictionary<string, string>() { + { "msg", message }, { "date", isoDate }, { "collection_name", collectionName } }); + LogEvent(String.Format("[DEPRECATION WARNING] {0} {1}", message, isoDate)); + } + + public void ExitJson() + { + WriteLine(GetFormattedResults(Result)); + CleanupFiles(null, null); + Exit(0); + } + + public void FailJson(string message) { FailJson(message, null, null); } + public void FailJson(string message, ErrorRecord psErrorRecord) { FailJson(message, psErrorRecord, null); } + public void FailJson(string message, Exception exception) { FailJson(message, null, exception); } + private void FailJson(string message, ErrorRecord psErrorRecord, Exception exception) + { + Result["failed"] = true; + Result["msg"] = RemoveNoLogValues(message, noLogValues); + + + if (!Result.ContainsKey("exception") && (Verbosity > 2 || DebugMode)) + { + if (psErrorRecord != null) + { + string traceback = String.Format("{0}\r\n{1}", psErrorRecord.ToString(), psErrorRecord.InvocationInfo.PositionMessage); + traceback += String.Format("\r\n + CategoryInfo : {0}", psErrorRecord.CategoryInfo.ToString()); + traceback += String.Format("\r\n + FullyQualifiedErrorId : {0}", psErrorRecord.FullyQualifiedErrorId.ToString()); + traceback += String.Format("\r\n\r\nScriptStackTrace:\r\n{0}", psErrorRecord.ScriptStackTrace); + Result["exception"] = traceback; + } + else if (exception != null) + Result["exception"] = exception.ToString(); + } + + WriteLine(GetFormattedResults(Result)); + CleanupFiles(null, null); + Exit(1); + } + + public void LogEvent(string message, EventLogEntryType logEntryType = EventLogEntryType.Information, bool sanitise = true) + { + if (NoLog) + return; + +#if WINDOWS + string logSource = "Ansible"; + bool logSourceExists = false; + try + { + logSourceExists = EventLog.SourceExists(logSource); + } + catch (System.Security.SecurityException) { } // non admin users may not have permission + + if (!logSourceExists) + { + try + { + EventLog.CreateEventSource(logSource, "Application"); + } + catch (System.Security.SecurityException) + { + // Cannot call Warn as that calls LogEvent and we get stuck in a loop + warnings.Add(String.Format("Access error when creating EventLog source {0}, logging to the Application source instead", logSource)); + logSource = "Application"; + } + } + if (sanitise) + message = (string)RemoveNoLogValues(message, noLogValues); + message = String.Format("{0} - {1}", ModuleName, message); + + using (EventLog eventLog = new EventLog("Application")) + { + eventLog.Source = logSource; + try + { + eventLog.WriteEntry(message, logEntryType, 0); + } + catch (System.InvalidOperationException) { } // Ignore permission errors on the Application event log + catch (System.Exception e) + { + // Cannot call Warn as that calls LogEvent and we get stuck in a loop + warnings.Add(String.Format("Unknown error when creating event log entry: {0}", e.Message)); + } + } +#else + // Windows Event Log is only available on Windows + return; +#endif + } + + public void Warn(string message) + { + warnings.Add(message); + LogEvent(String.Format("[WARNING] {0}", message), EventLogEntryType.Warning); + } + + public static object FromJson(string json) { return FromJson<object>(json); } + public static T FromJson<T>(string json) + { +#if CORECLR + return JsonConvert.DeserializeObject<T>(json); +#else + JavaScriptSerializer jss = new JavaScriptSerializer(); + jss.MaxJsonLength = int.MaxValue; + jss.RecursionLimit = int.MaxValue; + return jss.Deserialize<T>(json); +#endif + } + + public static string ToJson(object obj) + { + // Using PowerShell to serialize the JSON is preferable over the native .NET libraries as it handles + // PS Objects a lot better than the alternatives. In case we are debugging in Visual Studio we have a + // fallback to the other libraries as we won't be dealing with PowerShell objects there. + if (Runspace.DefaultRunspace != null) + { + PSObject rawOut = ScriptBlock.Create("ConvertTo-Json -InputObject $args[0] -Depth 99 -Compress").Invoke(obj)[0]; + return rawOut.BaseObject as string; + } + else + { +#if CORECLR + return JsonConvert.SerializeObject(obj); +#else + JavaScriptSerializer jss = new JavaScriptSerializer(); + jss.MaxJsonLength = int.MaxValue; + jss.RecursionLimit = int.MaxValue; + return jss.Serialize(obj); +#endif + } + } + + public static IDictionary GetParams(string[] args) + { + if (args.Length > 0) + { + string inputJson = File.ReadAllText(args[0]); + Dictionary<string, object> rawParams = FromJson<Dictionary<string, object>>(inputJson); + if (!rawParams.ContainsKey("ANSIBLE_MODULE_ARGS")) + throw new ArgumentException("Module was unable to get ANSIBLE_MODULE_ARGS value from the argument path json"); + return (IDictionary)rawParams["ANSIBLE_MODULE_ARGS"]; + } + else + { + // $complex_args is already a Hashtable, no need to waste time converting to a dictionary + PSObject rawArgs = ScriptBlock.Create("$complex_args").Invoke()[0]; + return rawArgs.BaseObject as Hashtable; + } + } + + public static bool ParseBool(object value) + { + if (value.GetType() == typeof(bool)) + return (bool)value; + + List<string> booleans = new List<string>(); + booleans.AddRange(BOOLEANS_TRUE); + booleans.AddRange(BOOLEANS_FALSE); + + string stringValue = ParseStr(value).ToLowerInvariant().Trim(); + if (BOOLEANS_TRUE.Contains(stringValue)) + return true; + else if (BOOLEANS_FALSE.Contains(stringValue)) + return false; + + string msg = String.Format("The value '{0}' is not a valid boolean. Valid booleans include: {1}", + stringValue, String.Join(", ", booleans)); + throw new ArgumentException(msg); + } + + public static Dictionary<string, object> ParseDict(object value) + { + Type valueType = value.GetType(); + if (valueType == typeof(Dictionary<string, object>)) + return (Dictionary<string, object>)value; + else if (value is IDictionary) + return ((IDictionary)value).Cast<DictionaryEntry>().ToDictionary(kvp => (string)kvp.Key, kvp => kvp.Value); + else if (valueType == typeof(string)) + { + string stringValue = (string)value; + if (stringValue.StartsWith("{") && stringValue.EndsWith("}")) + return FromJson<Dictionary<string, object>>((string)value); + else if (stringValue.IndexOfAny(new char[1] { '=' }) != -1) + { + List<string> fields = new List<string>(); + List<char> fieldBuffer = new List<char>(); + char? inQuote = null; + bool inEscape = false; + string field; + + foreach (char c in stringValue.ToCharArray()) + { + if (inEscape) + { + fieldBuffer.Add(c); + inEscape = false; + } + else if (c == '\\') + inEscape = true; + else if (inQuote == null && (c == '\'' || c == '"')) + inQuote = c; + else if (inQuote != null && c == inQuote) + inQuote = null; + else if (inQuote == null && (c == ',' || c == ' ')) + { + field = String.Join("", fieldBuffer); + if (field != "") + fields.Add(field); + fieldBuffer = new List<char>(); + } + else + fieldBuffer.Add(c); + } + + field = String.Join("", fieldBuffer); + if (field != "") + fields.Add(field); + + return fields.Distinct().Select(i => i.Split(new[] { '=' }, 2)).ToDictionary(i => i[0], i => i.Length > 1 ? (object)i[1] : null); + } + else + throw new ArgumentException("string cannot be converted to a dict, must either be a JSON string or in the key=value form"); + } + + throw new ArgumentException(String.Format("{0} cannot be converted to a dict", valueType.FullName)); + } + + public static float ParseFloat(object value) + { + if (value.GetType() == typeof(float)) + return (float)value; + + string valueStr = ParseStr(value); + return float.Parse(valueStr); + } + + public static int ParseInt(object value) + { + Type valueType = value.GetType(); + if (valueType == typeof(int)) + return (int)value; + else + return Int32.Parse(ParseStr(value)); + } + + public static string ParseJson(object value) + { + // mostly used to ensure a dict is a json string as it may + // have been converted on the controller side + Type valueType = value.GetType(); + if (value is IDictionary) + return ToJson(value); + else if (valueType == typeof(string)) + return (string)value; + else + throw new ArgumentException(String.Format("{0} cannot be converted to json", valueType.FullName)); + } + + public static List<object> ParseList(object value) + { + if (value == null) + return null; + + Type valueType = value.GetType(); + if (valueType.IsGenericType && valueType.GetGenericTypeDefinition() == typeof(List<>)) + return (List<object>)value; + else if (valueType == typeof(ArrayList)) + return ((ArrayList)value).Cast<object>().ToList(); + else if (valueType.IsArray) + return ((object[])value).ToList(); + else if (valueType == typeof(string)) + return ((string)value).Split(',').Select(s => s.Trim()).ToList<object>(); + else if (valueType == typeof(int)) + return new List<object>() { value }; + else + throw new ArgumentException(String.Format("{0} cannot be converted to a list", valueType.FullName)); + } + + public static string ParsePath(object value) + { + string stringValue = ParseStr(value); + + // do not validate, expand the env vars if it starts with \\?\ as + // it is a special path designed for the NT kernel to interpret + if (stringValue.StartsWith(@"\\?\")) + return stringValue; + + stringValue = Environment.ExpandEnvironmentVariables(stringValue); + if (stringValue.IndexOfAny(Path.GetInvalidPathChars()) != -1) + throw new ArgumentException("string value contains invalid path characters, cannot convert to path"); + + // will fire an exception if it contains any invalid chars + Path.GetFullPath(stringValue); + return stringValue; + } + + public static object ParseRaw(object value) { return value; } + + public static SecurityIdentifier ParseSid(object value) + { + string stringValue = ParseStr(value); + + try + { + return new SecurityIdentifier(stringValue); + } + catch (ArgumentException) { } // ignore failures string may not have been a SID + + NTAccount account = new NTAccount(stringValue); + return (SecurityIdentifier)account.Translate(typeof(SecurityIdentifier)); + } + + public static string ParseStr(object value) { return value.ToString(); } + + private void ValidateArgumentSpec(IDictionary argumentSpec) + { + Dictionary<string, object> changedValues = new Dictionary<string, object>(); + foreach (DictionaryEntry entry in argumentSpec) + { + string key = (string)entry.Key; + + // validate the key is a valid argument spec key + if (!specDefaults.ContainsKey(key)) + { + string msg = String.Format("argument spec entry contains an invalid key '{0}', valid keys: {1}", + key, String.Join(", ", specDefaults.Keys)); + throw new ArgumentException(FormatOptionsContext(msg, " - ")); + } + + // ensure the value is casted to the type we expect + Type optionType = null; + if (entry.Value != null) + optionType = (Type)specDefaults[key][1]; + if (optionType != null) + { + Type actualType = entry.Value.GetType(); + bool invalid = false; + if (optionType.IsGenericType && optionType.GetGenericTypeDefinition() == typeof(List<>)) + { + // verify the actual type is not just a single value of the list type + Type entryType = optionType.GetGenericArguments()[0]; + object[] arrayElementTypes = new object[] + { + null, // ArrayList does not have an ElementType + entryType, + typeof(object), // Hope the object is actually entryType or it can at least be casted. + }; + + bool isArray = entry.Value is IList && arrayElementTypes.Contains(actualType.GetElementType()); + if (actualType == entryType || isArray) + { + object rawArray; + if (isArray) + rawArray = entry.Value; + else + rawArray = new object[1] { entry.Value }; + + MethodInfo castMethod = typeof(Enumerable).GetMethod("Cast").MakeGenericMethod(entryType); + MethodInfo toListMethod = typeof(Enumerable).GetMethod("ToList").MakeGenericMethod(entryType); + + var enumerable = castMethod.Invoke(null, new object[1] { rawArray }); + var newList = toListMethod.Invoke(null, new object[1] { enumerable }); + changedValues.Add(key, newList); + } + else if (actualType != optionType && !(actualType == typeof(List<object>))) + invalid = true; + } + else + invalid = actualType != optionType; + + if (invalid) + { + string msg = String.Format("argument spec for '{0}' did not match expected type {1}: actual type {2}", + key, optionType.FullName, actualType.FullName); + throw new ArgumentException(FormatOptionsContext(msg, " - ")); + } + } + + // recursively validate the spec + if (key == "options" && entry.Value != null) + { + IDictionary optionsSpec = (IDictionary)entry.Value; + foreach (DictionaryEntry optionEntry in optionsSpec) + { + optionsContext.Add((string)optionEntry.Key); + IDictionary optionMeta = (IDictionary)optionEntry.Value; + ValidateArgumentSpec(optionMeta); + optionsContext.RemoveAt(optionsContext.Count - 1); + } + } + + // validate the type and elements key type values are known types + if (key == "type" || key == "elements" && entry.Value != null) + { + Type valueType = entry.Value.GetType(); + if (valueType == typeof(string)) + { + string typeValue = (string)entry.Value; + if (!optionTypes.ContainsKey(typeValue)) + { + string msg = String.Format("{0} '{1}' is unsupported", key, typeValue); + msg = String.Format("{0}. Valid types are: {1}", FormatOptionsContext(msg, " - "), String.Join(", ", optionTypes.Keys)); + throw new ArgumentException(msg); + } + } + else if (!(entry.Value is Delegate)) + { + string msg = String.Format("{0} must either be a string or delegate, was: {1}", key, valueType.FullName); + throw new ArgumentException(FormatOptionsContext(msg, " - ")); + } + } + } + + // Outside of the spec iterator, change the values that were casted above + foreach (KeyValuePair<string, object> changedValue in changedValues) + argumentSpec[changedValue.Key] = changedValue.Value; + } + + private void MergeFragmentSpec(IDictionary argumentSpec, IDictionary fragment) + { + foreach (DictionaryEntry fragmentEntry in fragment) + { + string fragmentKey = fragmentEntry.Key.ToString(); + + if (argumentSpec.Contains(fragmentKey)) + { + // We only want to add new list entries and merge dictionary new keys and values. Leave the other + // values as is in the argument spec as that takes priority over the fragment. + if (fragmentEntry.Value is IDictionary) + { + MergeFragmentSpec((IDictionary)argumentSpec[fragmentKey], (IDictionary)fragmentEntry.Value); + } + else if (fragmentEntry.Value is IList) + { + IList specValue = (IList)argumentSpec[fragmentKey]; + foreach (object fragmentValue in (IList)fragmentEntry.Value) + specValue.Add(fragmentValue); + } + } + else + argumentSpec[fragmentKey] = fragmentEntry.Value; + } + } + + private void SetArgumentSpecDefaults(IDictionary argumentSpec) + { + foreach (KeyValuePair<string, List<object>> metadataEntry in specDefaults) + { + List<object> defaults = metadataEntry.Value; + object defaultValue = defaults[0]; + if (defaultValue != null && defaultValue.GetType() == typeof(Type).GetType()) + defaultValue = Activator.CreateInstance((Type)defaultValue); + + if (!argumentSpec.Contains(metadataEntry.Key)) + argumentSpec[metadataEntry.Key] = defaultValue; + } + + // Recursively set the defaults for any inner options. + foreach (DictionaryEntry entry in argumentSpec) + { + if (entry.Value == null || entry.Key.ToString() != "options") + continue; + + IDictionary optionsSpec = (IDictionary)entry.Value; + foreach (DictionaryEntry optionEntry in optionsSpec) + { + optionsContext.Add((string)optionEntry.Key); + IDictionary optionMeta = (IDictionary)optionEntry.Value; + SetArgumentSpecDefaults(optionMeta); + optionsContext.RemoveAt(optionsContext.Count - 1); + } + } + } + + private Dictionary<string, string> GetAliases(IDictionary argumentSpec, IDictionary parameters) + { + Dictionary<string, string> aliasResults = new Dictionary<string, string>(); + + foreach (DictionaryEntry entry in (IDictionary)argumentSpec["options"]) + { + string k = (string)entry.Key; + Hashtable v = (Hashtable)entry.Value; + + List<string> aliases = (List<string>)v["aliases"]; + object defaultValue = v["default"]; + bool required = (bool)v["required"]; + + if (defaultValue != null && required) + throw new ArgumentException(String.Format("required and default are mutually exclusive for {0}", k)); + + foreach (string alias in aliases) + { + aliasResults.Add(alias, k); + if (parameters.Contains(alias)) + parameters[k] = parameters[alias]; + } + + List<Hashtable> deprecatedAliases = (List<Hashtable>)v["deprecated_aliases"]; + foreach (Hashtable depInfo in deprecatedAliases) + { + foreach (string keyName in new List<string> { "name" }) + { + if (!depInfo.ContainsKey(keyName)) + { + string msg = String.Format("{0} is required in a deprecated_aliases entry", keyName); + throw new ArgumentException(FormatOptionsContext(msg, " - ")); + } + } + if (!depInfo.ContainsKey("version") && !depInfo.ContainsKey("date")) + { + string msg = "One of version or date is required in a deprecated_aliases entry"; + throw new ArgumentException(FormatOptionsContext(msg, " - ")); + } + if (depInfo.ContainsKey("version") && depInfo.ContainsKey("date")) + { + string msg = "Only one of version or date is allowed in a deprecated_aliases entry"; + throw new ArgumentException(FormatOptionsContext(msg, " - ")); + } + if (depInfo.ContainsKey("date") && depInfo["date"].GetType() != typeof(DateTime)) + { + string msg = "A deprecated_aliases date must be a DateTime object"; + throw new ArgumentException(FormatOptionsContext(msg, " - ")); + } + string collectionName = null; + if (depInfo.ContainsKey("collection_name")) + { + collectionName = (string)depInfo["collection_name"]; + } + string aliasName = (string)depInfo["name"]; + + if (parameters.Contains(aliasName)) + { + string msg = String.Format("Alias '{0}' is deprecated. See the module docs for more information", aliasName); + if (depInfo.ContainsKey("version")) + { + string depVersion = (string)depInfo["version"]; + Deprecate(FormatOptionsContext(msg, " - "), depVersion, collectionName); + } + if (depInfo.ContainsKey("date")) + { + DateTime depDate = (DateTime)depInfo["date"]; + Deprecate(FormatOptionsContext(msg, " - "), depDate, collectionName); + } + } + } + } + + return aliasResults; + } + + private void SetNoLogValues(IDictionary argumentSpec, IDictionary parameters) + { + foreach (DictionaryEntry entry in (IDictionary)argumentSpec["options"]) + { + string k = (string)entry.Key; + Hashtable v = (Hashtable)entry.Value; + + if ((bool)v["no_log"]) + { + object noLogObject = parameters.Contains(k) ? parameters[k] : null; + string noLogString = noLogObject == null ? "" : noLogObject.ToString(); + if (!String.IsNullOrEmpty(noLogString)) + noLogValues.Add(noLogString); + } + string collectionName = null; + if (v.ContainsKey("removed_from_collection")) + { + collectionName = (string)v["removed_from_collection"]; + } + + object removedInVersion = v["removed_in_version"]; + if (removedInVersion != null && parameters.Contains(k)) + Deprecate(String.Format("Param '{0}' is deprecated. See the module docs for more information", k), + removedInVersion.ToString(), collectionName); + + object removedAtDate = v["removed_at_date"]; + if (removedAtDate != null && parameters.Contains(k)) + Deprecate(String.Format("Param '{0}' is deprecated. See the module docs for more information", k), + (DateTime)removedAtDate, collectionName); + } + } + + private void CheckArguments(IDictionary spec, IDictionary param, List<string> legalInputs) + { + // initially parse the params and check for unsupported ones and set internal vars + CheckUnsupportedArguments(param, legalInputs); + + // Only run this check if we are at the root argument (optionsContext.Count == 0) + if (CheckMode && !(bool)spec["supports_check_mode"] && optionsContext.Count == 0) + { + Result["skipped"] = true; + Result["msg"] = String.Format("remote module ({0}) does not support check mode", ModuleName); + ExitJson(); + } + IDictionary optionSpec = (IDictionary)spec["options"]; + + CheckMutuallyExclusive(param, (IList)spec["mutually_exclusive"]); + CheckRequiredArguments(optionSpec, param); + + // set the parameter types based on the type spec value + foreach (DictionaryEntry entry in optionSpec) + { + string k = (string)entry.Key; + Hashtable v = (Hashtable)entry.Value; + + object value = param.Contains(k) ? param[k] : null; + if (value != null) + { + // convert the current value to the wanted type + Delegate typeConverter; + string type; + if (v["type"].GetType() == typeof(string)) + { + type = (string)v["type"]; + typeConverter = optionTypes[type]; + } + else + { + type = "delegate"; + typeConverter = (Delegate)v["type"]; + } + + try + { + value = typeConverter.DynamicInvoke(value); + param[k] = value; + } + catch (Exception e) + { + string msg = String.Format("argument for {0} is of type {1} and we were unable to convert to {2}: {3}", + k, value.GetType(), type, e.InnerException.Message); + FailJson(FormatOptionsContext(msg)); + } + + // ensure it matches the choices if there are choices set + List<string> choices = ((List<object>)v["choices"]).Select(x => x.ToString()).Cast<string>().ToList(); + if (choices.Count > 0) + { + List<string> values; + string choiceMsg; + if (type == "list") + { + values = ((List<object>)value).Select(x => x.ToString()).Cast<string>().ToList(); + choiceMsg = "one or more of"; + } + else + { + values = new List<string>() { value.ToString() }; + choiceMsg = "one of"; + } + + List<string> diffList = values.Except(choices, StringComparer.OrdinalIgnoreCase).ToList(); + List<string> caseDiffList = values.Except(choices).ToList(); + if (diffList.Count > 0) + { + string msg = String.Format("value of {0} must be {1}: {2}. Got no match for: {3}", + k, choiceMsg, String.Join(", ", choices), String.Join(", ", diffList)); + FailJson(FormatOptionsContext(msg)); + } + /* + For now we will just silently accept case insensitive choices, uncomment this if we want to add it back in + else if (caseDiffList.Count > 0) + { + // For backwards compatibility with Legacy.psm1 we need to be matching choices that are not case sensitive. + // We will warn the user it was case insensitive and tell them this will become case sensitive in the future. + string msg = String.Format( + "value of {0} was a case insensitive match of {1}: {2}. Checking of choices will be case sensitive in a future Ansible release. Case insensitive matches were: {3}", + k, choiceMsg, String.Join(", ", choices), String.Join(", ", caseDiffList.Select(x => RemoveNoLogValues(x, noLogValues))) + ); + Warn(FormatOptionsContext(msg)); + }*/ + } + } + } + + CheckRequiredTogether(param, (IList)spec["required_together"]); + CheckRequiredOneOf(param, (IList)spec["required_one_of"]); + CheckRequiredIf(param, (IList)spec["required_if"]); + CheckRequiredBy(param, (IDictionary)spec["required_by"]); + + // finally ensure all missing parameters are set to null and handle sub options + foreach (DictionaryEntry entry in optionSpec) + { + string k = (string)entry.Key; + IDictionary v = (IDictionary)entry.Value; + + if (!param.Contains(k)) + param[k] = null; + + CheckSubOption(param, k, v); + } + } + + private void CheckUnsupportedArguments(IDictionary param, List<string> legalInputs) + { + HashSet<string> unsupportedParameters = new HashSet<string>(); + HashSet<string> caseUnsupportedParameters = new HashSet<string>(); + List<string> removedParameters = new List<string>(); + + foreach (DictionaryEntry entry in param) + { + string paramKey = (string)entry.Key; + if (!legalInputs.Contains(paramKey, StringComparer.OrdinalIgnoreCase)) + unsupportedParameters.Add(paramKey); + else if (!legalInputs.Contains(paramKey)) + // For backwards compatibility we do not care about the case but we need to warn the users as this will + // change in a future Ansible release. + caseUnsupportedParameters.Add(paramKey); + else if (paramKey.StartsWith("_ansible_")) + { + removedParameters.Add(paramKey); + string key = paramKey.Replace("_ansible_", ""); + // skip setting NoLog if NoLog is already set to true (set by the module) + // or there's no mapping for this key + if ((key == "no_log" && NoLog == true) || (passVars[key] == null)) + continue; + + object value = entry.Value; + if (passBools.Contains(key)) + value = ParseBool(value); + else if (passInts.Contains(key)) + value = ParseInt(value); + + string propertyName = passVars[key]; + PropertyInfo property = typeof(AnsibleModule).GetProperty(propertyName); + FieldInfo field = typeof(AnsibleModule).GetField(propertyName, BindingFlags.NonPublic | BindingFlags.Instance); + if (property != null) + property.SetValue(this, value, null); + else if (field != null) + field.SetValue(this, value); + else + FailJson(String.Format("implementation error: unknown AnsibleModule property {0}", propertyName)); + } + } + foreach (string parameter in removedParameters) + param.Remove(parameter); + + if (unsupportedParameters.Count > 0) + { + legalInputs.RemoveAll(x => passVars.Keys.Contains(x.Replace("_ansible_", ""))); + string msg = String.Format("Unsupported parameters for ({0}) module: {1}", ModuleName, String.Join(", ", unsupportedParameters)); + msg = String.Format("{0}. Supported parameters include: {1}", FormatOptionsContext(msg), String.Join(", ", legalInputs)); + FailJson(msg); + } + + /* + // Uncomment when we want to start warning users around options that are not a case sensitive match to the spec + if (caseUnsupportedParameters.Count > 0) + { + legalInputs.RemoveAll(x => passVars.Keys.Contains(x.Replace("_ansible_", ""))); + string msg = String.Format("Parameters for ({0}) was a case insensitive match: {1}", ModuleName, String.Join(", ", caseUnsupportedParameters)); + msg = String.Format("{0}. Module options will become case sensitive in a future Ansible release. Supported parameters include: {1}", + FormatOptionsContext(msg), String.Join(", ", legalInputs)); + Warn(msg); + }*/ + + // Make sure we convert all the incorrect case params to the ones set by the module spec + foreach (string key in caseUnsupportedParameters) + { + string correctKey = legalInputs[legalInputs.FindIndex(s => s.Equals(key, StringComparison.OrdinalIgnoreCase))]; + object value = param[key]; + param.Remove(key); + param.Add(correctKey, value); + } + } + + private void CheckMutuallyExclusive(IDictionary param, IList mutuallyExclusive) + { + if (mutuallyExclusive == null) + return; + + foreach (object check in mutuallyExclusive) + { + List<string> mutualCheck = ((IList)check).Cast<string>().ToList(); + int count = 0; + foreach (string entry in mutualCheck) + if (param.Contains(entry)) + count++; + + if (count > 1) + { + string msg = String.Format("parameters are mutually exclusive: {0}", String.Join(", ", mutualCheck)); + FailJson(FormatOptionsContext(msg)); + } + } + } + + private void CheckRequiredArguments(IDictionary spec, IDictionary param) + { + List<string> missing = new List<string>(); + foreach (DictionaryEntry entry in spec) + { + string k = (string)entry.Key; + Hashtable v = (Hashtable)entry.Value; + + // set defaults for values not already set + object defaultValue = v["default"]; + if (defaultValue != null && !param.Contains(k)) + param[k] = defaultValue; + + // check required arguments + bool required = (bool)v["required"]; + if (required && !param.Contains(k)) + missing.Add(k); + } + if (missing.Count > 0) + { + string msg = String.Format("missing required arguments: {0}", String.Join(", ", missing)); + FailJson(FormatOptionsContext(msg)); + } + } + + private void CheckRequiredTogether(IDictionary param, IList requiredTogether) + { + if (requiredTogether == null) + return; + + foreach (object check in requiredTogether) + { + List<string> requiredCheck = ((IList)check).Cast<string>().ToList(); + List<bool> found = new List<bool>(); + foreach (string field in requiredCheck) + if (param.Contains(field)) + found.Add(true); + else + found.Add(false); + + if (found.Contains(true) && found.Contains(false)) + { + string msg = String.Format("parameters are required together: {0}", String.Join(", ", requiredCheck)); + FailJson(FormatOptionsContext(msg)); + } + } + } + + private void CheckRequiredOneOf(IDictionary param, IList requiredOneOf) + { + if (requiredOneOf == null) + return; + + foreach (object check in requiredOneOf) + { + List<string> requiredCheck = ((IList)check).Cast<string>().ToList(); + int count = 0; + foreach (string field in requiredCheck) + if (param.Contains(field)) + count++; + + if (count == 0) + { + string msg = String.Format("one of the following is required: {0}", String.Join(", ", requiredCheck)); + FailJson(FormatOptionsContext(msg)); + } + } + } + + private void CheckRequiredIf(IDictionary param, IList requiredIf) + { + if (requiredIf == null) + return; + + foreach (object check in requiredIf) + { + IList requiredCheck = (IList)check; + List<string> missing = new List<string>(); + List<string> missingFields = new List<string>(); + int maxMissingCount = 1; + bool oneRequired = false; + + if (requiredCheck.Count < 3 && requiredCheck.Count < 4) + FailJson(String.Format("internal error: invalid required_if value count of {0}, expecting 3 or 4 entries", requiredCheck.Count)); + else if (requiredCheck.Count == 4) + oneRequired = (bool)requiredCheck[3]; + + string key = (string)requiredCheck[0]; + object val = requiredCheck[1]; + IList requirements = (IList)requiredCheck[2]; + + if (ParseStr(param[key]) != ParseStr(val)) + continue; + + string term = "all"; + if (oneRequired) + { + maxMissingCount = requirements.Count; + term = "any"; + } + + foreach (string required in requirements.Cast<string>()) + if (!param.Contains(required)) + missing.Add(required); + + if (missing.Count >= maxMissingCount) + { + string msg = String.Format("{0} is {1} but {2} of the following are missing: {3}", + key, val.ToString(), term, String.Join(", ", missing)); + FailJson(FormatOptionsContext(msg)); + } + } + } + + private void CheckRequiredBy(IDictionary param, IDictionary requiredBy) + { + foreach (DictionaryEntry entry in requiredBy) + { + string key = (string)entry.Key; + if (!param.Contains(key)) + continue; + + List<string> missing = new List<string>(); + List<string> requires = ParseList(entry.Value).Cast<string>().ToList(); + foreach (string required in requires) + if (!param.Contains(required)) + missing.Add(required); + + if (missing.Count > 0) + { + string msg = String.Format("missing parameter(s) required by '{0}': {1}", key, String.Join(", ", missing)); + FailJson(FormatOptionsContext(msg)); + } + } + } + + private void CheckSubOption(IDictionary param, string key, IDictionary spec) + { + object value = param[key]; + + string type; + if (spec["type"].GetType() == typeof(string)) + type = (string)spec["type"]; + else + type = "delegate"; + + string elements = null; + Delegate typeConverter = null; + if (spec["elements"] != null && spec["elements"].GetType() == typeof(string)) + { + elements = (string)spec["elements"]; + typeConverter = optionTypes[elements]; + } + else if (spec["elements"] != null) + { + elements = "delegate"; + typeConverter = (Delegate)spec["elements"]; + } + + if (!(type == "dict" || (type == "list" && elements != null))) + // either not a dict, or list with the elements set, so continue + return; + else if (type == "list") + { + // cast each list element to the type specified + if (value == null) + return; + + List<object> newValue = new List<object>(); + foreach (object element in (List<object>)value) + { + if (elements == "dict") + newValue.Add(ParseSubSpec(spec, element, key)); + else + { + try + { + object newElement = typeConverter.DynamicInvoke(element); + newValue.Add(newElement); + } + catch (Exception e) + { + string msg = String.Format("argument for list entry {0} is of type {1} and we were unable to convert to {2}: {3}", + key, element.GetType(), elements, e.Message); + FailJson(FormatOptionsContext(msg)); + } + } + } + + param[key] = newValue; + } + else + param[key] = ParseSubSpec(spec, value, key); + } + + private object ParseSubSpec(IDictionary spec, object value, string context) + { + bool applyDefaults = (bool)spec["apply_defaults"]; + + // set entry to an empty dict if apply_defaults is set + IDictionary optionsSpec = (IDictionary)spec["options"]; + if (applyDefaults && optionsSpec.Keys.Count > 0 && value == null) + value = new Dictionary<string, object>(); + else if (optionsSpec.Keys.Count == 0 || value == null) + return value; + + optionsContext.Add(context); + Dictionary<string, object> newValue = (Dictionary<string, object>)ParseDict(value); + Dictionary<string, string> aliases = GetAliases(spec, newValue); + SetNoLogValues(spec, newValue); + + List<string> subLegalInputs = optionsSpec.Keys.Cast<string>().ToList(); + subLegalInputs.AddRange(aliases.Keys.Cast<string>().ToList()); + + CheckArguments(spec, newValue, subLegalInputs); + optionsContext.RemoveAt(optionsContext.Count - 1); + return newValue; + } + + private string GetFormattedResults(Dictionary<string, object> result) + { + if (!result.ContainsKey("invocation")) + result["invocation"] = new Dictionary<string, object>() { { "module_args", RemoveNoLogValues(Params, noLogValues) } }; + + if (warnings.Count > 0) + result["warnings"] = warnings; + + if (deprecations.Count > 0) + result["deprecations"] = deprecations; + + if (Diff.Count > 0 && DiffMode) + result["diff"] = Diff; + + return ToJson(result); + } + + private string FormatLogData(object data, int indentLevel) + { + if (data == null) + return "$null"; + + string msg = ""; + if (data is IList) + { + string newMsg = ""; + foreach (object value in (IList)data) + { + string entryValue = FormatLogData(value, indentLevel + 2); + newMsg += String.Format("\r\n{0}- {1}", new String(' ', indentLevel), entryValue); + } + msg += newMsg; + } + else if (data is IDictionary) + { + bool start = true; + foreach (DictionaryEntry entry in (IDictionary)data) + { + string newMsg = FormatLogData(entry.Value, indentLevel + 2); + if (!start) + msg += String.Format("\r\n{0}", new String(' ', indentLevel)); + msg += String.Format("{0}: {1}", (string)entry.Key, newMsg); + start = false; + } + } + else + msg = (string)RemoveNoLogValues(ParseStr(data), noLogValues); + + return msg; + } + + private object RemoveNoLogValues(object value, HashSet<string> noLogStrings) + { + Queue<Tuple<object, object>> deferredRemovals = new Queue<Tuple<object, object>>(); + object newValue = RemoveValueConditions(value, noLogStrings, deferredRemovals); + + while (deferredRemovals.Count > 0) + { + Tuple<object, object> data = deferredRemovals.Dequeue(); + object oldData = data.Item1; + object newData = data.Item2; + + if (oldData is IDictionary) + { + foreach (DictionaryEntry entry in (IDictionary)oldData) + { + object newElement = RemoveValueConditions(entry.Value, noLogStrings, deferredRemovals); + ((IDictionary)newData).Add((string)entry.Key, newElement); + } + } + else + { + foreach (object element in (IList)oldData) + { + object newElement = RemoveValueConditions(element, noLogStrings, deferredRemovals); + ((IList)newData).Add(newElement); + } + } + } + + return newValue; + } + + private object RemoveValueConditions(object value, HashSet<string> noLogStrings, Queue<Tuple<object, object>> deferredRemovals) + { + if (value == null) + return value; + + Type valueType = value.GetType(); + HashSet<Type> numericTypes = new HashSet<Type> + { + typeof(byte), typeof(sbyte), typeof(short), typeof(ushort), typeof(int), typeof(uint), + typeof(long), typeof(ulong), typeof(decimal), typeof(double), typeof(float) + }; + + if (numericTypes.Contains(valueType) || valueType == typeof(bool)) + { + string valueString = ParseStr(value); + if (noLogStrings.Contains(valueString)) + return "VALUE_SPECIFIED_IN_NO_LOG_PARAMETER"; + foreach (string omitMe in noLogStrings) + if (valueString.Contains(omitMe)) + return "VALUE_SPECIFIED_IN_NO_LOG_PARAMETER"; + } + else if (valueType == typeof(DateTime)) + value = ((DateTime)value).ToString("o"); + else if (value is IList) + { + List<object> newValue = new List<object>(); + deferredRemovals.Enqueue(new Tuple<object, object>((IList)value, newValue)); + value = newValue; + } + else if (value is IDictionary) + { + Hashtable newValue = new Hashtable(); + deferredRemovals.Enqueue(new Tuple<object, object>((IDictionary)value, newValue)); + value = newValue; + } + else + { + string stringValue = value.ToString(); + if (noLogStrings.Contains(stringValue)) + return "VALUE_SPECIFIED_IN_NO_LOG_PARAMETER"; + foreach (string omitMe in noLogStrings) + if (stringValue.Contains(omitMe)) + return (stringValue).Replace(omitMe, "********"); + value = stringValue; + } + return value; + } + + private void CleanupFiles(object s, EventArgs ev) + { + foreach (string path in cleanupFiles) + { + if (File.Exists(path)) + File.Delete(path); + else if (Directory.Exists(path)) + Directory.Delete(path, true); + } + cleanupFiles = new List<string>(); + } + + private string FormatOptionsContext(string msg, string prefix = " ") + { + if (optionsContext.Count > 0) + msg += String.Format("{0}found in {1}", prefix, String.Join(" -> ", optionsContext)); + return msg; + } + + [DllImport("kernel32.dll")] + private static extern IntPtr GetConsoleWindow(); + + private static void ExitModule(int rc) + { + // When running in a Runspace Environment.Exit will kill the entire + // process which is not what we want, detect if we are in a + // Runspace and call a ScriptBlock with exit instead. + if (Runspace.DefaultRunspace != null) + ScriptBlock.Create("Set-Variable -Name LASTEXITCODE -Value $args[0] -Scope Global; exit $args[0]").Invoke(rc); + else + { + // Used for local debugging in Visual Studio + if (System.Diagnostics.Debugger.IsAttached) + { + Console.WriteLine("Press enter to continue..."); + Console.ReadLine(); + } + Environment.Exit(rc); + } + } + + private static void WriteLineModule(string line) + { + Console.WriteLine(line); + } + } +} diff --git a/lib/ansible/module_utils/csharp/Ansible.Become.cs b/lib/ansible/module_utils/csharp/Ansible.Become.cs new file mode 100644 index 0000000..a6f645c --- /dev/null +++ b/lib/ansible/module_utils/csharp/Ansible.Become.cs @@ -0,0 +1,655 @@ +using Microsoft.Win32.SafeHandles; +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; +using System.Security.AccessControl; +using System.Security.Principal; +using System.Text; +using Ansible.AccessToken; +using Ansible.Process; + +namespace Ansible.Become +{ + internal class NativeHelpers + { + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct KERB_S4U_LOGON + { + public UInt32 MessageType; + public UInt32 Flags; + public LSA_UNICODE_STRING ClientUpn; + public LSA_UNICODE_STRING ClientRealm; + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Ansi)] + public struct LSA_STRING + { + public UInt16 Length; + public UInt16 MaximumLength; + [MarshalAs(UnmanagedType.LPStr)] public string Buffer; + + public static implicit operator string(LSA_STRING s) + { + return s.Buffer; + } + + public static implicit operator LSA_STRING(string s) + { + if (s == null) + s = ""; + + LSA_STRING lsaStr = new LSA_STRING + { + Buffer = s, + Length = (UInt16)s.Length, + MaximumLength = (UInt16)(s.Length + 1), + }; + return lsaStr; + } + } + + [StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)] + public struct LSA_UNICODE_STRING + { + public UInt16 Length; + public UInt16 MaximumLength; + public IntPtr Buffer; + } + + [StructLayout(LayoutKind.Sequential)] + public struct SECURITY_LOGON_SESSION_DATA + { + public UInt32 Size; + public Luid LogonId; + public LSA_UNICODE_STRING UserName; + public LSA_UNICODE_STRING LogonDomain; + public LSA_UNICODE_STRING AuthenticationPackage; + public SECURITY_LOGON_TYPE LogonType; + } + + [StructLayout(LayoutKind.Sequential)] + public struct TOKEN_SOURCE + { + [MarshalAs(UnmanagedType.ByValArray, SizeConst = 8)] public char[] SourceName; + public Luid SourceIdentifier; + } + + public enum SECURITY_LOGON_TYPE + { + System = 0, // Used only by the System account + Interactive = 2, + Network, + Batch, + Service, + Proxy, + Unlock, + NetworkCleartext, + NewCredentials, + RemoteInteractive, + CachedInteractive, + CachedRemoteInteractive, + CachedUnlock + } + } + + internal class NativeMethods + { + [DllImport("advapi32.dll", SetLastError = true)] + public static extern bool AllocateLocallyUniqueId( + out Luid Luid); + + [DllImport("advapi32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + public static extern bool CreateProcessWithTokenW( + SafeNativeHandle hToken, + LogonFlags dwLogonFlags, + [MarshalAs(UnmanagedType.LPWStr)] string lpApplicationName, + StringBuilder lpCommandLine, + Process.NativeHelpers.ProcessCreationFlags dwCreationFlags, + Process.SafeMemoryBuffer lpEnvironment, + [MarshalAs(UnmanagedType.LPWStr)] string lpCurrentDirectory, + Process.NativeHelpers.STARTUPINFOEX lpStartupInfo, + out Process.NativeHelpers.PROCESS_INFORMATION lpProcessInformation); + + [DllImport("kernel32.dll")] + public static extern UInt32 GetCurrentThreadId(); + + [DllImport("user32.dll", SetLastError = true)] + public static extern NoopSafeHandle GetProcessWindowStation(); + + [DllImport("user32.dll", SetLastError = true)] + public static extern NoopSafeHandle GetThreadDesktop( + UInt32 dwThreadId); + + [DllImport("secur32.dll", SetLastError = true)] + public static extern UInt32 LsaDeregisterLogonProcess( + IntPtr LsaHandle); + + [DllImport("secur32.dll", SetLastError = true)] + public static extern UInt32 LsaFreeReturnBuffer( + IntPtr Buffer); + + [DllImport("secur32.dll", SetLastError = true)] + public static extern UInt32 LsaGetLogonSessionData( + ref Luid LogonId, + out SafeLsaMemoryBuffer ppLogonSessionData); + + [DllImport("secur32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + public static extern UInt32 LsaLogonUser( + SafeLsaHandle LsaHandle, + NativeHelpers.LSA_STRING OriginName, + LogonType LogonType, + UInt32 AuthenticationPackage, + IntPtr AuthenticationInformation, + UInt32 AuthenticationInformationLength, + IntPtr LocalGroups, + NativeHelpers.TOKEN_SOURCE SourceContext, + out SafeLsaMemoryBuffer ProfileBuffer, + out UInt32 ProfileBufferLength, + out Luid LogonId, + out SafeNativeHandle Token, + out IntPtr Quotas, + out UInt32 SubStatus); + + [DllImport("secur32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + public static extern UInt32 LsaLookupAuthenticationPackage( + SafeLsaHandle LsaHandle, + NativeHelpers.LSA_STRING PackageName, + out UInt32 AuthenticationPackage); + + [DllImport("advapi32.dll")] + public static extern UInt32 LsaNtStatusToWinError( + UInt32 Status); + + [DllImport("secur32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + public static extern UInt32 LsaRegisterLogonProcess( + NativeHelpers.LSA_STRING LogonProcessName, + out SafeLsaHandle LsaHandle, + out IntPtr SecurityMode); + } + + internal class SafeLsaHandle : SafeHandleZeroOrMinusOneIsInvalid + { + public SafeLsaHandle() : base(true) { } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + UInt32 res = NativeMethods.LsaDeregisterLogonProcess(handle); + return res == 0; + } + } + + internal class SafeLsaMemoryBuffer : SafeHandleZeroOrMinusOneIsInvalid + { + public SafeLsaMemoryBuffer() : base(true) { } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + UInt32 res = NativeMethods.LsaFreeReturnBuffer(handle); + return res == 0; + } + } + + internal class NoopSafeHandle : SafeHandle + { + public NoopSafeHandle() : base(IntPtr.Zero, false) { } + public override bool IsInvalid { get { return false; } } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() { return true; } + } + + [Flags] + public enum LogonFlags + { + WithProfile = 0x00000001, + NetcredentialsOnly = 0x00000002 + } + + public class BecomeUtil + { + private static List<string> SERVICE_SIDS = new List<string>() + { + "S-1-5-18", // NT AUTHORITY\SYSTEM + "S-1-5-19", // NT AUTHORITY\LocalService + "S-1-5-20" // NT AUTHORITY\NetworkService + }; + private static int WINDOWS_STATION_ALL_ACCESS = 0x000F037F; + private static int DESKTOP_RIGHTS_ALL_ACCESS = 0x000F01FF; + + public static Result CreateProcessAsUser(string username, string password, string command) + { + return CreateProcessAsUser(username, password, LogonFlags.WithProfile, LogonType.Interactive, + null, command, null, null, ""); + } + + public static Result CreateProcessAsUser(string username, string password, LogonFlags logonFlags, LogonType logonType, + string lpApplicationName, string lpCommandLine, string lpCurrentDirectory, IDictionary environment, + string stdin) + { + byte[] stdinBytes; + if (String.IsNullOrEmpty(stdin)) + stdinBytes = new byte[0]; + else + { + if (!stdin.EndsWith(Environment.NewLine)) + stdin += Environment.NewLine; + stdinBytes = new UTF8Encoding(false).GetBytes(stdin); + } + return CreateProcessAsUser(username, password, logonFlags, logonType, lpApplicationName, lpCommandLine, + lpCurrentDirectory, environment, stdinBytes); + } + + /// <summary> + /// Creates a process as another user account. This method will attempt to run as another user with the + /// highest possible permissions available. The main privilege required is the SeDebugPrivilege, without + /// this privilege you can only run as a local or domain user if the username and password is specified. + /// </summary> + /// <param name="username">The username of the runas user</param> + /// <param name="password">The password of the runas user</param> + /// <param name="logonFlags">LogonFlags to control how to logon a user when the password is specified</param> + /// <param name="logonType">Controls what type of logon is used, this only applies when the password is specified</param> + /// <param name="lpApplicationName">The name of the executable or batch file to executable</param> + /// <param name="lpCommandLine">The command line to execute, typically this includes lpApplication as the first argument</param> + /// <param name="lpCurrentDirectory">The full path to the current directory for the process, null will have the same cwd as the calling process</param> + /// <param name="environment">A dictionary of key/value pairs to define the new process environment</param> + /// <param name="stdin">Bytes sent to the stdin pipe</param> + /// <returns>Ansible.Process.Result object that contains the command output and return code</returns> + public static Result CreateProcessAsUser(string username, string password, LogonFlags logonFlags, LogonType logonType, + string lpApplicationName, string lpCommandLine, string lpCurrentDirectory, IDictionary environment, byte[] stdin) + { + // While we use STARTUPINFOEX having EXTENDED_STARTUPINFO_PRESENT causes a parameter validation error + Process.NativeHelpers.ProcessCreationFlags creationFlags = Process.NativeHelpers.ProcessCreationFlags.CREATE_UNICODE_ENVIRONMENT; + Process.NativeHelpers.PROCESS_INFORMATION pi = new Process.NativeHelpers.PROCESS_INFORMATION(); + Process.NativeHelpers.STARTUPINFOEX si = new Process.NativeHelpers.STARTUPINFOEX(); + si.startupInfo.dwFlags = Process.NativeHelpers.StartupInfoFlags.USESTDHANDLES; + + SafeFileHandle stdoutRead, stdoutWrite, stderrRead, stderrWrite, stdinRead, stdinWrite; + ProcessUtil.CreateStdioPipes(si, out stdoutRead, out stdoutWrite, out stderrRead, out stderrWrite, + out stdinRead, out stdinWrite); + FileStream stdinStream = new FileStream(stdinWrite, FileAccess.Write); + + // $null from PowerShell ends up as an empty string, we need to convert back as an empty string doesn't + // make sense for these parameters + if (lpApplicationName == "") + lpApplicationName = null; + + if (lpCurrentDirectory == "") + lpCurrentDirectory = null; + + // A user may have 2 tokens, 1 limited and 1 elevated. GetUserTokens will return both token to ensure + // we don't close one of the pairs while the process is still running. If the process tries to retrieve + // one of the pairs and the token handle is closed then it will fail with ERROR_NO_SUCH_LOGON_SESSION. + List<SafeNativeHandle> userTokens = GetUserTokens(username, password, logonType); + try + { + using (Process.SafeMemoryBuffer lpEnvironment = ProcessUtil.CreateEnvironmentPointer(environment)) + { + bool launchSuccess = false; + StringBuilder commandLine = new StringBuilder(lpCommandLine); + foreach (SafeNativeHandle token in userTokens) + { + // GetUserTokens could return null if an elevated token could not be retrieved. + if (token == null) + continue; + + if (NativeMethods.CreateProcessWithTokenW(token, logonFlags, lpApplicationName, + commandLine, creationFlags, lpEnvironment, lpCurrentDirectory, si, out pi)) + { + launchSuccess = true; + break; + } + } + + if (!launchSuccess) + throw new Process.Win32Exception("CreateProcessWithTokenW() failed"); + } + return ProcessUtil.WaitProcess(stdoutRead, stdoutWrite, stderrRead, stderrWrite, stdinStream, stdin, + pi.hProcess); + } + finally + { + userTokens.Where(t => t != null).ToList().ForEach(t => t.Dispose()); + } + } + + private static List<SafeNativeHandle> GetUserTokens(string username, string password, LogonType logonType) + { + List<SafeNativeHandle> userTokens = new List<SafeNativeHandle>(); + + SafeNativeHandle systemToken = null; + bool impersonated = false; + string becomeSid = username; + if (logonType != LogonType.NewCredentials) + { + // If prefixed with .\, we are becoming a local account, strip the prefix + if (username.StartsWith(".\\")) + username = username.Substring(2); + + NTAccount account = new NTAccount(username); + becomeSid = ((SecurityIdentifier)account.Translate(typeof(SecurityIdentifier))).Value; + + // Grant access to the current Windows Station and Desktop to the become user + GrantAccessToWindowStationAndDesktop(account); + + // Try and impersonate a SYSTEM token, we need a SYSTEM token to either become a well known service + // account or have administrative rights on the become access token. + // If we ultimately are becoming the SYSTEM account we want the token with the most privileges available. + // https://github.com/ansible/ansible/issues/71453 + bool mostPrivileges = becomeSid == "S-1-5-18"; + systemToken = GetPrimaryTokenForUser(new SecurityIdentifier("S-1-5-18"), + new List<string>() { "SeTcbPrivilege" }, mostPrivileges); + if (systemToken != null) + { + try + { + TokenUtil.ImpersonateToken(systemToken); + impersonated = true; + } + catch (Process.Win32Exception) { } // We tried, just rely on current user's permissions. + } + } + + // We require impersonation if becoming a service sid or becoming a user without a password + if (!impersonated && (SERVICE_SIDS.Contains(becomeSid) || String.IsNullOrEmpty(password))) + throw new Exception("Failed to get token for NT AUTHORITY\\SYSTEM required for become as a service account or an account without a password"); + + try + { + if (becomeSid == "S-1-5-18") + userTokens.Add(systemToken); + // Cannot use String.IsEmptyOrNull() as an empty string is an account that doesn't have a pass. + // We only use S4U if no password was defined or it was null + else if (!SERVICE_SIDS.Contains(becomeSid) && password == null && logonType != LogonType.NewCredentials) + { + // If no password was specified, try and duplicate an existing token for that user or use S4U to + // generate one without network credentials + SecurityIdentifier sid = new SecurityIdentifier(becomeSid); + SafeNativeHandle becomeToken = GetPrimaryTokenForUser(sid); + if (becomeToken != null) + { + userTokens.Add(GetElevatedToken(becomeToken)); + userTokens.Add(becomeToken); + } + else + { + becomeToken = GetS4UTokenForUser(sid, logonType); + userTokens.Add(null); + userTokens.Add(becomeToken); + } + } + else + { + string domain = null; + switch (becomeSid) + { + case "S-1-5-19": + logonType = LogonType.Service; + domain = "NT AUTHORITY"; + username = "LocalService"; + break; + case "S-1-5-20": + logonType = LogonType.Service; + domain = "NT AUTHORITY"; + username = "NetworkService"; + break; + default: + // Trying to become a local or domain account + if (username.Contains(@"\")) + { + string[] userSplit = username.Split(new char[1] { '\\' }, 2); + domain = userSplit[0]; + username = userSplit[1]; + } + else if (!username.Contains("@")) + domain = "."; + break; + } + + SafeNativeHandle hToken = TokenUtil.LogonUser(username, domain, password, logonType, + LogonProvider.Default); + + // Get the elevated token for a local/domain accounts only + if (!SERVICE_SIDS.Contains(becomeSid)) + userTokens.Add(GetElevatedToken(hToken)); + userTokens.Add(hToken); + } + } + finally + { + if (impersonated) + TokenUtil.RevertToSelf(); + } + + return userTokens; + } + + private static SafeNativeHandle GetPrimaryTokenForUser(SecurityIdentifier sid, + List<string> requiredPrivileges = null, bool mostPrivileges = false) + { + // According to CreateProcessWithTokenW we require a token with + // TOKEN_QUERY, TOKEN_DUPLICATE and TOKEN_ASSIGN_PRIMARY + // Also add in TOKEN_IMPERSONATE so we can get an impersonated token + TokenAccessLevels dwAccess = TokenAccessLevels.Query | + TokenAccessLevels.Duplicate | + TokenAccessLevels.AssignPrimary | + TokenAccessLevels.Impersonate; + + SafeNativeHandle userToken = null; + int privilegeCount = 0; + + foreach (SafeNativeHandle hToken in TokenUtil.EnumerateUserTokens(sid, dwAccess)) + { + // Filter out any Network logon tokens, using become with that is useless when S4U + // can give us a Batch logon + NativeHelpers.SECURITY_LOGON_TYPE tokenLogonType = GetTokenLogonType(hToken); + if (tokenLogonType == NativeHelpers.SECURITY_LOGON_TYPE.Network) + continue; + + List<string> actualPrivileges = TokenUtil.GetTokenPrivileges(hToken).Select(x => x.Name).ToList(); + + // If the token has less or the same number of privileges than the current token, skip it. + if (mostPrivileges && privilegeCount >= actualPrivileges.Count) + continue; + + // Check that the required privileges are on the token + if (requiredPrivileges != null) + { + int missing = requiredPrivileges.Where(x => !actualPrivileges.Contains(x)).Count(); + if (missing > 0) + continue; + } + + // Duplicate the token to convert it to a primary token with the access level required. + try + { + userToken = TokenUtil.DuplicateToken(hToken, TokenAccessLevels.MaximumAllowed, + SecurityImpersonationLevel.Anonymous, TokenType.Primary); + privilegeCount = actualPrivileges.Count; + } + catch (Process.Win32Exception) + { + continue; + } + + // If we don't care about getting the token with the most privileges, escape the loop as we already + // have a token. + if (!mostPrivileges) + break; + } + + return userToken; + } + + private static SafeNativeHandle GetS4UTokenForUser(SecurityIdentifier sid, LogonType logonType) + { + NTAccount becomeAccount = (NTAccount)sid.Translate(typeof(NTAccount)); + string[] userSplit = becomeAccount.Value.Split(new char[1] { '\\' }, 2); + string domainName = userSplit[0]; + string username = userSplit[1]; + bool domainUser = domainName.ToLowerInvariant() != Environment.MachineName.ToLowerInvariant(); + + NativeHelpers.LSA_STRING logonProcessName = "ansible"; + SafeLsaHandle lsaHandle; + IntPtr securityMode; + UInt32 res = NativeMethods.LsaRegisterLogonProcess(logonProcessName, out lsaHandle, out securityMode); + if (res != 0) + throw new Process.Win32Exception((int)NativeMethods.LsaNtStatusToWinError(res), "LsaRegisterLogonProcess() failed"); + + using (lsaHandle) + { + NativeHelpers.LSA_STRING packageName = domainUser ? "Kerberos" : "MICROSOFT_AUTHENTICATION_PACKAGE_V1_0"; + UInt32 authPackage; + res = NativeMethods.LsaLookupAuthenticationPackage(lsaHandle, packageName, out authPackage); + if (res != 0) + throw new Process.Win32Exception((int)NativeMethods.LsaNtStatusToWinError(res), + String.Format("LsaLookupAuthenticationPackage({0}) failed", (string)packageName)); + + int usernameLength = username.Length * sizeof(char); + int domainLength = domainName.Length * sizeof(char); + int authInfoLength = (Marshal.SizeOf(typeof(NativeHelpers.KERB_S4U_LOGON)) + usernameLength + domainLength); + IntPtr authInfo = Marshal.AllocHGlobal((int)authInfoLength); + try + { + IntPtr usernamePtr = IntPtr.Add(authInfo, Marshal.SizeOf(typeof(NativeHelpers.KERB_S4U_LOGON))); + IntPtr domainPtr = IntPtr.Add(usernamePtr, usernameLength); + + // KERB_S4U_LOGON has the same structure as MSV1_0_S4U_LOGON (local accounts) + NativeHelpers.KERB_S4U_LOGON s4uLogon = new NativeHelpers.KERB_S4U_LOGON + { + MessageType = 12, // KerbS4ULogon + Flags = 0, + ClientUpn = new NativeHelpers.LSA_UNICODE_STRING + { + Length = (UInt16)usernameLength, + MaximumLength = (UInt16)usernameLength, + Buffer = usernamePtr, + }, + ClientRealm = new NativeHelpers.LSA_UNICODE_STRING + { + Length = (UInt16)domainLength, + MaximumLength = (UInt16)domainLength, + Buffer = domainPtr, + }, + }; + Marshal.StructureToPtr(s4uLogon, authInfo, false); + Marshal.Copy(username.ToCharArray(), 0, usernamePtr, username.Length); + Marshal.Copy(domainName.ToCharArray(), 0, domainPtr, domainName.Length); + + Luid sourceLuid; + if (!NativeMethods.AllocateLocallyUniqueId(out sourceLuid)) + throw new Process.Win32Exception("AllocateLocallyUniqueId() failed"); + + NativeHelpers.TOKEN_SOURCE tokenSource = new NativeHelpers.TOKEN_SOURCE + { + SourceName = "ansible\0".ToCharArray(), + SourceIdentifier = sourceLuid, + }; + + // Only Batch or Network will work with S4U, prefer Batch but use Network if asked + LogonType lsaLogonType = logonType == LogonType.Network + ? LogonType.Network + : LogonType.Batch; + SafeLsaMemoryBuffer profileBuffer; + UInt32 profileBufferLength; + Luid logonId; + SafeNativeHandle hToken; + IntPtr quotas; + UInt32 subStatus; + + res = NativeMethods.LsaLogonUser(lsaHandle, logonProcessName, lsaLogonType, authPackage, + authInfo, (UInt32)authInfoLength, IntPtr.Zero, tokenSource, out profileBuffer, out profileBufferLength, + out logonId, out hToken, out quotas, out subStatus); + if (res != 0) + throw new Process.Win32Exception((int)NativeMethods.LsaNtStatusToWinError(res), + String.Format("LsaLogonUser() failed with substatus {0}", subStatus)); + + profileBuffer.Dispose(); + return hToken; + } + finally + { + Marshal.FreeHGlobal(authInfo); + } + } + } + + private static SafeNativeHandle GetElevatedToken(SafeNativeHandle hToken) + { + TokenElevationType tet = TokenUtil.GetTokenElevationType(hToken); + // We already have the best token we can get, no linked token is really available. + if (tet != TokenElevationType.Limited) + return null; + + SafeNativeHandle linkedToken = TokenUtil.GetTokenLinkedToken(hToken); + TokenStatistics tokenStats = TokenUtil.GetTokenStatistics(linkedToken); + + // We can only use a token if it's a primary one (we had the SeTcbPrivilege set) + if (tokenStats.TokenType == TokenType.Primary) + return linkedToken; + else + return null; + } + + private static NativeHelpers.SECURITY_LOGON_TYPE GetTokenLogonType(SafeNativeHandle hToken) + { + TokenStatistics stats = TokenUtil.GetTokenStatistics(hToken); + + SafeLsaMemoryBuffer sessionDataPtr; + UInt32 res = NativeMethods.LsaGetLogonSessionData(ref stats.AuthenticationId, out sessionDataPtr); + if (res != 0) + // Default to Network, if we weren't able to get the actual type treat it as an error and assume + // we don't want to run a process with the token + return NativeHelpers.SECURITY_LOGON_TYPE.Network; + + using (sessionDataPtr) + { + NativeHelpers.SECURITY_LOGON_SESSION_DATA sessionData = (NativeHelpers.SECURITY_LOGON_SESSION_DATA)Marshal.PtrToStructure( + sessionDataPtr.DangerousGetHandle(), typeof(NativeHelpers.SECURITY_LOGON_SESSION_DATA)); + return sessionData.LogonType; + } + } + + private static void GrantAccessToWindowStationAndDesktop(IdentityReference account) + { + GrantAccess(account, NativeMethods.GetProcessWindowStation(), WINDOWS_STATION_ALL_ACCESS); + GrantAccess(account, NativeMethods.GetThreadDesktop(NativeMethods.GetCurrentThreadId()), DESKTOP_RIGHTS_ALL_ACCESS); + } + + private static void GrantAccess(IdentityReference account, NoopSafeHandle handle, int accessMask) + { + GenericSecurity security = new GenericSecurity(false, ResourceType.WindowObject, handle, AccessControlSections.Access); + security.AddAccessRule(new GenericAccessRule(account, accessMask, AccessControlType.Allow)); + security.Persist(handle, AccessControlSections.Access); + } + + private class GenericSecurity : NativeObjectSecurity + { + public GenericSecurity(bool isContainer, ResourceType resType, SafeHandle objectHandle, AccessControlSections sectionsRequested) + : base(isContainer, resType, objectHandle, sectionsRequested) { } + public new void Persist(SafeHandle handle, AccessControlSections includeSections) { base.Persist(handle, includeSections); } + public new void AddAccessRule(AccessRule rule) { base.AddAccessRule(rule); } + public override Type AccessRightType { get { throw new NotImplementedException(); } } + public override AccessRule AccessRuleFactory(System.Security.Principal.IdentityReference identityReference, int accessMask, bool isInherited, + InheritanceFlags inheritanceFlags, PropagationFlags propagationFlags, AccessControlType type) + { throw new NotImplementedException(); } + public override Type AccessRuleType { get { return typeof(AccessRule); } } + public override AuditRule AuditRuleFactory(System.Security.Principal.IdentityReference identityReference, int accessMask, bool isInherited, + InheritanceFlags inheritanceFlags, PropagationFlags propagationFlags, AuditFlags flags) + { throw new NotImplementedException(); } + public override Type AuditRuleType { get { return typeof(AuditRule); } } + } + + private class GenericAccessRule : AccessRule + { + public GenericAccessRule(IdentityReference identity, int accessMask, AccessControlType type) : + base(identity, accessMask, false, InheritanceFlags.None, PropagationFlags.None, type) + { } + } + } +} diff --git a/lib/ansible/module_utils/csharp/Ansible.Privilege.cs b/lib/ansible/module_utils/csharp/Ansible.Privilege.cs new file mode 100644 index 0000000..2c0b266 --- /dev/null +++ b/lib/ansible/module_utils/csharp/Ansible.Privilege.cs @@ -0,0 +1,443 @@ +using Microsoft.Win32.SafeHandles; +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; +using System.Security.Principal; +using System.Text; + +namespace Ansible.Privilege +{ + internal class NativeHelpers + { + [StructLayout(LayoutKind.Sequential)] + public struct LUID + { + public UInt32 LowPart; + public Int32 HighPart; + } + + [StructLayout(LayoutKind.Sequential)] + public struct LUID_AND_ATTRIBUTES + { + public LUID Luid; + public PrivilegeAttributes Attributes; + } + + [StructLayout(LayoutKind.Sequential)] + public struct TOKEN_PRIVILEGES + { + public UInt32 PrivilegeCount; + [MarshalAs(UnmanagedType.ByValArray, SizeConst = 1)] + public LUID_AND_ATTRIBUTES[] Privileges; + } + } + + internal class NativeMethods + { + [DllImport("advapi32.dll", SetLastError = true)] + public static extern bool AdjustTokenPrivileges( + SafeNativeHandle TokenHandle, + [MarshalAs(UnmanagedType.Bool)] bool DisableAllPrivileges, + SafeMemoryBuffer NewState, + UInt32 BufferLength, + SafeMemoryBuffer PreviousState, + out UInt32 ReturnLength); + + [DllImport("kernel32.dll")] + public static extern bool CloseHandle( + IntPtr hObject); + + [DllImport("kernel32")] + public static extern SafeWaitHandle GetCurrentProcess(); + + [DllImport("advapi32.dll", SetLastError = true)] + public static extern bool GetTokenInformation( + SafeNativeHandle TokenHandle, + UInt32 TokenInformationClass, + SafeMemoryBuffer TokenInformation, + UInt32 TokenInformationLength, + out UInt32 ReturnLength); + + [DllImport("advapi32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + public static extern bool LookupPrivilegeName( + string lpSystemName, + ref NativeHelpers.LUID lpLuid, + StringBuilder lpName, + ref UInt32 cchName); + + [DllImport("advapi32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + public static extern bool LookupPrivilegeValue( + string lpSystemName, + string lpName, + out NativeHelpers.LUID lpLuid); + + [DllImport("advapi32.dll", SetLastError = true)] + public static extern bool OpenProcessToken( + SafeHandle ProcessHandle, + TokenAccessLevels DesiredAccess, + out SafeNativeHandle TokenHandle); + } + + internal class SafeMemoryBuffer : SafeHandleZeroOrMinusOneIsInvalid + { + public SafeMemoryBuffer() : base(true) { } + public SafeMemoryBuffer(int cb) : base(true) + { + base.SetHandle(Marshal.AllocHGlobal(cb)); + } + public SafeMemoryBuffer(IntPtr handle) : base(true) + { + base.SetHandle(handle); + } + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + Marshal.FreeHGlobal(handle); + return true; + } + } + + internal class SafeNativeHandle : SafeHandleZeroOrMinusOneIsInvalid + { + public SafeNativeHandle() : base(true) { } + public SafeNativeHandle(IntPtr handle) : base(true) { this.handle = handle; } + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + return NativeMethods.CloseHandle(handle); + } + } + + public class Win32Exception : System.ComponentModel.Win32Exception + { + private string _msg; + public Win32Exception(string message) : this(Marshal.GetLastWin32Error(), message) { } + public Win32Exception(int errorCode, string message) : base(errorCode) + { + _msg = String.Format("{0} ({1}, Win32ErrorCode {2})", message, base.Message, errorCode); + } + public override string Message { get { return _msg; } } + public static explicit operator Win32Exception(string message) { return new Win32Exception(message); } + } + + [Flags] + public enum PrivilegeAttributes : uint + { + Disabled = 0x00000000, + EnabledByDefault = 0x00000001, + Enabled = 0x00000002, + Removed = 0x00000004, + UsedForAccess = 0x80000000, + } + + public class PrivilegeEnabler : IDisposable + { + private SafeHandle process; + private Dictionary<string, bool?> previousState; + + /// <summary> + /// Temporarily enables the privileges specified and reverts once the class is disposed. + /// </summary> + /// <param name="strict">Whether to fail if any privilege failed to be enabled, if false then this will continue silently</param> + /// <param name="privileges">A list of privileges to enable</param> + public PrivilegeEnabler(bool strict, params string[] privileges) + { + if (privileges.Length > 0) + { + process = PrivilegeUtil.GetCurrentProcess(); + Dictionary<string, bool?> newState = new Dictionary<string, bool?>(); + for (int i = 0; i < privileges.Length; i++) + newState.Add(privileges[i], true); + try + { + previousState = PrivilegeUtil.SetTokenPrivileges(process, newState, strict); + } + catch (Win32Exception e) + { + throw new Win32Exception(e.NativeErrorCode, String.Format("Failed to enable privilege(s) {0}", String.Join(", ", privileges))); + } + } + } + + public void Dispose() + { + // disables any privileges that were enabled by this class + if (previousState != null) + PrivilegeUtil.SetTokenPrivileges(process, previousState); + GC.SuppressFinalize(this); + } + ~PrivilegeEnabler() { this.Dispose(); } + } + + public class PrivilegeUtil + { + private static readonly UInt32 TOKEN_PRIVILEGES = 3; + + /// <summary> + /// Checks if the specific privilege constant is a valid privilege name + /// </summary> + /// <param name="name">The privilege constant (Se*Privilege) is valid</param> + /// <returns>true if valid, else false</returns> + public static bool CheckPrivilegeName(string name) + { + NativeHelpers.LUID luid; + if (!NativeMethods.LookupPrivilegeValue(null, name, out luid)) + { + int errCode = Marshal.GetLastWin32Error(); + if (errCode != 1313) // ERROR_NO_SUCH_PRIVILEGE + throw new Win32Exception(errCode, String.Format("LookupPrivilegeValue({0}) failed", name)); + return false; + } + else + { + return true; + } + } + + /// <summary> + /// Disables the privilege specified + /// </summary> + /// <param name="token">The process token to that contains the privilege to disable</param> + /// <param name="privilege">The privilege constant to disable</param> + /// <returns>The previous state that can be passed to SetTokenPrivileges to revert the action</returns> + public static Dictionary<string, bool?> DisablePrivilege(SafeHandle token, string privilege) + { + return SetTokenPrivileges(token, new Dictionary<string, bool?>() { { privilege, false } }); + } + + /// <summary> + /// Disables all the privileges + /// </summary> + /// <param name="token">The process token to that contains the privilege to disable</param> + /// <returns>The previous state that can be passed to SetTokenPrivileges to revert the action</returns> + public static Dictionary<string, bool?> DisableAllPrivileges(SafeHandle token) + { + return AdjustTokenPrivileges(token, null, false); + } + + /// <summary> + /// Enables the privilege specified + /// </summary> + /// <param name="token">The process token to that contains the privilege to enable</param> + /// <param name="privilege">The privilege constant to enable</param> + /// <returns>The previous state that can be passed to SetTokenPrivileges to revert the action</returns> + public static Dictionary<string, bool?> EnablePrivilege(SafeHandle token, string privilege) + { + return SetTokenPrivileges(token, new Dictionary<string, bool?>() { { privilege, true } }); + } + + /// <summary> + /// Get's the status of all the privileges on the token specified + /// </summary> + /// <param name="token">The process token to get the privilege status on</param> + /// <returns>Dictionary where the key is the privilege constant and the value is the PrivilegeAttributes flags</returns> + public static Dictionary<String, PrivilegeAttributes> GetAllPrivilegeInfo(SafeHandle token) + { + SafeNativeHandle hToken = null; + if (!NativeMethods.OpenProcessToken(token, TokenAccessLevels.Query, out hToken)) + throw new Win32Exception("OpenProcessToken() failed"); + + using (hToken) + { + UInt32 tokenLength = 0; + NativeMethods.GetTokenInformation(hToken, TOKEN_PRIVILEGES, new SafeMemoryBuffer(0), 0, out tokenLength); + + NativeHelpers.LUID_AND_ATTRIBUTES[] privileges; + using (SafeMemoryBuffer privilegesPtr = new SafeMemoryBuffer((int)tokenLength)) + { + if (!NativeMethods.GetTokenInformation(hToken, TOKEN_PRIVILEGES, privilegesPtr, tokenLength, out tokenLength)) + throw new Win32Exception("GetTokenInformation() for TOKEN_PRIVILEGES failed"); + + NativeHelpers.TOKEN_PRIVILEGES privilegeInfo = (NativeHelpers.TOKEN_PRIVILEGES)Marshal.PtrToStructure( + privilegesPtr.DangerousGetHandle(), typeof(NativeHelpers.TOKEN_PRIVILEGES)); + privileges = new NativeHelpers.LUID_AND_ATTRIBUTES[privilegeInfo.PrivilegeCount]; + PtrToStructureArray(privileges, IntPtr.Add(privilegesPtr.DangerousGetHandle(), Marshal.SizeOf(privilegeInfo.PrivilegeCount))); + } + + return privileges.ToDictionary(p => GetPrivilegeName(p.Luid), p => p.Attributes); + } + } + + /// <summary> + /// Get a handle to the current process for use with the methods above + /// </summary> + /// <returns>SafeWaitHandle handle of the current process token</returns> + public static SafeWaitHandle GetCurrentProcess() + { + return NativeMethods.GetCurrentProcess(); + } + + /// <summary> + /// Removes a privilege from the token. This operation is irreversible + /// </summary> + /// <param name="token">The process token to that contains the privilege to remove</param> + /// <param name="privilege">The privilege constant to remove</param> + public static void RemovePrivilege(SafeHandle token, string privilege) + { + SetTokenPrivileges(token, new Dictionary<string, bool?>() { { privilege, null } }); + } + + /// <summary> + /// Do a bulk set of multiple privileges + /// </summary> + /// <param name="token">The process token to use when setting the privilege state</param> + /// <param name="state">A dictionary that contains the privileges to set, the key is the constant name and the value can be; + /// true - enable the privilege + /// false - disable the privilege + /// null - remove the privilege (this cannot be reversed) + /// </param> + /// <param name="strict">When true, will fail if one privilege failed to be set, otherwise it will silently continue</param> + /// <returns>The previous state that can be passed to SetTokenPrivileges to revert the action</returns> + public static Dictionary<string, bool?> SetTokenPrivileges(SafeHandle token, IDictionary state, bool strict = true) + { + NativeHelpers.LUID_AND_ATTRIBUTES[] privilegeAttr = new NativeHelpers.LUID_AND_ATTRIBUTES[state.Count]; + int i = 0; + + foreach (DictionaryEntry entry in state) + { + string key = (string)entry.Key; + NativeHelpers.LUID luid; + if (!NativeMethods.LookupPrivilegeValue(null, key, out luid)) + throw new Win32Exception(String.Format("LookupPrivilegeValue({0}) failed", key)); + + PrivilegeAttributes attributes; + switch ((bool?)entry.Value) + { + case true: + attributes = PrivilegeAttributes.Enabled; + break; + case false: + attributes = PrivilegeAttributes.Disabled; + break; + default: + attributes = PrivilegeAttributes.Removed; + break; + } + + privilegeAttr[i].Luid = luid; + privilegeAttr[i].Attributes = attributes; + i++; + } + + return AdjustTokenPrivileges(token, privilegeAttr, strict); + } + + private static Dictionary<string, bool?> AdjustTokenPrivileges(SafeHandle token, NativeHelpers.LUID_AND_ATTRIBUTES[] newState, bool strict) + { + bool disableAllPrivileges; + SafeMemoryBuffer newStatePtr; + NativeHelpers.LUID_AND_ATTRIBUTES[] oldStatePrivileges; + UInt32 returnLength; + + if (newState == null) + { + disableAllPrivileges = true; + newStatePtr = new SafeMemoryBuffer(0); + } + else + { + disableAllPrivileges = false; + + // Need to manually marshal the bytes requires for newState as the constant size + // of LUID_AND_ATTRIBUTES is set to 1 and can't be overridden at runtime, TOKEN_PRIVILEGES + // always contains at least 1 entry so we need to calculate the extra size if there are + // nore than 1 LUID_AND_ATTRIBUTES entry + int tokenPrivilegesSize = Marshal.SizeOf(typeof(NativeHelpers.TOKEN_PRIVILEGES)); + int luidAttrSize = 0; + if (newState.Length > 1) + luidAttrSize = Marshal.SizeOf(typeof(NativeHelpers.LUID_AND_ATTRIBUTES)) * (newState.Length - 1); + int totalSize = tokenPrivilegesSize + luidAttrSize; + byte[] newStateBytes = new byte[totalSize]; + + // get the first entry that includes the struct details + NativeHelpers.TOKEN_PRIVILEGES tokenPrivileges = new NativeHelpers.TOKEN_PRIVILEGES() + { + PrivilegeCount = (UInt32)newState.Length, + Privileges = new NativeHelpers.LUID_AND_ATTRIBUTES[1], + }; + if (newState.Length > 0) + tokenPrivileges.Privileges[0] = newState[0]; + int offset = StructureToBytes(tokenPrivileges, newStateBytes, 0); + + // copy the remaining LUID_AND_ATTRIBUTES (if any) + for (int i = 1; i < newState.Length; i++) + offset += StructureToBytes(newState[i], newStateBytes, offset); + + // finally create the pointer to the byte array we just created + newStatePtr = new SafeMemoryBuffer(newStateBytes.Length); + Marshal.Copy(newStateBytes, 0, newStatePtr.DangerousGetHandle(), newStateBytes.Length); + } + + using (newStatePtr) + { + SafeNativeHandle hToken; + if (!NativeMethods.OpenProcessToken(token, TokenAccessLevels.Query | TokenAccessLevels.AdjustPrivileges, out hToken)) + throw new Win32Exception("OpenProcessToken() failed with Query and AdjustPrivileges"); + + using (hToken) + { + if (!NativeMethods.AdjustTokenPrivileges(hToken, disableAllPrivileges, newStatePtr, 0, new SafeMemoryBuffer(0), out returnLength)) + { + int errCode = Marshal.GetLastWin32Error(); + if (errCode != 122) // ERROR_INSUFFICIENT_BUFFER + throw new Win32Exception(errCode, "AdjustTokenPrivileges() failed to get old state size"); + } + + using (SafeMemoryBuffer oldStatePtr = new SafeMemoryBuffer((int)returnLength)) + { + bool res = NativeMethods.AdjustTokenPrivileges(hToken, disableAllPrivileges, newStatePtr, returnLength, oldStatePtr, out returnLength); + int errCode = Marshal.GetLastWin32Error(); + + // even when res == true, ERROR_NOT_ALL_ASSIGNED may be set as the last error code + // fail if we are running with strict, otherwise ignore those privileges + if (!res || ((strict && errCode != 0) || (!strict && !(errCode == 0 || errCode == 0x00000514)))) + throw new Win32Exception(errCode, "AdjustTokenPrivileges() failed"); + + // Marshal the oldStatePtr to the struct + NativeHelpers.TOKEN_PRIVILEGES oldState = (NativeHelpers.TOKEN_PRIVILEGES)Marshal.PtrToStructure( + oldStatePtr.DangerousGetHandle(), typeof(NativeHelpers.TOKEN_PRIVILEGES)); + oldStatePrivileges = new NativeHelpers.LUID_AND_ATTRIBUTES[oldState.PrivilegeCount]; + PtrToStructureArray(oldStatePrivileges, IntPtr.Add(oldStatePtr.DangerousGetHandle(), Marshal.SizeOf(oldState.PrivilegeCount))); + } + } + } + + return oldStatePrivileges.ToDictionary(p => GetPrivilegeName(p.Luid), p => (bool?)p.Attributes.HasFlag(PrivilegeAttributes.Enabled)); + } + + private static string GetPrivilegeName(NativeHelpers.LUID luid) + { + UInt32 nameLen = 0; + NativeMethods.LookupPrivilegeName(null, ref luid, null, ref nameLen); + + StringBuilder name = new StringBuilder((int)(nameLen + 1)); + if (!NativeMethods.LookupPrivilegeName(null, ref luid, name, ref nameLen)) + throw new Win32Exception("LookupPrivilegeName() failed"); + + return name.ToString(); + } + + private static void PtrToStructureArray<T>(T[] array, IntPtr ptr) + { + IntPtr ptrOffset = ptr; + for (int i = 0; i < array.Length; i++, ptrOffset = IntPtr.Add(ptrOffset, Marshal.SizeOf(typeof(T)))) + array[i] = (T)Marshal.PtrToStructure(ptrOffset, typeof(T)); + } + + private static int StructureToBytes<T>(T structure, byte[] array, int offset) + { + int size = Marshal.SizeOf(structure); + using (SafeMemoryBuffer structPtr = new SafeMemoryBuffer(size)) + { + Marshal.StructureToPtr(structure, structPtr.DangerousGetHandle(), false); + Marshal.Copy(structPtr.DangerousGetHandle(), array, offset, size); + } + + return size; + } + } +} + diff --git a/lib/ansible/module_utils/csharp/Ansible.Process.cs b/lib/ansible/module_utils/csharp/Ansible.Process.cs new file mode 100644 index 0000000..f4c68f0 --- /dev/null +++ b/lib/ansible/module_utils/csharp/Ansible.Process.cs @@ -0,0 +1,461 @@ +using Microsoft.Win32.SafeHandles; +using System; +using System.Collections; +using System.IO; +using System.Linq; +using System.Runtime.ConstrainedExecution; +using System.Runtime.InteropServices; +using System.Text; +using System.Threading; + +namespace Ansible.Process +{ + internal class NativeHelpers + { + [StructLayout(LayoutKind.Sequential)] + public class SECURITY_ATTRIBUTES + { + public UInt32 nLength; + public IntPtr lpSecurityDescriptor; + public bool bInheritHandle = false; + public SECURITY_ATTRIBUTES() + { + nLength = (UInt32)Marshal.SizeOf(this); + } + } + + [StructLayout(LayoutKind.Sequential)] + public class STARTUPINFO + { + public UInt32 cb; + public IntPtr lpReserved; + [MarshalAs(UnmanagedType.LPWStr)] public string lpDesktop; + [MarshalAs(UnmanagedType.LPWStr)] public string lpTitle; + public UInt32 dwX; + public UInt32 dwY; + public UInt32 dwXSize; + public UInt32 dwYSize; + public UInt32 dwXCountChars; + public UInt32 dwYCountChars; + public UInt32 dwFillAttribute; + public StartupInfoFlags dwFlags; + public UInt16 wShowWindow; + public UInt16 cbReserved2; + public IntPtr lpReserved2; + public SafeFileHandle hStdInput; + public SafeFileHandle hStdOutput; + public SafeFileHandle hStdError; + public STARTUPINFO() + { + cb = (UInt32)Marshal.SizeOf(this); + } + } + + [StructLayout(LayoutKind.Sequential)] + public class STARTUPINFOEX + { + public STARTUPINFO startupInfo; + public IntPtr lpAttributeList; + public STARTUPINFOEX() + { + startupInfo = new STARTUPINFO(); + startupInfo.cb = (UInt32)Marshal.SizeOf(this); + } + } + + [StructLayout(LayoutKind.Sequential)] + public struct PROCESS_INFORMATION + { + public IntPtr hProcess; + public IntPtr hThread; + public int dwProcessId; + public int dwThreadId; + } + + [Flags] + public enum ProcessCreationFlags : uint + { + CREATE_NEW_CONSOLE = 0x00000010, + CREATE_UNICODE_ENVIRONMENT = 0x00000400, + EXTENDED_STARTUPINFO_PRESENT = 0x00080000 + } + + [Flags] + public enum StartupInfoFlags : uint + { + USESTDHANDLES = 0x00000100 + } + + [Flags] + public enum HandleFlags : uint + { + None = 0, + INHERIT = 1 + } + } + + internal class NativeMethods + { + [DllImport("kernel32.dll", SetLastError = true)] + public static extern bool AllocConsole(); + + [DllImport("shell32.dll", SetLastError = true)] + public static extern SafeMemoryBuffer CommandLineToArgvW( + [MarshalAs(UnmanagedType.LPWStr)] string lpCmdLine, + out int pNumArgs); + + [DllImport("kernel32.dll", SetLastError = true)] + public static extern bool CreatePipe( + out SafeFileHandle hReadPipe, + out SafeFileHandle hWritePipe, + NativeHelpers.SECURITY_ATTRIBUTES lpPipeAttributes, + UInt32 nSize); + + [DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + public static extern bool CreateProcessW( + [MarshalAs(UnmanagedType.LPWStr)] string lpApplicationName, + StringBuilder lpCommandLine, + IntPtr lpProcessAttributes, + IntPtr lpThreadAttributes, + bool bInheritHandles, + NativeHelpers.ProcessCreationFlags dwCreationFlags, + SafeMemoryBuffer lpEnvironment, + [MarshalAs(UnmanagedType.LPWStr)] string lpCurrentDirectory, + NativeHelpers.STARTUPINFOEX lpStartupInfo, + out NativeHelpers.PROCESS_INFORMATION lpProcessInformation); + + [DllImport("kernel32.dll", SetLastError = true)] + public static extern bool FreeConsole(); + + [DllImport("kernel32.dll", SetLastError = true)] + public static extern IntPtr GetConsoleWindow(); + + [DllImport("kernel32.dll", SetLastError = true)] + public static extern bool GetExitCodeProcess( + SafeWaitHandle hProcess, + out UInt32 lpExitCode); + + [DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Unicode)] + public static extern uint SearchPathW( + [MarshalAs(UnmanagedType.LPWStr)] string lpPath, + [MarshalAs(UnmanagedType.LPWStr)] string lpFileName, + [MarshalAs(UnmanagedType.LPWStr)] string lpExtension, + UInt32 nBufferLength, + [MarshalAs(UnmanagedType.LPTStr)] StringBuilder lpBuffer, + out IntPtr lpFilePart); + + [DllImport("kernel32.dll", SetLastError = true)] + public static extern bool SetConsoleCP( + UInt32 wCodePageID); + + [DllImport("kernel32.dll", SetLastError = true)] + public static extern bool SetConsoleOutputCP( + UInt32 wCodePageID); + + [DllImport("kernel32.dll", SetLastError = true)] + public static extern bool SetHandleInformation( + SafeFileHandle hObject, + NativeHelpers.HandleFlags dwMask, + NativeHelpers.HandleFlags dwFlags); + + [DllImport("kernel32.dll")] + public static extern UInt32 WaitForSingleObject( + SafeWaitHandle hHandle, + UInt32 dwMilliseconds); + } + + internal class SafeMemoryBuffer : SafeHandleZeroOrMinusOneIsInvalid + { + public SafeMemoryBuffer() : base(true) { } + public SafeMemoryBuffer(int cb) : base(true) + { + base.SetHandle(Marshal.AllocHGlobal(cb)); + } + public SafeMemoryBuffer(IntPtr handle) : base(true) + { + base.SetHandle(handle); + } + + [ReliabilityContract(Consistency.WillNotCorruptState, Cer.MayFail)] + protected override bool ReleaseHandle() + { + Marshal.FreeHGlobal(handle); + return true; + } + } + + public class Win32Exception : System.ComponentModel.Win32Exception + { + private string _msg; + + public Win32Exception(string message) : this(Marshal.GetLastWin32Error(), message) { } + public Win32Exception(int errorCode, string message) : base(errorCode) + { + _msg = String.Format("{0} ({1}, Win32ErrorCode {2})", message, base.Message, errorCode); + } + + public override string Message { get { return _msg; } } + public static explicit operator Win32Exception(string message) { return new Win32Exception(message); } + } + + public class Result + { + public string StandardOut { get; internal set; } + public string StandardError { get; internal set; } + public uint ExitCode { get; internal set; } + } + + public class ProcessUtil + { + /// <summary> + /// Parses a command line string into an argv array according to the Windows rules + /// </summary> + /// <param name="lpCommandLine">The command line to parse</param> + /// <returns>An array of arguments interpreted by Windows</returns> + public static string[] ParseCommandLine(string lpCommandLine) + { + int numArgs; + using (SafeMemoryBuffer buf = NativeMethods.CommandLineToArgvW(lpCommandLine, out numArgs)) + { + if (buf.IsInvalid) + throw new Win32Exception("Error parsing command line"); + IntPtr[] strptrs = new IntPtr[numArgs]; + Marshal.Copy(buf.DangerousGetHandle(), strptrs, 0, numArgs); + return strptrs.Select(s => Marshal.PtrToStringUni(s)).ToArray(); + } + } + + /// <summary> + /// Searches the path for the executable specified. Will throw a Win32Exception if the file is not found. + /// </summary> + /// <param name="lpFileName">The executable to search for</param> + /// <returns>The full path of the executable to search for</returns> + public static string SearchPath(string lpFileName) + { + StringBuilder sbOut = new StringBuilder(0); + IntPtr filePartOut = IntPtr.Zero; + UInt32 res = NativeMethods.SearchPathW(null, lpFileName, null, (UInt32)sbOut.Capacity, sbOut, out filePartOut); + if (res == 0) + { + int lastErr = Marshal.GetLastWin32Error(); + if (lastErr == 2) // ERROR_FILE_NOT_FOUND + throw new FileNotFoundException(String.Format("Could not find file '{0}'.", lpFileName)); + else + throw new Win32Exception(String.Format("SearchPathW({0}) failed to get buffer length", lpFileName)); + } + + sbOut.EnsureCapacity((int)res); + if (NativeMethods.SearchPathW(null, lpFileName, null, (UInt32)sbOut.Capacity, sbOut, out filePartOut) == 0) + throw new Win32Exception(String.Format("SearchPathW({0}) failed", lpFileName)); + + return sbOut.ToString(); + } + + public static Result CreateProcess(string command) + { + return CreateProcess(null, command, null, null, String.Empty); + } + + public static Result CreateProcess(string lpApplicationName, string lpCommandLine, string lpCurrentDirectory, + IDictionary environment) + { + return CreateProcess(lpApplicationName, lpCommandLine, lpCurrentDirectory, environment, String.Empty); + } + + public static Result CreateProcess(string lpApplicationName, string lpCommandLine, string lpCurrentDirectory, + IDictionary environment, string stdin) + { + return CreateProcess(lpApplicationName, lpCommandLine, lpCurrentDirectory, environment, stdin, null); + } + + public static Result CreateProcess(string lpApplicationName, string lpCommandLine, string lpCurrentDirectory, + IDictionary environment, byte[] stdin) + { + return CreateProcess(lpApplicationName, lpCommandLine, lpCurrentDirectory, environment, stdin, null); + } + + public static Result CreateProcess(string lpApplicationName, string lpCommandLine, string lpCurrentDirectory, + IDictionary environment, string stdin, string outputEncoding) + { + byte[] stdinBytes; + if (String.IsNullOrEmpty(stdin)) + stdinBytes = new byte[0]; + else + { + if (!stdin.EndsWith(Environment.NewLine)) + stdin += Environment.NewLine; + stdinBytes = new UTF8Encoding(false).GetBytes(stdin); + } + return CreateProcess(lpApplicationName, lpCommandLine, lpCurrentDirectory, environment, stdinBytes, outputEncoding); + } + + /// <summary> + /// Creates a process based on the CreateProcess API call. + /// </summary> + /// <param name="lpApplicationName">The name of the executable or batch file to execute</param> + /// <param name="lpCommandLine">The command line to execute, typically this includes lpApplication as the first argument</param> + /// <param name="lpCurrentDirectory">The full path to the current directory for the process, null will have the same cwd as the calling process</param> + /// <param name="environment">A dictionary of key/value pairs to define the new process environment</param> + /// <param name="stdin">A byte array to send over the stdin pipe</param> + /// <param name="outputEncoding">The character encoding for decoding stdout/stderr output of the process.</param> + /// <returns>Result object that contains the command output and return code</returns> + public static Result CreateProcess(string lpApplicationName, string lpCommandLine, string lpCurrentDirectory, + IDictionary environment, byte[] stdin, string outputEncoding) + { + NativeHelpers.ProcessCreationFlags creationFlags = NativeHelpers.ProcessCreationFlags.CREATE_UNICODE_ENVIRONMENT | + NativeHelpers.ProcessCreationFlags.EXTENDED_STARTUPINFO_PRESENT; + NativeHelpers.PROCESS_INFORMATION pi = new NativeHelpers.PROCESS_INFORMATION(); + NativeHelpers.STARTUPINFOEX si = new NativeHelpers.STARTUPINFOEX(); + si.startupInfo.dwFlags = NativeHelpers.StartupInfoFlags.USESTDHANDLES; + + SafeFileHandle stdoutRead, stdoutWrite, stderrRead, stderrWrite, stdinRead, stdinWrite; + CreateStdioPipes(si, out stdoutRead, out stdoutWrite, out stderrRead, out stderrWrite, out stdinRead, + out stdinWrite); + FileStream stdinStream = new FileStream(stdinWrite, FileAccess.Write); + + // $null from PowerShell ends up as an empty string, we need to convert back as an empty string doesn't + // make sense for these parameters + if (lpApplicationName == "") + lpApplicationName = null; + + if (lpCurrentDirectory == "") + lpCurrentDirectory = null; + + using (SafeMemoryBuffer lpEnvironment = CreateEnvironmentPointer(environment)) + { + // Create console with utf-8 CP if no existing console is present + bool isConsole = false; + if (NativeMethods.GetConsoleWindow() == IntPtr.Zero) + { + isConsole = NativeMethods.AllocConsole(); + + // Set console input/output codepage to UTF-8 + NativeMethods.SetConsoleCP(65001); + NativeMethods.SetConsoleOutputCP(65001); + } + + try + { + StringBuilder commandLine = new StringBuilder(lpCommandLine); + if (!NativeMethods.CreateProcessW(lpApplicationName, commandLine, IntPtr.Zero, IntPtr.Zero, + true, creationFlags, lpEnvironment, lpCurrentDirectory, si, out pi)) + { + throw new Win32Exception("CreateProcessW() failed"); + } + } + finally + { + if (isConsole) + NativeMethods.FreeConsole(); + } + } + + return WaitProcess(stdoutRead, stdoutWrite, stderrRead, stderrWrite, stdinStream, stdin, pi.hProcess, + outputEncoding); + } + + internal static void CreateStdioPipes(NativeHelpers.STARTUPINFOEX si, out SafeFileHandle stdoutRead, + out SafeFileHandle stdoutWrite, out SafeFileHandle stderrRead, out SafeFileHandle stderrWrite, + out SafeFileHandle stdinRead, out SafeFileHandle stdinWrite) + { + NativeHelpers.SECURITY_ATTRIBUTES pipesec = new NativeHelpers.SECURITY_ATTRIBUTES(); + pipesec.bInheritHandle = true; + + if (!NativeMethods.CreatePipe(out stdoutRead, out stdoutWrite, pipesec, 0)) + throw new Win32Exception("STDOUT pipe setup failed"); + if (!NativeMethods.SetHandleInformation(stdoutRead, NativeHelpers.HandleFlags.INHERIT, 0)) + throw new Win32Exception("STDOUT pipe handle setup failed"); + + if (!NativeMethods.CreatePipe(out stderrRead, out stderrWrite, pipesec, 0)) + throw new Win32Exception("STDERR pipe setup failed"); + if (!NativeMethods.SetHandleInformation(stderrRead, NativeHelpers.HandleFlags.INHERIT, 0)) + throw new Win32Exception("STDERR pipe handle setup failed"); + + if (!NativeMethods.CreatePipe(out stdinRead, out stdinWrite, pipesec, 0)) + throw new Win32Exception("STDIN pipe setup failed"); + if (!NativeMethods.SetHandleInformation(stdinWrite, NativeHelpers.HandleFlags.INHERIT, 0)) + throw new Win32Exception("STDIN pipe handle setup failed"); + + si.startupInfo.hStdOutput = stdoutWrite; + si.startupInfo.hStdError = stderrWrite; + si.startupInfo.hStdInput = stdinRead; + } + + internal static SafeMemoryBuffer CreateEnvironmentPointer(IDictionary environment) + { + IntPtr lpEnvironment = IntPtr.Zero; + if (environment != null && environment.Count > 0) + { + StringBuilder environmentString = new StringBuilder(); + foreach (DictionaryEntry kv in environment) + environmentString.AppendFormat("{0}={1}\0", kv.Key, kv.Value); + environmentString.Append('\0'); + + lpEnvironment = Marshal.StringToHGlobalUni(environmentString.ToString()); + } + return new SafeMemoryBuffer(lpEnvironment); + } + + internal static Result WaitProcess(SafeFileHandle stdoutRead, SafeFileHandle stdoutWrite, SafeFileHandle stderrRead, + SafeFileHandle stderrWrite, FileStream stdinStream, byte[] stdin, IntPtr hProcess, string outputEncoding = null) + { + // Default to using UTF-8 as the output encoding, this should be a sane default for most scenarios. + outputEncoding = String.IsNullOrEmpty(outputEncoding) ? "utf-8" : outputEncoding; + Encoding encodingInstance = Encoding.GetEncoding(outputEncoding); + + FileStream stdoutFS = new FileStream(stdoutRead, FileAccess.Read, 4096); + StreamReader stdout = new StreamReader(stdoutFS, encodingInstance, true, 4096); + stdoutWrite.Close(); + + FileStream stderrFS = new FileStream(stderrRead, FileAccess.Read, 4096); + StreamReader stderr = new StreamReader(stderrFS, encodingInstance, true, 4096); + stderrWrite.Close(); + + stdinStream.Write(stdin, 0, stdin.Length); + stdinStream.Close(); + + string stdoutStr, stderrStr = null; + GetProcessOutput(stdout, stderr, out stdoutStr, out stderrStr); + UInt32 rc = GetProcessExitCode(hProcess); + + return new Result + { + StandardOut = stdoutStr, + StandardError = stderrStr, + ExitCode = rc + }; + } + + internal static void GetProcessOutput(StreamReader stdoutStream, StreamReader stderrStream, out string stdout, out string stderr) + { + var sowait = new EventWaitHandle(false, EventResetMode.ManualReset); + var sewait = new EventWaitHandle(false, EventResetMode.ManualReset); + string so = null, se = null; + ThreadPool.QueueUserWorkItem((s) => + { + so = stdoutStream.ReadToEnd(); + sowait.Set(); + }); + ThreadPool.QueueUserWorkItem((s) => + { + se = stderrStream.ReadToEnd(); + sewait.Set(); + }); + foreach (var wh in new WaitHandle[] { sowait, sewait }) + wh.WaitOne(); + stdout = so; + stderr = se; + } + + internal static UInt32 GetProcessExitCode(IntPtr processHandle) + { + SafeWaitHandle hProcess = new SafeWaitHandle(processHandle, true); + NativeMethods.WaitForSingleObject(hProcess, 0xFFFFFFFF); + + UInt32 exitCode; + if (!NativeMethods.GetExitCodeProcess(hProcess, out exitCode)) + throw new Win32Exception("GetExitCodeProcess() failed"); + return exitCode; + } + } +} diff --git a/lib/ansible/module_utils/csharp/__init__.py b/lib/ansible/module_utils/csharp/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/lib/ansible/module_utils/csharp/__init__.py |