#include <mpi.h>
#include <cassert>
#include <cmath>
#include <iostream>
#include <random>

#define SERVER_NUM 2
#define LOOP 1024

std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<> dis(0, 2047);

int main(int argc, char *argv[]) {
  int rank, size;
  // MPI init
  MPI_Init(&argc, &argv);
  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  MPI_Comm_size(MPI_COMM_WORLD, &size);
  // server or client
  int SorC = rank >= SERVER_NUM;
  MPI_Comm serverClient, service;
  MPI_Comm_split(MPI_COMM_WORLD, SorC, -1, &serverClient);
  int service_num = rank % SERVER_NUM;
  MPI_Comm_split(MPI_COMM_WORLD, service_num, SorC, &service);
  int service_rank;
  MPI_Comm_rank(service, &service_rank);
  // start
  float value;
  float sum;
  for (int i = 0; i < LOOP; ++i) {
    if (SorC) { // client
      value = dis(gen);
      // printf("c:%d, %d, %f\n", rank, service_rank, value);
    } else { // server
      value = 0;
    }
    MPI_Reduce(&value, &sum, 1, MPI_FLOAT, MPI_SUM, 0, service);
    if (!SorC) {
      // printf("s:%d, %d, %f\n", rank, service_rank, sum);
      MPI_Allreduce(&sum, &value, 1, MPI_FLOAT, MPI_SUM, serverClient);
      value /= (size - SERVER_NUM);
      // printf("s:%d, %d, %f\n", rank, service_rank, value);
    }
    MPI_Bcast(&value, 1, MPI_FLOAT, 0, service);
    // printf("c:%d, %f\n", rank, value);
  }
  MPI_Finalize();
}