infer.cc 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. #include "include/version.h"
  2. #include "include/model.h"
  3. #include "include/version.h"
  4. #include "include/context.h"
  5. #include "include/errorcode.h"
  6. #include "include/lite_session.h"
  7. #include <iostream>
  8. #include <string>
  9. #include <cstdio>
  10. #include "imagenet_label.inc"
  11. // model size limit: 64 MiB
  12. const int MAX_MODEL_SIZE = 64 * 1024 * 1024;
  13. // dataset
  14. const std::string IMAGE_FILE = "/data/val_data_c/%05d.bin";
  15. using namespace mindspore;
  16. using namespace mindspore::lite;
  17. using namespace mindspore::session;
  18. void read_image(int idx, void *tensor_buf, size_t size)
  19. {
  20. char image[128];
  21. sprintf(image, IMAGE_FILE.c_str(), idx);
  22. FILE *fp = fopen(image, "rb");
  23. fread(tensor_buf, sizeof(char), size, fp);
  24. fclose(fp);
  25. }
  26. void print_tensor(tensor::MSTensor *t)
  27. {
  28. float *data_ptr = static_cast<float *>(t->MutableData());
  29. for (int i = 0; i < t->ElementsNum(); ++i)
  30. {
  31. std::cout << data_ptr[i] << ", ";
  32. if (i % 13 == 12)
  33. {
  34. std::cout << std::endl;
  35. }
  36. }
  37. }
  38. int arg_max(tensor::MSTensor *t)
  39. {
  40. float *data_ptr = static_cast<float *>(t->MutableData());
  41. float max_val = 0.f;
  42. int max_idx = -1;
  43. for (int i = 0; i < t->ElementsNum(); ++i)
  44. {
  45. if (data_ptr[i] > max_val)
  46. {
  47. max_idx = i;
  48. max_val = data_ptr[i];
  49. }
  50. }
  51. return max_idx;
  52. }
  53. int main(int argc, const char *argv[])
  54. {
  55. if (argc != 3)
  56. {
  57. std::cout << "usage: ./classification your_model.ms image_num" << std::endl;
  58. return -1;
  59. }
  60. std::string version = mindspore::lite::Version();
  61. std::cout << "version: " << version << std::endl;
  62. // load model
  63. FILE *fp = fopen(argv[1], "rb");
  64. char *model_buf = new char[MAX_MODEL_SIZE];
  65. size_t model_size = fread(model_buf, sizeof(char), MAX_MODEL_SIZE, fp);
  66. fclose(fp);
  67. std::cout << "model: " << argv[1] << ", size: " << model_size << " Bytes" << std::endl;
  68. Model *model = Model::Import(model_buf, model_size);
  69. // create context
  70. Context *context = new (std::nothrow) Context;
  71. if (context == nullptr)
  72. {
  73. std::cerr << "New context failed while running %s", argv[1];
  74. return RET_ERROR;
  75. }
  76. CpuDeviceInfo &cpu_decice_info = context->device_list_[0].device_info_.cpu_device_info_;
  77. cpu_decice_info.cpu_bind_mode_ = HIGHER_CPU;
  78. context->thread_num_ = 2;
  79. // create session1
  80. LiteSession *session = LiteSession::CreateSession(context);
  81. delete (context);
  82. if (session == nullptr)
  83. {
  84. std::cerr << "CreateSession failed while running %s", argv[1];
  85. return RET_ERROR;
  86. }
  87. // compile graph
  88. int ret = session->CompileGraph(model);
  89. if (ret != RET_OK)
  90. {
  91. std::cerr << "CompileGraph failed" << std::endl;
  92. // session and model need to be released by users manually.
  93. delete (session);
  94. delete (model);
  95. return ret;
  96. }
  97. model->Free();
  98. // alloc input mem
  99. std::vector<tensor::MSTensor *> inputs = session->GetInputs();
  100. tensor::MSTensor *input = inputs.front();
  101. void *input_buf = input->MutableData();
  102. std::cout << "input tenosr num: " << inputs.size() << std::endl;
  103. std::cout << "input tensor[0] shape: ";
  104. for (int i : input->shape())
  105. {
  106. std::cout << i << " ";
  107. }
  108. std::cout << std::endl;
  109. // get output
  110. std::unordered_map<std::string, tensor::MSTensor *> outputs = session->GetOutputs();
  111. tensor::MSTensor *output = outputs.begin()->second;
  112. std::cout << "output tenosr num: " << outputs.size() << std::endl;
  113. std::cout << "output tensor[0] name: " << outputs.begin()->first << ", shape: ";
  114. void *output_buf = output->MutableData();
  115. for (int i : output->shape())
  116. {
  117. std::cout << i << " ";
  118. }
  119. std::cout << std::endl;
  120. // infer
  121. std::vector<int> result;
  122. int IMAGE_NUM = std::atoi(argv[2]);
  123. std::cout << "inference start" << std::endl;
  124. for (size_t i = 0; i < IMAGE_NUM; i++)
  125. {
  126. read_image(i, input_buf, input->Size());
  127. int ret = session->RunGraph();
  128. if (ret != RET_OK)
  129. {
  130. std::cerr << "Run graph failed." << std::endl;
  131. return RET_ERROR;
  132. }
  133. //print_tensor(output);
  134. //std::cout << arg_max(output) << std::endl;
  135. result.push_back(arg_max(output));
  136. std::cout << "\r" << i * 100 / IMAGE_NUM << "%, ";
  137. for (int j = 0; j < i * 80 / IMAGE_NUM; ++j)
  138. {
  139. std::cout << '*';
  140. }
  141. }
  142. std::cout << std::endl;
  143. std::cout << "inference finished" << std::endl;
  144. int correct = 0;
  145. for (int i = 0; i < IMAGE_NUM; ++i)
  146. {
  147. if (label[i] == result[i] - 1)
  148. {
  149. correct++;
  150. }
  151. }
  152. std::cout << "top1 acc: " << correct / (float)IMAGE_NUM << std::endl;
  153. return 0;
  154. }