summaryrefslogtreecommitdiffstats
path: root/src/boost/libs/mpi/test/python/nonblocking_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/boost/libs/mpi/test/python/nonblocking_test.py')
-rw-r--r--src/boost/libs/mpi/test/python/nonblocking_test.py131
1 files changed, 131 insertions, 0 deletions
diff --git a/src/boost/libs/mpi/test/python/nonblocking_test.py b/src/boost/libs/mpi/test/python/nonblocking_test.py
new file mode 100644
index 00000000..73b451c5
--- /dev/null
+++ b/src/boost/libs/mpi/test/python/nonblocking_test.py
@@ -0,0 +1,131 @@
+# (C) Copyright 2007
+# Andreas Kloeckner <inform -at- tiker.net>
+#
+# Use, modification and distribution is subject to the Boost Software
+# License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
+# http://www.boost.org/LICENSE_1_0.txt)
+#
+# Authors: Andreas Kloeckner
+
+
+
+
+import boost.mpi as mpi
+import random
+import sys
+
+MAX_GENERATIONS = 20
+TAG_DEBUG = 0
+TAG_DATA = 1
+TAG_TERMINATE = 2
+TAG_PROGRESS_REPORT = 3
+
+
+
+
+class TagGroupListener:
+ """Class to help listen for only a given set of tags.
+
+ This is contrived: Typicallly you could just listen for
+ mpi.any_tag and filter."""
+ def __init__(self, comm, tags):
+ self.tags = tags
+ self.comm = comm
+ self.active_requests = {}
+
+ def wait(self):
+ for tag in self.tags:
+ if tag not in self.active_requests:
+ self.active_requests[tag] = self.comm.irecv(tag=tag)
+ requests = mpi.RequestList(self.active_requests.values())
+ data, status, index = mpi.wait_any(requests)
+ del self.active_requests[status.tag]
+ return status, data
+
+ def cancel(self):
+ for r in self.active_requests.itervalues():
+ r.cancel()
+ #r.wait()
+ self.active_requests = {}
+
+
+
+def rank0():
+ sent_histories = (mpi.size-1)*15
+ print "sending %d packets on their way" % sent_histories
+ send_reqs = mpi.RequestList()
+ for i in range(sent_histories):
+ dest = random.randrange(1, mpi.size)
+ send_reqs.append(mpi.world.isend(dest, TAG_DATA, []))
+
+ mpi.wait_all(send_reqs)
+
+ completed_histories = []
+ progress_reports = {}
+ dead_kids = []
+
+ tgl = TagGroupListener(mpi.world,
+ [TAG_DATA, TAG_DEBUG, TAG_PROGRESS_REPORT, TAG_TERMINATE])
+
+ def is_complete():
+ for i in progress_reports.values():
+ if i != sent_histories:
+ return False
+ return len(dead_kids) == mpi.size-1
+
+ while True:
+ status, data = tgl.wait()
+
+ if status.tag == TAG_DATA:
+ #print "received completed history %s from %d" % (data, status.source)
+ completed_histories.append(data)
+ if len(completed_histories) == sent_histories:
+ print "all histories received, exiting"
+ for rank in range(1, mpi.size):
+ mpi.world.send(rank, TAG_TERMINATE, None)
+ elif status.tag == TAG_PROGRESS_REPORT:
+ progress_reports[len(data)] = progress_reports.get(len(data), 0) + 1
+ elif status.tag == TAG_DEBUG:
+ print "[DBG %d] %s" % (status.source, data)
+ elif status.tag == TAG_TERMINATE:
+ dead_kids.append(status.source)
+ else:
+ print "unexpected tag %d from %d" % (status.tag, status.source)
+
+ if is_complete():
+ break
+
+ print "OK"
+
+def comm_rank():
+ while True:
+ data, status = mpi.world.recv(return_status=True)
+ if status.tag == TAG_DATA:
+ mpi.world.send(0, TAG_PROGRESS_REPORT, data)
+ data.append(mpi.rank)
+ if len(data) >= MAX_GENERATIONS:
+ dest = 0
+ else:
+ dest = random.randrange(1, mpi.size)
+ mpi.world.send(dest, TAG_DATA, data)
+ elif status.tag == TAG_TERMINATE:
+ from time import sleep
+ mpi.world.send(0, TAG_TERMINATE, 0)
+ break
+ else:
+ print "[DIRECTDBG %d] unexpected tag %d from %d" % (mpi.rank, status.tag, status.source)
+
+
+def main():
+ # this program sends around messages consisting of lists of visited nodes
+ # randomly. After MAX_GENERATIONS, they are returned to rank 0.
+
+ if mpi.rank == 0:
+ rank0()
+ else:
+ comm_rank()
+
+
+
+if __name__ == "__main__":
+ main()