#! /usr/bin/env python
# -*- Mode: python; tab-width: 4; indent-tabs-mode: t -*-
#
# This file is part of the LibreOffice project.
#
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
#

"""
This script is to fix precompiled headers.

This script runs in two modes.
In one mode, it starts with a header
that doesn't compile. If finds the
minimum number of includes in the
header to remove to get a successful
run of the command (i.e. compile).

In the second mode, it starts with a
header that compiles fine, however,
it contains one or more required
include without which it wouldn't
compile, which it identifies.

Usage: ./bin/update_pch_bisect ./vcl/inc/pch/precompiled_vcl.hxx "make vcl.build" --find-required --verbose
"""

import sys
import re
import os
import unittest
import subprocess

SILENT = True
FIND_CONFLICTS = True

IGNORE = 0
GOOD = 1
TEST_ON = 2
TEST_OFF = 3
BAD = 4

def run(command):
    try:
        cmd = command.split(' ', 1)
        status = subprocess.call(cmd, stdout=open(os.devnull, 'w'),
                                 stderr=subprocess.STDOUT, close_fds=True)
        return True if status == 0 else False
    except Exception as e:
        sys.stderr.write('Error: {}\n'.format(e))
        return False

def update_pch(filename, lines, marks):
    with open(filename, 'w') as f:
        for i in xrange(len(marks)):
            mark = marks[i]
            if mark <= TEST_ON:
                f.write(lines[i])
            else:
                f.write('//' + lines[i])

def log(*args, **kwargs):
    global SILENT
    if not SILENT:
        print(*args, **kwargs)

def bisect(lines, marks, min, max, update, command):
    """ Disable half the includes and
        calls the command.
        Depending on the result,
        recurse or return.
    """
    global FIND_CONFLICTS

    log('Bisecting [{}, {}].'.format(min+1, max))
    for i in range(min, max):
        if marks[i] != IGNORE:
            marks[i] = TEST_ON if FIND_CONFLICTS else TEST_OFF

    assume_fail = False
    if not FIND_CONFLICTS:
        on_list = [x for x in marks if x in (TEST_ON, GOOD)]
        assume_fail = (len(on_list) == 0)

    update(lines, marks)
    if assume_fail or not command():
        # Failed
        log('Failed [{}, {}].'.format(min+1, max))
        if min >= max - 1:
            if not FIND_CONFLICTS:
                # Try with this one alone.
                marks[min] = TEST_ON
                update(lines, marks)
                if command():
                    log(' Found @{}: {}'.format(min+1, lines[min].strip('\n')))
                    marks[min] = GOOD
                    return marks
            else:
                log(' Found @{}: {}'.format(min+1, lines[min].strip('\n')))
            # Either way, this one is irrelevant.
            marks[min] = BAD
            return marks

        # Bisect
        for i in range(min, max):
            if marks[i] != IGNORE:
                marks[i] = TEST_OFF if FIND_CONFLICTS else TEST_ON

        half = min + ((max - min) / 2)
        marks = bisect(lines, marks, min, half, update, command)
        marks = bisect(lines, marks, half, max, update, command)
    else:
        # Success
        if FIND_CONFLICTS:
            log(' Good [{}, {}].'.format(min+1, max))
            for i in range(min, max):
                if marks[i] != IGNORE:
                    marks[i] = GOOD

    return marks

def get_filename(line):
    """ Strips the line from the
        '#include' and angled brackets
        and return the filename only.
    """
    return re.sub(r'(.*#include\s*)<(.*)>(.*)', r'\2', line)

def get_marks(lines):
    marks = []
    min = -1
    max = -1
    for i in xrange(len(lines)):
        line = lines[i]
        if line.startswith('#include'):
            marks.append(TEST_ON)
            min = i if min < 0 else min
            max = i
        else:
            marks.append(IGNORE)

    return (marks, min, max+1)

def main():

    global FIND_CONFLICTS
    global SILENT

    filename = sys.argv[1]
    command = sys.argv[2]

    for i in range(3, len(sys.argv)):
        opt = sys.argv[i]
        if opt == '--find-conflicts':
            FIND_CONFLICTS = True
        elif opt == '--find-required':
            FIND_CONFLICTS = False
        elif opt == '--verbose':
            SILENT = False
        else:
            sys.stderr.write('Error: Unknown option [{}].\n'.format(opt))
            return 1

    lines = []
    with open(filename) as f:
        lines = f.readlines()

    (marks, min, max) = get_marks(lines)

    # Test preconditions.
    log('Validating all-excluded state...')
    for i in range(min, max):
        if marks[i] != IGNORE:
            marks[i] = TEST_OFF
    update_pch(filename, lines, marks)
    res = run(command)

    if FIND_CONFLICTS:
        # Must build all excluded.
        if not res:
            sys.stderr.write("Error: broken state when all excluded, fix first and try again.")
            return 1
    else:
        # If builds all excluded, we can't bisect.
        if res:
            sys.stderr.write("Done: in good state when all excluded, nothing to do.")
            return 1

        # Must build all included.
        log('Validating all-included state...')
        for i in range(min, max):
            if marks[i] != IGNORE:
                marks[i] = TEST_ON
        update_pch(filename, lines, marks)
        if not run(command):
            sys.stderr.write("Error: broken state without modifying, fix first and try again.")
            return 1

    marks = bisect(lines, marks, min, max+1,
                   lambda l, m: update_pch(filename, l, m),
                   lambda: run(command))
    if not FIND_CONFLICTS:
        # Simplify further, as sometimes we can have
        # false positives due to the benign nature
        # of includes that are not absolutely required.
        for i in xrange(len(marks)):
            if marks[i] == GOOD:
                marks[i] = TEST_OFF
                update_pch(filename, lines, marks)
                if not run(command):
                    # Revert.
                    marks[i] = GOOD
                else:
                    marks[i] = BAD
            elif marks[i] == TEST_OFF:
                marks[i] = TEST_ON

    update_pch(filename, lines, marks)

    log('')
    for i in xrange(len(marks)):
        if marks[i] == (BAD if FIND_CONFLICTS else GOOD):
            print("'{}',".format(get_filename(lines[i].strip('\n'))))

    return 0

