summaryrefslogtreecommitdiffstats
path: root/bin/psn
blob: fefcfa1d7b0a7c1464fe39f94536bdd800c9e59e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
#!/usr/bin/env python
#
#  psn -- Linux Process Snapper by Tanel Poder [https://0x.tools]
#  Copyright 2019-2021 Tanel Poder
#
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 2 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License along
#  with this program; if not, write to the Free Software Foundation, Inc.,
#  51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
#  SPDX-License-Identifier: GPL-2.0-or-later
#
# NOTES
#
#  If you have only python3 binary in your system (no "python" command without any suffix),
#  then set up the OS default alternative or symlink so that you'd have a "python" command.
#
#  For example (on RHEL 8):
#    sudo alternatives --set python /usr/bin/python3
#
#  One exception on RHEL 5, use the following line, without quotes (if you have the
#  additional python26 package installed from EPEL):
#
#  "#!/usr/bin/env python26"
#

PSN_VERSION = '1.2.6'

import sys, os, os.path, time, datetime
import re
import sqlite3
import logging
from signal import signal, SIGPIPE, SIG_DFL

sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + '/lib/0xtools')

import argparse
import psnproc as proc, psnreport

# don't print "broken pipe" error when piped into commands like "head"
signal(SIGPIPE,SIG_DFL)

### Defaults ###
default_group_by='comm,state'

### CLI handling ###
parser = argparse.ArgumentParser()
parser.add_argument('-d', dest="sample_seconds", metavar='seconds', type=int, default=5, help='number of seconds to sample for')
parser.add_argument('-p', '--pid', metavar='pid', default=None, nargs='?', help='process id to sample (including all its threads), or process name regex, or omit for system-wide sampling')
# TODO implement -t filtering below
parser.add_argument('-t', '--thread', metavar='tid', default=None, nargs='?', help='thread/task id to sample (not implemented yet)')

parser.add_argument('-r', '--recursive', default=False, action='store_true', help='also sample and report for descendant processes')
parser.add_argument('-a', '--all-states', default=False, action='store_true', help='display threads in all states, including idle ones')

parser.add_argument('--sample-hz', default=20, type=int, help='sample rate in Hz (default: %(default)d)')
parser.add_argument('--ps-hz', default=2, type=int, help='sample rate of new processes in Hz (default: %(default)d)')

parser.add_argument('-o', '--output-sample-db', metavar='filename', default=':memory:', type=str, help='path of sqlite3 database to persist samples to, defaults to in-memory/transient')
parser.add_argument('-i', '--input-sample-db', metavar='filename', default=None, type=str, help='path of sqlite3 database to read samples from instead of actively sampling')

parser.add_argument('-s', '--select', metavar='csv-columns', default='', help='additional columns to report')
parser.add_argument('-g', '--group-by', metavar='csv-columns', default=default_group_by, help='columns to aggregate by in reports')
parser.add_argument('-G', '--append-group-by', metavar='csv-columns', default=None, help='default + additional columns to aggregate by in reports')
#parser.add_argument('-f', '--filter', metavar='filter-sql', default='1=1', help="sample schema SQL to apply to filter report. ('active' and 'idle' keywords may also be used)")

parser.add_argument('--list', default=None, action='store_true', help='list all available columns')

# specify csv sources to be captured in full. use with -o or --output-sample-db to capture proc data for manual analysis
parser.add_argument('--sources', metavar='csv-source-names', default='', help=argparse.SUPPRESS)
# capture all sources in full. use with -o or --output-sample-db', help=argparse.SUPPRESS)
parser.add_argument('--all-sources', default=False, action='store_true', help=argparse.SUPPRESS)

parser.add_argument('--debug', default=False, action='store_true', help=argparse.SUPPRESS)
parser.add_argument('--profile', default=False, action='store_true', help=argparse.SUPPRESS)
parser.add_argument('--show-yourself', default=False, action='store_true', help=argparse.SUPPRESS)

args = parser.parse_args()


### Set up Logging
if args.debug:
    logging.basicConfig(stream=sys.stderr, level=logging.DEBUG, format='%(asctime)s %(message)s')
else:
    logging.basicConfig(stream=sys.stderr, level=logging.INFO, format='%(asctime)s %(message)s')

start_time = time.time()
sample_period_s = 1. / args.sample_hz
sample_ps_period_s = 1. / args.ps_hz

if args.list:
    for p in proc.all_sources:
        print("\n" + p.name)
        print("====================================")
        for c in p.available_columns:
            if c[1]:
                print("%-30s %5s" % (c[0], c[1].__name__))

    exit(0)

