Detection.hpp 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. #pragma once
  2. #include <cstdint>
  3. #include <vector>
  4. #include <algorithm>
  5. #include <cmath>
  6. #include <string>
  7. #include <float.h>
  8. #include <iostream>
  9. #include "AIStatus.h"
  10. #ifndef CLIPRETINA
  11. #define CLIPRETINA(v, mn, mx) \
  12. { \
  13. if ((v) < (mn)) \
  14. { \
  15. (v) = (mn); \
  16. } \
  17. else if ((v) > (mx)) \
  18. { \
  19. (v) = (mx); \
  20. } \
  21. }
  22. #endif
  23. namespace detection
  24. {
  25. // sigmoid
  26. static float sigmoid(float x)
  27. {
  28. return 1.0 / (1.0 + exp(-x));
  29. }
  30. struct Box
  31. {
  32. float xyxy[4] = {0, 0, 0, 0};
  33. float object_score = 0;
  34. size_t index = 0;
  35. float score = 0;
  36. float area = 0;
  37. };
  38. static bool yolov_box_cmp(const Box &a, const Box &b)
  39. {
  40. return a.score > b.score;
  41. }
  42. static void yolo_nms(std::vector<Box> &boxes, const float &nms_threshold)
  43. {
  44. std::sort(boxes.begin(), boxes.end(), yolov_box_cmp);
  45. size_t current_index = 0;
  46. while (current_index < boxes.size())
  47. {
  48. Box current_box = boxes[current_index];
  49. size_t running_index = current_index + 1;
  50. while (running_index < boxes.size())
  51. {
  52. Box running_box = boxes[running_index];
  53. float xx1 = std::max(current_box.xyxy[0], running_box.xyxy[0]);
  54. float yy1 = std::max(current_box.xyxy[1], running_box.xyxy[1]);
  55. float xx2 = std::min(current_box.xyxy[2], running_box.xyxy[2]);
  56. float yy2 = std::min(current_box.xyxy[3], running_box.xyxy[3]);
  57. float w = std::max(0.0f, xx2 - xx1 + 1.0f);
  58. float h = std::max(0.0f, yy2 - yy1 + 1.0f);
  59. float inter_area = w * h;
  60. float union_area = current_box.area + running_box.area - inter_area;
  61. float overlap = inter_area / union_area;
  62. if (overlap > nms_threshold)
  63. {
  64. boxes.erase(boxes.begin() + running_index);
  65. }
  66. else
  67. {
  68. ++running_index;
  69. }
  70. }
  71. ++current_index;
  72. }
  73. }
  74. static void generate_proposals_yolov5(std::vector<Box> &proposals,
  75. const signed char *output_ptr,
  76. std::array<vx_size, 4U> output_size,
  77. const unsigned int &stride,
  78. const int &fl,
  79. const int &data_type,
  80. const unsigned int *anchor,
  81. const std::vector<std::string> class_names,
  82. std::unordered_map<std::string, float> class_attributes,
  83. float lookup_table[][6])
  84. {
  85. // printf("[%s Line:%d] dim:[%d %d %d %d] anchor:[%d,%d %d,%d %d,%d] stride:%d fl:%d data_type:%d class_num:%d\n", __FUNCTION__, __LINE__,
  86. // output_size[0], output_size[1], output_size[2], output_size[3],
  87. // anchor[0], anchor[1], anchor[2], anchor[3], anchor[4], anchor[5],
  88. // stride, fl, data_type, class_names.size());
  89. int H_algin = 0;
  90. if (data_type == 2)
  91. H_algin = (output_size[1] + 3) / 4 * 4;
  92. else
  93. H_algin = (output_size[1] + 1) / 2 * 2;
  94. int class_num = class_names.size();
  95. // #pragma unroll
  96. for (size_t a = 0; a < 3; ++a)
  97. { // anchor groups = 3
  98. for (size_t w = 0; w < output_size[0]; ++w)
  99. {
  100. for (size_t h = 0; h < output_size[1]; ++h)
  101. {
  102. Box box;
  103. size_t max_index = 0;
  104. float max_score = -1;
  105. // #pragma unroll
  106. // for (size_t c = 0; c < 4 + 1 + class_num; ++c)
  107. for (size_t c = 0; c < static_cast<size_t>(4 + 1 + class_num); ++c)
  108. {
  109. size_t ci = a * (4 + 1 + class_num) + c;
  110. size_t index = ci / 16 * output_size[0] * H_algin * 16 + w * H_algin * 16 + h * 16 + (ci % 16);
  111. // scale and sigmoid
  112. // float data = sigmoid(output_ptr[index] * 1.0 / pow(2, fl));
  113. float data = lookup_table[output_ptr[index] - (-128)][fl];
  114. if (c == 0)
  115. {
  116. data = (data * 2 - 0.5f + w) * static_cast<float>(stride);
  117. box.xyxy[c] = data;
  118. }
  119. else if (c == 1)
  120. {
  121. data = (data * 2 - 0.5f + h) * static_cast<float>(stride);
  122. box.xyxy[c] = data;
  123. }
  124. else if (c == 2 || c == 3)
  125. {
  126. data = powf((data * 2), 2) * anchor[a * 2 + c - 2];
  127. box.xyxy[c] = data;
  128. }
  129. else if (c == 4)
  130. {
  131. box.object_score = data;
  132. }
  133. else
  134. {
  135. if (data > max_score)
  136. {
  137. max_index = c - 5;
  138. max_score = data;
  139. }
  140. }
  141. }
  142. box.score = max_score * box.object_score;
  143. box.index = max_index;
  144. bool is_push = false;
  145. if (class_attributes.find(class_names[box.index]) != class_attributes.end())
  146. {
  147. if (box.object_score > class_attributes.find(class_names[box.index])->second &&
  148. box.score > class_attributes.find(class_names[box.index])->second)
  149. {
  150. is_push = true;
  151. }
  152. }
  153. else
  154. {
  155. if (box.object_score > class_attributes.find("all")->second &&
  156. box.score > class_attributes.find("all")->second)
  157. {
  158. is_push = true;
  159. }
  160. }
  161. // if (box.object_score > 0.5 && box.score > 0.5)
  162. // {
  163. // is_push = true;
  164. // }
  165. if (is_push)
  166. {
  167. // xywh -> xyxy
  168. float x = box.xyxy[0], y = box.xyxy[1], w = box.xyxy[2], h = box.xyxy[3];
  169. box.xyxy[0] = x - w / 2;
  170. box.xyxy[1] = y - h / 2;
  171. box.xyxy[2] = x + w / 2;
  172. box.xyxy[3] = y + h / 2;
  173. box.area = (box.xyxy[2] - box.xyxy[0] + 1) * (box.xyxy[3] - box.xyxy[1] + 1);
  174. proposals.push_back(box);
  175. // printf("[%s Line:%d] ------------box %f %f %f %f %f %d------------\n", __FUNCTION__, __LINE__, box.xyxy[0], box.xyxy[1], box.xyxy[2], box.xyxy[3], box.score, box.index);
  176. }
  177. }
  178. }
  179. }
  180. }
  181. }