#include <mpi.h>
#include <cassert>
#include <cmath>
#include <iostream>

int main(int argc, char *argv[]) {
  int rank, size;
  int send, recv;
  // MPI init
  MPI_Init(&argc, &argv);
  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  MPI_Comm_size(MPI_COMM_WORLD, &size);
  // check proc size is power of 2
  float steps_f = std::log2(size);
  int steps = (int)steps_f;
  assert(std::abs(steps - steps_f) < 1e-9);
  // init value
  send = rank;
  // sum
  int base = 1;
  for (int i = 0; i < steps; ++i) {
    int group = rank / base;
    int offset = rank % base;
    int target = (group % 2 ? group - 1 : group + 1) * base + offset;
    if (rank < target) {
      MPI_Send(reinterpret_cast<void*>(&send), 1, MPI_INT, target, target, MPI_COMM_WORLD);
      MPI_Recv(reinterpret_cast<void*>(&recv), 1, MPI_INT, target, rank, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
    } else {
      MPI_Recv(reinterpret_cast<void*>(&recv), 1, MPI_INT, target, rank, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
      MPI_Send(reinterpret_cast<void*>(&send), 1, MPI_INT, target, target, MPI_COMM_WORLD);
    }
    send += recv;
    base *= 2;
  }
  // print result
  std::cout << "rank:" << rank << ", sum:" << send << std::endl;
  MPI_Finalize();
}