if args.all_sources:
    args.sources = proc.all_sources
else:
    args.sources = [s for s in proc.all_sources if s.name in args.sources.split(',')]

args.select = args.select.split(',')
args.group_by = args.group_by.split(',')
if args.append_group_by:
    args.append_group_by = args.append_group_by.split(',')
else:
    args.append_group_by = []

if not args.all_states:
   included_states = 'active'
else:
   included_states = None

### Report definition ###

active_process_sample = psnreport.Report(
    'Active Threads',
    projection=args.select+['samples', 'avg_threads'],
    dimensions=args.group_by+args.append_group_by,
    order=['samples'],
    where=[included_states]
)

system_activity_by_user = psnreport.Report(
    'Active Threads by User',
    projection=args.select+['uid', 'comm', 'samples', 'avg_threads'],
    dimensions=args.group_by+args.append_group_by,
)

reports = [active_process_sample]#, system_activity_by_user]


# consolidate sources required for all reports
sources = {} # source -> set(cols)
for r in reports:
    for s, cs in r.sources.items():
        source_cols = sources.get(s, set())
        source_cols.update(cs)
        sources[s] = source_cols

# add sources requested for full sampling
for full_source in args.sources:
    sources[full_source] = set([c[0] for c in full_source.available_columns])

# update sources with set of reqd columns
for s, cs in sources.items():
    s.set_stored_columns(cs)


def sqlexec(conn, sql):
    logging.debug(sql)
    conn.execute(sql)
    logging.debug('Done')

def sqlexecmany(conn, sql, samples):
    logging.debug(sql)
    conn.execute(sql, samples)
    logging.debug('Done')

### Schema setup ###
if args.input_sample_db:
    conn = sqlite3.connect(args.input_sample_db)
else:
    conn = sqlite3.connect(args.output_sample_db)
    sqlexec(conn, 'pragma synchronous=0')

    def create_table(conn, s):
        def sqlite_type(python_type):
            return {int: 'INTEGER', int: 'INTEGER', float: 'REAL', str: 'TEXT', str: 'TEXT'}.get(python_type)

        sqlite_cols = [(c[0], sqlite_type(c[1])) for c in s.schema_columns]
        sqlite_cols_sql = ', '.join([' '.join(c) for c in sqlite_cols])
        sql = "CREATE TABLE IF NOT EXISTS '%s' (%s)" % (s.name, sqlite_cols_sql)
        sqlexec(conn,sql)

        sql = "CREATE INDEX IF NOT EXISTS '%s_event_idx' ON %s (event_time, pid, task)" % (s.name, s.name)
        sqlexec(conn,sql)

    for s in sources:
        create_table(conn, s)


### Sampling utility functions ###
def get_matching_processes(pid_arg=None, recursive=False):
    # where pid_arg can be a single pid, comma-separated pids, or a regex on process executable name

    process_children = {} # pid -> [children-pids]
    pid_basename = []

    # TODO use os.listdir() when -r or process name regex is not needed
    for line in  os.popen('ps -A -o pid,ppid,comm', 'r').readlines()[1:]:
        tokens = line.split()
        pid = int(tokens[0])
        ppid = int(tokens[1])

        children = process_children.get(ppid, [])
        children.append(pid)
        process_children[ppid] = children

        path_cmd = tokens[2]
        #cmd = os.path.basename(path_cmd)
        cmd = path_cmd
        pid_basename.append((pid, cmd))

    if pid_arg:
        try:
            arg_pids = [int(p) for p in pid_arg.split(',')]
            selected_pids = [p for p, b in pid_basename if p in arg_pids]
        except ValueError as e:
            selected_pids = [p for p, b in pid_basename if re.search(pid_arg, b)]

        # recursive pid walking is not needed when looking for all system pids anyway
        if recursive:
            i = 0
            while i < len(selected_pids):
                children = process_children.get(selected_pids[i], [])
                selected_pids.extend(children)
                i += 1
    else:
        selected_pids = [p for p, b in pid_basename]

    # deduplicate pids & remove pSnapper pid (as pSnapper doesnt consume any resources when it's not running)
    selected_pids=list(set(selected_pids))
    if not args.show_yourself:
        if os.getpid() in selected_pids:
            selected_pids.remove(os.getpid())
    return selected_pids


def get_process_tasks(pids):
    tasks_by_pid = {}
    for pid in pids:
        try:
            tasks_by_pid[pid] = os.listdir('/proc/%s/task' % pid)
        except OSError as e:
            pass # process may nolonger exist by the time we get here
    return tasks_by_pid


