1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
|
# (C) Copyright 2007
# Andreas Kloeckner <inform -at- tiker.net>
#
# Use, modification and distribution is subject to the Boost Software
# License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
# http://www.boost.org/LICENSE_1_0.txt)
#
# Authors: Andreas Kloeckner
import boost.mpi as mpi
import random
import sys
MAX_GENERATIONS = 20
TAG_DEBUG = 0
TAG_DATA = 1
TAG_TERMINATE = 2
TAG_PROGRESS_REPORT = 3
class TagGroupListener:
"""Class to help listen for only a given set of tags.
This is contrived: Typicallly you could just listen for
mpi.any_tag and filter."""
def __init__(self, comm, tags):
self.tags = tags
self.comm = comm
self.active_requests = {}
def wait(self):
for tag in self.tags:
if tag not in self.active_requests:
self.active_requests[tag] = self.comm.irecv(tag=tag)
requests = mpi.RequestList(self.active_requests.values())
data, status, index = mpi.wait_any(requests)
del self.active_requests[status.tag]
return status, data
def cancel(self):
for r in self.active_requests.itervalues():
r.cancel()
#r.wait()
self.active_requests = {}
def rank0():
sent_histories = (mpi.size-1)*15
print "sending %d packets on their way" % sent_histories
send_reqs = mpi.RequestList()
for i in range(sent_histories):
dest = random.randrange(1, mpi.size)
send_reqs.append(mpi.world.isend(dest, TAG_DATA, []))
mpi.wait_all(send_reqs)
completed_histories = []
progress_reports = {}
dead_kids = []
tgl = TagGroupListener(mpi.world,
[TAG_DATA, TAG_DEBUG, TAG_PROGRESS_REPORT, TAG_TERMINATE])
def is_complete():
for i in progress_reports.values():
if i != sent_histories:
return False
return len(dead_kids) == mpi.size-1
while True:
status, data = tgl.wait()
if status.tag == TAG_DATA:
#print "received completed history %s from %d" % (data, status.source)
completed_histories.append(data)
if len(completed_histories) == sent_histories:
print "all histories received, exiting"
for rank in range(1, mpi.size):
mpi.world.send(rank, TAG_TERMINATE, None)
elif status.tag == TAG_PROGRESS_REPORT:
progress_reports[len(data)] = progress_reports.get(len(data), 0) + 1
elif status.tag == TAG_DEBUG:
print "[DBG %d] %s" % (status.source, data)
elif status.tag == TAG_TERMINATE:
dead_kids.append(status.source)
else:
print "unexpected tag %d from %d" % (status.tag, status.source)
if is_complete():
break
print "OK"
def comm_rank():
while True:
data, status = mpi.world.recv(return_status=True)
if status.tag == TAG_DATA:
mpi.world.send(0, TAG_PROGRESS_REPORT, data)
data.append(mpi.rank)
if len(data) >= MAX_GENERATIONS:
dest = 0
else:
dest = random.randrange(1, mpi.size)
mpi.world.send(dest, TAG_DATA, data)
elif status.tag == TAG_TERMINATE:
from time import sleep
mpi.world.send(0, TAG_TERMINATE, 0)
break
else:
print "[DIRECTDBG %d] unexpected tag %d from %d" % (mpi.rank, status.tag, status.source)
def main():
# this program sends around messages consisting of lists of visited nodes
# randomly. After MAX_GENERATIONS, they are returned to rank 0.
if mpi.rank == 0:
rank0()
else:
comm_rank()
if __name__ == "__main__":
main()
|