summaryrefslogtreecommitdiffstats
path: root/test/unittests/test_prun.py
blob: 7e987bf1d7842a8eb598a2b5bab329bcc31b4b4f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import typing

import crmsh.constants
import crmsh.prun.prun
import crmsh.prun.runner

import unittest
from unittest import mock


class TestPrun(unittest.TestCase):
    @mock.patch("os.geteuid")
    @mock.patch("crmsh.userdir.getuser")
    @mock.patch("crmsh.prun.prun._is_local_host")
    @mock.patch("crmsh.user_of_host.UserOfHost.user_pair_for_ssh")
    @mock.patch("crmsh.prun.runner.Runner.run")
    @mock.patch("crmsh.prun.runner.Runner.add_task")
    def test_prun(
            self,
            mock_runner_add_task: mock.MagicMock,
            mock_runner_run: mock.MagicMock,
            mock_user_pair_for_ssh: mock.MagicMock,
            mock_is_local_host: mock.MagicMock,
            mock_getuser: mock.MagicMock,
            mock_geteuid: mock.MagicMock,
    ):
        host_cmdline = {"host1": "foo", "host2": "bar"}
        mock_user_pair_for_ssh.return_value = "alice", "bob"
        mock_is_local_host.return_value = False
        mock_getuser.return_value = 'root'
        mock_geteuid.return_value = 0
        results = crmsh.prun.prun.prun(host_cmdline)
        mock_user_pair_for_ssh.assert_has_calls([
            mock.call("host1"),
            mock.call("host2"),
        ])
        mock_is_local_host.assert_has_calls([
            mock.call("host1"),
            mock.call("host2"),
        ])
        mock_runner_add_task.assert_has_calls([
            mock.call(TaskArgumentsEq(
                ['su', 'alice', '--login', '-c', 'ssh {} bob@host1 sudo -H /bin/sh'.format(crmsh.constants.SSH_OPTION)],
                b'foo',
                stdout=crmsh.prun.runner.Task.Capture,
                stderr=crmsh.prun.runner.Task.Capture,
                context={"host": 'host1', "ssh_user": 'bob'},
            )),
            mock.call(TaskArgumentsEq(
                ['su', 'alice', '--login', '-c', 'ssh {} bob@host2 sudo -H /bin/sh'.format(crmsh.constants.SSH_OPTION)],
                b'bar',
                stdout=crmsh.prun.runner.Task.Capture,
                stderr=crmsh.prun.runner.Task.Capture,
                context={"host": 'host2', "ssh_user": 'bob'},
            )),
        ])
        mock_runner_run.assert_called_once()
        self.assertTrue(isinstance(results, typing.Dict))
        self.assertSetEqual({"host1", "host2"}, set(results.keys()))

    @mock.patch("os.geteuid")
    @mock.patch("crmsh.userdir.getuser")
    @mock.patch("crmsh.prun.prun._is_local_host")
    @mock.patch("crmsh.user_of_host.UserOfHost.user_pair_for_ssh")
    @mock.patch("crmsh.prun.runner.Runner.run")
    @mock.patch("crmsh.prun.runner.Runner.add_task")
    def test_prun_root(
            self,
            mock_runner_add_task: mock.MagicMock,
            mock_runner_run: mock.MagicMock,
            mock_user_pair_for_ssh: mock.MagicMock,
            mock_is_local_host: mock.MagicMock,
            mock_getuser: mock.MagicMock,
            mock_geteuid: mock.MagicMock,
    ):
        host_cmdline = {"host1": "foo", "host2": "bar"}
        mock_user_pair_for_ssh.return_value = "root", "root"
        mock_is_local_host.return_value = False
        mock_getuser.return_value = 'root'
        mock_geteuid.return_value = 0
        results = crmsh.prun.prun.prun(host_cmdline)
        mock_geteuid.assert_not_called()
        mock_user_pair_for_ssh.assert_has_calls([
            mock.call("host1"),
            mock.call("host2"),
        ])
        mock_is_local_host.assert_has_calls([
            mock.call("host1"),
            mock.call("host2"),
        ])
        mock_runner_add_task.assert_has_calls([
            mock.call(TaskArgumentsEq(
                ['/bin/sh', '-c', 'ssh {} root@host1 sudo -H /bin/sh'.format(crmsh.constants.SSH_OPTION)],
                b'foo',
                stdout=crmsh.prun.runner.Task.Capture,
                stderr=crmsh.prun.runner.Task.Capture,
                context={"host": 'host1', "ssh_user": 'root'},
            )),
            mock.call(TaskArgumentsEq(
                ['/bin/sh', '-c', 'ssh {} root@host2 sudo -H /bin/sh'.format(crmsh.constants.SSH_OPTION)],
                b'bar',
                stdout=crmsh.prun.runner.Task.Capture,
                stderr=crmsh.prun.runner.Task.Capture,
                context={"host": 'host2', "ssh_user": 'root'},
            )),
        ])
        mock_runner_run.assert_called_once()
        self.assertTrue(isinstance(results, typing.Dict))
        self.assertSetEqual({"host1", "host2"}, set(results.keys()))

    @mock.patch("os.geteuid")
    @mock.patch("crmsh.userdir.getuser")
    @mock.patch("crmsh.prun.prun._is_local_host")
    @mock.patch("crmsh.user_of_host.UserOfHost.user_pair_for_ssh")
    @mock.patch("crmsh.prun.runner.Runner.run")
    @mock.patch("crmsh.prun.runner.Runner.add_task")
    def test_prun_localhost(
            self,
            mock_runner_add_task: mock.MagicMock,
            mock_runner_run: mock.MagicMock,
            mock_user_pair_for_ssh: mock.MagicMock,
            mock_is_local_host: mock.MagicMock,
            mock_getuser: mock.MagicMock,
            mock_geteuid: mock.MagicMock,
    ):
        host_cmdline = {"host1": "foo"}
        #mock_user_pair_for_ssh.return_value = "alice", "bob"
        mock_is_local_host.return_value = True
        mock_getuser.return_value = 'root'
        mock_geteuid.return_value = 0
        results = crmsh.prun.prun.prun(host_cmdline)
        mock_user_pair_for_ssh.assert_not_called()
        mock_is_local_host.assert_called_once_with('host1')
        mock_runner_add_task.assert_called_once_with(
            TaskArgumentsEq(
                ['/bin/sh'],
                b'foo',
                stdout=crmsh.prun.runner.Task.Capture,
                stderr=crmsh.prun.runner.Task.Capture,
                context={"host": 'host1', "ssh_user": 'root'},
            )
        )
        mock_user_pair_for_ssh.assert_not_called()
        mock_runner_run.assert_called_once()
        self.assertTrue(isinstance(results, typing.Dict))
        self.assertSetEqual({"host1"}, set(results.keys()))


class TaskArgumentsEq(crmsh.prun.runner.Task):
    def __eq__(self, other):
        if not isinstance(other, crmsh.prun.runner.Task):
            return False
        return self.args == other.args \
            and self.input == other.input \
            and self.stdout_config == other.stdout_config \
            and self.stderr_config == other.stderr_config \
            and self.context == other.context