fox.cc 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. #include <mpi.h>
  2. #include <cassert>
  3. #include <cmath>
  4. #include <cstdio>
  5. #include <cstring>
  6. #include <iostream>
  7. #include <vector>
  8. inline int sqrti(int n) {
  9. float x = std::sqrt(n);
  10. int r = static_cast<int>(x);
  11. if (std::abs(r - x) < 1e-9) {
  12. return r;
  13. } else {
  14. assert(0);
  15. }
  16. }
  17. class MatWrap {
  18. private:
  19. float* data_;
  20. int n_;
  21. public:
  22. MatWrap() {}
  23. MatWrap(float* data, int n) : data_(data), n_(n) {}
  24. float* operator[](size_t n) const { return data_ + n * n_; }
  25. float* split_map(int sqrt_q, int i, int j) const {
  26. int sub_n = n_ / sqrt_q;
  27. return data_ + (i / sub_n * sqrt_q + j / sub_n) * sub_n * sub_n +
  28. i % sub_n * sub_n + j % sub_n;
  29. }
  30. void print(bool mapping = false, int sqrt_q = -1) const {
  31. for (int i = 0; i < n_; ++i) {
  32. for (int j = 0; j < n_; ++j) {
  33. if (mapping) {
  34. printf("%9.4f", *(split_map(sqrt_q, i, j)));
  35. } else {
  36. printf("%9.4f", operator[](i)[j]);
  37. }
  38. }
  39. printf("\n");
  40. }
  41. }
  42. friend void MatMultAdd(const MatWrap& a, const MatWrap& b, MatWrap& c);
  43. };
  44. void MatMultAdd(const MatWrap& a, const MatWrap& b, MatWrap& c) {
  45. assert(a.n_ == b.n_);
  46. assert(a.n_ == c.n_);
  47. int n = a.n_;
  48. for (int i = 0; i < n; ++i) {
  49. for (int j = 0; j < n; ++j) {
  50. for (int k = 0; k < n; ++k) {
  51. c[i][j] += a[i][k] * b[k][j];
  52. }
  53. }
  54. }
  55. }
  56. int main(int argc, char** argv) {
  57. int rank, size;
  58. int n, sqrt_q, sub_n;
  59. float *mat_a, *mat_b, *mat_c;
  60. float* sub_mat_comm;
  61. float *sub_mat_a, *sub_mat_b, *sub_mat_c;
  62. MPI_Init(&argc, &argv);
  63. MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  64. MPI_Comm_size(MPI_COMM_WORLD, &size);
  65. sqrt_q = sqrti(size);
  66. if (rank == 0) {
  67. // load matrix data
  68. FILE* fp = fopen("data.txt", "rt");
  69. fscanf(fp, "%d", &n);
  70. mat_a = new float[n * n];
  71. mat_b = new float[n * n];
  72. mat_c = new float[n * n];
  73. MatWrap ma(mat_a, n);
  74. MatWrap mb(mat_b, n);
  75. for (int i = 0; i < n; ++i) {
  76. for (int j = 0; j < n; ++j) {
  77. fscanf(fp, "%f", ma.split_map(sqrt_q, i, j));
  78. }
  79. }
  80. for (int i = 0; i < n; ++i) {
  81. for (int j = 0; j < n; ++j) {
  82. fscanf(fp, "%f", mb.split_map(sqrt_q, i, j));
  83. }
  84. }
  85. }
  86. // broadcast matrix size
  87. MPI_Bcast(&n, 1, MPI_INT, 0, MPI_COMM_WORLD);
  88. assert(n % sqrt_q == 0);
  89. sub_n = n / sqrt_q;
  90. // alloc space for submatrix
  91. sub_mat_a = new float[sub_n * sub_n];
  92. sub_mat_b = new float[sub_n * sub_n];
  93. sub_mat_c = new float[sub_n * sub_n];
  94. sub_mat_comm = new float[sub_n * sub_n];
  95. MatWrap sub_comm(sub_mat_comm, sub_n);
  96. MatWrap sub_b(sub_mat_b, sub_n);
  97. MatWrap sub_c(sub_mat_c, sub_n);
  98. // init sub_c
  99. for (int i = 0; i < sub_n; ++i) {
  100. for (int j = 0; j < sub_n; ++j) {
  101. sub_c[i][j] = 0.0f;
  102. }
  103. }
  104. // broadcast sub matrix
  105. MPI_Scatter(mat_a, sub_n * sub_n, MPI_FLOAT, sub_mat_a, sub_n * sub_n,
  106. MPI_FLOAT, 0, MPI_COMM_WORLD);
  107. MPI_Scatter(mat_b, sub_n * sub_n, MPI_FLOAT, sub_mat_b, sub_n * sub_n,
  108. MPI_FLOAT, 0, MPI_COMM_WORLD);
  109. // split comm in col and row
  110. MPI_Comm col_world, row_world;
  111. int col_rank = rank % sqrt_q;
  112. int row_rank = rank / sqrt_q;
  113. MPI_Comm_split(MPI_COMM_WORLD, col_rank, row_rank, &col_world);
  114. MPI_Comm_split(MPI_COMM_WORLD, row_rank, col_rank, &row_world);
  115. // compute
  116. for (int i = 0; i < sqrt_q; ++i) {
  117. // broadcast sub_a
  118. int send_root = (row_rank + i) % sqrt_q;
  119. if (col_rank == (row_rank + i) % sqrt_q) {
  120. memcpy(sub_mat_comm, sub_mat_a, sub_n * sub_n * sizeof(float));
  121. }
  122. MPI_Bcast(sub_mat_comm, sub_n * sub_n, MPI_FLOAT, send_root, row_world);
  123. // calculate sub mat gemm
  124. MatMultAdd(sub_comm, sub_b, sub_c);
  125. // swap sub_b
  126. MPI_Sendrecv_replace(
  127. sub_mat_b, sub_n * sub_n, MPI_FLOAT, (row_rank + sqrt_q - 1) % sqrt_q,
  128. 1, (row_rank + 1) % sqrt_q, 1, col_world, MPI_STATUS_IGNORE);
  129. }
  130. // gather result
  131. MPI_Gather(sub_mat_c, sub_n * sub_n, MPI_FLOAT, mat_c, sub_n * sub_n,
  132. MPI_FLOAT, 0, MPI_COMM_WORLD);
  133. // print result
  134. if (rank == 0) {
  135. MatWrap mc(mat_c, n);
  136. for (int i = 0; i < n; ++i) {
  137. for (int j = 0; j < n; ++j) {
  138. printf("%9.5f", *mc.split_map(sqrt_q, i, j));
  139. }
  140. printf("\n");
  141. }
  142. }
  143. MPI_Finalize();
  144. return 0;
  145. }