1234567891011121314151617181920212223242526272829303132333435363738 |
- #include <mpi.h>
- #include <cassert>
- #include <cmath>
- #include <iostream>
- int main(int argc, char *argv[]) {
- int rank, size;
- int send, recv;
-
- MPI_Init(&argc, &argv);
- MPI_Comm_rank(MPI_COMM_WORLD, &rank);
- MPI_Comm_size(MPI_COMM_WORLD, &size);
-
- float steps_f = std::log2(size);
- int steps = (int)steps_f;
- assert(std::abs(steps - steps_f) < 1e-9);
-
- send = rank;
-
- int base = 1;
- for (int i = 0; i < steps; ++i) {
- int group = rank / base;
- int offset = rank % base;
- int target = (group % 2 ? group - 1 : group + 1) * base + offset;
- if (rank < target) {
- MPI_Send(reinterpret_cast<void*>(&send), 1, MPI_INT, target, target, MPI_COMM_WORLD);
- MPI_Recv(reinterpret_cast<void*>(&recv), 1, MPI_INT, target, rank, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
- } else {
- MPI_Recv(reinterpret_cast<void*>(&recv), 1, MPI_INT, target, rank, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
- MPI_Send(reinterpret_cast<void*>(&send), 1, MPI_INT, target, target, MPI_COMM_WORLD);
- }
- send += recv;
- base *= 2;
- }
-
- std::cout << "rank:" << rank << ", sum:" << send << std::endl;
- MPI_Finalize();
- }
|