bitree.cc 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. #include <mpi.h>
  2. #include <cassert>
  3. #include <cmath>
  4. #include <iostream>
  5. int main(int argc, char *argv[]) {
  6. int rank, size;
  7. int send, recv;
  8. // MPI init
  9. MPI_Init(&argc, &argv);
  10. MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  11. MPI_Comm_size(MPI_COMM_WORLD, &size);
  12. // check proc size is power of 2
  13. float steps_f = std::log2(size);
  14. int steps = (int)steps_f;
  15. assert(std::abs(steps - steps_f) < 1e-9);
  16. // init value
  17. send = rank;
  18. // sum
  19. int base = 1;
  20. for (int i = 0; i < steps; ++i) {
  21. if (rank % base == 0) {
  22. int b = base * 2;
  23. if (rank % b == 0) { // recv
  24. MPI_Recv(&recv, 1, MPI_INT, rank + base, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
  25. } else { // send
  26. MPI_Send(&send, 1, MPI_INT, rank - base, 0, MPI_COMM_WORLD);
  27. }
  28. }
  29. send += recv;
  30. base *= 2;
  31. }
  32. base /= 2;
  33. for (int i = 0; i < steps; ++i) {
  34. if (rank % base == 0) {
  35. int b = base * 2;
  36. if (rank % b == 0) { // recv
  37. MPI_Send(&send, 1, MPI_INT, rank + base, 0, MPI_COMM_WORLD);
  38. } else { // send
  39. MPI_Recv(&send, 1, MPI_INT, rank - base, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
  40. }
  41. }
  42. base /= 2;
  43. }
  44. // print result
  45. std::cout << "rank:" << rank << ", sum:" << send << std::endl;
  46. MPI_Finalize();
  47. }