CHEN Yihui преди 3 години
родител
ревизия
4666941c89
променени са 1 файла, в които са добавени 110 реда и са изтрити 74 реда
  1. 110 74
      infer.cc

+ 110 - 74
infer.cc

@@ -5,17 +5,21 @@
 #include "include/errorcode.h"
 #include "include/lite_session.h"
 
+#include <thread>
 #include <iostream>
 #include <string>
+#include <vector>
 #include <cstdio>
+#include <utility>
+#include <tuple>
 
 #include "imagenet_label.inc"
 
-// model size limit: 64 MiB
-const int MAX_MODEL_SIZE = 64 * 1024 * 1024;
+// model size limit: 256 MiB
+const int MAX_MODEL_SIZE = 256 * 1024 * 1024;
 
 // dataset
-const std::string IMAGE_FILE = "/data/val_data_c/%05d.bin";
+std::string IMAGE_FILE = "/data/val_data_c/%05d.bin";
 
 using namespace mindspore;
 using namespace mindspore::lite;
@@ -59,13 +63,74 @@ int arg_max(tensor::MSTensor *t)
   return max_idx;
 }
 
+std::tuple<void *, void *, tensor::MSTensor *, tensor::MSTensor *> sessionInOut(mindspore::session::LiteSession *session)
+{
+  // alloc input mem
+  std::vector<tensor::MSTensor *> inputs = session->GetInputs();
+  tensor::MSTensor *input = inputs.front();
+  void *input_buf = input->MutableData();
+  //std::cout << "input tenosr num: " << inputs.size() << std::endl;
+  //std::cout << "input tensor[0] shape: ";
+  //for (int i : input->shape())
+  //{
+  //  std::cout << i << " ";
+  //}
+  //std::cout << std::endl;
+
+  // get output
+  std::unordered_map<std::string, tensor::MSTensor *> outputs = session->GetOutputs();
+  tensor::MSTensor *output = outputs.begin()->second;
+  //std::cout << "output tenosr num: " << outputs.size() << std::endl;
+  //std::cout << "output tensor[0] name: " << outputs.begin()->first << ", shape: ";
+  void *output_buf = output->MutableData();
+  ///for (int i : output->shape())
+  ///{
+  ///  std::cout << i << " ";
+  ///}
+  ///std::cout << std::endl;
+  return {input_buf, output_buf, input, output};
+}
+
+mindspore::session::LiteSession *GenerateSession(mindspore::lite::Model *model)
+{
+  if (model == nullptr)
+  {
+    std::cerr << "Read model file failed while running" << std::endl;
+    return nullptr;
+  }
+  Context *context = new (std::nothrow) mindspore::lite::Context;
+  if (context == nullptr)
+  {
+    std::cerr << "New context failed while running" << std::endl;
+    return nullptr;
+  }
+
+  LiteSession *session = mindspore::session::LiteSession::CreateSession(context);
+  delete (context);
+  if (session == nullptr)
+  {
+    std::cerr << "CreateSession failed while running" << std::endl;
+    return nullptr;
+  }
+  int ret = session->CompileGraph(model);
+  if (ret != mindspore::lite::RET_OK)
+  {
+    std::cout << "CompileGraph failed while running" << std::endl;
+    delete (session);
+    return nullptr;
+  }
+
+  return session;
+}
+
 int main(int argc, const char *argv[])
 {
-  if (argc != 3)
+  if (argc != 5)
   {
-    std::cout << "usage: ./classification your_model.ms image_num" << std::endl;
+    std::cout << "usage: ./classification your_model.ms image_num thread_num imagenetval_path" << std::endl;
     return -1;
   }
+  IMAGE_FILE = std::string(argv[4]);
   std::string version = mindspore::lite::Version();
   std::cout << "version: " << version << std::endl;
 
@@ -78,88 +143,59 @@ int main(int argc, const char *argv[])
 
   Model *model = Model::Import(model_buf, model_size);
 
-  // create context
-  Context *context = new (std::nothrow) Context;
-  if (context == nullptr)
-  {
-    std::cerr << "New context failed while running %s", argv[1];
-    return RET_ERROR;
-  }
-  CpuDeviceInfo &cpu_decice_info = context->device_list_[0].device_info_.cpu_device_info_;
-  cpu_decice_info.cpu_bind_mode_ = HIGHER_CPU;
-  context->thread_num_ = 2;
+  int THREAD_NUM = std::atoi(argv[3]);
 
-  // create session1
-  LiteSession *session = LiteSession::CreateSession(context);
-  delete (context);
-  if (session == nullptr)
+  std::vector<mindspore::session::LiteSession *> sessions(THREAD_NUM);
+  std::vector<void *> inputs(THREAD_NUM);
+  std::vector<void *> outputs(THREAD_NUM);
+  std::vector<tensor::MSTensor *> inputTensors(THREAD_NUM);
+  std::vector<tensor::MSTensor *> outputTensors(THREAD_NUM);
+  for (int i = 0; i < THREAD_NUM; ++i)
   {
-    std::cerr << "CreateSession failed while running %s", argv[1];
-    return RET_ERROR;
+    sessions[i] = GenerateSession(model);
+    std::tuple<void *, void *, tensor::MSTensor *, tensor::MSTensor *> inOut = sessionInOut(sessions[i]);
+    inputs[i] = std::get<0>(inOut);
+    outputs[i] = std::get<1>(inOut);
+    inputTensors[i] = std::get<2>(inOut);
+    outputTensors[i] = std::get<3>(inOut);
   }
 
-  // compile graph
-
-  int ret = session->CompileGraph(model);
-  if (ret != RET_OK)
-  {
-    std::cerr << "CompileGraph failed" << std::endl;
-    // session and model need to be released by users manually.
-    delete (session);
-    delete (model);
-    return ret;
-  }
   model->Free();
 
-  // alloc input mem
-  std::vector<tensor::MSTensor *> inputs = session->GetInputs();
-  tensor::MSTensor *input = inputs.front();
-  void *input_buf = input->MutableData();
-  std::cout << "input tenosr num: " << inputs.size() << std::endl;
-  std::cout << "input tensor[0] shape: ";
-  for (int i : input->shape())
-  {
-    std::cout << i << " ";
-  }
-  std::cout << std::endl;
+  // infer
+  int IMAGE_NUM = std::atoi(argv[2]);
+  std::vector<int> result(IMAGE_NUM);
 
-  // get output
-  std::unordered_map<std::string, tensor::MSTensor *> outputs = session->GetOutputs();
-  tensor::MSTensor *output = outputs.begin()->second;
-  std::cout << "output tenosr num: " << outputs.size() << std::endl;
-  std::cout << "output tensor[0] name: " << outputs.begin()->first << ", shape: ";
-  void *output_buf = output->MutableData();
-  for (int i : output->shape())
+  std::vector<std::thread> threads(THREAD_NUM);
+  for (int id = 0; id < THREAD_NUM; ++id)
   {
-    std::cout << i << " ";
+    threads[id] = std::thread([&](int id) {
+      for (size_t i = id; i < IMAGE_NUM; i += THREAD_NUM)
+      {
+        read_image(i, inputs[id], inputTensors[id]->Size());
+        int ret = sessions[id]->RunGraph();
+        if (ret != RET_OK)
+        {
+          std::cerr << "Run graph failed." << std::endl;
+          return RET_ERROR;
+        }
+        //print_tensor(output);
+        //std::cout << arg_max(output) << std::endl;
+        result[i] = arg_max(outputTensors[id]);
+      }
+    },
+                              id);
   }
-  std::cout << std::endl;
-  // infer
-  std::vector<int> result;
-  int IMAGE_NUM = std::atoi(argv[2]);
 
-  std::cout << "inference start" << std::endl;
-  for (size_t i = 0; i < IMAGE_NUM; i++)
+  for (int i = 0; i < THREAD_NUM; ++i)
   {
-    read_image(i, input_buf, input->Size());
-    int ret = session->RunGraph();
-    if (ret != RET_OK)
-    {
-      std::cerr << "Run graph failed." << std::endl;
-      return RET_ERROR;
-    }
-    //print_tensor(output);
-    //std::cout << arg_max(output) << std::endl;
-    result.push_back(arg_max(output));
-    std::cout << "\r" << i * 100 / IMAGE_NUM << "%, ";
-    for (int j = 0; j < i * 80 / IMAGE_NUM; ++j)
-    {
-      std::cout << '*';
-    }
+    threads[i].join();
   }
-  std::cout << std::endl;
   std::cout << "inference finished" << std::endl;
-
+  for (int i = 0; i < THREAD_NUM; ++i)
+  {
+    delete sessions[i];
+  }
   int correct = 0;
   for (int i = 0; i < IMAGE_NUM; ++i)
   {