def get_utc_now():
    if sys.version_info >= (3, 2):
        return datetime.datetime.now(datetime.timezone.utc)
    else:
        return datetime.datetime.utcnow()


## Main sampling loop ###
def main():
    final_warning='' # show warning text in the end of output, if any

    if args.input_sample_db:
        num_sample_events = int(sqlexec(conn,'SELECT COUNT(DISTINCT(event_time)) FROM stat').fetchone()[0])
        total_measure_s = 0.
    else:
        print("")
        print('Linux Process Snapper v%s by Tanel Poder [https://0x.tools]' % PSN_VERSION)
        print('Sampling /proc/%s for %d seconds...' % (', '.join([s.name for s in sources.keys()]), args.sample_seconds)),
        sys.stdout.flush()
        num_sample_events = 0
        num_ps_samples = 0
        total_measure_s = 0.

        selected_pids = None
        process_tasks = {}
        sample_ioerrors = 0  # detect permissions issues with /proc/access

        sample_seconds = args.sample_seconds

        try:
            while time.time() < start_time + args.sample_seconds:
                measure_begin_time = get_utc_now()
                event_time = measure_begin_time.isoformat()

                # refresh matching pids at a much lower frequency than process sampling as the underlying ps is expensive
                if not selected_pids or time.time() > start_time + num_ps_samples * sample_ps_period_s:
                    selected_pids = get_matching_processes(args.pid, args.recursive)
                    process_tasks = get_process_tasks(selected_pids)
                    num_ps_samples += 1

                    if not selected_pids:
                        print('No matching processes found:', args.pid)
                        sys.exit(1)

                for pid in selected_pids:
                    try:
                        # if any process-level samples fail, don't insert any sample event rows for process or tasks...
                        process_samples = [s.sample(event_time, pid, pid) for s in sources.keys() if s.task_level == False]

                        for s, samples in zip([s for s in sources.keys() if s.task_level == False], process_samples):
                            sqlexecmany(conn, s.insert_sql, samples)

                        for task in process_tasks.get(pid, []):
                            try:
                                # ...but if a task disappears mid-sample simply discard data for that task only
                                task_samples = [s.sample(event_time, pid, task) for s in sources.keys() if s.task_level == True]

                                for s, samples in zip([s for s in sources.keys() if s.task_level == True], task_samples):
                                    # TODO factor out
                                    conn.executemany(s.insert_sql, samples)

                            except IOError as e:
                                sample_ioerrors +=1
                                logging.debug(e)
                                continue

                    except IOError as e:
                        sample_ioerrors +=1
                        logging.debug(e)
                        continue

                conn.commit()

                num_sample_events += 1
                measure_delta = get_utc_now() - measure_begin_time
                total_measure_s += measure_delta.seconds + (measure_delta.microseconds * 0.000001)

                sleep_until = start_time + num_sample_events * sample_period_s
                time.sleep(max(0, sleep_until - time.time()))

            print('finished.')
            print
        except KeyboardInterrupt as e:
            sample_seconds = time.time() - start_time
            print('sampling interrupted.')
            print


    ### Query and report ###
    for r in reports:
        r.output_report(conn)

    print('samples: %s' %  num_sample_events),
    if not args.input_sample_db:
        print('(expected: %s)' % int(sample_seconds * args.sample_hz))
        if num_sample_events < 10:
           final_warning += 'Warning: less than 10 samples captured. Reduce the amount of processes monitored or run pSnapper for longer.'
    else:
        print

    first_table_name = list(sources.keys())[0].name
    print('total processes: %s, threads: %s' % (conn.execute('SELECT COUNT(DISTINCT(pid)) FROM ' + first_table_name).fetchone()[0], conn.execute('SELECT COUNT(DISTINCT(task)) FROM ' + first_table_name).fetchone()[0]))
    print('runtime: %.2f, measure time: %.2f' % (time.time() - start_time, total_measure_s))
    print

    # TODO needs some improvement as some file IO errors come from processes/threads exiting
    # before we get to sample all files of interest under the PID
    if sample_ioerrors >= 100:
        final_warning = 'Warning: %d /proc file accesses failed. Run as root or avoid restricted\n         proc-files like "syscall" or measure only your own processes' % sample_ioerrors

    if final_warning:
        print(final_warning)
        print


## dump python self-profiling data on exit
def exit(r):
    if args.profile:
        import pstats
        p = pstats.Stats('psn-prof.dmp')
        p.sort_stats('cumulative').print_stats(30)
    sys.exit(r)


## let's do this!
if args.profile:
    import cProfile
    cProfile.run('main()', 'psn-prof.dmp')
else:
    main()

exit(0)