#include #include #include #include #include #define OMP_THREADS 1 void MatrixMult(float *a, float *b, float *c, int n, bool init = false) { #pragma omp parallel for num_threads(OMP_THREADS) for (int i = 0; i < n; ++i) { for (int j = 0; j < n; ++j) { if (init) { c[i * n + j] = 0.0f; } for (int k = 0; k < n; ++k) { c[i * n + j] += a[i * n + k] * b[k * n + j]; } } } } void MatrixDot(float *a, float *b, float *c, int n) { #pragma omp parallel for num_threads(OMP_THREADS) for (int i = 0; i < n * n; ++i) { c[i] = a[i] * b[i]; } } void MatrixAdd(float *a, float *b, float *c, int n) { #pragma omp parallel for num_threads(OMP_THREADS) for (int i = 0; i < n * n; ++i) { c[i] = a[i] + b[i]; } } void tanh(float *x, float *y, int n) { #pragma omp parallel for num_threads(OMP_THREADS) for (int i = 0; i < n; ++i) { y[i] = std::tanh(x[i]); } } void sigmoid(float *x, float *y, int n) { #pragma omp parallel for num_threads(OMP_THREADS) for (int i = 0; i < n; ++i) { y[i] = 1.0f / (1.0f + std::exp(-x[i])); } } int main(int argc, char **argv) { int rank, size; MPI_Init(&argc, &argv); MPI_Comm_size(MPI_COMM_WORLD, &size); MPI_Comm_rank(MPI_COMM_WORLD, &rank); // check 4 MPI procs assert(size == 4); int seq_len; int n; // n == in_feat == in_batch == out_feat float *w, *x, *h, *c; float *ft, *ot, *gt; float *sub_w, *sub_x, *sub_h; float *ifgo; // load data and broadcast if (rank == 0) { FILE *fp = fopen("input.txt", "rt"); fscanf(fp, "%d%d", &seq_len, &n); // alloc w = new float[8 * n * n]; x = new float[seq_len * n * n]; h = new float[seq_len * n * n]; c = new float[n * n]; ft = new float[n * n]; ot = new float[n * n]; gt = new float[n * n]; for (int i = 0; i < seq_len * n * n; ++i) { fscanf(fp, "%f", &x[i]); } for (int i = 0; i < 8 * n * n; ++i) { fscanf(fp, "%f", &w[i]); } fclose(fp); } MPI_Bcast(&n, 1, MPI_INT, 0, MPI_COMM_WORLD); MPI_Bcast(&seq_len, 1, MPI_INT, 0, MPI_COMM_WORLD); ifgo = new float[n * n]; sub_w = new float[n * n * 2]; sub_x = new float[n * n]; sub_h = new float[n * n]; MPI_Scatter(w, 2 * n * n, MPI_FLOAT, sub_w, 2 * n * n, MPI_FLOAT, 0, MPI_COMM_WORLD); // computing for (int i = 0; i < seq_len; ++i) { // broadcast x if (rank == 0) { memcpy(sub_x, x + i * n * n, n * n * sizeof(float)); } MPI_Bcast(sub_x, n * n, MPI_FLOAT, 0, MPI_COMM_WORLD); // broadcast h if (i == 0) { for (int i = 0; i < n * n; ++i) { sub_h[i] = 0.0f; } } else { MPI_Bcast(sub_h, n * n, MPI_FLOAT, 0, MPI_COMM_WORLD); } // matrix multply MatrixMult(sub_w, sub_x, ifgo, n, true); MatrixMult(sub_w + n * n, sub_h, ifgo, n); // active if (rank == 2) { tanh(ifgo, ifgo, n * n); } else { sigmoid(ifgo, ifgo, n * n); } // gather result if (rank == 0) { MPI_Recv(ft, n * n, MPI_FLOAT, 1, 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE); MPI_Recv(gt, n * n, MPI_FLOAT, 2, 2, MPI_COMM_WORLD, MPI_STATUS_IGNORE); MPI_Recv(ot, n * n, MPI_FLOAT, 3, 3, MPI_COMM_WORLD, MPI_STATUS_IGNORE); } else if (rank == 1) { MPI_Send(ifgo, n * n, MPI_FLOAT, 0, 1, MPI_COMM_WORLD); } else if (rank == 2) { MPI_Send(ifgo, n * n, MPI_FLOAT, 0, 2, MPI_COMM_WORLD); } else if (rank == 3) { MPI_Send(ifgo, n * n, MPI_FLOAT, 0, 3, MPI_COMM_WORLD); } // compute ct ht if (rank == 0) { MatrixDot(ft, c, c, n); MatrixDot(ifgo, gt, gt, n); MatrixAdd(c, gt, c, n); tanh(c, sub_h, n * n); MatrixDot(ot, sub_h, sub_h, n); memcpy(h + i * n * n, sub_h, n * n * sizeof(float)); } } MPI_Finalize(); }