2_12.cc 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. #include <cmath>
  2. #include <iostream>
  3. #include <random>
  4. #include <vector>
  5. const std::vector<std::vector<double>> train_data =
  6. {{0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0}, {1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0}, {1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0}};
  7. const std::vector<std::vector<double>> train_label =
  8. {{0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 0.0, 1.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 0.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 1.0}, {0.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}};
  9. const std::vector<std::vector<double>> val_data =
  10. {{1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0}, {0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0}, {0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0}};
  11. const std::vector<std::vector<double>> val_label =
  12. {{1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 0.0, 0.0}};
  13. const double learning_rate = 0.5f;
  14. const int training_epochs = 500;
  15. const int n_hidden_1 = 17;
  16. const int n_hidden_2 = 9;
  17. const int n_input = 8;
  18. const int n_output = 5;
  19. // weights
  20. std::vector<std::vector<double>> w1(n_hidden_1, std::vector<double>(n_input));
  21. std::vector<std::vector<double>> w2(n_hidden_2, std::vector<double>(n_hidden_1));
  22. std::vector<std::vector<double>> w3(n_output, std::vector<double>(n_hidden_2));
  23. // biases
  24. std::vector<double> b1(n_hidden_1);
  25. std::vector<double> b2(n_hidden_2);
  26. std::vector<double> b3(n_output);
  27. inline double inner_prod(const std::vector<double> &lhs, const std::vector<double> &rhs)
  28. {
  29. double ret = 0.0f;
  30. for (int i = 0; i < lhs.size(); ++i)
  31. {
  32. ret += lhs[i] * rhs[i];
  33. }
  34. return ret;
  35. }
  36. inline double sigmoid(double x)
  37. {
  38. return 1.0f / (1.0f + std::exp(-x));
  39. }
  40. std::vector<double> mlp_forward(const std::vector<double> &input)
  41. {
  42. std::vector<double> z1(n_hidden_1);
  43. for (int i = 0; i < n_hidden_1; ++i)
  44. {
  45. z1[i] = sigmoid(inner_prod(input, w1[i]) + b1[i]);
  46. }
  47. std::vector<double> z2(n_hidden_2);
  48. for (int i = 0; i < n_hidden_2; ++i)
  49. {
  50. z2[i] = sigmoid(inner_prod(z1, w2[i]) + b2[i]);
  51. }
  52. std::vector<double> out(n_output);
  53. for (int i = 0; i < n_output; ++i)
  54. {
  55. out[i] = sigmoid(inner_prod(z2, w3[i]) + b3[i]);
  56. }
  57. return out;
  58. }
  59. double loss(const std::vector<double> &h,
  60. const std::vector<double> &y)
  61. {
  62. double ret = 0.0f;
  63. for (int i = 0; i < h.size(); ++i)
  64. {
  65. ret += (h[i] - y[i]) * (h[i] - y[i]);
  66. }
  67. return ret * 0.5f;
  68. }
  69. double accuracy(const std::vector<double> &h,
  70. const std::vector<double> &y)
  71. {
  72. double ret;
  73. int count = 0;
  74. for (int i = 0; i < h.size(); ++i)
  75. {
  76. double t = h[i] < 0.5f ? 0.0f : 1.0f;
  77. if (abs(t - y[i]) < 1e-9)
  78. {
  79. count++;
  80. }
  81. }
  82. return count / (double)h.size();
  83. }
  84. void mlp_backward(const std::vector<double> &input,
  85. const std::vector<double> &label)
  86. {
  87. // weights
  88. std::vector<std::vector<double>> d_w1(n_hidden_1, std::vector<double>(n_input));
  89. std::vector<std::vector<double>> d_w2(n_hidden_2, std::vector<double>(n_hidden_1));
  90. std::vector<std::vector<double>> d_w3(n_output, std::vector<double>(n_hidden_2));
  91. // biases
  92. std::vector<double> d_b1(n_hidden_1);
  93. std::vector<double> d_b2(n_hidden_2);
  94. std::vector<double> d_b3(n_output);
  95. for (int i = 0; i < n_hidden_1; ++i)
  96. {
  97. double l0 = loss(mlp_forward(input), label);
  98. b1[i] += 1e-9;
  99. double l = loss(mlp_forward(input), label);
  100. b1[i] -= 1e-9;
  101. d_b1[i] = (l - l0) / 1e-9 * learning_rate;
  102. for (int j = 0; j < n_input; ++j)
  103. {
  104. w1[i][j] += 1e-9;
  105. double l = loss(mlp_forward(input), label);
  106. w1[i][j] -= 1e-9;
  107. d_w1[i][j] = (l - l0) / 1e-9 * learning_rate;
  108. }
  109. }
  110. for (int i = 0; i < n_hidden_2; ++i)
  111. {
  112. double l0 = loss(mlp_forward(input), label);
  113. b2[i] += 1e-9;
  114. double l = loss(mlp_forward(input), label);
  115. b2[i] -= 1e-9;
  116. d_b2[i] = (l - l0) / 1e-9 * learning_rate;
  117. for (int j = 0; j < n_hidden_1; ++j)
  118. {
  119. w2[i][j] += 1e-9;
  120. double l = loss(mlp_forward(input), label);
  121. w2[i][j] -= 1e-9;
  122. d_w2[i][j] = (l - l0) / 1e-9 * learning_rate;
  123. }
  124. }
  125. for (int i = 0; i < n_output; ++i)
  126. {
  127. double l0 = loss(mlp_forward(input), label);
  128. b3[i] += 1e-9;
  129. double l = loss(mlp_forward(input), label);
  130. b3[i] -= 1e-9;
  131. d_b3[i] = (l - l0) / 1e-9 * learning_rate;
  132. for (int j = 0; j < n_hidden_2; ++j)
  133. {
  134. w3[i][j] += 1e-9;
  135. double l = loss(mlp_forward(input), label);
  136. w3[i][j] -= 1e-9;
  137. d_w3[i][j] = (l - l0) / 1e-9 * learning_rate;
  138. }
  139. }
  140. for (int i = 0; i < n_hidden_1; ++i)
  141. {
  142. b1[i] -= d_b1[i];
  143. for (int j = 0; j < n_input; ++j)
  144. {
  145. w1[i][j] -= d_w1[i][j];
  146. }
  147. }
  148. for (int i = 0; i < n_hidden_2; ++i)
  149. {
  150. b2[i] -= d_b2[i];
  151. for (int j = 0; j < n_hidden_1; ++j)
  152. {
  153. w2[i][j] -= d_w2[i][j];
  154. }
  155. }
  156. for (int i = 0; i < n_output; ++i)
  157. {
  158. b3[i] -= d_b3[i];
  159. for (int j = 0; j < n_hidden_2; ++j)
  160. {
  161. w3[i][j] -= d_w3[i][j];
  162. }
  163. }
  164. }
  165. int main()
  166. {
  167. // init
  168. std::random_device rd;
  169. std::mt19937 gen(rd());
  170. std::uniform_real_distribution<> dis(0.0, 1.0);
  171. for (int i = 0; i < n_hidden_1; ++i)
  172. {
  173. b1[i] = dis(gen);
  174. for (int j = 0; j < n_input; ++j)
  175. {
  176. w1[i][j] = dis(gen);
  177. }
  178. }
  179. for (int i = 0; i < n_hidden_2; ++i)
  180. {
  181. b2[i] = dis(gen);
  182. for (int j = 0; j < n_hidden_1; ++j)
  183. {
  184. w2[i][j] = dis(gen);
  185. }
  186. }
  187. for (int i = 0; i < n_output; ++i)
  188. {
  189. b3[i] = dis(gen);
  190. for (int j = 0; j < n_hidden_2; ++j)
  191. {
  192. w3[i][j] = dis(gen);
  193. }
  194. }
  195. // train
  196. int n = train_label.size();
  197. int m = val_label.size();
  198. for (int i = 0; i < training_epochs; ++i)
  199. {
  200. for (int j = 0; j < n; ++j)
  201. {
  202. mlp_backward(train_data[j], train_label[j]);
  203. }
  204. double acc = 0.0f;
  205. for (int j = 0; j < n; ++j)
  206. {
  207. acc += accuracy(mlp_forward(train_data[j]), train_label[j]);
  208. }
  209. std::cout << "Train Accuracy:" << acc / n << " ";
  210. acc = 0.0f;
  211. for (int j = 0; j < m; ++j)
  212. {
  213. acc += accuracy(mlp_forward(val_data[j]), val_label[j]);
  214. }
  215. std::cout << "Val Accuracy:" << acc / m << std::endl;
  216. }
  217. }