#include <mpi.h>
#include <cassert>
#include <cmath>
#include <iostream>

int main(int argc, char *argv[]) {
  int rank, size;
  int send, recv;
  // MPI init
  MPI_Init(&argc, &argv);
  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  MPI_Comm_size(MPI_COMM_WORLD, &size);
  // check proc size is power of 2
  float steps_f = std::log2(size);
  int steps = (int)steps_f;
  assert(std::abs(steps - steps_f) < 1e-9);
  // init value
  send = rank;
  // sum
  int base = 1;
  for (int i = 0; i < steps; ++i) {
    if (rank % base == 0) {
      int b = base * 2;
      if (rank % b == 0) { // recv
        MPI_Recv(&recv, 1, MPI_INT, rank + base, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
      } else { // send
        MPI_Send(&send, 1, MPI_INT, rank - base, 0, MPI_COMM_WORLD);
      }
    }
    send += recv;
    base *= 2;
  }
  base /= 2;
  for (int i = 0; i < steps; ++i) {
    if (rank % base == 0) {
      int b = base * 2;
      if (rank % b == 0) { // recv
        MPI_Send(&send, 1, MPI_INT, rank + base, 0, MPI_COMM_WORLD);
      } else { // send
        MPI_Recv(&send, 1, MPI_INT, rank - base, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
      }
    }
    base /= 2;
  }
  // print result
  std::cout << "rank:" << rank << ", sum:" << send << std::endl;
  MPI_Finalize();
}