123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145 |
- #include <mpi.h>
- #include <omp.h>
- #include <cassert>
- #include <cmath>
- #include <cstdio>
- #include <cstring>
- #include <iostream>
- #include <vector>
- #define OMP_THREADS 2
- inline int sqrti(int n) {
- float x = std::sqrt(n);
- int r = static_cast<int>(x);
- if (std::abs(r - x) < 1e-9) {
- return r;
- } else {
- assert(0);
- }
- }
- class MatWrap {
- private:
- int* data_;
- int n_;
- public:
- MatWrap() {}
- MatWrap(int* data, int n) : data_(data), n_(n) {}
- int* operator[](size_t n) const { return data_ + n * n_; }
- int* split_map(int sqrt_q, int i, int j) const {
- int sub_n = n_ / sqrt_q;
- return data_ + (i / sub_n * sqrt_q + j / sub_n) * sub_n * sub_n +
- i % sub_n * sub_n + j % sub_n;
- }
- void print(bool mapping = false, int sqrt_q = -1) const {
- for (int i = 0; i < n_; ++i) {
- for (int j = 0; j < n_; ++j) {
- if (mapping) {
- printf("%5d", *(split_map(sqrt_q, i, j)));
- } else {
- printf("%5d", operator[](i)[j]);
- }
- }
- printf("\n");
- }
- }
- friend void MatMultAdd(const MatWrap& a, const MatWrap& b, MatWrap& c);
- };
- void MatMultAdd(const MatWrap& a, const MatWrap& b, MatWrap& c) {
- assert(a.n_ == b.n_);
- assert(a.n_ == c.n_);
- int n = a.n_;
- #pragma omp parallel for num_threads(OMP_THREADS)
- for (int i = 0; i < n; ++i) {
- for (int j = 0; j < n; ++j) {
- for (int k = 0; k < n; ++k) {
- c[i][j] = static_cast<bool>(c[i][j] + a[i][k] * b[k][j]);
- }
- }
- }
- }
- int main(int argc, char** argv) {
- int rank, size;
- int n, sqrt_q, sub_n;
- int* mat_a;
- int* sub_mat_comm;
- int *sub_mat_a, *sub_mat_b, *sub_mat_c;
- MPI_Init(&argc, &argv);
- MPI_Comm_rank(MPI_COMM_WORLD, &rank);
- MPI_Comm_size(MPI_COMM_WORLD, &size);
- sqrt_q = sqrti(size);
- if (rank == 0) {
- // load matrix data
- FILE* fp = fopen("data.txt", "rt");
- fscanf(fp, "%d", &n);
- mat_a = new int[n * n];
- MatWrap ma(mat_a, n);
- for (int i = 0; i < n; ++i) {
- for (int j = 0; j < n; ++j) {
- fscanf(fp, "%d", ma.split_map(sqrt_q, i, j));
- }
- *ma.split_map(sqrt_q, i, i) = 1;
- }
- }
- // broadcast matrix size
- MPI_Bcast(&n, 1, MPI_INT, 0, MPI_COMM_WORLD);
- assert(n % sqrt_q == 0);
- sub_n = n / sqrt_q;
- // alloc space for submatrix
- sub_mat_a = new int[sub_n * sub_n];
- sub_mat_b = new int[sub_n * sub_n];
- sub_mat_c = new int[sub_n * sub_n];
- sub_mat_comm = new int[sub_n * sub_n];
- MatWrap sub_comm(sub_mat_comm, sub_n);
- MatWrap sub_b(sub_mat_b, sub_n);
- MatWrap sub_c(sub_mat_c, sub_n);
- // init sub_c
- for (int k = 0; k <= std::log2f(n); ++k) {
- for (int i = 0; i < sub_n; ++i) {
- for (int j = 0; j < sub_n; ++j) {
- sub_c[i][j] = 0;
- }
- }
- // broadcast sub matrix
- MPI_Scatter(mat_a, sub_n * sub_n, MPI_INT, sub_mat_a, sub_n * sub_n,
- MPI_INT, 0, MPI_COMM_WORLD);
- MPI_Scatter(mat_a, sub_n * sub_n, MPI_INT, sub_mat_b, sub_n * sub_n,
- MPI_INT, 0, MPI_COMM_WORLD);
- // split comm in col and row
- MPI_Comm col_world, row_world;
- int col_rank = rank % sqrt_q;
- int row_rank = rank / sqrt_q;
- MPI_Comm_split(MPI_COMM_WORLD, col_rank, row_rank, &col_world);
- MPI_Comm_split(MPI_COMM_WORLD, row_rank, col_rank, &row_world);
- // compute
- for (int i = 0; i < sqrt_q; ++i) {
- // broadcast sub_a
- int send_root = (row_rank + i) % sqrt_q;
- if (col_rank == (row_rank + i) % sqrt_q) {
- memcpy(sub_mat_comm, sub_mat_a, sub_n * sub_n * sizeof(int));
- }
- MPI_Bcast(sub_mat_comm, sub_n * sub_n, MPI_INT, send_root, row_world);
- // calculate sub mat gemm
- MatMultAdd(sub_comm, sub_b, sub_c);
- // swap sub_b
- MPI_Sendrecv_replace(
- sub_mat_b, sub_n * sub_n, MPI_INT, (row_rank + sqrt_q - 1) % sqrt_q,
- 1, (row_rank + 1) % sqrt_q, 1, col_world, MPI_STATUS_IGNORE);
- }
- // gather result
- MPI_Gather(sub_mat_c, sub_n * sub_n, MPI_INT, mat_a, sub_n * sub_n, MPI_INT,
- 0, MPI_COMM_WORLD);
- // print result
- if (rank == 0) {
- MatWrap mc(mat_a, n);
- printf("loop:%d\n", k);
- mc.print(true, sqrt_q);
- }
- }
- MPI_Finalize();
- return 0;
- }
|