closure.cc 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. #include <mpi.h>
  2. #include <omp.h>
  3. #include <cassert>
  4. #include <cmath>
  5. #include <cstdio>
  6. #include <cstring>
  7. #include <iostream>
  8. #include <vector>
  9. #define OMP_THREADS 2
  10. inline int sqrti(int n) {
  11. float x = std::sqrt(n);
  12. int r = static_cast<int>(x);
  13. if (std::abs(r - x) < 1e-9) {
  14. return r;
  15. } else {
  16. assert(0);
  17. }
  18. }
  19. class MatWrap {
  20. private:
  21. int* data_;
  22. int n_;
  23. public:
  24. MatWrap() {}
  25. MatWrap(int* data, int n) : data_(data), n_(n) {}
  26. int* operator[](size_t n) const { return data_ + n * n_; }
  27. int* split_map(int sqrt_q, int i, int j) const {
  28. int sub_n = n_ / sqrt_q;
  29. return data_ + (i / sub_n * sqrt_q + j / sub_n) * sub_n * sub_n +
  30. i % sub_n * sub_n + j % sub_n;
  31. }
  32. void print(bool mapping = false, int sqrt_q = -1) const {
  33. for (int i = 0; i < n_; ++i) {
  34. for (int j = 0; j < n_; ++j) {
  35. if (mapping) {
  36. printf("%5d", *(split_map(sqrt_q, i, j)));
  37. } else {
  38. printf("%5d", operator[](i)[j]);
  39. }
  40. }
  41. printf("\n");
  42. }
  43. }
  44. friend void MatMultAdd(const MatWrap& a, const MatWrap& b, MatWrap& c);
  45. };
  46. void MatMultAdd(const MatWrap& a, const MatWrap& b, MatWrap& c) {
  47. assert(a.n_ == b.n_);
  48. assert(a.n_ == c.n_);
  49. int n = a.n_;
  50. #pragma omp parallel for num_threads(OMP_THREADS)
  51. for (int i = 0; i < n; ++i) {
  52. for (int j = 0; j < n; ++j) {
  53. for (int k = 0; k < n; ++k) {
  54. c[i][j] = static_cast<bool>(c[i][j] + a[i][k] * b[k][j]);
  55. }
  56. }
  57. }
  58. }
  59. int main(int argc, char** argv) {
  60. int rank, size;
  61. int n, sqrt_q, sub_n;
  62. int* mat_a;
  63. int* sub_mat_comm;
  64. int *sub_mat_a, *sub_mat_b, *sub_mat_c;
  65. MPI_Init(&argc, &argv);
  66. MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  67. MPI_Comm_size(MPI_COMM_WORLD, &size);
  68. sqrt_q = sqrti(size);
  69. if (rank == 0) {
  70. // load matrix data
  71. FILE* fp = fopen("data.txt", "rt");
  72. fscanf(fp, "%d", &n);
  73. mat_a = new int[n * n];
  74. MatWrap ma(mat_a, n);
  75. for (int i = 0; i < n; ++i) {
  76. for (int j = 0; j < n; ++j) {
  77. fscanf(fp, "%d", ma.split_map(sqrt_q, i, j));
  78. }
  79. *ma.split_map(sqrt_q, i, i) = 1;
  80. }
  81. }
  82. // broadcast matrix size
  83. MPI_Bcast(&n, 1, MPI_INT, 0, MPI_COMM_WORLD);
  84. assert(n % sqrt_q == 0);
  85. sub_n = n / sqrt_q;
  86. // alloc space for submatrix
  87. sub_mat_a = new int[sub_n * sub_n];
  88. sub_mat_b = new int[sub_n * sub_n];
  89. sub_mat_c = new int[sub_n * sub_n];
  90. sub_mat_comm = new int[sub_n * sub_n];
  91. MatWrap sub_comm(sub_mat_comm, sub_n);
  92. MatWrap sub_b(sub_mat_b, sub_n);
  93. MatWrap sub_c(sub_mat_c, sub_n);
  94. // init sub_c
  95. for (int k = 0; k <= std::log2f(n); ++k) {
  96. for (int i = 0; i < sub_n; ++i) {
  97. for (int j = 0; j < sub_n; ++j) {
  98. sub_c[i][j] = 0;
  99. }
  100. }
  101. // broadcast sub matrix
  102. MPI_Scatter(mat_a, sub_n * sub_n, MPI_INT, sub_mat_a, sub_n * sub_n,
  103. MPI_INT, 0, MPI_COMM_WORLD);
  104. MPI_Scatter(mat_a, sub_n * sub_n, MPI_INT, sub_mat_b, sub_n * sub_n,
  105. MPI_INT, 0, MPI_COMM_WORLD);
  106. // split comm in col and row
  107. MPI_Comm col_world, row_world;
  108. int col_rank = rank % sqrt_q;
  109. int row_rank = rank / sqrt_q;
  110. MPI_Comm_split(MPI_COMM_WORLD, col_rank, row_rank, &col_world);
  111. MPI_Comm_split(MPI_COMM_WORLD, row_rank, col_rank, &row_world);
  112. // compute
  113. for (int i = 0; i < sqrt_q; ++i) {
  114. // broadcast sub_a
  115. int send_root = (row_rank + i) % sqrt_q;
  116. if (col_rank == (row_rank + i) % sqrt_q) {
  117. memcpy(sub_mat_comm, sub_mat_a, sub_n * sub_n * sizeof(int));
  118. }
  119. MPI_Bcast(sub_mat_comm, sub_n * sub_n, MPI_INT, send_root, row_world);
  120. // calculate sub mat gemm
  121. MatMultAdd(sub_comm, sub_b, sub_c);
  122. // swap sub_b
  123. MPI_Sendrecv_replace(
  124. sub_mat_b, sub_n * sub_n, MPI_INT, (row_rank + sqrt_q - 1) % sqrt_q,
  125. 1, (row_rank + 1) % sqrt_q, 1, col_world, MPI_STATUS_IGNORE);
  126. }
  127. // gather result
  128. MPI_Gather(sub_mat_c, sub_n * sub_n, MPI_INT, mat_a, sub_n * sub_n, MPI_INT,
  129. 0, MPI_COMM_WORLD);
  130. // print result
  131. if (rank == 0) {
  132. MatWrap mc(mat_a, n);
  133. printf("loop:%d\n", k);
  134. mc.print(true, sqrt_q);
  135. }
  136. }
  137. MPI_Finalize();
  138. return 0;
  139. }