summaryrefslogtreecommitdiffstats
path: root/mycli/config.py
blob: 5d711093a2a1518f716269dfcae717d08008ddeb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
from copy import copy
from io import BytesIO, TextIOWrapper
import logging
import os
from os.path import exists
import struct
import sys
from typing import Union, IO

from configobj import ConfigObj, ConfigObjError
import pyaes

try:
    import importlib.resources as resources
except ImportError:
    # Python < 3.7
    import importlib_resources as resources

try:
    basestring
except NameError:
    basestring = str


logger = logging.getLogger(__name__)


def log(logger, level, message):
    """Logs message to stderr if logging isn't initialized."""

    if logger.parent.name != 'root':
        logger.log(level, message)
    else:
        print(message, file=sys.stderr)


def read_config_file(f, list_values=True):
    """Read a config file.

    *list_values* set to `True` is the default behavior of ConfigObj.
    Disabling it causes values to not be parsed for lists,
    (e.g. 'a,b,c' -> ['a', 'b', 'c']. Additionally, the config values are
    not unquoted. We are disabling list_values when reading MySQL config files
    so we can correctly interpret commas in passwords.

    """

    if isinstance(f, basestring):
        f = os.path.expanduser(f)

    try:
        config = ConfigObj(f, interpolation=False, encoding='utf8',
                           list_values=list_values)
    except ConfigObjError as e:
        log(logger, logging.WARNING, "Unable to parse line {0} of config file "
            "'{1}'.".format(e.line_number, f))
        log(logger, logging.WARNING, "Using successfully parsed config values.")
        return e.config
    except (IOError, OSError) as e:
        log(logger, logging.WARNING, "You don't have permission to read "
            "config file '{0}'.".format(e.filename))
        return None

    return config


def get_included_configs(config_file: Union[str, TextIOWrapper]) -> list:
    """Get a list of configuration files that are included into config_path
    with !includedir directive.

    "Normal" configs should be passed as file paths. The only exception
    is .mylogin which is decoded into a stream. However, it never
    contains include directives and so will be ignored by this
    function.

    """
    if not isinstance(config_file, str) or not os.path.isfile(config_file):
        return []
    included_configs = []

    try:
        with open(config_file) as f:
            include_directives = filter(
                lambda s: s.startswith('!includedir'),
                f
            )
            dirs = map(lambda s: s.strip().split()[-1], include_directives)
            dirs = filter(os.path.isdir, dirs)
            for dir in dirs:
                for filename in os.listdir(dir):
                    if filename.endswith('.cnf'):
                        included_configs.append(os.path.join(dir, filename))
    except (PermissionError, UnicodeDecodeError):
        pass
    return included_configs


def read_config_files(files, list_values=True):
    """Read and merge a list of config files."""

    config = create_default_config(list_values=list_values)
    _files = copy(files)
    while _files:
        _file = _files.pop(0)
        _config = read_config_file(_file, list_values=list_values)

        # expand includes only if we were able to parse config
        # (otherwise we'll just encounter the same errors again)
        if config is not None:
            _files = get_included_configs(_file) + _files
        if bool(_config) is True:
            config.merge(_config)
            config.filename = _config.filename

    return config


def create_default_config(list_values=True):
    import mycli
    default_config_file = resources.open_text(mycli, 'myclirc')
    return read_config_file(default_config_file, list_values=list_values)


def write_default_config(destination, overwrite=False):
    import mycli
    default_config = resources.read_text(mycli, 'myclirc')
    destination = os.path.expanduser(destination)
    if not overwrite and exists(destination):
        return

    with open(destination, 'w') as f:
        f.write(default_config)


def get_mylogin_cnf_path():
    """Return the path to the login path file or None if it doesn't exist."""
    mylogin_cnf_path = os.getenv('MYSQL_TEST_LOGIN_FILE')

    if mylogin_cnf_path is None:
        app_data = os.getenv('APPDATA')
        default_dir = os.path.join(app_data, 'MySQL') if app_data else '~'
        mylogin_cnf_path = os.path.join(default_dir, '.mylogin.cnf')

    mylogin_cnf_path = os.path.expanduser(mylogin_cnf_path)

    if exists(mylogin_cnf_path):
        logger.debug("Found login path file at '{0}'".format(mylogin_cnf_path))
        return mylogin_cnf_path
    return None


def open_mylogin_cnf(name):
    """Open a readable version of .mylogin.cnf.

    Returns the file contents as a TextIOWrapper object.

    :param str name: The pathname of the file to be opened.
    :return: the login path file or None
    """

    try:
        with open(name, 'rb') as f:
            plaintext = read_and_decrypt_mylogin_cnf(f)
    except (OSError, IOError, ValueError):
        logger.error('Unable to open login path file.')
        return None

    if not isinstance(plaintext, BytesIO):
        logger.error('Unable to read login path file.')
        return None

    return TextIOWrapper(plaintext)


