summaryrefslogtreecommitdiffstats
path: root/src/boost/libs/mpi/test/python/nonblocking_test.py
blob: 73b451c53f5f4f6c9848399d504bc80b065ddf02 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# (C) Copyright 2007 
# Andreas Kloeckner <inform -at- tiker.net>
#
# Use, modification and distribution is subject to the Boost Software
# License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
# http://www.boost.org/LICENSE_1_0.txt)
#
#  Authors: Andreas Kloeckner




import boost.mpi as mpi
import random
import sys

MAX_GENERATIONS = 20
TAG_DEBUG = 0
TAG_DATA = 1
TAG_TERMINATE = 2
TAG_PROGRESS_REPORT = 3




class TagGroupListener:
    """Class to help listen for only a given set of tags.

    This is contrived: Typicallly you could just listen for 
    mpi.any_tag and filter."""
    def __init__(self, comm, tags):
        self.tags = tags
        self.comm = comm
        self.active_requests = {}

    def wait(self):
        for tag in self.tags:
            if tag not in self.active_requests:
                self.active_requests[tag] = self.comm.irecv(tag=tag)
        requests = mpi.RequestList(self.active_requests.values())
        data, status, index = mpi.wait_any(requests)
        del self.active_requests[status.tag]
        return status, data

    def cancel(self):
        for r in self.active_requests.itervalues():
            r.cancel()
            #r.wait()
        self.active_requests = {}



def rank0():
    sent_histories = (mpi.size-1)*15
    print "sending %d packets on their way" % sent_histories
    send_reqs = mpi.RequestList()
    for i in range(sent_histories):
        dest = random.randrange(1, mpi.size)
        send_reqs.append(mpi.world.isend(dest, TAG_DATA, []))

    mpi.wait_all(send_reqs)

    completed_histories = []
    progress_reports = {}
    dead_kids = []

    tgl = TagGroupListener(mpi.world,
            [TAG_DATA, TAG_DEBUG, TAG_PROGRESS_REPORT, TAG_TERMINATE])

    def is_complete():
        for i in progress_reports.values():
            if i != sent_histories:
                return False
        return len(dead_kids) == mpi.size-1

    while True:
        status, data = tgl.wait()

        if status.tag == TAG_DATA:
            #print "received completed history %s from %d" % (data, status.source)
            completed_histories.append(data)
            if len(completed_histories) == sent_histories:
                print "all histories received, exiting"
                for rank in range(1, mpi.size):
                    mpi.world.send(rank, TAG_TERMINATE, None)
        elif status.tag == TAG_PROGRESS_REPORT:
            progress_reports[len(data)] = progress_reports.get(len(data), 0) + 1
        elif status.tag == TAG_DEBUG:
            print "[DBG %d] %s" % (status.source, data)
        elif status.tag == TAG_TERMINATE:
            dead_kids.append(status.source)
        else:
            print "unexpected tag %d from %d" % (status.tag, status.source)

        if is_complete():
            break

    print "OK"

def comm_rank():
    while True:
        data, status = mpi.world.recv(return_status=True)
        if status.tag == TAG_DATA:
            mpi.world.send(0, TAG_PROGRESS_REPORT, data)
            data.append(mpi.rank)
            if len(data) >= MAX_GENERATIONS:
                dest = 0
            else:
                dest = random.randrange(1, mpi.size)
            mpi.world.send(dest, TAG_DATA, data)
        elif status.tag == TAG_TERMINATE:
            from time import sleep
            mpi.world.send(0, TAG_TERMINATE, 0)
            break
        else:
            print "[DIRECTDBG %d] unexpected tag %d from %d" % (mpi.rank, status.tag, status.source)


def main():
    # this program sends around messages consisting of lists of visited nodes
    # randomly. After MAX_GENERATIONS, they are returned to rank 0.

    if mpi.rank == 0:
        rank0()
    else:
        comm_rank()
        


if __name__ == "__main__":
    main()