pp11 5 年之前
父節點
當前提交
bbc9195c6a
共有 3 個文件被更改,包括 169 次插入1 次删除
  1. 12 0
      b/data.py
  2. 13 0
      b/data.txt
  3. 144 1
      b/fox.cc

+ 12 - 0
b/data.py

@@ -0,0 +1,12 @@
+#!/usr/bin/env python3
+import numpy as np
+import sys
+n = int(sys.argv[1])
+f = open('data.txt', 'wt')
+f.write('%d\n'%n)
+a = np.random.random(size=(n, n))
+b = np.random.random(size=(n, n))
+np.savetxt(f, a, fmt='%.5f')
+np.savetxt(f, b, fmt='%.5f')
+np.savetxt(f, np.mat(a) * np.mat(b), fmt='%.5f')
+f.close()

+ 13 - 0
b/data.txt

@@ -0,0 +1,13 @@
+4
+0.61314 0.90078 0.31830 0.55924
+0.97482 0.69245 0.83941 0.05200
+0.20574 0.91299 0.30536 0.01616
+0.22978 0.48120 0.83645 0.15133
+0.11541 0.91907 0.36166 0.44359
+0.50170 0.84402 0.03947 0.48543
+0.10534 0.80784 0.24040 0.17560
+0.26392 0.96701 0.49676 0.90167
+0.70381 2.12171 0.61163 1.26940
+0.56205 2.20876 0.60751 0.96285
+0.51823 1.22198 0.19188 0.60266
+0.39599 1.43938 0.37835 0.61886

+ 144 - 1
b/fox.cc

@@ -1,7 +1,150 @@
 #include <mpi.h>
+#include <cassert>
+#include <cmath>
+#include <cstdio>
+#include <cstring>
 #include <iostream>
+#include <vector>
 
