diff options
Diffstat (limited to '')
-rwxr-xr-x | bin/psn | 384 |
1 files changed, 384 insertions, 0 deletions
@@ -0,0 +1,384 @@ +#!/usr/bin/env python +# +# psn -- Linux Process Snapper by Tanel Poder [https://0x.tools] +# Copyright 2019-2021 Tanel Poder +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with this program; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# SPDX-License-Identifier: GPL-2.0-or-later +# +# NOTES +# +# If you have only python3 binary in your system (no "python" command without any suffix), +# then set up the OS default alternative or symlink so that you'd have a "python" command. +# +# For example (on RHEL 8): +# sudo alternatives --set python /usr/bin/python3 +# +# One exception on RHEL 5, use the following line, without quotes (if you have the +# additional python26 package installed from EPEL): +# +# "#!/usr/bin/env python26" +# + +PSN_VERSION = '1.2.6' + +import sys, os, os.path, time, datetime +import re +import sqlite3 +import logging +from signal import signal, SIGPIPE, SIG_DFL + +sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + '/lib/0xtools') + +import argparse +import psnproc as proc, psnreport + +# don't print "broken pipe" error when piped into commands like "head" +signal(SIGPIPE,SIG_DFL) + +### Defaults ### +default_group_by='comm,state' + +### CLI handling ### +parser = argparse.ArgumentParser() +parser.add_argument('-d', dest="sample_seconds", metavar='seconds', type=int, default=5, help='number of seconds to sample for') +parser.add_argument('-p', '--pid', metavar='pid', default=None, nargs='?', help='process id to sample (including all its threads), or process name regex, or omit for system-wide sampling') +# TODO implement -t filtering below +parser.add_argument('-t', '--thread', metavar='tid', default=None, nargs='?', help='thread/task id to sample (not implemented yet)') + +parser.add_argument('-r', '--recursive', default=False, action='store_true', help='also sample and report for descendant processes') +parser.add_argument('-a', '--all-states', default=False, action='store_true', help='display threads in all states, including idle ones') + +parser.add_argument('--sample-hz', default=20, type=int, help='sample rate in Hz (default: %(default)d)') +parser.add_argument('--ps-hz', default=2, type=int, help='sample rate of new processes in Hz (default: %(default)d)') + +parser.add_argument('-o', '--output-sample-db', metavar='filename', default=':memory:', type=str, help='path of sqlite3 database to persist samples to, defaults to in-memory/transient') +parser.add_argument('-i', '--input-sample-db', metavar='filename', default=None, type=str, help='path of sqlite3 database to read samples from instead of actively sampling') + +parser.add_argument('-s', '--select', metavar='csv-columns', default='', help='additional columns to report') +parser.add_argument('-g', '--group-by', metavar='csv-columns', default=default_group_by, help='columns to aggregate by in reports') +parser.add_argument('-G', '--append-group-by', metavar='csv-columns', default=None, help='default + additional columns to aggregate by in reports') +#parser.add_argument('-f', '--filter', metavar='filter-sql', default='1=1', help="sample schema SQL to apply to filter report. ('active' and 'idle' keywords may also be used)") + +parser.add_argument('--list', default=None, action='store_true', help='list all available columns') + +# specify csv sources to be captured in full. use with -o or --output-sample-db to capture proc data for manual analysis +parser.add_argument('--sources', metavar='csv-source-names', default='', help=argparse.SUPPRESS) +# capture all sources in full. use with -o or --output-sample-db', help=argparse.SUPPRESS) +parser.add_argument('--all-sources', default=False, action='store_true', help=argparse.SUPPRESS) + +parser.add_argument('--debug', default=False, action='store_true', help=argparse.SUPPRESS) +parser.add_argument('--profile', default=False, action='store_true', help=argparse.SUPPRESS) +parser.add_argument('--show-yourself', default=False, action='store_true', help=argparse.SUPPRESS) + +args = parser.parse_args() + + +### Set up Logging +if args.debug: + logging.basicConfig(stream=sys.stderr, level=logging.DEBUG, format='%(asctime)s %(message)s') +else: + logging.basicConfig(stream=sys.stderr, level=logging.INFO, format='%(asctime)s %(message)s') + +start_time = time.time() +sample_period_s = 1. / args.sample_hz +sample_ps_period_s = 1. / args.ps_hz + +if args.list: + for p in proc.all_sources: + print("\n" + p.name) + print("====================================") + for c in p.available_columns: + if c[1]: + print("%-30s %5s" % (c[0], c[1].__name__)) + + exit(0) + +if args.all_sources: + args.sources = proc.all_sources +else: + args.sources = [s for s in proc.all_sources if s.name in args.sources.split(',')] + +args.select = args.select.split(',') +args.group_by = args.group_by.split(',') +if args.append_group_by: + args.append_group_by = args.append_group_by.split(',') +else: + args.append_group_by = [] + +if not args.all_states: + included_states = 'active' +else: + included_states = None + +### Report definition ### + +active_process_sample = psnreport.Report( + 'Active Threads', + projection=args.select+['samples', 'avg_threads'], + dimensions=args.group_by+args.append_group_by, + order=['samples'], + where=[included_states] +) + +system_activity_by_user = psnreport.Report( + 'Active Threads by User', + projection=args.select+['uid', 'comm', 'samples', 'avg_threads'], + dimensions=args.group_by+args.append_group_by, +) + +reports = [active_process_sample]#, system_activity_by_user] + + +# consolidate sources required for all reports +sources = {} # source -> set(cols) +for r in reports: + for s, cs in r.sources.items(): + source_cols = sources.get(s, set()) + source_cols.update(cs) + sources[s] = source_cols + +# add sources requested for full sampling +for full_source in args.sources: + sources[full_source] = set([c[0] for c in full_source.available_columns]) + +# update sources with set of reqd columns +for s, cs in sources.items(): + s.set_stored_columns(cs) + + +def sqlexec(conn, sql): + logging.debug(sql) + conn.execute(sql) + logging.debug('Done') + +def sqlexecmany(conn, sql, samples): + logging.debug(sql) + conn.execute(sql, samples) + logging.debug('Done') + +### Schema setup ### +if args.input_sample_db: + conn = sqlite3.connect(args.input_sample_db) +else: + conn = sqlite3.connect(args.output_sample_db) + sqlexec(conn, 'pragma synchronous=0') + + def create_table(conn, s): + def sqlite_type(python_type): + return {int: 'INTEGER', int: 'INTEGER', float: 'REAL', str: 'TEXT', str: 'TEXT'}.get(python_type) + + sqlite_cols = [(c[0], sqlite_type(c[1])) for c in s.schema_columns] + sqlite_cols_sql = ', '.join([' '.join(c) for c in sqlite_cols]) + sql = "CREATE TABLE IF NOT EXISTS '%s' (%s)" % (s.name, sqlite_cols_sql) + sqlexec(conn,sql) + + sql = "CREATE INDEX IF NOT EXISTS '%s_event_idx' ON %s (event_time, pid, task)" % (s.name, s.name) + sqlexec(conn,sql) + + for s in sources: + create_table(conn, s) + + +### Sampling utility functions ### +def get_matching_processes(pid_arg=None, recursive=False): + # where pid_arg can be a single pid, comma-separated pids, or a regex on process executable name + + process_children = {} # pid -> [children-pids] + pid_basename = [] + + # TODO use os.listdir() when -r or process name regex is not needed + for line in os.popen('ps -A -o pid,ppid,comm', 'r').readlines()[1:]: + tokens = line.split() + pid = int(tokens[0]) + ppid = int(tokens[1]) + + children = process_children.get(ppid, []) + children.append(pid) + process_children[ppid] = children + + path_cmd = tokens[2] + #cmd = os.path.basename(path_cmd) + cmd = path_cmd + pid_basename.append((pid, cmd)) + + if pid_arg: + try: + arg_pids = [int(p) for p in pid_arg.split(',')] + selected_pids = [p for p, b in pid_basename if p in arg_pids] + except ValueError as e: + selected_pids = [p for p, b in pid_basename if re.search(pid_arg, b)] + + # recursive pid walking is not needed when looking for all system pids anyway + if recursive: + i = 0 + while i < len(selected_pids): + children = process_children.get(selected_pids[i], []) + selected_pids.extend(children) + i += 1 + else: + selected_pids = [p for p, b in pid_basename] + + # deduplicate pids & remove pSnapper pid (as pSnapper doesnt consume any resources when it's not running) + selected_pids=list(set(selected_pids)) + if not args.show_yourself: + if os.getpid() in selected_pids: + selected_pids.remove(os.getpid()) + return selected_pids + + +def get_process_tasks(pids): + tasks_by_pid = {} + for pid in pids: + try: + tasks_by_pid[pid] = os.listdir('/proc/%s/task' % pid) + except OSError as e: + pass # process may nolonger exist by the time we get here + return tasks_by_pid + + +def get_utc_now(): + if sys.version_info >= (3, 2): + return datetime.datetime.now(datetime.timezone.utc) + else: + return datetime.datetime.utcnow() + + +## Main sampling loop ### +def main(): + final_warning='' # show warning text in the end of output, if any + + if args.input_sample_db: + num_sample_events = int(sqlexec(conn,'SELECT COUNT(DISTINCT(event_time)) FROM stat').fetchone()[0]) + total_measure_s = 0. + else: + print("") + print('Linux Process Snapper v%s by Tanel Poder [https://0x.tools]' % PSN_VERSION) + print('Sampling /proc/%s for %d seconds...' % (', '.join([s.name for s in sources.keys()]), args.sample_seconds)), + sys.stdout.flush() + num_sample_events = 0 + num_ps_samples = 0 + total_measure_s = 0. + + selected_pids = None + process_tasks = {} + sample_ioerrors = 0 # detect permissions issues with /proc/access + + sample_seconds = args.sample_seconds + + try: + while time.time() < start_time + args.sample_seconds: + measure_begin_time = get_utc_now() + event_time = measure_begin_time.isoformat() + + # refresh matching pids at a much lower frequency than process sampling as the underlying ps is expensive + if not selected_pids or time.time() > start_time + num_ps_samples * sample_ps_period_s: + selected_pids = get_matching_processes(args.pid, args.recursive) + process_tasks = get_process_tasks(selected_pids) + num_ps_samples += 1 + + if not selected_pids: + print('No matching processes found:', args.pid) + sys.exit(1) + + for pid in selected_pids: + try: + # if any process-level samples fail, don't insert any sample event rows for process or tasks... + process_samples = [s.sample(event_time, pid, pid) for s in sources.keys() if s.task_level == False] + + for s, samples in zip([s for s in sources.keys() if s.task_level == False], process_samples): + sqlexecmany(conn, s.insert_sql, samples) + + for task in process_tasks.get(pid, []): + try: + # ...but if a task disappears mid-sample simply discard data for that task only + task_samples = [s.sample(event_time, pid, task) for s in sources.keys() if s.task_level == True] + + for s, samples in zip([s for s in sources.keys() if s.task_level == True], task_samples): + # TODO factor out + conn.executemany(s.insert_sql, samples) + + except IOError as e: + sample_ioerrors +=1 + logging.debug(e) + continue + + except IOError as e: + sample_ioerrors +=1 + logging.debug(e) + continue + + conn.commit() + + num_sample_events += 1 + measure_delta = get_utc_now() - measure_begin_time + total_measure_s += measure_delta.seconds + (measure_delta.microseconds * 0.000001) + + sleep_until = start_time + num_sample_events * sample_period_s + time.sleep(max(0, sleep_until - time.time())) + + print('finished.') + print + except KeyboardInterrupt as e: + sample_seconds = time.time() - start_time + print('sampling interrupted.') + print + + + ### Query and report ### + for r in reports: + r.output_report(conn) + + print('samples: %s' % num_sample_events), + if not args.input_sample_db: + print('(expected: %s)' % int(sample_seconds * args.sample_hz)) + if num_sample_events < 10: + final_warning += 'Warning: less than 10 samples captured. Reduce the amount of processes monitored or run pSnapper for longer.' + else: + print + + first_table_name = list(sources.keys())[0].name + print('total processes: %s, threads: %s' % (conn.execute('SELECT COUNT(DISTINCT(pid)) FROM ' + first_table_name).fetchone()[0], conn.execute('SELECT COUNT(DISTINCT(task)) FROM ' + first_table_name).fetchone()[0])) + print('runtime: %.2f, measure time: %.2f' % (time.time() - start_time, total_measure_s)) + print + + # TODO needs some improvement as some file IO errors come from processes/threads exiting + # before we get to sample all files of interest under the PID + if sample_ioerrors >= 100: + final_warning = 'Warning: %d /proc file accesses failed. Run as root or avoid restricted\n proc-files like "syscall" or measure only your own processes' % sample_ioerrors + + if final_warning: + print(final_warning) + print + + +## dump python self-profiling data on exit +def exit(r): + if args.profile: + import pstats + p = pstats.Stats('psn-prof.dmp') + p.sort_stats('cumulative').print_stats(30) + sys.exit(r) + + +## let's do this! +if args.profile: + import cProfile + cProfile.run('main()', 'psn-prof.dmp') +else: + main() + +exit(0) |