2_12.cc 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. #include <cmath>
  2. #include <iostream>
  3. #include <random>
  4. #include <vector>
  5. const std::vector<std::vector<double>> train_data = {
  6. {0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0},
  7. {1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0},
  8. {1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0},
  9. {1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0},
  10. {1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0},
  11. {1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0},
  12. {0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0},
  13. {1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0},
  14. {0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0},
  15. {0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0},
  16. {1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0},
  17. {1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0},
  18. {1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0},
  19. {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0},
  20. {0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0},
  21. {1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0},
  22. {0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0},
  23. {1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0},
  24. {0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0},
  25. {1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0},
  26. {1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0},
  27. {1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0},
  28. {1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0},
  29. {0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0},
  30. {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0},
  31. {1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0},
  32. {1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0},
  33. {1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0},
  34. {1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0},
  35. {0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0},
  36. {0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0},
  37. {1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0},
  38. {0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0},
  39. {0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0},
  40. {1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0},
  41. {1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0},
  42. {1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0},
  43. {0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0},
  44. {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0},
  45. {0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0},
  46. {1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0},
  47. {0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0},
  48. {1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0},
  49. {0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0},
  50. {1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0},
  51. {1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0},
  52. {0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0},
  53. {1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0},
  54. {0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0},
  55. {0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0},
  56. {1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0},
  57. {0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0},
  58. {0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0},
  59. {0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0},
  60. {1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0},
  61. {0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0},
  62. {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
  63. {0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0},
  64. {1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0},
  65. {1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0},
  66. {0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0},
  67. {0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0},
  68. {1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0},
  69. {0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0},
  70. {0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0},
  71. {0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0},
  72. {0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0},
  73. {1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0},
  74. {0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0},
  75. {1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0},
  76. {1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0},
  77. {1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0},
  78. {0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0},
  79. {0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0},
  80. {1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0},
  81. {1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0},
  82. {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0},
  83. {0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0},
  84. {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0},
  85. {0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0},
  86. {0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0},
  87. {0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0},
  88. {1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0},
  89. {1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0},
  90. {0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0},
  91. {1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0},
  92. {0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0},
  93. {1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0},
  94. {1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0},
  95. {0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0},
  96. {1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0},
  97. {0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0},
  98. {1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0},
  99. {1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0},
  100. {1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0},
  101. {1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0},
  102. {0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0},
  103. {1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0},
  104. {0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0},
  105. {0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0},
  106. {1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0},
  107. {1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0},
  108. {0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0},
  109. {0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0},
  110. {1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0},
  111. {0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0},
  112. {0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0},
  113. {0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0},
  114. {1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0},
  115. {0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0},
  116. {1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0},
  117. {0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0},
  118. {0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0},
  119. {1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0},
  120. {0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0},
  121. {0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0},
  122. {0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0},
  123. {0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0},
  124. {1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0},
  125. {1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0},
  126. {0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0},
  127. {0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0},
  128. {0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0},
  129. {0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0},
  130. {1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0},
  131. {0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0},
  132. {1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
  133. {0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0},
  134. {1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0},
  135. {1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0},
  136. {0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0},
  137. {1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0},
  138. {1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0},
  139. {0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0},
  140. {1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0},
  141. {0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0},
  142. {0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0},
  143. {0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0},
  144. {1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0},
  145. {0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0},
  146. {1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0},
  147. {1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0},
  148. {0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0},
  149. {0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0},
  150. {0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0},
  151. {1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0},
  152. {0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0},
  153. {1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0},
  154. {1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0},
  155. {1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0},
  156. {0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0},
  157. {0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0},
  158. {0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0},
  159. {0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0},
  160. {1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0},
  161. {0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0},
  162. {0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0},
  163. {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0},
  164. {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0},
  165. {1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0},
  166. {0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
  167. {0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0},
  168. {0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0},
  169. {0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0},
  170. {0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0},
  171. {1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0},
  172. {1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0},
  173. {1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0},
  174. {1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0},
  175. {1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0},
  176. {0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0},
  177. {0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0},
  178. {0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0},
  179. {1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0},
  180. {1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0},
  181. {1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0},
  182. {0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0},
  183. {0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0},
  184. {1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0},
  185. {1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0},
  186. {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0},
  187. {1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0},
  188. {1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0},
  189. {1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0},
  190. {1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0},
  191. {1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0},
  192. {0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0},
  193. {0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0},
  194. {1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0},
  195. {1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0},
  196. {1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0},
  197. {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0},
  198. {0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0},
  199. {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
  200. {0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0},
  201. {1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0},
  202. {1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0},
  203. {0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0},
  204. {1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0},
  205. {1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0},
  206. {0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0},
  207. {1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0},
  208. {0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0},
  209. {0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0}};
  210. const std::vector<std::vector<double>> train_label = {
  211. {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0},
  212. {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0},
  213. {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 0.0, 0.0},
  214. {0.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0},
  215. {0.0, 1.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 1.0},
  216. {1.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 1.0, 0.0, 0.0, 1.0},
  217. {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 0.0, 1.0},
  218. {0.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0},
  219. {0.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0},
  220. {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0},
  221. {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0},
  222. {1.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 1.0},
  223. {0.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0},
  224. {0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 1.0},
  225. {0.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 0.0},
  226. {0.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 1.0},
  227. {1.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 1.0},
  228. {1.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0},
  229. {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 0.0},
  230. {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0},
  231. {1.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 1.0},
  232. {0.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 1.0},
  233. {0.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 0.0},
  234. {0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 0.0, 1.0},
  235. {0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0},
  236. {1.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0},
  237. {0.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 1.0},
  238. {1.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 1.0},
  239. {0.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0},
  240. {1.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0},
  241. {0.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 0.0},
  242. {1.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 1.0},
  243. {0.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 0.0},
  244. {0.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 0.0},
  245. {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 0.0},
  246. {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 0.0, 0.0},
  247. {1.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 0.0},
  248. {1.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0},
  249. {0.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 0.0},
  250. {0.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 0.0, 1.0, 1.0},
  251. {0.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 0.0, 0.0, 1.0, 1.0},
  252. {0.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 1.0},
  253. {0.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 1.0},
  254. {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 0.0, 0.0},
  255. {0.0, 1.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 0.0},
  256. {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0},
  257. {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 1.0},
  258. {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 1.0},
  259. {0.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0},
  260. {1.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 1.0},
  261. {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 1.0},
  262. {0.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 1.0},
  263. {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 1.0},
  264. {0.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0},
  265. {1.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 0.0},
  266. {0.0, 1.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 1.0},
  267. {0.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 0.0},
  268. {0.0, 1.0, 1.0, 1.0, 1.0}, {0.0, 0.0, 1.0, 0.0, 1.0},
  269. {0.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 0.0},
  270. {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 0.0},
  271. {0.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 0.0},
  272. {0.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 1.0},
  273. {0.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 0.0},
  274. {0.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 1.0},
  275. {1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 1.0},
  276. {0.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 1.0},
  277. {1.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 0.0},
  278. {1.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 1.0},
  279. {1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 1.0},
  280. {1.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 1.0, 0.0},
  281. {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0},
  282. {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 0.0},
  283. {0.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 0.0},
  284. {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0},
  285. {1.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 0.0},
  286. {1.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 1.0},
  287. {0.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 0.0},
  288. {1.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 1.0},
  289. {0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 0.0},
  290. {0.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 0.0},
  291. {0.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 0.0},
  292. {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 0.0},
  293. {1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 0.0},
  294. {1.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 0.0, 1.0},
  295. {1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 1.0},
  296. {0.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0},
  297. {0.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 0.0},
  298. {1.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0},
  299. {0.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0},
  300. {1.0, 0.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0},
  301. {0.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 1.0},
  302. {1.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 1.0},
  303. {1.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 0.0},
  304. {1.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 1.0, 1.0},
  305. {1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0},
  306. {1.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 0.0},
  307. {1.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0},
  308. {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 0.0, 1.0},
  309. {1.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 0.0},
  310. {0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 1.0},
  311. {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 0.0},
  312. {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}};
  313. const std::vector<std::vector<double>> val_data = {
  314. {1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0},
  315. {0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0},
  316. {1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0},
  317. {0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0},
  318. {0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0},
  319. {1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0},
  320. {0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0},
  321. {1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0},
  322. {1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0},
  323. {0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0},
  324. {0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0},
  325. {0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0},
  326. {0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0},
  327. {1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0},
  328. {0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0},
  329. {1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0},
  330. {1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0},
  331. {1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0},
  332. {1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0},
  333. {0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0},
  334. {0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0},
  335. {1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0},
  336. {1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0},
  337. {0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0},
  338. {0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0},
  339. {0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0},
  340. {0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0},
  341. {0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0},
  342. {0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0},
  343. {0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0},
  344. {1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0},
  345. {0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0},
  346. {0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0},
  347. {0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0},
  348. {0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0},
  349. {0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0},
  350. {1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0},
  351. {1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0},
  352. {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0},
  353. {1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0},
  354. {1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0},
  355. {0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0},
  356. {1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0},
  357. {1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0},
  358. {0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0},
  359. {1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0},
  360. {0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0},
  361. {0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0},
  362. {0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0},
  363. {1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0},
  364. {1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0},
  365. {1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0}};
  366. const std::vector<std::vector<double>> val_label = {
  367. {1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 1.0},
  368. {1.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 0.0},
  369. {0.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 1.0, 1.0, 0.0, 0.0},
  370. {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0},
  371. {1.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0},
  372. {0.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 1.0},
  373. {0.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0},
  374. {0.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 0.0},
  375. {1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0},
  376. {1.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 1.0},
  377. {0.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 1.0},
  378. {0.0, 1.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 1.0},
  379. {0.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 0.0},
  380. {0.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0},
  381. {1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 1.0},
  382. {0.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 0.0},
  383. {1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0},
  384. {0.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 1.0},
  385. {1.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0},
  386. {1.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0},
  387. {1.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 1.0},
  388. {1.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 1.0},
  389. {0.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 0.0},
  390. {0.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 1.0},
  391. {0.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 0.0},
  392. {1.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 0.0, 0.0}};
  393. const double learning_rate = 0.5f;
  394. const int training_epochs = 500;
  395. const int n_hidden_1 = 17;
  396. const int n_hidden_2 = 9;
  397. const int n_input = 8;
  398. const int n_output = 5;
  399. // weights
  400. std::vector<std::vector<double>> w1(n_hidden_1, std::vector<double>(n_input));
  401. std::vector<std::vector<double>> w2(n_hidden_2,
  402. std::vector<double>(n_hidden_1));
  403. std::vector<std::vector<double>> w3(n_output, std::vector<double>(n_hidden_2));
  404. // biases
  405. std::vector<double> b1(n_hidden_1);
  406. std::vector<double> b2(n_hidden_2);
  407. std::vector<double> b3(n_output);
  408. inline double inner_prod(const std::vector<double> &lhs,
  409. const std::vector<double> &rhs) {
  410. double ret = 0.0f;
  411. for (int i = 0; i < lhs.size(); ++i) {
  412. ret += lhs[i] * rhs[i];
  413. }
  414. return ret;
  415. }
  416. inline double sigmoid(double x) { return 1.0f / (1.0f + std::exp(-x)); }
  417. std::vector<double> mlp_forward(const std::vector<double> &input) {
  418. std::vector<double> z1(n_hidden_1);
  419. for (int i = 0; i < n_hidden_1; ++i) {
  420. z1[i] = sigmoid(inner_prod(input, w1[i]) + b1[i]);
  421. }
  422. std::vector<double> z2(n_hidden_2);
  423. for (int i = 0; i < n_hidden_2; ++i) {
  424. z2[i] = sigmoid(inner_prod(z1, w2[i]) + b2[i]);
  425. }
  426. std::vector<double> out(n_output);
  427. for (int i = 0; i < n_output; ++i) {
  428. out[i] = sigmoid(inner_prod(z2, w3[i]) + b3[i]);
  429. }
  430. return out;
  431. }
  432. double loss(const std::vector<double> &h, const std::vector<double> &y) {
  433. double ret = 0.0f;
  434. for (int i = 0; i < h.size(); ++i) {
  435. ret += (h[i] - y[i]) * (h[i] - y[i]);
  436. }
  437. return ret * 0.5f;
  438. }
  439. double accuracy(const std::vector<double> &h, const std::vector<double> &y) {
  440. double ret;
  441. int count = 0;
  442. for (int i = 0; i < h.size(); ++i) {
  443. double t = h[i] < 0.5f ? 0.0f : 1.0f;
  444. if (abs(t - y[i]) < 1e-9) {
  445. count++;
  446. }
  447. }
  448. return count / (double)h.size();
  449. }
  450. void mlp_backward(const std::vector<double> &input,
  451. const std::vector<double> &label) {
  452. // weights
  453. std::vector<std::vector<double>> d_w1(n_hidden_1,
  454. std::vector<double>(n_input));
  455. std::vector<std::vector<double>> d_w2(n_hidden_2,
  456. std::vector<double>(n_hidden_1));
  457. std::vector<std::vector<double>> d_w3(n_output,
  458. std::vector<double>(n_hidden_2));
  459. // biases
  460. std::vector<double> d_b1(n_hidden_1);
  461. std::vector<double> d_b2(n_hidden_2);
  462. std::vector<double> d_b3(n_output);
  463. for (int i = 0; i < n_hidden_1; ++i) {
  464. double l0 = loss(mlp_forward(input), label);
  465. b1[i] += 1e-9;
  466. double l = loss(mlp_forward(input), label);
  467. b1[i] -= 1e-9;
  468. d_b1[i] = (l - l0) / 1e-9 * learning_rate;
  469. for (int j = 0; j < n_input; ++j) {
  470. w1[i][j] += 1e-9;
  471. double l = loss(mlp_forward(input), label);
  472. w1[i][j] -= 1e-9;
  473. d_w1[i][j] = (l - l0) / 1e-9 * learning_rate;
  474. }
  475. }
  476. for (int i = 0; i < n_hidden_2; ++i) {
  477. double l0 = loss(mlp_forward(input), label);
  478. b2[i] += 1e-9;
  479. double l = loss(mlp_forward(input), label);
  480. b2[i] -= 1e-9;
  481. d_b2[i] = (l - l0) / 1e-9 * learning_rate;
  482. for (int j = 0; j < n_hidden_1; ++j) {
  483. w2[i][j] += 1e-9;
  484. double l = loss(mlp_forward(input), label);
  485. w2[i][j] -= 1e-9;
  486. d_w2[i][j] = (l - l0) / 1e-9 * learning_rate;
  487. }
  488. }
  489. for (int i = 0; i < n_output; ++i) {
  490. double l0 = loss(mlp_forward(input), label);
  491. b3[i] += 1e-9;
  492. double l = loss(mlp_forward(input), label);
  493. b3[i] -= 1e-9;
  494. d_b3[i] = (l - l0) / 1e-9 * learning_rate;
  495. for (int j = 0; j < n_hidden_2; ++j) {
  496. w3[i][j] += 1e-9;
  497. double l = loss(mlp_forward(input), label);
  498. w3[i][j] -= 1e-9;
  499. d_w3[i][j] = (l - l0) / 1e-9 * learning_rate;
  500. }
  501. }
  502. for (int i = 0; i < n_hidden_1; ++i) {
  503. b1[i] -= d_b1[i];
  504. for (int j = 0; j < n_input; ++j) {
  505. w1[i][j] -= d_w1[i][j];
  506. }
  507. }
  508. for (int i = 0; i < n_hidden_2; ++i) {
  509. b2[i] -= d_b2[i];
  510. for (int j = 0; j < n_hidden_1; ++j) {
  511. w2[i][j] -= d_w2[i][j];
  512. }
  513. }
  514. for (int i = 0; i < n_output; ++i) {
  515. b3[i] -= d_b3[i];
  516. for (int j = 0; j < n_hidden_2; ++j) {
  517. w3[i][j] -= d_w3[i][j];
  518. }
  519. }
  520. }
  521. int main() {
  522. // init
  523. std::random_device rd;
  524. std::mt19937 gen(rd());
  525. std::uniform_real_distribution<> dis(-1.0, 1.0);
  526. for (int i = 0; i < n_hidden_1; ++i) {
  527. b1[i] = dis(gen);
  528. for (int j = 0; j < n_input; ++j) {
  529. w1[i][j] = dis(gen);
  530. }
  531. }
  532. for (int i = 0; i < n_hidden_2; ++i) {
  533. b2[i] = dis(gen);
  534. for (int j = 0; j < n_hidden_1; ++j) {
  535. w2[i][j] = dis(gen);
  536. }
  537. }
  538. for (int i = 0; i < n_output; ++i) {
  539. b3[i] = dis(gen);
  540. for (int j = 0; j < n_hidden_2; ++j) {
  541. w3[i][j] = dis(gen);
  542. }
  543. }
  544. // train
  545. int n = train_label.size();
  546. int m = val_label.size();
  547. for (int i = 0; i < training_epochs; ++i) {
  548. for (int j = 0; j < n; ++j) {
  549. mlp_backward(train_data[j], train_label[j]);
  550. }
  551. double acc = 0.0f;
  552. for (int j = 0; j < n; ++j) {
  553. acc += accuracy(mlp_forward(train_data[j]), train_label[j]);
  554. }
  555. std::cout << "Train Accuracy:" << acc / n << " ";
  556. acc = 0.0f;
  557. for (int j = 0; j < m; ++j) {
  558. acc += accuracy(mlp_forward(val_data[j]), val_label[j]);
  559. }
  560. std::cout << "Val Accuracy:" << acc / m << std::endl;
  561. }
  562. }