main.cc 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. #include <cassert>
  2. #include <cstring>
  3. #include <cmath>
  4. #include <mpi.h>
  5. #include <omp.h>
  6. #define OMP_THREADS 1
  7. void MatrixMult(float *a, float *b, float *c, int n, bool init = false) {
  8. #pragma omp parallel for num_threads(OMP_THREADS)
  9. for (int i = 0; i < n; ++i) {
  10. for (int j = 0; j < n; ++j) {
  11. if (init) {
  12. c[i * n + j] = 0.0f;
  13. }
  14. for (int k = 0; k < n; ++k) {
  15. c[i * n + j] += a[i * n + k] * b[k * n + j];
  16. }
  17. }
  18. }
  19. }
  20. void MatrixDot(float *a, float *b, float *c, int n) {
  21. #pragma omp parallel for num_threads(OMP_THREADS)
  22. for (int i = 0; i < n * n; ++i) {
  23. c[i] = a[i] * b[i];
  24. }
  25. }
  26. void MatrixAdd(float *a, float *b, float *c, int n) {
  27. #pragma omp parallel for num_threads(OMP_THREADS)
  28. for (int i = 0; i < n * n; ++i) {
  29. c[i] = a[i] + b[i];
  30. }
  31. }
  32. void tanh(float *x, float *y, int n) {
  33. #pragma omp parallel for num_threads(OMP_THREADS)
  34. for (int i = 0; i < n; ++i) {
  35. y[i] = std::tanh(x[i]);
  36. }
  37. }
  38. void sigmoid(float *x, float *y, int n) {
  39. #pragma omp parallel for num_threads(OMP_THREADS)
  40. for (int i = 0; i < n; ++i) {
  41. y[i] = 1.0f / (1.0f + std::exp(-x[i]));
  42. }
  43. }
  44. int main(int argc, char **argv) {
  45. int rank, size;
  46. MPI_Init(&argc, &argv);
  47. MPI_Comm_size(MPI_COMM_WORLD, &size);
  48. MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  49. // check 4 MPI procs
  50. assert(size == 4);
  51. int seq_len;
  52. int n; // n == in_feat == in_batch == out_feat
  53. float *w, *x, *h, *c;
  54. float *ft, *ot, *gt;
  55. float *sub_w, *sub_x, *sub_h;
  56. float *ifgo;
  57. // load data and broadcast
  58. if (rank == 0) {
  59. FILE *fp = fopen("input.txt", "rt");
  60. fscanf(fp, "%d%d", &seq_len, &n);
  61. // alloc
  62. w = new float[8 * n * n];
  63. x = new float[seq_len * n * n];
  64. h = new float[seq_len * n * n];
  65. c = new float[n * n];
  66. ft = new float[n * n];
  67. ot = new float[n * n];
  68. gt = new float[n * n];
  69. for (int i = 0; i < seq_len * n * n; ++i) {
  70. fscanf(fp, "%f", &x[i]);
  71. }
  72. for (int i = 0; i < 8 * n * n; ++i) {
  73. fscanf(fp, "%f", &w[i]);
  74. }
  75. fclose(fp);
  76. }
  77. MPI_Bcast(&n, 1, MPI_INT, 0, MPI_COMM_WORLD);
  78. MPI_Bcast(&seq_len, 1, MPI_INT, 0, MPI_COMM_WORLD);
  79. ifgo = new float[n * n];
  80. sub_w = new float[n * n * 2];
  81. sub_x = new float[n * n];
  82. sub_h = new float[n * n];
  83. MPI_Scatter(w, 2 * n * n, MPI_FLOAT, sub_w, 2 * n * n, MPI_FLOAT, 0,
  84. MPI_COMM_WORLD);
  85. // computing
  86. for (int i = 0; i < seq_len; ++i) {
  87. // broadcast x
  88. if (rank == 0) {
  89. memcpy(sub_x, x + i * n * n, n * n * sizeof(float));
  90. }
  91. MPI_Bcast(sub_x, n * n, MPI_FLOAT, 0, MPI_COMM_WORLD);
  92. // broadcast h
  93. if (i == 0) {
  94. for (int i = 0; i < n * n; ++i) {
  95. sub_h[i] = 0.0f;
  96. }
  97. } else {
  98. MPI_Bcast(sub_h, n * n, MPI_FLOAT, 0, MPI_COMM_WORLD);
  99. }
  100. // matrix multply
  101. MatrixMult(sub_w, sub_x, ifgo, n, true);
  102. MatrixMult(sub_w + n * n, sub_h, ifgo, n);
  103. // active
  104. if (rank == 2) {
  105. tanh(ifgo, ifgo, n * n);
  106. } else {
  107. sigmoid(ifgo, ifgo, n * n);
  108. }
  109. // gather result
  110. if (rank == 0) {
  111. MPI_Recv(ft, n * n, MPI_FLOAT, 1, 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
  112. MPI_Recv(gt, n * n, MPI_FLOAT, 2, 2, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
  113. MPI_Recv(ot, n * n, MPI_FLOAT, 3, 3, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
  114. } else if (rank == 1) {
  115. MPI_Send(ifgo, n * n, MPI_FLOAT, 0, 1, MPI_COMM_WORLD);
  116. } else if (rank == 2) {
  117. MPI_Send(ifgo, n * n, MPI_FLOAT, 0, 2, MPI_COMM_WORLD);
  118. } else if (rank == 3) {
  119. MPI_Send(ifgo, n * n, MPI_FLOAT, 0, 3, MPI_COMM_WORLD);
  120. }
  121. // compute ct ht
  122. if (rank == 0) {
  123. MatrixDot(ft, c, c, n);
  124. MatrixDot(ifgo, gt, gt, n);
  125. MatrixAdd(c, gt, c, n);
  126. tanh(c, sub_h, n * n);
  127. MatrixDot(ot, sub_h, sub_h, n);
  128. memcpy(h + i * n * n, sub_h, n * n * sizeof(float));
  129. }
  130. }
  131. MPI_Finalize();
  132. }