# TODO reuse code between encryption an decryption
def encrypt_mylogin_cnf(plaintext: IO[str]):
    """Encryption of .mylogin.cnf file, analogous to calling
    mysql_config_editor.

    Code is based on the python implementation by Kristian Koehntopp
    https://github.com/isotopp/mysql-config-coder

    """
    def realkey(key):
        """Create the AES key from the login key."""
        rkey = bytearray(16)
        for i in range(len(key)):
            rkey[i % 16] ^= key[i]
        return bytes(rkey)

    def encode_line(plaintext, real_key, buf_len):
        aes = pyaes.AESModeOfOperationECB(real_key)
        text_len = len(plaintext)
        pad_len = buf_len - text_len
        pad_chr = bytes(chr(pad_len), "utf8")
        plaintext = plaintext.encode() + pad_chr * pad_len
        encrypted_text = b''.join(
            [aes.encrypt(plaintext[i: i + 16])
             for i in range(0, len(plaintext), 16)]
        )
        return encrypted_text

    LOGIN_KEY_LENGTH = 20
    key = os.urandom(LOGIN_KEY_LENGTH)
    real_key = realkey(key)

    outfile = BytesIO()

    outfile.write(struct.pack("i", 0))
    outfile.write(key)

    while True:
        line = plaintext.readline()
        if not line:
            break
        real_len = len(line)
        pad_len = (int(real_len / 16) + 1) * 16

        outfile.write(struct.pack("i", pad_len))
        x = encode_line(line, real_key, pad_len)
        outfile.write(x)

    outfile.seek(0)
    return outfile


def read_and_decrypt_mylogin_cnf(f):
    """Read and decrypt the contents of .mylogin.cnf.

    This decryption algorithm mimics the code in MySQL's
    mysql_config_editor.cc.

    The login key is 20-bytes of random non-printable ASCII.
    It is written to the actual login path file. It is used
    to generate the real key used in the AES cipher.

    :param f: an I/O object opened in binary mode
    :return: the decrypted login path file
    :rtype: io.BytesIO or None
    """

    # Number of bytes used to store the length of ciphertext.
    MAX_CIPHER_STORE_LEN = 4

    LOGIN_KEY_LEN = 20

    # Move past the unused buffer.
    buf = f.read(4)

    if not buf or len(buf) != 4:
        logger.error('Login path file is blank or incomplete.')
        return None

    # Read the login key.
    key = f.read(LOGIN_KEY_LEN)

    # Generate the real key.
    rkey = [0] * 16
    for i in range(LOGIN_KEY_LEN):
        try:
            rkey[i % 16] ^= ord(key[i:i+1])
        except TypeError:
            # ord() was unable to get the value of the byte.
            logger.error('Unable to generate login path AES key.')
            return None
    rkey = struct.pack('16B', *rkey)

    # Create a bytes buffer to hold the plaintext.
    plaintext = BytesIO()
    aes = pyaes.AESModeOfOperationECB(rkey)

    while True:
        # Read the length of the ciphertext.
        len_buf = f.read(MAX_CIPHER_STORE_LEN)
        if len(len_buf) < MAX_CIPHER_STORE_LEN:
            break
        cipher_len, = struct.unpack("<i", len_buf)

        # Read cipher_len bytes from the file and decrypt.
        cipher = f.read(cipher_len)
        plain = _remove_pad(
            b''.join([aes.decrypt(cipher[i: i + 16])
                      for i in range(0, cipher_len, 16)])
        )
        if plain is False:
            continue
        plaintext.write(plain)

    if plaintext.tell() == 0:
        logger.error('No data successfully decrypted from login path file.')
        return None

    plaintext.seek(0)
    return plaintext


def str_to_bool(s):
    """Convert a string value to its corresponding boolean value."""
    if isinstance(s, bool):
        return s
    elif not isinstance(s, basestring):
        raise TypeError('argument must be a string')

    true_values = ('true', 'on', '1')
    false_values = ('false', 'off', '0')

    if s.lower() in true_values:
        return True
    elif s.lower() in false_values:
        return False
    else:
        raise ValueError('not a recognized boolean value: {0}'.format(s))


def strip_matching_quotes(s):
    """Remove matching, surrounding quotes from a string.

    This is the same logic that ConfigObj uses when parsing config
    values.

    """
    if (isinstance(s, basestring) and len(s) >= 2 and
            s[0] == s[-1] and s[0] in ('"', "'")):
        s = s[1:-1]
    return s


def _remove_pad(line):
    """Remove the pad from the *line*."""
    try:
        # Determine pad length.
        pad_length = ord(line[-1:])
    except TypeError:
        # ord() was unable to get the value of the byte.
        logger.warning('Unable to remove pad.')
        return False

    if pad_length > len(line) or len(set(line[-pad_length:])) != 1:
        # Pad length should be less than or equal to the length of the
        # plaintext. The pad should have a single unique byte.
        logger.warning('Invalid pad found in login path file.')
        return False

    return line[:-pad_length]