#include #include #include #include #include #include #include #include #define OMP_THREADS 2 inline int sqrti(int n) { float x = std::sqrt(n); int r = static_cast(x); if (std::abs(r - x) < 1e-9) { return r; } else { assert(0); } } class MatWrap { private: int* data_; int n_; public: MatWrap() {} MatWrap(int* data, int n) : data_(data), n_(n) {} int* operator[](size_t n) const { return data_ + n * n_; } int* split_map(int sqrt_q, int i, int j) const { int sub_n = n_ / sqrt_q; return data_ + (i / sub_n * sqrt_q + j / sub_n) * sub_n * sub_n + i % sub_n * sub_n + j % sub_n; } void print(bool mapping = false, int sqrt_q = -1) const { for (int i = 0; i < n_; ++i) { for (int j = 0; j < n_; ++j) { if (mapping) { printf("%5d", *(split_map(sqrt_q, i, j))); } else { printf("%5d", operator[](i)[j]); } } printf("\n"); } } friend void MatMultAdd(const MatWrap& a, const MatWrap& b, MatWrap& c); }; void MatMultAdd(const MatWrap& a, const MatWrap& b, MatWrap& c) { assert(a.n_ == b.n_); assert(a.n_ == c.n_); int n = a.n_; #pragma omp parallel for num_threads(OMP_THREADS) for (int i = 0; i < n; ++i) { for (int j = 0; j < n; ++j) { for (int k = 0; k < n; ++k) { c[i][j] = static_cast(c[i][j] + a[i][k] * b[k][j]); } } } } int main(int argc, char** argv) { int rank, size; int n, sqrt_q, sub_n; int* mat_a; int* sub_mat_comm; int *sub_mat_a, *sub_mat_b, *sub_mat_c; MPI_Init(&argc, &argv); MPI_Comm_rank(MPI_COMM_WORLD, &rank); MPI_Comm_size(MPI_COMM_WORLD, &size); sqrt_q = sqrti(size); if (rank == 0) { // load matrix data FILE* fp = fopen("data.txt", "rt"); fscanf(fp, "%d", &n); mat_a = new int[n * n]; MatWrap ma(mat_a, n); for (int i = 0; i < n; ++i) { for (int j = 0; j < n; ++j) { fscanf(fp, "%d", ma.split_map(sqrt_q, i, j)); } *ma.split_map(sqrt_q, i, i) = 1; } } // broadcast matrix size MPI_Bcast(&n, 1, MPI_INT, 0, MPI_COMM_WORLD); assert(n % sqrt_q == 0); sub_n = n / sqrt_q; // alloc space for submatrix sub_mat_a = new int[sub_n * sub_n]; sub_mat_b = new int[sub_n * sub_n]; sub_mat_c = new int[sub_n * sub_n]; sub_mat_comm = new int[sub_n * sub_n]; MatWrap sub_comm(sub_mat_comm, sub_n); MatWrap sub_b(sub_mat_b, sub_n); MatWrap sub_c(sub_mat_c, sub_n); // init sub_c for (int k = 0; k <= std::log2f(n); ++k) { for (int i = 0; i < sub_n; ++i) { for (int j = 0; j < sub_n; ++j) { sub_c[i][j] = 0; } } // broadcast sub matrix MPI_Scatter(mat_a, sub_n * sub_n, MPI_INT, sub_mat_a, sub_n * sub_n, MPI_INT, 0, MPI_COMM_WORLD); MPI_Scatter(mat_a, sub_n * sub_n, MPI_INT, sub_mat_b, sub_n * sub_n, MPI_INT, 0, MPI_COMM_WORLD); // split comm in col and row MPI_Comm col_world, row_world; int col_rank = rank % sqrt_q; int row_rank = rank / sqrt_q; MPI_Comm_split(MPI_COMM_WORLD, col_rank, row_rank, &col_world); MPI_Comm_split(MPI_COMM_WORLD, row_rank, col_rank, &row_world); // compute for (int i = 0; i < sqrt_q; ++i) { // broadcast sub_a int send_root = (row_rank + i) % sqrt_q; if (col_rank == (row_rank + i) % sqrt_q) { memcpy(sub_mat_comm, sub_mat_a, sub_n * sub_n * sizeof(int)); } MPI_Bcast(sub_mat_comm, sub_n * sub_n, MPI_INT, send_root, row_world); // calculate sub mat gemm MatMultAdd(sub_comm, sub_b, sub_c); // swap sub_b MPI_Sendrecv_replace( sub_mat_b, sub_n * sub_n, MPI_INT, (row_rank + sqrt_q - 1) % sqrt_q, 1, (row_rank + 1) % sqrt_q, 1, col_world, MPI_STATUS_IGNORE); } // gather result MPI_Gather(sub_mat_c, sub_n * sub_n, MPI_INT, mat_a, sub_n * sub_n, MPI_INT, 0, MPI_COMM_WORLD); // print result if (rank == 0) { MatWrap mc(mat_a, n); printf("loop:%d\n", k); mc.print(true, sqrt_q); } } MPI_Finalize(); return 0; }