diff options
Diffstat (limited to 'tests/features/environment.py')
-rw-r--r-- | tests/features/environment.py | 227 |
1 files changed, 227 insertions, 0 deletions
diff --git a/tests/features/environment.py b/tests/features/environment.py new file mode 100644 index 0000000..50ac5fa --- /dev/null +++ b/tests/features/environment.py @@ -0,0 +1,227 @@ +import copy +import os +import sys +import db_utils as dbutils +import fixture_utils as fixutils +import pexpect +import tempfile +import shutil +import signal + + +from steps import wrappers + + +def before_all(context): + """Set env parameters.""" + env_old = copy.deepcopy(dict(os.environ)) + os.environ["LINES"] = "100" + os.environ["COLUMNS"] = "100" + os.environ["PAGER"] = "cat" + os.environ["EDITOR"] = "ex" + os.environ["VISUAL"] = "ex" + os.environ["PROMPT_TOOLKIT_NO_CPR"] = "1" + + context.package_root = os.path.abspath( + os.path.dirname(os.path.dirname(os.path.dirname(__file__))) + ) + fixture_dir = os.path.join(context.package_root, "tests/features/fixture_data") + + print("package root:", context.package_root) + print("fixture dir:", fixture_dir) + + os.environ["COVERAGE_PROCESS_START"] = os.path.join( + context.package_root, ".coveragerc" + ) + + context.exit_sent = False + + vi = "_".join([str(x) for x in sys.version_info[:3]]) + db_name = context.config.userdata.get("pg_test_db", "pgcli_behave_tests") + db_name_full = f"{db_name}_{vi}" + + # Store get params from config. + context.conf = { + "host": context.config.userdata.get( + "pg_test_host", os.getenv("PGHOST", "localhost") + ), + "user": context.config.userdata.get( + "pg_test_user", os.getenv("PGUSER", "postgres") + ), + "pass": context.config.userdata.get( + "pg_test_pass", os.getenv("PGPASSWORD", None) + ), + "port": context.config.userdata.get( + "pg_test_port", os.getenv("PGPORT", "5432") + ), + "cli_command": ( + context.config.userdata.get("pg_cli_command", None) + or '{python} -c "{startup}"'.format( + python=sys.executable, + startup="; ".join( + [ + "import coverage", + "coverage.process_startup()", + "import pgcli.main", + "pgcli.main.cli(auto_envvar_prefix='BEHAVE')", + ] + ), + ) + ), + "dbname": db_name_full, + "dbname_tmp": db_name_full + "_tmp", + "vi": vi, + "pager_boundary": "---boundary---", + } + os.environ["PAGER"] = "{0} {1} {2}".format( + sys.executable, + os.path.join(context.package_root, "tests/features/wrappager.py"), + context.conf["pager_boundary"], + ) + + # Store old env vars. + context.pgenv = { + "PGDATABASE": os.environ.get("PGDATABASE", None), + "PGUSER": os.environ.get("PGUSER", None), + "PGHOST": os.environ.get("PGHOST", None), + "PGPASSWORD": os.environ.get("PGPASSWORD", None), + "PGPORT": os.environ.get("PGPORT", None), + "XDG_CONFIG_HOME": os.environ.get("XDG_CONFIG_HOME", None), + "PGSERVICEFILE": os.environ.get("PGSERVICEFILE", None), + } + + # Set new env vars. + os.environ["PGDATABASE"] = context.conf["dbname"] + os.environ["PGUSER"] = context.conf["user"] + os.environ["PGHOST"] = context.conf["host"] + os.environ["PGPORT"] = context.conf["port"] + os.environ["PGSERVICEFILE"] = os.path.join(fixture_dir, "mock_pg_service.conf") + + if context.conf["pass"]: + os.environ["PGPASSWORD"] = context.conf["pass"] + else: + if "PGPASSWORD" in os.environ: + del os.environ["PGPASSWORD"] + os.environ["BEHAVE_WARN"] = "moderate" + + context.cn = dbutils.create_db( + context.conf["host"], + context.conf["user"], + context.conf["pass"], + context.conf["dbname"], + context.conf["port"], + ) + context.pgbouncer_available = dbutils.pgbouncer_available( + hostname=context.conf["host"], + password=context.conf["pass"], + username=context.conf["user"], + ) + context.fixture_data = fixutils.read_fixture_files() + + # use temporary directory as config home + context.env_config_home = tempfile.mkdtemp(prefix="pgcli_home_") + os.environ["XDG_CONFIG_HOME"] = context.env_config_home + show_env_changes(env_old, dict(os.environ)) + + +def show_env_changes(env_old, env_new): + """Print out all test-specific env values.""" + print("--- os.environ changed values: ---") + all_keys = env_old.keys() | env_new.keys() + for k in sorted(all_keys): + old_value = env_old.get(k, "") + new_value = env_new.get(k, "") + if new_value and old_value != new_value: + print(f'{k}="{new_value}"') + print("-" * 20) + + +def after_all(context): + """ + Unset env parameters. + """ + dbutils.close_cn(context.cn) + dbutils.drop_db( + context.conf["host"], + context.conf["user"], + context.conf["pass"], + context.conf["dbname"], + context.conf["port"], + ) + + # Remove temp config directory + shutil.rmtree(context.env_config_home) + + # Restore env vars. + for k, v in context.pgenv.items(): + if k in os.environ and v is None: + del os.environ[k] + elif v: + os.environ[k] = v + + +def before_step(context, _): + context.atprompt = False + + +def is_known_problem(scenario): + """TODO: why is this not working in 3.12?""" + if sys.version_info >= (3, 12): + return scenario.name in ( + 'interrupt current query via "ctrl + c"', + "run the cli with --username", + "run the cli with --user", + "run the cli with --port", + ) + return False + + +def before_scenario(context, scenario): + if scenario.name == "list databases": + # not using the cli for that + return + if is_known_problem(scenario): + scenario.skip() + currentdb = None + if "pgbouncer" in scenario.feature.tags: + if context.pgbouncer_available: + os.environ["PGDATABASE"] = "pgbouncer" + os.environ["PGPORT"] = "6432" + currentdb = "pgbouncer" + else: + scenario.skip() + else: + # set env vars back to normal test database + os.environ["PGDATABASE"] = context.conf["dbname"] + os.environ["PGPORT"] = context.conf["port"] + wrappers.run_cli(context, currentdb=currentdb) + wrappers.wait_prompt(context) + + +def after_scenario(context, scenario): + """Cleans up after each scenario completes.""" + if hasattr(context, "cli") and context.cli and not context.exit_sent: + # Quit nicely. + if not getattr(context, "atprompt", False): + dbname = context.currentdb + context.cli.expect_exact(f"{dbname}>", timeout=5) + try: + context.cli.sendcontrol("c") + context.cli.sendcontrol("d") + except Exception as x: + print("Failed cleanup after scenario:") + print(x) + try: + context.cli.expect_exact(pexpect.EOF, timeout=5) + except pexpect.TIMEOUT: + print(f"--- after_scenario {scenario.name}: kill cli") + context.cli.kill(signal.SIGKILL) + if hasattr(context, "tmpfile_sql_help") and context.tmpfile_sql_help: + context.tmpfile_sql_help.close() + context.tmpfile_sql_help = None + + +# # TODO: uncomment to debug a failure +# def after_step(context, step): +# if step.status == "failed": +# import pdb; pdb.set_trace() |