if __name__ == '__main__':

    if len(sys.argv) in (3, 4, 5):
        status = main()
        sys.exit(status)

    print('Usage: {} <pch> <command> [--find-conflicts]|[--find-required] [--verbose]\n'.format(sys.argv[0]))
    print('    --find-conflicts - Finds all conflicting includes. (Default)')
    print('                       Must compile without any includes.\n')
    print('    --find-required - Finds all required includes.')
    print('                      Must compile with all includes.\n')
    print('    --verbose - print noisy progress.')
    print('Example: ./bin/update_pch_bisect ./vcl/inc/pch/precompiled_vcl.hxx "make vcl.build" --find-required --verbose')
    print('\nRunning unit-tests...')


class TestBisectConflict(unittest.TestCase):
    TEST = """ /* Test header. */
#include <memory>
#include <set>
#include <algorithm>
#include <vector>
/* blah blah */
"""
    BAD_LINE = "#include <bad>"

    def setUp(self):
        global FIND_CONFLICTS
        FIND_CONFLICTS = True

    def _update_func(self, lines, marks):
        self.lines = []
        for i in xrange(len(marks)):
            mark = marks[i]
            if mark <= TEST_ON:
                self.lines.append(lines[i])
            else:
                self.lines.append('//' + lines[i])

    def _test_func(self):
        """ Command function called by bisect.
            Returns True on Success, False on failure.
        """
        # If the bad line is still there, fail.
        return self.BAD_LINE not in self.lines

    def test_success(self):
        lines = self.TEST.split('\n')
        (marks, min, max) = get_marks(lines)
        marks = bisect(lines, marks, min, max,
                       lambda l, m: self._update_func(l, m),
                       lambda: self._test_func())
        self.assertTrue(BAD not in marks)

    def test_conflict(self):
        lines = self.TEST.split('\n')
        for pos in xrange(len(lines) + 1):
            lines = self.TEST.split('\n')
            lines.insert(pos, self.BAD_LINE)
            (marks, min, max) = get_marks(lines)

            marks = bisect(lines, marks, min, max,
                           lambda l, m: self._update_func(l, m),
                           lambda: self._test_func())
            for i in xrange(len(marks)):
                if i == pos:
                    self.assertEqual(BAD, marks[i])
                else:
                    self.assertNotEqual(BAD, marks[i])

class TestBisectRequired(unittest.TestCase):
    TEST = """#include <algorithm>
#include <set>
#include <map>
#include <vector>
"""
    REQ_LINE = "#include <req>"

    def setUp(self):
        global FIND_CONFLICTS
        FIND_CONFLICTS = False

    def _update_func(self, lines, marks):
        self.lines = []
        for i in xrange(len(marks)):
            mark = marks[i]
            if mark <= TEST_ON:
                self.lines.append(lines[i])
            else:
                self.lines.append('//' + lines[i])

    def _test_func(self):
        """ Command function called by bisect.
            Returns True on Success, False on failure.
        """
        # If the required line is not there, fail.
        found = self.REQ_LINE in self.lines
        return found

    def test_success(self):
        lines = self.TEST.split('\n')
        (marks, min, max) = get_marks(lines)
        marks = bisect(lines, marks, min, max,
                       lambda l, m: self._update_func(l, m),
                       lambda: self._test_func())
        self.assertTrue(GOOD not in marks)

    def test_required(self):
        lines = self.TEST.split('\n')
        for pos in xrange(len(lines) + 1):
            lines = self.TEST.split('\n')
            lines.insert(pos, self.REQ_LINE)
            (marks, min, max) = get_marks(lines)

            marks = bisect(lines, marks, min, max,
                           lambda l, m: self._update_func(l, m),
                           lambda: self._test_func())
            for i in xrange(len(marks)):
                if i == pos:
                    self.assertEqual(GOOD, marks[i])
                else:
                    self.assertNotEqual(GOOD, marks[i])

unittest.main()

# vim: set et sw=4 ts=4 expandtab: