EeasyModel.h 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. #pragma once
  2. #include <string>
  3. #include <unistd.h>
  4. #include "ax_type.h"
  5. #include "VX/vx.h"
  6. #include "VX/vx_vendors.h"
  7. #include "VX/vx_types.h"
  8. #include "libnn/net_api.h"
  9. #include "nlohmann/json.hpp"
  10. #include "Detection.hpp"
  11. #include "AIStatus.h"
  12. #include "Common.h"
  13. /**
  14. * @brief EZModel模型类
  15. *
  16. * 该类包含了用于管理和操作Eeasy模型的各种成员变量和方法。
  17. */
  18. class EeasyModel
  19. {
  20. public:
  21. /**
  22. * @brief 预处理结构体:用于表示图像预处理的状态和配置。
  23. */
  24. struct PreprocessConfig
  25. {
  26. vx_size dstHandle; ///< 预处理用的句柄
  27. vx_context handleContext; ///< 预处理用的context
  28. int netWidth, netHeight; ///< 网络的宽高
  29. int srcWidth, srcHeight; ///< 媒体流的宽高
  30. float scaleInfo; ///< 缩放因子
  31. vx_size dstImageSize; ///< 预处理结果图像大小
  32. vx_size dstVirtualAddress; ///< 预处理结果虚拟地址
  33. };
  34. /**
  35. * @brief 模型结构体:用于表示神经网络模型的状态和配置。
  36. */
  37. struct ModelConfig
  38. {
  39. vx_context context; ///< 模型用的context
  40. vx_graph graph; ///< 模型用的graph
  41. int netWidth, netHeight; ///< 网络的宽高
  42. // unsigned int anchorsG[30] = {0}; ///< yolo_anchor数组
  43. std::string modelType; ///< 模型类型
  44. std::vector<std::string> classNames; ///< 类别名称数组
  45. std::vector<std::string> inputBlobsName; ///< 输入 blob 的名称数组
  46. std::vector<vx_tensor> inputBlobsTensor; ///< 输入 blob 的张量数组
  47. std::vector<std::string> outputBlobsName; ///< 输出 blob 的名称数组
  48. std::vector<vx_tensor> outputBlobsTensor; ///< 输出 blob 的张量数组
  49. std::vector<int8_t *> outputBlobsPtr; ///< 输出 blob 的张量指针(request用)
  50. std::vector<int> outputBlobsPtrSize; ///< 输出 blob 的张量指针(request用)
  51. ///< 输入数据的维度
  52. // std::vector<std::array<vx_size, 4>> inputDataDims;
  53. // std::vector<vx_size[4]> outputDataDims; ///< 输出数据的维度
  54. std::vector<std::array<vx_size, 4>> outputDataDims;
  55. // vx_size outputDataDims[4];
  56. // std::vector<detection::Box> proposals; ///< 推理解码后的box
  57. vx_enum dataType[5]; ///< 数据类型数组
  58. int floatingPoint[5]; ///< 浮点数位数数组
  59. // int floatingPoint[5] = {0}; ///< 浮点数位数数组
  60. unsigned int strides[5] = {0}; ///< 步长数组
  61. };
  62. struct PostprocessConfig
  63. {
  64. std::string name;
  65. unsigned int yoloAnchor[30] = {0}; ///< yolo_anchor数组
  66. // 模型传入
  67. // std::string modelType; ///< 模型类型
  68. // std::vector<std::string> classNames; ///< 类别名称数组
  69. // vx_enum dataType[5]; ///< 数据类型数组
  70. // int floatingPoint[5] = {0}; ///< 浮点数位数数组
  71. // unsigned int strides[5] = {0}; ///< 步长数组
  72. // std::vector<vx_size[4]> outputDataDims; ///< 输出数据的维度
  73. // vx_size outputDataDims[4];
  74. // int srcWidth, srcHeight; ///< 媒体流的宽高
  75. // float scaleInfo; ///< 缩放因子
  76. };
  77. /**
  78. * @brief 初始化预处理资源
  79. *
  80. * 根据提供的配置信息初始化预处理资源。
  81. *
  82. * @param cfg 预处理配置的 JSON 对象
  83. * @param preprocess_st 预处理状态结构体指针
  84. * @return 表示预处理资源初始化成功或失败的错误代码。
  85. * - AIStatus::StatusCode::SUCCESS :初始化成功
  86. * - AIStatus::StatusCode::FAILED :初始化失败(已创建资源销毁成功)
  87. * - AIStatus::EeasyErrorCode::RELEASE_PREPROCESS_RESOURCES_ERROR :初始化失败(已创建资源销毁也失败)
  88. */
  89. int initializePreprocessResource(const nlohmann::json_abi_v3_11_2::json cfg, PreprocessConfig *preprocessConfig);
  90. /**
  91. * @brief 初始化模型资源
  92. *
  93. * 根据提供的配置信息初始化模型资源。
  94. *
  95. * @param cfg 模型配置的 JSON 对象
  96. * @param model_st 模型状态结构体指针
  97. * @return 表示模型资源初始化成功或失败的错误代码。
  98. * - AIStatus::StatusCode::SUCCESS :初始化成功
  99. * - AIStatus::StatusCode::FAILED :初始化失败(已创建资源销毁成功)
  100. * - AIStatus::EeasyErrorCode::RELEASE_MODEL_RESOURCES_ERROR :初始化成功
  101. */
  102. int initializeModelResource(const nlohmann::json_abi_v3_11_2::json cfg, ModelConfig *modelConfig);
  103. /**
  104. * @brief 预处理图像
  105. *
  106. * 使用提供的预处理状态和图像帧状态执行图像预处理。
  107. *
  108. * @param preprocess_st 预处理状态结构体指针
  109. * @param frame_st 图像帧状态结构体指针
  110. * @return 表示预处理图像操作成功或失败的错误代码。
  111. * - AIStatus::StatusCode::SUCCESS :预处理图像成功
  112. * - AIStatus::EeasyErrorCode::INFER_IMAGECONVERT_ERROR :ImageConvert处理失败
  113. */
  114. int preprocessImage(PreprocessConfig *preprocessConfig, frame_t *frame);
  115. /**
  116. * @brief 模型推理
  117. *
  118. * 使用提供的预处理状态和模型状态执行推理。
  119. *
  120. * @param preprocess_st 预处理状态结构体指针
  121. * @param model_st 模型状态结构体指针
  122. * @param boundingBoxes 存储推理结果的边界框列表
  123. * @param class_attributes 类别属性的映射表
  124. * @return 表示模型推理成功或失败的错误代码。
  125. * - AIStatus::StatusCode::SUCCESS :模型推理成功
  126. * - AIStatus::EeasyErrorCode::INFER_INPUTDATAFROMMEM_ERROR :ImportNetInputDataFromMem操作失败
  127. * - AIStatus::EeasyErrorCode::INFER_PROCESSGRAPH_ERROR :vxProcessGraph操作失败
  128. * - AIStatus::EeasyErrorCode::INFER_FINISHGRAPH_ERROR :vxFinish操作失败
  129. */
  130. int inferModel(unsigned char *preProcessedImg,int preProcessedImgSize, ModelConfig *modelConfig);
  131. // std::vector<detection::Box> *proposals, std::unordered_map<std::string, float> classAttributes);
  132. // int postprocess(std::vector<int8_t *> outputBlobsPtr, PostprocessConfig *postProcessConfig, std::vector<BoundingBox> *boundingBoxes, std::unordered_map<std::string, float> classAttributes);
  133. int postprocess(std::vector<BoundingBox> *boundingBoxes,
  134. std::vector<int8_t *> outputBlobsPtr,
  135. std::vector<std::array<size_t, 4>> outputDataDims,
  136. unsigned int strides[5],
  137. int floatingPoint[5],
  138. vx_enum dataType[5],
  139. unsigned int yoloAnchor[30],
  140. std::vector<std::string> classNames,
  141. std::unordered_map<std::string, float> classAttributes,
  142. float scaleInfo,int srcWidth,int srcHeight);
  143. int postprocess(std::vector<BoundingBox> *boundingBoxes);
  144. /**
  145. * @brief 释放预处理资源
  146. *
  147. * 使用提供的预处理状态释放相关资源。
  148. *
  149. * @param preprocess_st 预处理状态结构体指针
  150. * @return 表示释放预处理资源成功或失败的错误代码。
  151. * - AIStatus::StatusCode::SUCCESS :释放预处理资源成功
  152. * - AIStatus::EeasyErrorCode::RELEASE_PREPROCESS_RESOURCES_ERROR :释放预处理资源失败
  153. */
  154. int releasePreprocessResources(PreprocessConfig *preprocessConfig);
  155. int releaseModelResources(ModelConfig *modelConfig);
  156. private:
  157. /**
  158. * @brief 将图像格式转换为EasyImage格式。
  159. *
  160. * 此函数将给定的图像格式转换为EasyImage格式。
  161. *
  162. * @param asj_format 输入图像格式,支持的格式包括 YUV420SPNV21。
  163. * @param[out] out 输出的EasyImage格式。
  164. *
  165. * @return 成功转换返回0,否则返回-1。
  166. */
  167. int convertImageFormatToEeasyImage(image_format_t asjFormat, img_fmt *out);
  168. };