summaryrefslogtreecommitdiffstats
path: root/src/boost/libs/mpi/example/generate_collect.cpp
blob: 5579d50d2e64fed003bce515ca97d0034ae55e6d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// Copyright (C) 2006 Douglas Gregor <doug.gregor@gmail.com>

// Use, modification and distribution is subject to the Boost Software
// License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
// http://www.boost.org/LICENSE_1_0.txt)

// An example using Boost.MPI's split() operation on communicators to
// create separate data-generating processes and data-collecting
// processes.
#include <boost/mpi.hpp>
#include <iostream>
#include <cstdlib>
#include <boost/serialization/vector.hpp>
namespace mpi = boost::mpi;

enum message_tags { msg_data_packet, msg_broadcast_data, msg_finished };

void generate_data(mpi::communicator local, mpi::communicator world)
{
  using std::srand;
  using std::rand;

  // The rank of the collector within the world communicator
  int master_collector = local.size();

  srand(time(0) + world.rank());

  // Send out several blocks of random data to the collectors.
  int num_data_blocks = rand() % 3 + 1;
  for (int block = 0; block < num_data_blocks; ++block) {
    // Generate some random data
    int num_samples = rand() % 1000;
    std::vector<int> data;
    for (int i = 0; i < num_samples; ++i) {
      data.push_back(rand());
    }

    // Send our data to the master collector process.
    std::cout << "Generator #" << local.rank() << " sends some data..."
              << std::endl;
    world.send(master_collector, msg_data_packet, data);
  }

  // Wait for all of the generators to complete
  (local.barrier)();

  // The first generator will send the message to the master collector
  // indicating that we're done.
  if (local.rank() == 0)
    world.send(master_collector, msg_finished);
}

void collect_data(mpi::communicator local, mpi::communicator world)
{
  // The rank of the collector within the world communicator
  int master_collector = world.size() - local.size();

  if (world.rank() == master_collector) {
    while (true) {
      // Wait for a message
      mpi::status msg = world.probe();
      if (msg.tag() == msg_data_packet) {
        // Receive the packet of data
        std::vector<int> data;
        world.recv(msg.source(), msg.tag(), data);

        // Tell each of the collectors that we'll be broadcasting some data
        for (int dest = 1; dest < local.size(); ++dest)
          local.send(dest, msg_broadcast_data, msg.source());

        // Broadcast the actual data.
        broadcast(local, data, 0);
      } else if (msg.tag() == msg_finished) {
        // Receive the message
        world.recv(msg.source(), msg.tag());

        // Tell each of the collectors that we're finished
        for (int dest = 1; dest < local.size(); ++dest)
          local.send(dest, msg_finished);

        break;
      }
    }
  } else {
    while (true) {
      // Wait for a message from the master collector
      mpi::status msg = local.probe();
      if (msg.tag() == msg_broadcast_data) {
        // Receive the broadcast message
        int originator;
        local.recv(msg.source(), msg.tag(), originator);

        // Receive the data broadcasted from the master collector
        std::vector<int> data;
        broadcast(local, data, 0);

        std::cout << "Collector #" << local.rank()
                  << " is processing data from generator #" << originator
                  << "." << std::endl;
      } else if (msg.tag() == msg_finished) {
        // Receive the message
        local.recv(msg.source(), msg.tag());

        break;
      }
    }
  }
}

int main(int argc, char* argv[])
{
  mpi::environment env(argc, argv);
  mpi::communicator world;

  if (world.size() < 3) {
    if (world.rank() == 0) {
      std::cerr << "Error: this example requires at least 3 processes."
                << std::endl;
    }
    env.abort(-1);
  }

  bool is_generator = world.rank() < 2 * world.size() / 3;
  mpi::communicator local = world.split(is_generator? 0 : 1);
  if (is_generator) generate_data(local, world);
  else collect_data(local, world);

  return 0;
}