-int main() {
+inline int sqrti(int n) {
+  float x = std::sqrt(n);
+  int r = static_cast<int>(x);
+  if (std::abs(r - x) < 1e-9) {
+    return r;
+  } else {
+    assert(0);
+  }
+}
 
+class MatWrap {
+ private:
+  float* data_;
+  int n_;
+
+ public:
+  MatWrap() {}
+  MatWrap(float* data, int n) : data_(data), n_(n) {}
+  float* operator[](size_t n) const { return data_ + n * n_; }
+  float* split_map(int sqrt_q, int i, int j) const {
+    int sub_n = n_ / sqrt_q;
+    return data_ + (i / sub_n * sqrt_q + j / sub_n) * sub_n * sub_n +
+           i % sub_n * sub_n + j % sub_n;
+  }
+  void print(bool mapping = false, int sqrt_q = -1) const {
+    for (int i = 0; i < n_; ++i) {
+      for (int j = 0; j < n_; ++j) {
+        if (mapping) {
+          printf("%9.4f", *(split_map(sqrt_q, i, j)));
+        } else {
+          printf("%9.4f", operator[](i)[j]);
+        }
+      }
+      printf("\n");
+    }
+  }
+  friend void MatMultAdd(const MatWrap& a, const MatWrap& b, MatWrap& c);
+};
+
+void MatMultAdd(const MatWrap& a, const MatWrap& b, MatWrap& c) {
+  assert(a.n_ == b.n_);
+  assert(a.n_ == c.n_);
+  int n = a.n_;
+  for (int i = 0; i < n; ++i) {
+    for (int j = 0; j < n; ++j) {
+      for (int k = 0; k < n; ++k) {
+        c[i][j] += a[i][k] * b[k][j];
+      }
+    }
+  }
+}
+
+int main(int argc, char** argv) {
+  int rank, size;
+  int n, sqrt_q, sub_n;
+  float *mat_a, *mat_b, *mat_c;
+  float* sub_mat_comm;
+  float *sub_mat_a, *sub_mat_b, *sub_mat_c;
+  MPI_Init(&argc, &argv);
+  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
+  MPI_Comm_size(MPI_COMM_WORLD, &size);
+  sqrt_q = sqrti(size);
+  if (rank == 0) {
+    // load matrix data
+    FILE* fp = fopen("data.txt", "rt");
+    fscanf(fp, "%d", &n);
+    mat_a = new float[n * n];
+    mat_b = new float[n * n];
+    mat_c = new float[n * n];
+    MatWrap ma(mat_a, n);
+    MatWrap mb(mat_b, n);
+    for (int i = 0; i < n; ++i) {
+      for (int j = 0; j < n; ++j) {
+        fscanf(fp, "%f", ma.split_map(sqrt_q, i, j));
+      }
+    }
+    for (int i = 0; i < n; ++i) {
+      for (int j = 0; j < n; ++j) {
+        fscanf(fp, "%f", mb.split_map(sqrt_q, i, j));
+      }
+    }
+  }
+  // broadcast matrix size
+  MPI_Bcast(&n, 1, MPI_INT, 0, MPI_COMM_WORLD);
+  assert(n % sqrt_q == 0);
+  sub_n = n / sqrt_q;
+  // alloc space for submatrix
+  sub_mat_a = new float[sub_n * sub_n];
+  sub_mat_b = new float[sub_n * sub_n];
+  sub_mat_c = new float[sub_n * sub_n];
+  sub_mat_comm = new float[sub_n * sub_n];
+  MatWrap sub_comm(sub_mat_comm, sub_n);
+  MatWrap sub_b(sub_mat_b, sub_n);
+  MatWrap sub_c(sub_mat_c, sub_n);
+  // init sub_c
+  for (int i = 0; i < sub_n; ++i) {
+    for (int j = 0; j < sub_n; ++j) {
+      sub_c[i][j] = 0.0f;
+    }
+  }
+  // broadcast sub matrix
+  MPI_Scatter(mat_a, sub_n * sub_n, MPI_FLOAT, sub_mat_a, sub_n * sub_n,
+              MPI_FLOAT, 0, MPI_COMM_WORLD);
+  MPI_Scatter(mat_b, sub_n * sub_n, MPI_FLOAT, sub_mat_b, sub_n * sub_n,
+              MPI_FLOAT, 0, MPI_COMM_WORLD);
+  // split comm in col and row
+  MPI_Comm col_world, row_world;
+  int col_rank = rank % sqrt_q;
+  int row_rank = rank / sqrt_q;
+  MPI_Comm_split(MPI_COMM_WORLD, col_rank, row_rank, &col_world);
+  MPI_Comm_split(MPI_COMM_WORLD, row_rank, col_rank, &row_world);
+  // compute
+  for (int i = 0; i < sqrt_q; ++i) {
+    // broadcast sub_a
+    int send_root = (row_rank + i) % sqrt_q;
+    if (col_rank == (row_rank + i) % sqrt_q) {
+      memcpy(sub_mat_comm, sub_mat_a, sub_n * sub_n * sizeof(float));
+    }
+    MPI_Bcast(sub_mat_comm, sub_n * sub_n, MPI_FLOAT, send_root, row_world);
+    // calculate sub mat gemm
+    MatMultAdd(sub_comm, sub_b, sub_c);
+    // swap sub_b
+    MPI_Sendrecv_replace(
+        sub_mat_b, sub_n * sub_n, MPI_FLOAT, (row_rank + sqrt_q - 1) % sqrt_q,
+        1, (row_rank + 1) % sqrt_q, 1, col_world, MPI_STATUS_IGNORE);
+  }
+  // gather result
+  MPI_Gather(sub_mat_c, sub_n * sub_n, MPI_FLOAT, mat_c, sub_n * sub_n,
+             MPI_FLOAT, 0, MPI_COMM_WORLD);
+  // print result
+  if (rank == 0) {
+    MatWrap mc(mat_c, n);
+    for (int i = 0; i < n; ++i) {
+      for (int j = 0; j < n; ++j) {
+        printf("%9.5f", *mc.split_map(sqrt_q, i, j));
+      }
+      printf("\n");
+    }
+  }
+  MPI_Finalize();
   return 0;
 }