123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572 |
- #include <cmath>
- #include <iostream>
- #include <random>
- #include <vector>
- const std::vector<std::vector<double>> train_data = {
- {0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0},
- {1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0},
- {1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0},
- {1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0},
- {1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0},
- {1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0},
- {0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0},
- {1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0},
- {0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0},
- {0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0},
- {1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0},
- {1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0},
- {1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0},
- {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0},
- {0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0},
- {1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0},
- {0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0},
- {1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0},
- {0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0},
- {1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0},
- {1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0},
- {1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0},
- {1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0},
- {0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0},
- {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0},
- {1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0},
- {1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0},
- {1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0},
- {1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0},
- {0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0},
- {0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0},
- {1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0},
- {0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0},
- {0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0},
- {1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0},
- {1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0},
- {1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0},
- {0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0},
- {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0},
- {0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0},
- {1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0},
- {0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0},
- {1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0},
- {0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0},
- {1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0},
- {1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0},
- {0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0},
- {1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0},
- {0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0},
- {0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0},
- {1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0},
- {0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0},
- {0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0},
- {0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0},
- {1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0},
- {0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0},
- {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
- {0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0},
- {1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0},
- {1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0},
- {0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0},
- {0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0},
- {1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0},
- {0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0},
- {0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0},
- {0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0},
- {0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0},
- {1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0},
- {0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0},
- {1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0},
- {1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0},
- {1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0},
- {0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0},
- {0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0},
- {1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0},
- {1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0},
- {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0},
- {0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0},
- {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0},
- {0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0},
- {0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0},
- {0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0},
- {1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0},
- {1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0},
- {0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0},
- {1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0},
- {0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0},
- {1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0},
- {1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0},
- {0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0},
- {1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0},
- {0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0},
- {1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0},
- {1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0},
- {1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0},
- {1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0},
- {0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0},
- {1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0},
- {0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0},
- {0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0},
- {1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0},
- {1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0},
- {0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0},
- {0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0},
- {1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0},
- {0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0},
- {0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0},
- {0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0},
- {1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0},
- {0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0},
- {1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0},
- {0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0},
- {0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0},
- {1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0},
- {0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0},
- {0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0},
- {0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0},
- {0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0},
- {1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0},
- {1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0},
- {0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0},
- {0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0},
- {0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0},
- {0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0},
- {1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0},
- {0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0},
- {1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
- {0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0},
- {1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0},
- {1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0},
- {0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0},
- {1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0},
- {1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0},
- {0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0},
- {1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0},
- {0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0},
- {0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0},
- {0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0},
- {1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0},
- {0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0},
- {1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0},
- {1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0},
- {0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0},
- {0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0},
- {0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0},
- {1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0},
- {0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0},
- {1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0},
- {1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0},
- {1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0},
- {0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0},
- {0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0},
- {0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0},
- {0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0},
- {1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0},
- {0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0},
- {0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0},
- {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0},
- {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0},
- {1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0},
- {0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
- {0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0},
- {0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0},
- {0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0},
- {0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0},
- {1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0},
- {1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0},
- {1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0},
- {1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0},
- {1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0},
- {0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0},
- {0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0},
- {0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0},
- {1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0},
- {1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0},
- {1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0},
- {0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0},
- {0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0},
- {1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0},
- {1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0},
- {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0},
- {1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0},
- {1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0},
- {1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0},
- {1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0},
- {1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0},
- {0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0},
- {0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0},
- {1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0},
- {1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0},
- {1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0},
- {1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0},
- {0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0},
- {0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0},
- {0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0},
- {1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0},
- {1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0},
- {0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0},
- {1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0},
- {1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0},
- {0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0},
- {1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0},
- {0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0},
- {0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0}};
- const std::vector<std::vector<double>> train_label = {
- {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0},
- {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0},
- {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 0.0, 0.0},
- {0.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0},
- {0.0, 1.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 1.0},
- {1.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 1.0, 0.0, 0.0, 1.0},
- {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 0.0, 1.0},
- {0.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0},
- {0.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0},
- {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0},
- {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0},
- {1.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 1.0},
- {0.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0},
- {0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 1.0},
- {0.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 0.0},
- {0.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 1.0, 1.0},
- {1.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 1.0},
- {1.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0},
- {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 0.0},
- {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0},
- {1.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 1.0},
- {0.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 1.0},
- {0.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 0.0},
- {0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 0.0, 1.0},
- {0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0},
- {1.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0},
- {0.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 1.0},
- {1.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 1.0},
- {0.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0},
- {1.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0},
- {0.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 0.0},
- {1.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 1.0},
- {0.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 0.0},
- {0.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 0.0},
- {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 0.0},
- {1.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 0.0, 0.0},
- {1.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 0.0},
- {1.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0},
- {0.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 0.0},
- {0.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 0.0, 0.0, 1.0, 1.0},
- {0.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 0.0, 0.0, 1.0, 1.0},
- {0.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 1.0},
- {0.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 1.0},
- {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 0.0, 0.0},
- {0.0, 1.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 0.0},
- {1.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 1.0},
- {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 1.0},
- {1.0, 0.0, 0.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 1.0},
- {0.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0},
- {1.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 1.0},
- {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 1.0},
- {0.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 1.0},
- {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 1.0},
- {0.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0},
- {1.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 0.0},
- {0.0, 1.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 1.0},
- {0.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 0.0},
- {0.0, 1.0, 1.0, 1.0, 1.0}, {0.0, 0.0, 1.0, 0.0, 1.0},
- {0.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 0.0},
- {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 1.0, 0.0, 1.0, 0.0},
- {0.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 0.0},
- {0.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 1.0},
- {0.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 0.0},
- {0.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 1.0, 1.0},
- {1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 1.0},
- {0.0, 1.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 1.0},
- {1.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 0.0},
- {1.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 1.0},
- {1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 1.0},
- {1.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 1.0, 1.0, 0.0},
- {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 1.0},
- {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 0.0},
- {0.0, 0.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 0.0},
- {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0},
- {1.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 0.0},
- {1.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 1.0},
- {0.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 0.0},
- {1.0, 1.0, 0.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 1.0},
- {0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 1.0, 1.0, 1.0, 0.0},
- {0.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 0.0},
- {0.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 0.0},
- {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 0.0},
- {1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 0.0},
- {1.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 0.0, 1.0},
- {1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 1.0},
- {0.0, 0.0, 0.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0},
- {0.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 0.0},
- {1.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0},
- {0.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0},
- {1.0, 0.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0},
- {0.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 1.0},
- {1.0, 1.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 0.0, 1.0},
- {1.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 0.0},
- {1.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 1.0, 1.0},
- {1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0},
- {1.0, 1.0, 0.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 0.0},
- {1.0, 0.0, 0.0, 0.0, 0.0}, {0.0, 0.0, 0.0, 0.0, 0.0},
- {0.0, 1.0, 1.0, 1.0, 1.0}, {1.0, 1.0, 1.0, 0.0, 1.0},
- {1.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 0.0},
- {0.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 1.0},
- {0.0, 1.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 1.0, 1.0, 0.0},
- {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0}};
- const std::vector<std::vector<double>> val_data = {
- {1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0},
- {0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0},
- {1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0},
- {0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0},
- {0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0},
- {1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0},
- {0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0},
- {1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0},
- {1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0},
- {0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0},
- {0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0},
- {0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0},
- {0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0},
- {1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0},
- {0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0},
- {1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0},
- {1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0},
- {1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0},
- {1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0},
- {0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0},
- {0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0},
- {1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0},
- {1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0},
- {0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0},
- {0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0},
- {0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0},
- {0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0},
- {0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0},
- {0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0},
- {0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0},
- {1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0},
- {0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0},
- {0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0},
- {0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0},
- {0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0},
- {0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0},
- {1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0},
- {1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0},
- {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0},
- {1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0},
- {1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0},
- {0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0},
- {1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0},
- {1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0},
- {0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0},
- {1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0},
- {0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0},
- {0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 1.0},
- {0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0},
- {1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0},
- {1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0},
- {1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0}};
- const std::vector<std::vector<double>> val_label = {
- {1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 1.0},
- {1.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 0.0},
- {0.0, 0.0, 1.0, 0.0, 0.0}, {1.0, 1.0, 1.0, 0.0, 0.0},
- {0.0, 1.0, 1.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 1.0, 0.0},
- {1.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0},
- {0.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 0.0, 0.0, 1.0},
- {0.0, 1.0, 0.0, 1.0, 1.0}, {1.0, 0.0, 0.0, 1.0, 0.0},
- {0.0, 0.0, 1.0, 1.0, 1.0}, {1.0, 0.0, 1.0, 0.0, 0.0},
- {1.0, 0.0, 0.0, 0.0, 1.0}, {0.0, 1.0, 1.0, 1.0, 1.0},
- {1.0, 0.0, 1.0, 1.0, 1.0}, {0.0, 1.0, 0.0, 1.0, 1.0},
- {0.0, 0.0, 1.0, 0.0, 1.0}, {1.0, 1.0, 0.0, 0.0, 1.0},
- {0.0, 1.0, 1.0, 0.0, 1.0}, {0.0, 1.0, 0.0, 0.0, 1.0},
- {0.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 0.0},
- {0.0, 0.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0},
- {1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 1.0},
- {0.0, 1.0, 1.0, 1.0, 0.0}, {0.0, 0.0, 1.0, 0.0, 0.0},
- {1.0, 0.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 0.0},
- {0.0, 0.0, 1.0, 0.0, 1.0}, {0.0, 0.0, 1.0, 1.0, 1.0},
- {1.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0},
- {1.0, 1.0, 1.0, 0.0, 0.0}, {0.0, 1.0, 1.0, 1.0, 0.0},
- {1.0, 1.0, 0.0, 0.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 1.0},
- {1.0, 0.0, 1.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 0.0, 1.0},
- {0.0, 0.0, 1.0, 1.0, 0.0}, {0.0, 1.0, 1.0, 0.0, 0.0},
- {0.0, 1.0, 0.0, 0.0, 1.0}, {1.0, 0.0, 0.0, 0.0, 1.0},
- {0.0, 1.0, 0.0, 1.0, 0.0}, {1.0, 0.0, 0.0, 1.0, 0.0},
- {1.0, 1.0, 1.0, 0.0, 0.0}, {1.0, 0.0, 1.0, 0.0, 0.0}};
- const double learning_rate = 0.5f;
- const int training_epochs = 500;
- const int n_hidden_1 = 17;
- const int n_hidden_2 = 9;
- const int n_input = 8;
- const int n_output = 5;
- std::vector<std::vector<double>> w1(n_hidden_1, std::vector<double>(n_input));
- std::vector<std::vector<double>> w2(n_hidden_2,
- std::vector<double>(n_hidden_1));
- std::vector<std::vector<double>> w3(n_output, std::vector<double>(n_hidden_2));
- std::vector<double> b1(n_hidden_1);
- std::vector<double> b2(n_hidden_2);
- std::vector<double> b3(n_output);
- inline double inner_prod(const std::vector<double> &lhs,
- const std::vector<double> &rhs) {
- double ret = 0.0f;
- for (int i = 0; i < lhs.size(); ++i) {
- ret += lhs[i] * rhs[i];
- }
- return ret;
- }
- inline double sigmoid(double x) { return 1.0f / (1.0f + std::exp(-x)); }
- std::vector<double> mlp_forward(const std::vector<double> &input) {
- std::vector<double> z1(n_hidden_1);
- for (int i = 0; i < n_hidden_1; ++i) {
- z1[i] = sigmoid(inner_prod(input, w1[i]) + b1[i]);
- }
- std::vector<double> z2(n_hidden_2);
- for (int i = 0; i < n_hidden_2; ++i) {
- z2[i] = sigmoid(inner_prod(z1, w2[i]) + b2[i]);
- }
- std::vector<double> out(n_output);
- for (int i = 0; i < n_output; ++i) {
- out[i] = sigmoid(inner_prod(z2, w3[i]) + b3[i]);
- }
- return out;
- }
- double loss(const std::vector<double> &h, const std::vector<double> &y) {
- double ret = 0.0f;
- for (int i = 0; i < h.size(); ++i) {
- ret += (h[i] - y[i]) * (h[i] - y[i]);
- }
- return ret * 0.5f;
- }
- double accuracy(const std::vector<double> &h, const std::vector<double> &y) {
- double ret;
- int count = 0;
- for (int i = 0; i < h.size(); ++i) {
- double t = h[i] < 0.5f ? 0.0f : 1.0f;
- if (abs(t - y[i]) < 1e-9) {
- count++;
- }
- }
- return count / (double)h.size();
- }
- void mlp_backward(const std::vector<double> &input,
- const std::vector<double> &label) {
-
- std::vector<std::vector<double>> d_w1(n_hidden_1,
- std::vector<double>(n_input));
- std::vector<std::vector<double>> d_w2(n_hidden_2,
- std::vector<double>(n_hidden_1));
- std::vector<std::vector<double>> d_w3(n_output,
- std::vector<double>(n_hidden_2));
-
- std::vector<double> d_b1(n_hidden_1);
- std::vector<double> d_b2(n_hidden_2);
- std::vector<double> d_b3(n_output);
- for (int i = 0; i < n_hidden_1; ++i) {
- double l0 = loss(mlp_forward(input), label);
- b1[i] += 1e-9;
- double l = loss(mlp_forward(input), label);
- b1[i] -= 1e-9;
- d_b1[i] = (l - l0) / 1e-9 * learning_rate;
- for (int j = 0; j < n_input; ++j) {
- w1[i][j] += 1e-9;
- double l = loss(mlp_forward(input), label);
- w1[i][j] -= 1e-9;
- d_w1[i][j] = (l - l0) / 1e-9 * learning_rate;
- }
- }
- for (int i = 0; i < n_hidden_2; ++i) {
- double l0 = loss(mlp_forward(input), label);
- b2[i] += 1e-9;
- double l = loss(mlp_forward(input), label);
- b2[i] -= 1e-9;
- d_b2[i] = (l - l0) / 1e-9 * learning_rate;
- for (int j = 0; j < n_hidden_1; ++j) {
- w2[i][j] += 1e-9;
- double l = loss(mlp_forward(input), label);
- w2[i][j] -= 1e-9;
- d_w2[i][j] = (l - l0) / 1e-9 * learning_rate;
- }
- }
- for (int i = 0; i < n_output; ++i) {
- double l0 = loss(mlp_forward(input), label);
- b3[i] += 1e-9;
- double l = loss(mlp_forward(input), label);
- b3[i] -= 1e-9;
- d_b3[i] = (l - l0) / 1e-9 * learning_rate;
- for (int j = 0; j < n_hidden_2; ++j) {
- w3[i][j] += 1e-9;
- double l = loss(mlp_forward(input), label);
- w3[i][j] -= 1e-9;
- d_w3[i][j] = (l - l0) / 1e-9 * learning_rate;
- }
- }
- for (int i = 0; i < n_hidden_1; ++i) {
- b1[i] -= d_b1[i];
- for (int j = 0; j < n_input; ++j) {
- w1[i][j] -= d_w1[i][j];
- }
- }
- for (int i = 0; i < n_hidden_2; ++i) {
- b2[i] -= d_b2[i];
- for (int j = 0; j < n_hidden_1; ++j) {
- w2[i][j] -= d_w2[i][j];
- }
- }
- for (int i = 0; i < n_output; ++i) {
- b3[i] -= d_b3[i];
- for (int j = 0; j < n_hidden_2; ++j) {
- w3[i][j] -= d_w3[i][j];
- }
- }
- }
- int main() {
-
- std::random_device rd;
- std::mt19937 gen(rd());
- std::uniform_real_distribution<> dis(-1.0, 1.0);
- for (int i = 0; i < n_hidden_1; ++i) {
- b1[i] = dis(gen);
- for (int j = 0; j < n_input; ++j) {
- w1[i][j] = dis(gen);
- }
- }
- for (int i = 0; i < n_hidden_2; ++i) {
- b2[i] = dis(gen);
- for (int j = 0; j < n_hidden_1; ++j) {
- w2[i][j] = dis(gen);
- }
- }
- for (int i = 0; i < n_output; ++i) {
- b3[i] = dis(gen);
- for (int j = 0; j < n_hidden_2; ++j) {
- w3[i][j] = dis(gen);
- }
- }
-
- int n = train_label.size();
- int m = val_label.size();
- for (int i = 0; i < training_epochs; ++i) {
- for (int j = 0; j < n; ++j) {
- mlp_backward(train_data[j], train_label[j]);
- }
- double acc = 0.0f;
- for (int j = 0; j < n; ++j) {
- acc += accuracy(mlp_forward(train_data[j]), train_label[j]);
- }
- std::cout << "Train Accuracy:" << acc / n << " ";
- acc = 0.0f;
- for (int j = 0; j < m; ++j) {
- acc += accuracy(mlp_forward(val_data[j]), val_label[j]);
- }
- std::cout << "Val Accuracy:" << acc / m << std::endl;
- }
- }
|