123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142 |
- #include <cassert>
- #include <cstring>
- #include <cmath>
- #include <mpi.h>
- #include <omp.h>
- #define OMP_THREADS 1
- void MatrixMult(float *a, float *b, float *c, int n, bool init = false) {
- #pragma omp parallel for num_threads(OMP_THREADS)
- for (int i = 0; i < n; ++i) {
- for (int j = 0; j < n; ++j) {
- if (init) {
- c[i * n + j] = 0.0f;
- }
- for (int k = 0; k < n; ++k) {
- c[i * n + j] += a[i * n + k] * b[k * n + j];
- }
- }
- }
- }
- void MatrixDot(float *a, float *b, float *c, int n) {
- #pragma omp parallel for num_threads(OMP_THREADS)
- for (int i = 0; i < n * n; ++i) {
- c[i] = a[i] * b[i];
- }
- }
- void MatrixAdd(float *a, float *b, float *c, int n) {
- #pragma omp parallel for num_threads(OMP_THREADS)
- for (int i = 0; i < n * n; ++i) {
- c[i] = a[i] + b[i];
- }
- }
- void tanh(float *x, float *y, int n) {
- #pragma omp parallel for num_threads(OMP_THREADS)
- for (int i = 0; i < n; ++i) {
- y[i] = std::tanh(x[i]);
- }
- }
- void sigmoid(float *x, float *y, int n) {
- #pragma omp parallel for num_threads(OMP_THREADS)
- for (int i = 0; i < n; ++i) {
- y[i] = 1.0f / (1.0f + std::exp(-x[i]));
- }
- }
- int main(int argc, char **argv) {
- int rank, size;
- MPI_Init(&argc, &argv);
- MPI_Comm_size(MPI_COMM_WORLD, &size);
- MPI_Comm_rank(MPI_COMM_WORLD, &rank);
- // check 4 MPI procs
- assert(size == 4);
- int seq_len;
- int n; // n == in_feat == in_batch == out_feat
- float *w, *x, *h, *c;
- float *ft, *ot, *gt;
- float *sub_w, *sub_x, *sub_h;
- float *ifgo;
- // load data and broadcast
- if (rank == 0) {
- FILE *fp = fopen("input.txt", "rt");
- fscanf(fp, "%d%d", &seq_len, &n);
- // alloc
- w = new float[8 * n * n];
- x = new float[seq_len * n * n];
- h = new float[seq_len * n * n];
- c = new float[n * n];
- ft = new float[n * n];
- ot = new float[n * n];
- gt = new float[n * n];
- for (int i = 0; i < seq_len * n * n; ++i) {
- fscanf(fp, "%f", &x[i]);
- }
- for (int i = 0; i < 8 * n * n; ++i) {
- fscanf(fp, "%f", &w[i]);
- }
- fclose(fp);
- }
- MPI_Bcast(&n, 1, MPI_INT, 0, MPI_COMM_WORLD);
- MPI_Bcast(&seq_len, 1, MPI_INT, 0, MPI_COMM_WORLD);
- ifgo = new float[n * n];
- sub_w = new float[n * n * 2];
- sub_x = new float[n * n];
- sub_h = new float[n * n];
- MPI_Scatter(w, 2 * n * n, MPI_FLOAT, sub_w, 2 * n * n, MPI_FLOAT, 0,
- MPI_COMM_WORLD);
- // computing
- for (int i = 0; i < seq_len; ++i) {
- // broadcast x
- if (rank == 0) {
- memcpy(sub_x, x + i * n * n, n * n * sizeof(float));
- }
- MPI_Bcast(sub_x, n * n, MPI_FLOAT, 0, MPI_COMM_WORLD);
- // broadcast h
- if (i == 0) {
- for (int i = 0; i < n * n; ++i) {
- sub_h[i] = 0.0f;
- }
- } else {
- MPI_Bcast(sub_h, n * n, MPI_FLOAT, 0, MPI_COMM_WORLD);
- }
- // matrix multply
- MatrixMult(sub_w, sub_x, ifgo, n, true);
- MatrixMult(sub_w + n * n, sub_h, ifgo, n);
- // active
- if (rank == 2) {
- tanh(ifgo, ifgo, n * n);
- } else {
- sigmoid(ifgo, ifgo, n * n);
- }
- // gather result
- if (rank == 0) {
- MPI_Recv(ft, n * n, MPI_FLOAT, 1, 1, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
- MPI_Recv(gt, n * n, MPI_FLOAT, 2, 2, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
- MPI_Recv(ot, n * n, MPI_FLOAT, 3, 3, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
- } else if (rank == 1) {
- MPI_Send(ifgo, n * n, MPI_FLOAT, 0, 1, MPI_COMM_WORLD);
- } else if (rank == 2) {
- MPI_Send(ifgo, n * n, MPI_FLOAT, 0, 2, MPI_COMM_WORLD);
- } else if (rank == 3) {
- MPI_Send(ifgo, n * n, MPI_FLOAT, 0, 3, MPI_COMM_WORLD);
- }
- // compute ct ht
- if (rank == 0) {
- MatrixDot(ft, c, c, n);
- MatrixDot(ifgo, gt, gt, n);
- MatrixAdd(c, gt, c, n);
- tanh(c, sub_h, n * n);
- MatrixDot(ot, sub_h, sub_h, n);
- memcpy(h + i * n * n, sub_h, n * n * sizeof(float));
- }
- }
- MPI_Finalize();
- }
|