infer.cc 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. #include "include/version.h"
  2. #include "include/model.h"
  3. #include "include/version.h"
  4. #include "include/context.h"
  5. #include "include/errorcode.h"
  6. #include "include/lite_session.h"
  7. #include <thread>
  8. #include <iostream>
  9. #include <string>
  10. #include <vector>
  11. #include <cstdio>
  12. #include <utility>
  13. #include <tuple>
  14. #include "imagenet_label.inc"
  15. // model size limit: 256 MiB
  16. const int MAX_MODEL_SIZE = 256 * 1024 * 1024;
  17. // dataset
  18. std::string IMAGE_FILE = "/data/val_data_c/%05d.bin";
  19. using namespace mindspore;
  20. using namespace mindspore::lite;
  21. using namespace mindspore::session;
  22. void read_image(int idx, void *tensor_buf, size_t size)
  23. {
  24. char image[128];
  25. sprintf(image, IMAGE_FILE.c_str(), idx);
  26. FILE *fp = fopen(image, "rb");
  27. fread(tensor_buf, sizeof(char), size, fp);
  28. fclose(fp);
  29. }
  30. void print_tensor(tensor::MSTensor *t)
  31. {
  32. float *data_ptr = static_cast<float *>(t->MutableData());
  33. for (int i = 0; i < t->ElementsNum(); ++i)
  34. {
  35. std::cout << data_ptr[i] << ", ";
  36. if (i % 13 == 12)
  37. {
  38. std::cout << std::endl;
  39. }
  40. }
  41. }
  42. int arg_max(tensor::MSTensor *t)
  43. {
  44. float *data_ptr = static_cast<float *>(t->MutableData());
  45. float max_val = 0.f;
  46. int max_idx = -1;
  47. for (int i = 0; i < t->ElementsNum(); ++i)
  48. {
  49. if (data_ptr[i] > max_val)
  50. {
  51. max_idx = i;
  52. max_val = data_ptr[i];
  53. }
  54. }
  55. return max_idx;
  56. }
  57. std::tuple<void *, void *, tensor::MSTensor *, tensor::MSTensor *> sessionInOut(mindspore::session::LiteSession *session)
  58. {
  59. // alloc input mem
  60. std::vector<tensor::MSTensor *> inputs = session->GetInputs();
  61. tensor::MSTensor *input = inputs.front();
  62. void *input_buf = input->MutableData();
  63. //std::cout << "input tenosr num: " << inputs.size() << std::endl;
  64. //std::cout << "input tensor[0] shape: ";
  65. //for (int i : input->shape())
  66. //{
  67. // std::cout << i << " ";
  68. //}
  69. //std::cout << std::endl;
  70. // get output
  71. std::unordered_map<std::string, tensor::MSTensor *> outputs = session->GetOutputs();
  72. tensor::MSTensor *output = outputs.begin()->second;
  73. //std::cout << "output tenosr num: " << outputs.size() << std::endl;
  74. //std::cout << "output tensor[0] name: " << outputs.begin()->first << ", shape: ";
  75. void *output_buf = output->MutableData();
  76. ///for (int i : output->shape())
  77. ///{
  78. /// std::cout << i << " ";
  79. ///}
  80. ///std::cout << std::endl;
  81. return {input_buf, output_buf, input, output};
  82. }
  83. mindspore::session::LiteSession *GenerateSession(mindspore::lite::Model *model)
  84. {
  85. if (model == nullptr)
  86. {
  87. std::cerr << "Read model file failed while running" << std::endl;
  88. return nullptr;
  89. }
  90. Context *context = new (std::nothrow) mindspore::lite::Context;
  91. if (context == nullptr)
  92. {
  93. std::cerr << "New context failed while running" << std::endl;
  94. return nullptr;
  95. }
  96. LiteSession *session = mindspore::session::LiteSession::CreateSession(context);
  97. delete (context);
  98. if (session == nullptr)
  99. {
  100. std::cerr << "CreateSession failed while running" << std::endl;
  101. return nullptr;
  102. }
  103. int ret = session->CompileGraph(model);
  104. if (ret != mindspore::lite::RET_OK)
  105. {
  106. std::cout << "CompileGraph failed while running" << std::endl;
  107. delete (session);
  108. return nullptr;
  109. }
  110. return session;
  111. }
  112. int main(int argc, const char *argv[])
  113. {
  114. if (argc != 5)
  115. {
  116. std::cout << "usage: ./classification your_model.ms image_num thread_num imagenetval_path" << std::endl;
  117. return -1;
  118. }
  119. IMAGE_FILE = std::string(argv[4]);
  120. std::string version = mindspore::lite::Version();
  121. std::cout << "version: " << version << std::endl;
  122. // load model
  123. FILE *fp = fopen(argv[1], "rb");
  124. char *model_buf = new char[MAX_MODEL_SIZE];
  125. size_t model_size = fread(model_buf, sizeof(char), MAX_MODEL_SIZE, fp);
  126. fclose(fp);
  127. std::cout << "model: " << argv[1] << ", size: " << model_size << " Bytes" << std::endl;
  128. Model *model = Model::Import(model_buf, model_size);
  129. int THREAD_NUM = std::atoi(argv[3]);
  130. std::vector<mindspore::session::LiteSession *> sessions(THREAD_NUM);
  131. std::vector<void *> inputs(THREAD_NUM);
  132. std::vector<void *> outputs(THREAD_NUM);
  133. std::vector<tensor::MSTensor *> inputTensors(THREAD_NUM);
  134. std::vector<tensor::MSTensor *> outputTensors(THREAD_NUM);
  135. for (int i = 0; i < THREAD_NUM; ++i)
  136. {
  137. sessions[i] = GenerateSession(model);
  138. std::tuple<void *, void *, tensor::MSTensor *, tensor::MSTensor *> inOut = sessionInOut(sessions[i]);
  139. inputs[i] = std::get<0>(inOut);
  140. outputs[i] = std::get<1>(inOut);
  141. inputTensors[i] = std::get<2>(inOut);
  142. outputTensors[i] = std::get<3>(inOut);
  143. }
  144. model->Free();
  145. // infer
  146. int IMAGE_NUM = std::atoi(argv[2]);
  147. std::vector<int> result(IMAGE_NUM);
  148. std::vector<std::thread> threads(THREAD_NUM);
  149. for (int id = 0; id < THREAD_NUM; ++id)
  150. {
  151. threads[id] = std::thread([&](int id) {
  152. for (size_t i = id; i < IMAGE_NUM; i += THREAD_NUM)
  153. {
  154. read_image(i, inputs[id], inputTensors[id]->Size());
  155. int ret = sessions[id]->RunGraph();
  156. if (ret != RET_OK)
  157. {
  158. std::cerr << "Run graph failed." << std::endl;
  159. return RET_ERROR;
  160. }
  161. //print_tensor(output);
  162. //std::cout << arg_max(output) << std::endl;
  163. result[i] = arg_max(outputTensors[id]);
  164. }
  165. },
  166. id);
  167. }
  168. for (int i = 0; i < THREAD_NUM; ++i)
  169. {
  170. threads[i].join();
  171. }
  172. std::cout << "inference finished" << std::endl;
  173. for (int i = 0; i < THREAD_NUM; ++i)
  174. {
  175. delete sessions[i];
  176. }
  177. int correct = 0;
  178. for (int i = 0; i < IMAGE_NUM; ++i)
  179. {
  180. if (label[i] == result[i] - 1)
  181. {
  182. correct++;
  183. }
  184. }
  185. std::cout << "top1 acc: " << correct / (float)IMAGE_NUM << std::endl;
  186. return 0;
  187. }