diff options
Diffstat (limited to 'test/features/environment.py')
-rw-r--r-- | test/features/environment.py | 133 |
1 files changed, 51 insertions, 82 deletions
diff --git a/test/features/environment.py b/test/features/environment.py index 1ea0f08..a3d3764 100644 --- a/test/features/environment.py +++ b/test/features/environment.py @@ -9,96 +9,72 @@ import pexpect from steps.wrappers import run_cli, wait_prompt -test_log_file = os.path.join(os.environ['HOME'], '.mycli.test.log') +test_log_file = os.path.join(os.environ["HOME"], ".mycli.test.log") -SELF_CONNECTING_FEATURES = ( - 'test/features/connection.feature', -) +SELF_CONNECTING_FEATURES = ("test/features/connection.feature",) -MY_CNF_PATH = os.path.expanduser('~/.my.cnf') -MY_CNF_BACKUP_PATH = f'{MY_CNF_PATH}.backup' -MYLOGIN_CNF_PATH = os.path.expanduser('~/.mylogin.cnf') -MYLOGIN_CNF_BACKUP_PATH = f'{MYLOGIN_CNF_PATH}.backup' +MY_CNF_PATH = os.path.expanduser("~/.my.cnf") +MY_CNF_BACKUP_PATH = f"{MY_CNF_PATH}.backup" +MYLOGIN_CNF_PATH = os.path.expanduser("~/.mylogin.cnf") +MYLOGIN_CNF_BACKUP_PATH = f"{MYLOGIN_CNF_PATH}.backup" def get_db_name_from_context(context): - return context.config.userdata.get( - 'my_test_db', None - ) or "mycli_behave_tests" - + return context.config.userdata.get("my_test_db", None) or "mycli_behave_tests" def before_all(context): """Set env parameters.""" - os.environ['LINES'] = "100" - os.environ['COLUMNS'] = "100" - os.environ['EDITOR'] = 'ex' - os.environ['LC_ALL'] = 'en_US.UTF-8' - os.environ['PROMPT_TOOLKIT_NO_CPR'] = '1' - os.environ['MYCLI_HISTFILE'] = os.devnull + os.environ["LINES"] = "100" + os.environ["COLUMNS"] = "100" + os.environ["EDITOR"] = "ex" + os.environ["LC_ALL"] = "en_US.UTF-8" + os.environ["PROMPT_TOOLKIT_NO_CPR"] = "1" + os.environ["MYCLI_HISTFILE"] = os.devnull - test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) - login_path_file = os.path.join(test_dir, 'mylogin.cnf') -# os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file + # test_dir = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) + # login_path_file = os.path.join(test_dir, "mylogin.cnf") + # os.environ['MYSQL_TEST_LOGIN_FILE'] = login_path_file - context.package_root = os.path.abspath( - os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + context.package_root = os.path.abspath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) - os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root, - '.coveragerc') + os.environ["COVERAGE_PROCESS_START"] = os.path.join(context.package_root, ".coveragerc") context.exit_sent = False - vi = '_'.join([str(x) for x in sys.version_info[:3]]) + vi = "_".join([str(x) for x in sys.version_info[:3]]) db_name = get_db_name_from_context(context) - db_name_full = '{0}_{1}'.format(db_name, vi) + db_name_full = "{0}_{1}".format(db_name, vi) # Store get params from config/environment variables context.conf = { - 'host': context.config.userdata.get( - 'my_test_host', - os.getenv('PYTEST_HOST', 'localhost') - ), - 'port': context.config.userdata.get( - 'my_test_port', - int(os.getenv('PYTEST_PORT', '3306')) - ), - 'user': context.config.userdata.get( - 'my_test_user', - os.getenv('PYTEST_USER', 'root') - ), - 'pass': context.config.userdata.get( - 'my_test_pass', - os.getenv('PYTEST_PASSWORD', None) - ), - 'cli_command': context.config.userdata.get( - 'my_cli_command', None) or - sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"', - 'dbname': db_name, - 'dbname_tmp': db_name_full + '_tmp', - 'vi': vi, - 'pager_boundary': '---boundary---', + "host": context.config.userdata.get("my_test_host", os.getenv("PYTEST_HOST", "localhost")), + "port": context.config.userdata.get("my_test_port", int(os.getenv("PYTEST_PORT", "3306"))), + "user": context.config.userdata.get("my_test_user", os.getenv("PYTEST_USER", "root")), + "pass": context.config.userdata.get("my_test_pass", os.getenv("PYTEST_PASSWORD", None)), + "cli_command": context.config.userdata.get("my_cli_command", None) + or sys.executable + ' -c "import coverage ; coverage.process_startup(); import mycli.main; mycli.main.cli()"', + "dbname": db_name, + "dbname_tmp": db_name_full + "_tmp", + "vi": vi, + "pager_boundary": "---boundary---", } _, my_cnf = mkstemp() - with open(my_cnf, 'w') as f: + with open(my_cnf, "w") as f: f.write( - '[client]\n' - 'pager={0} {1} {2}\n'.format( - sys.executable, os.path.join(context.package_root, - 'test/features/wrappager.py'), - context.conf['pager_boundary']) + "[client]\n" "pager={0} {1} {2}\n".format( + sys.executable, os.path.join(context.package_root, "test/features/wrappager.py"), context.conf["pager_boundary"] + ) ) - context.conf['defaults-file'] = my_cnf - context.conf['myclirc'] = os.path.join(context.package_root, 'test', - 'myclirc') + context.conf["defaults-file"] = my_cnf + context.conf["myclirc"] = os.path.join(context.package_root, "test", "myclirc") - context.cn = dbutils.create_db(context.conf['host'], context.conf['port'], - context.conf['user'], - context.conf['pass'], - context.conf['dbname']) + context.cn = dbutils.create_db( + context.conf["host"], context.conf["port"], context.conf["user"], context.conf["pass"], context.conf["dbname"] + ) context.fixture_data = fixutils.read_fixture_files() @@ -106,12 +82,10 @@ def before_all(context): def after_all(context): """Unset env parameters.""" dbutils.close_cn(context.cn) - dbutils.drop_db(context.conf['host'], context.conf['port'], - context.conf['user'], context.conf['pass'], - context.conf['dbname']) + dbutils.drop_db(context.conf["host"], context.conf["port"], context.conf["user"], context.conf["pass"], context.conf["dbname"]) # Restore env vars. - #for k, v in context.pgenv.items(): + # for k, v in context.pgenv.items(): # if k in os.environ and v is None: # del os.environ[k] # elif v: @@ -123,8 +97,8 @@ def before_step(context, _): def before_scenario(context, arg): - with open(test_log_file, 'w') as f: - f.write('') + with open(test_log_file, "w") as f: + f.write("") if arg.location.filename not in SELF_CONNECTING_FEATURES: run_cli(context) wait_prompt(context) @@ -140,23 +114,18 @@ def after_scenario(context, _): """Cleans up after each test complete.""" with open(test_log_file) as f: for line in f: - if 'error' in line.lower(): - raise RuntimeError(f'Error in log file: {line}') + if "error" in line.lower(): + raise RuntimeError(f"Error in log file: {line}") - if hasattr(context, 'cli') and not context.exit_sent: + if hasattr(context, "cli") and not context.exit_sent: # Quit nicely. if not context.atprompt: - user = context.conf['user'] - host = context.conf['host'] + user = context.conf["user"] + host = context.conf["host"] dbname = context.currentdb - context.cli.expect_exact( - '{0}@{1}:{2}>'.format( - user, host, dbname - ), - timeout=5 - ) - context.cli.sendcontrol('c') - context.cli.sendcontrol('d') + context.cli.expect_exact("{0}@{1}:{2}>".format(user, host, dbname), timeout=5) + context.cli.sendcontrol("c") + context.cli.sendcontrol("d") context.cli.expect_exact(pexpect.EOF, timeout=5) if os.path.exists(MY_CNF_BACKUP_PATH): |