diff options
Diffstat (limited to 'src/boost/libs/mpi/test/python/nonblocking_test.py')
-rw-r--r-- | src/boost/libs/mpi/test/python/nonblocking_test.py | 131 |
1 files changed, 131 insertions, 0 deletions
diff --git a/src/boost/libs/mpi/test/python/nonblocking_test.py b/src/boost/libs/mpi/test/python/nonblocking_test.py new file mode 100644 index 00000000..73b451c5 --- /dev/null +++ b/src/boost/libs/mpi/test/python/nonblocking_test.py @@ -0,0 +1,131 @@ +# (C) Copyright 2007 +# Andreas Kloeckner <inform -at- tiker.net> +# +# Use, modification and distribution is subject to the Boost Software +# License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at +# http://www.boost.org/LICENSE_1_0.txt) +# +# Authors: Andreas Kloeckner + + + + +import boost.mpi as mpi +import random +import sys + +MAX_GENERATIONS = 20 +TAG_DEBUG = 0 +TAG_DATA = 1 +TAG_TERMINATE = 2 +TAG_PROGRESS_REPORT = 3 + + + + +class TagGroupListener: + """Class to help listen for only a given set of tags. + + This is contrived: Typicallly you could just listen for + mpi.any_tag and filter.""" + def __init__(self, comm, tags): + self.tags = tags + self.comm = comm + self.active_requests = {} + + def wait(self): + for tag in self.tags: + if tag not in self.active_requests: + self.active_requests[tag] = self.comm.irecv(tag=tag) + requests = mpi.RequestList(self.active_requests.values()) + data, status, index = mpi.wait_any(requests) + del self.active_requests[status.tag] + return status, data + + def cancel(self): + for r in self.active_requests.itervalues(): + r.cancel() + #r.wait() + self.active_requests = {} + + + +def rank0(): + sent_histories = (mpi.size-1)*15 + print "sending %d packets on their way" % sent_histories + send_reqs = mpi.RequestList() + for i in range(sent_histories): + dest = random.randrange(1, mpi.size) + send_reqs.append(mpi.world.isend(dest, TAG_DATA, [])) + + mpi.wait_all(send_reqs) + + completed_histories = [] + progress_reports = {} + dead_kids = [] + + tgl = TagGroupListener(mpi.world, + [TAG_DATA, TAG_DEBUG, TAG_PROGRESS_REPORT, TAG_TERMINATE]) + + def is_complete(): + for i in progress_reports.values(): + if i != sent_histories: + return False + return len(dead_kids) == mpi.size-1 + + while True: + status, data = tgl.wait() + + if status.tag == TAG_DATA: + #print "received completed history %s from %d" % (data, status.source) + completed_histories.append(data) + if len(completed_histories) == sent_histories: + print "all histories received, exiting" + for rank in range(1, mpi.size): + mpi.world.send(rank, TAG_TERMINATE, None) + elif status.tag == TAG_PROGRESS_REPORT: + progress_reports[len(data)] = progress_reports.get(len(data), 0) + 1 + elif status.tag == TAG_DEBUG: + print "[DBG %d] %s" % (status.source, data) + elif status.tag == TAG_TERMINATE: + dead_kids.append(status.source) + else: + print "unexpected tag %d from %d" % (status.tag, status.source) + + if is_complete(): + break + + print "OK" + +def comm_rank(): + while True: + data, status = mpi.world.recv(return_status=True) + if status.tag == TAG_DATA: + mpi.world.send(0, TAG_PROGRESS_REPORT, data) + data.append(mpi.rank) + if len(data) >= MAX_GENERATIONS: + dest = 0 + else: + dest = random.randrange(1, mpi.size) + mpi.world.send(dest, TAG_DATA, data) + elif status.tag == TAG_TERMINATE: + from time import sleep + mpi.world.send(0, TAG_TERMINATE, 0) + break + else: + print "[DIRECTDBG %d] unexpected tag %d from %d" % (mpi.rank, status.tag, status.source) + + +def main(): + # this program sends around messages consisting of lists of visited nodes + # randomly. After MAX_GENERATIONS, they are returned to rank 0. + + if mpi.rank == 0: + rank0() + else: + comm_rank() + + + +if __name__ == "__main__": + main() |