|
@@ -1,7 +1,150 @@
|
|
#include <mpi.h>
|
|
#include <mpi.h>
|
|
|
|
+#include <cassert>
|
|
|
|
+#include <cmath>
|
|
|
|
+#include <cstdio>
|
|
|
|
+#include <cstring>
|
|
#include <iostream>
|
|
#include <iostream>
|
|
|
|
+#include <vector>
|
|
|
|
|
|
-int main() {
|
|
|
|
|
|
+inline int sqrti(int n) {
|
|
|
|
+ float x = std::sqrt(n);
|
|
|
|
+ int r = static_cast<int>(x);
|
|
|
|
+ if (std::abs(r - x) < 1e-9) {
|
|
|
|
+ return r;
|
|
|
|
+ } else {
|
|
|
|
+ assert(0);
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
|
|
|
|
+class MatWrap {
|
|
|
|
+ private:
|
|
|
|
+ float* data_;
|
|
|
|
+ int n_;
|
|
|
|
+
|
|
|
|
+ public:
|
|
|
|
+ MatWrap() {}
|
|
|
|
+ MatWrap(float* data, int n) : data_(data), n_(n) {}
|
|
|
|
+ float* operator[](size_t n) const { return data_ + n * n_; }
|
|
|
|
+ float* split_map(int sqrt_q, int i, int j) const {
|
|
|
|
+ int sub_n = n_ / sqrt_q;
|
|
|
|
+ return data_ + (i / sub_n * sqrt_q + j / sub_n) * sub_n * sub_n +
|
|
|
|
+ i % sub_n * sub_n + j % sub_n;
|
|
|
|
+ }
|
|
|
|
+ void print(bool mapping = false, int sqrt_q = -1) const {
|
|
|
|
+ for (int i = 0; i < n_; ++i) {
|
|
|
|
+ for (int j = 0; j < n_; ++j) {
|
|
|
|
+ if (mapping) {
|
|
|
|
+ printf("%9.4f", *(split_map(sqrt_q, i, j)));
|
|
|
|
+ } else {
|
|
|
|
+ printf("%9.4f", operator[](i)[j]);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ printf("\n");
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ friend void MatMultAdd(const MatWrap& a, const MatWrap& b, MatWrap& c);
|
|
|
|
+};
|
|
|
|
+
|
|
|
|
+void MatMultAdd(const MatWrap& a, const MatWrap& b, MatWrap& c) {
|
|
|
|
+ assert(a.n_ == b.n_);
|
|
|
|
+ assert(a.n_ == c.n_);
|
|
|
|
+ int n = a.n_;
|
|
|
|
+ for (int i = 0; i < n; ++i) {
|
|
|
|
+ for (int j = 0; j < n; ++j) {
|
|
|
|
+ for (int k = 0; k < n; ++k) {
|
|
|
|
+ c[i][j] += a[i][k] * b[k][j];
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+int main(int argc, char** argv) {
|
|
|
|
+ int rank, size;
|
|
|
|
+ int n, sqrt_q, sub_n;
|
|
|
|
+ float *mat_a, *mat_b, *mat_c;
|
|
|
|
+ float* sub_mat_comm;
|
|
|
|
+ float *sub_mat_a, *sub_mat_b, *sub_mat_c;
|
|
|
|
+ MPI_Init(&argc, &argv);
|
|
|
|
+ MPI_Comm_rank(MPI_COMM_WORLD, &rank);
|
|
|
|
+ MPI_Comm_size(MPI_COMM_WORLD, &size);
|
|
|
|
+ sqrt_q = sqrti(size);
|
|
|
|
+ if (rank == 0) {
|
|
|
|
+ // load matrix data
|
|
|
|
+ FILE* fp = fopen("data.txt", "rt");
|
|
|
|
+ fscanf(fp, "%d", &n);
|
|
|
|
+ mat_a = new float[n * n];
|
|
|
|
+ mat_b = new float[n * n];
|
|
|
|
+ mat_c = new float[n * n];
|
|
|
|
+ MatWrap ma(mat_a, n);
|
|
|
|
+ MatWrap mb(mat_b, n);
|
|
|
|
+ for (int i = 0; i < n; ++i) {
|
|
|
|
+ for (int j = 0; j < n; ++j) {
|
|
|
|
+ fscanf(fp, "%f", ma.split_map(sqrt_q, i, j));
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ for (int i = 0; i < n; ++i) {
|
|
|
|
+ for (int j = 0; j < n; ++j) {
|
|
|
|
+ fscanf(fp, "%f", mb.split_map(sqrt_q, i, j));
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ // broadcast matrix size
|
|
|
|
+ MPI_Bcast(&n, 1, MPI_INT, 0, MPI_COMM_WORLD);
|
|
|
|
+ assert(n % sqrt_q == 0);
|
|
|
|
+ sub_n = n / sqrt_q;
|
|
|
|
+ // alloc space for submatrix
|
|
|
|
+ sub_mat_a = new float[sub_n * sub_n];
|
|
|
|
+ sub_mat_b = new float[sub_n * sub_n];
|
|
|
|
+ sub_mat_c = new float[sub_n * sub_n];
|
|
|
|
+ sub_mat_comm = new float[sub_n * sub_n];
|
|
|
|
+ MatWrap sub_comm(sub_mat_comm, sub_n);
|
|
|
|
+ MatWrap sub_b(sub_mat_b, sub_n);
|
|
|
|
+ MatWrap sub_c(sub_mat_c, sub_n);
|
|
|
|
+ // init sub_c
|
|
|
|
+ for (int i = 0; i < sub_n; ++i) {
|
|
|
|
+ for (int j = 0; j < sub_n; ++j) {
|
|
|
|
+ sub_c[i][j] = 0.0f;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ // broadcast sub matrix
|
|
|
|
+ MPI_Scatter(mat_a, sub_n * sub_n, MPI_FLOAT, sub_mat_a, sub_n * sub_n,
|
|
|
|
+ MPI_FLOAT, 0, MPI_COMM_WORLD);
|
|
|
|
+ MPI_Scatter(mat_b, sub_n * sub_n, MPI_FLOAT, sub_mat_b, sub_n * sub_n,
|
|
|
|
+ MPI_FLOAT, 0, MPI_COMM_WORLD);
|
|
|
|
+ // split comm in col and row
|
|
|
|
+ MPI_Comm col_world, row_world;
|
|
|
|
+ int col_rank = rank % sqrt_q;
|
|
|
|
+ int row_rank = rank / sqrt_q;
|
|
|
|
+ MPI_Comm_split(MPI_COMM_WORLD, col_rank, row_rank, &col_world);
|
|
|
|
+ MPI_Comm_split(MPI_COMM_WORLD, row_rank, col_rank, &row_world);
|
|
|
|
+ // compute
|
|
|
|
+ for (int i = 0; i < sqrt_q; ++i) {
|
|
|
|
+ // broadcast sub_a
|
|
|
|
+ int send_root = (row_rank + i) % sqrt_q;
|
|
|
|
+ if (col_rank == (row_rank + i) % sqrt_q) {
|
|
|
|
+ memcpy(sub_mat_comm, sub_mat_a, sub_n * sub_n * sizeof(float));
|
|
|
|
+ }
|
|
|
|
+ MPI_Bcast(sub_mat_comm, sub_n * sub_n, MPI_FLOAT, send_root, row_world);
|
|
|
|
+ // calculate sub mat gemm
|
|
|
|
+ MatMultAdd(sub_comm, sub_b, sub_c);
|
|
|
|
+ // swap sub_b
|
|
|
|
+ MPI_Sendrecv_replace(
|
|
|
|
+ sub_mat_b, sub_n * sub_n, MPI_FLOAT, (row_rank + sqrt_q - 1) % sqrt_q,
|
|
|
|
+ 1, (row_rank + 1) % sqrt_q, 1, col_world, MPI_STATUS_IGNORE);
|
|
|
|
+ }
|
|
|
|
+ // gather result
|
|
|
|
+ MPI_Gather(sub_mat_c, sub_n * sub_n, MPI_FLOAT, mat_c, sub_n * sub_n,
|
|
|
|
+ MPI_FLOAT, 0, MPI_COMM_WORLD);
|
|
|
|
+ // print result
|
|
|
|
+ if (rank == 0) {
|
|
|
|
+ MatWrap mc(mat_c, n);
|
|
|
|
+ for (int i = 0; i < n; ++i) {
|
|
|
|
+ for (int j = 0; j < n; ++j) {
|
|
|
|
+ printf("%9.5f", *mc.split_map(sqrt_q, i, j));
|
|
|
|
+ }
|
|
|
|
+ printf("\n");
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+ MPI_Finalize();
|
|
return 0;
|
|
return 0;
|
|
}
|
|
}
|