inference.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. from quantization import *
  2. from yolo import *
  3. import numpy as np
  4. import math, random
  5. random.seed(0)
  6. import cv2, sys, os, argparse, glob, shutil
  7. from get_map import letterbox, scale_boxes
  8. from yolo import non_max_suppression, plot_one_box, yolov8layer
  9. from config import names, net_shape, bin_path, output_names, fl, model_type
  10. os.environ['GLOG_minloglevel'] = '3'
  11. PROJECT_NAME = os.environ['PROJECT_NAME']
  12. qb_file = PROJECT_NAME+'/model_quantization/checkpoint_quan.qb'
  13. qk_file = PROJECT_NAME+'/model_quantization/checkpoint_quan.qk'
  14. infer_img = PROJECT_NAME+'/ezb_640x640.png'
  15. video_save_path = PROJECT_NAME+'/video_result.mp4'
  16. sh_commad = 'sh sh/ezb_inference.sh'
  17. color_arr = [[random.randint(0, 255) for _ in range(3)] for i in names]
  18. def parse_opt():
  19. parser = argparse.ArgumentParser()
  20. parser.add_argument('--path', type=str, required=True, help='image_path/image_folder_path/video_path')
  21. parser.add_argument('--type', type=str, choices=['caffe', 'qkqb', 'ezb', 'all'], required=True, help='caffe, qkqb, ezb')
  22. parser.add_argument('--iou', type=float, default=0.45, help='iou threshold')
  23. parser.add_argument('--conf', type=float, default=0.25, help='conf threshold')
  24. parser.add_argument('--video', action='store_true', help='video inference')
  25. parser.add_argument('--video_batch_size', type=int, default=32, help='video batch size')
  26. parser.add_argument('--device', type=int, default=1, help='device')
  27. opt = parser.parse_known_args()[0]
  28. return opt
  29. def letterbox_resize(image_path):
  30. ori_img = cv2.imdecode(np.fromfile(image_path, np.uint8), cv2.IMREAD_COLOR)
  31. ori_ima_shape = ori_img.shape
  32. img, ratio, (dw, dh) = letterbox(ori_img, new_shape=net_shape, auto=False)
  33. cv2.imwrite(infer_img, img)
  34. print(f'image_path:{image_path} ori_shape:{ori_ima_shape} resize_shape:(w:{int(ratio[1] * ori_ima_shape[1])}, h:{int(ratio[0] * ori_ima_shape[0])}) padding:(dw:{dw}, dh:{dh}) save_success -> {infer_img}.')
  35. return ori_img
  36. class CustomDataset(BaseDataset):
  37. def __init__(self, image_list, is_path_arr=False):
  38. super().__init__()
  39. self.image_list = image_list
  40. self.is_path_arr = is_path_arr
  41. def __getitem__(self, item):
  42. if self.is_path_arr:
  43. process_image = cv2.imdecode(np.fromfile(self.image_list[item], np.uint8), cv2.IMREAD_COLOR)
  44. process_image, ratio, (dw, dh) = letterbox(process_image, new_shape=net_shape, auto=False)
  45. else:
  46. # letterbox
  47. process_image = self.image_list[item]
  48. process_image, ratio, (dw, dh) = letterbox(process_image, new_shape=net_shape, auto=False)
  49. # process_image = cv2.resize(process_image, (640, 640))
  50. # srcnp = cv2.cvtColor(process_image, cv2.COLOR_BGR2RGB)
  51. srcnp = process_image
  52. srcnp = srcnp.astype(np.float32) / 256
  53. srcnp = np.array(srcnp)
  54. srcnp = np.transpose(srcnp, [2,0,1])
  55. return srcnp
  56. def __len__(self):
  57. return len(self.image_list)
  58. class Base_Detector:
  59. def __init__(self, opt, im_shape=net_shape) -> None:
  60. self.opt = opt
  61. self.im_shape = im_shape
  62. def __call__(self, image):
  63. pass
  64. def post_processing(self, ori_image, result):
  65. result = yolov8layer(result)
  66. pred = non_max_suppression([torch.from_numpy(result)], conf_thres=self.opt.conf, iou_thres=self.opt.iou)[0]
  67. pred[:, :4] = scale_boxes(self.im_shape, pred[:, :4], ori_image.shape[:2])
  68. for *xyxy, conf, cls in reversed(pred):
  69. x1, y1, x2, y2 = xyxy
  70. # print(x1, y1, x2, y2, conf, cls)
  71. plot_one_box([int(x1), int(y1), int(x2), int(y2)], ori_image, label=f'{names[int(cls)]} {float(conf):.2f}', color=color_arr[int(cls)])
  72. return ori_image
  73. def inference(self, image=None):
  74. pass
  75. class Caffe_Detector(Base_Detector):
  76. def __init__(self, opt, im_shape=net_shape) -> None:
  77. super().__init__(opt, im_shape)
  78. self.net = Net(opt.device)
  79. self.prototxt_file = f'{PROJECT_NAME}/model_caffe/model_0.prototxt'
  80. self.caffemodel_file = f'{PROJECT_NAME}/model_caffe/model_0.caffemodel'
  81. def __call__(self, path):
  82. image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_COLOR)
  83. pred = self.net.src_forward(self.prototxt_file, self.caffemodel_file, CustomDataset([image]), 1)
  84. result = self.post_processing(image, [pred['output0'].copy(),pred['output1'].copy(),pred['output2'].copy()])
  85. return result
  86. class Caffe_Batch_Detector(Base_Detector):
  87. def __init__(self, opt, im_shape=net_shape) -> None:
  88. super().__init__(opt, im_shape)
  89. self.net = Net(opt.device)
  90. self.prototxt_file = f'{PROJECT_NAME}/model_caffe/model_0.prototxt'
  91. self.caffemodel_file = f'{PROJECT_NAME}/model_caffe/model_0.caffemodel'
  92. def __call__(self, path):
  93. pred = self.net.src_forward(self.prototxt_file, self.caffemodel_file, CustomDataset(path, is_path_arr=True), len(path))
  94. result = self.post_processing(path, [pred['output0'].copy(),pred['output1'].copy(),pred['output2'].copy()])
  95. return result
  96. def post_processing(self, image_path, result):
  97. pred = yolov8layer(result)
  98. pred = non_max_suppression(pred, conf_thres=self.opt.conf, iou_thres=self.opt.iou)
  99. image_result = []
  100. for i in range(len(pred)):
  101. pred_temp = pred[i]
  102. ori_image = cv2.imdecode(np.fromfile(image_path[i], np.uint8), cv2.IMREAD_COLOR)
  103. pred_temp[:, :4] = scale_boxes(self.im_shape, pred_temp[:, :4], ori_image.shape[:2])
  104. for *xyxy, conf, cls in reversed(pred_temp):
  105. x1, y1, x2, y2 = xyxy
  106. plot_one_box([int(x1), int(y1), int(x2), int(y2)], ori_image, label=f'{names[int(cls)]} {float(conf):.2f}', color=color_arr[int(cls)])
  107. image_result.append(ori_image.copy())
  108. return image_result
  109. class Qkqb_Detector(Base_Detector):
  110. def __init__(self, opt, im_shape=net_shape) -> None:
  111. super().__init__(opt, im_shape)
  112. self.net = Net(opt.device)
  113. self.qk_file = f'{PROJECT_NAME}/model_quantization/checkpoint_quan.qk'
  114. self.qb_file = f'{PROJECT_NAME}/model_quantization/checkpoint_quan.qb'
  115. def __call__(self, path):
  116. image = cv2.imdecode(np.fromfile(path, np.uint8), cv2.IMREAD_COLOR)
  117. pred = self.net.easy_forward(self.qk_file, self.qb_file, CustomDataset([image]), 1)
  118. result = self.post_processing(image, [pred[output_names[0]].copy(),pred[output_names[1]].copy(),pred[output_names[2]].copy()])
  119. return result
  120. class Qkqb_Batch_Detector(Base_Detector):
  121. def __init__(self, opt, im_shape=net_shape) -> None:
  122. super().__init__(opt, im_shape)
  123. self.net = Net(opt.device)
  124. self.qk_file = f'{PROJECT_NAME}/model_quantization/checkpoint_quan.qk'
  125. self.qb_file = f'{PROJECT_NAME}/model_quantization/checkpoint_quan.qb'
  126. def __call__(self, path):
  127. pred = self.net.easy_forward(self.qk_file, self.qb_file, CustomDataset(path, is_path_arr=True), len(path))
  128. result = self.post_processing(path, [pred[output_names[0]].copy(),pred[output_names[1]].copy(),pred[output_names[2]].copy()])
  129. return result
  130. def post_processing(self, image_path, result):
  131. pred = yolov8layer(result)
  132. pred = non_max_suppression(pred, conf_thres=self.opt.conf, iou_thres=self.opt.iou)
  133. image_result = []
  134. for i in range(len(pred)):
  135. pred_temp = pred[i]
  136. ori_image = cv2.imdecode(np.fromfile(image_path[i], np.uint8), cv2.IMREAD_COLOR)
  137. pred_temp[:, :4] = scale_boxes(self.im_shape, pred_temp[:, :4], ori_image.shape[:2])
  138. for *xyxy, conf, cls in reversed(pred_temp):
  139. x1, y1, x2, y2 = xyxy
  140. plot_one_box([int(x1), int(y1), int(x2), int(y2)], ori_image, label=f'{names[int(cls)]} {float(conf):.2f}', color=color_arr[int(cls)])
  141. image_result.append(ori_image.copy())
  142. return image_result
  143. class Ezb_Detector(Base_Detector):
  144. def __init__(self, opt, im_shape=net_shape) -> None:
  145. super().__init__(opt, im_shape)
  146. def __call__(self, path):
  147. image = letterbox_resize(path)
  148. os.system(sh_commad)
  149. result = self.post_processing(image, self.get_single_result_yolov8())
  150. return result
  151. def get_single_result_yolov8(self):
  152. shapes = Helper.get_caffe_output_shapes(qk_file)
  153. sim_res = np.fromfile(PROJECT_NAME+bin_path[0], dtype=np.int8 if Helper.get_quantize_out_bw(qk_file, output_names[0]) == 8 else np.int16)
  154. output1 = Helper.hw_data_to_caffe_int_data(sim_res, shapes[output_names[0]]) / math.pow(2, fl[0])
  155. sim_res = np.fromfile(PROJECT_NAME+bin_path[1], dtype=np.int8 if Helper.get_quantize_out_bw(qk_file, output_names[1]) == 8 else np.int16)
  156. output2 = Helper.hw_data_to_caffe_int_data(sim_res, shapes[output_names[1]]) / math.pow(2, fl[1])
  157. sim_res = np.fromfile(PROJECT_NAME+bin_path[2], dtype=np.int8 if Helper.get_quantize_out_bw(qk_file, output_names[2]) == 8 else np.int16)
  158. output3 = Helper.hw_data_to_caffe_int_data(sim_res, shapes[output_names[2]]) / math.pow(2, fl[2])
  159. return [output1, output2, output3]
  160. if __name__ == '__main__':
  161. opt = parse_opt()
  162. # 根据type 选择模型
  163. if opt.type != 'all':
  164. if opt.type == 'caffe':
  165. if opt.video:
  166. model = Caffe_Batch_Detector(opt)
  167. else:
  168. model = Caffe_Detector(opt)
  169. elif opt.type == 'qkqb':
  170. if opt.video:
  171. model = Qkqb_Batch_Detector(opt)
  172. else:
  173. model = Qkqb_Detector(opt)
  174. else:
  175. model = Ezb_Detector(opt)
  176. else:
  177. model = [Caffe_Detector(opt), Qkqb_Detector(opt), Ezb_Detector(opt)]
  178. # 判断设定的path是否是文件
  179. if os.path.isfile(opt.path):
  180. # 判断是否是视频推理
  181. if opt.video:
  182. time_arr, batch_image_path = [], []
  183. cap = cv2.VideoCapture(opt.path)
  184. fourcc, size = cv2.VideoWriter_fourcc(*'XVID'), (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
  185. out = cv2.VideoWriter(video_save_path, fourcc, 25.0, size)
  186. count, video_count = 0, int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
  187. folder_save_path = f'{PROJECT_NAME}/inference_video_save/'
  188. if os.path.exists(folder_save_path):
  189. shutil.rmtree(folder_save_path)
  190. os.makedirs(folder_save_path, exist_ok=True)
  191. video_temp_save_path = f'{PROJECT_NAME}/video_temp/'
  192. if os.path.exists(video_temp_save_path):
  193. shutil.rmtree(video_temp_save_path)
  194. os.makedirs(video_temp_save_path, exist_ok=True)
  195. if cap.isOpened():
  196. while True:
  197. flag, image = cap.read()
  198. if image is None:
  199. break
  200. cv2.imwrite(video_temp_save_path+f'{len(os.listdir(video_temp_save_path))}.jpg', image)
  201. batch_image_path.append(video_temp_save_path+f'{len(batch_image_path)}.jpg')
  202. if len(batch_image_path) == opt.video_batch_size:
  203. since = time.time()
  204. image_arr = model(batch_image_path)
  205. time_arr.append(time.time() - since)
  206. batch_image_path = []
  207. if os.path.exists(video_temp_save_path):
  208. shutil.rmtree(video_temp_save_path)
  209. os.makedirs(video_temp_save_path, exist_ok=True)
  210. for image in image_arr:
  211. out.write(image)
  212. cv2.imwrite(f'{folder_save_path + str(count + 1)}.jpg', image)
  213. count += 1
  214. print(f'video:{count}/{video_count}, img shape:{image.shape}, using time:{time_arr[-1]:.4f}s, mean time:{np.mean(time_arr):.4f}s, save frame in {folder_save_path + str(count + 1)}.jpg')
  215. if len(batch_image_path) != 0:
  216. since = time.time()
  217. image_arr = model(batch_image_path)
  218. time_arr.append(time.time() - since)
  219. for image in image_arr:
  220. out.write(image)
  221. cv2.imwrite(f'{folder_save_path + str(count + 1)}.jpg', image)
  222. count += 1
  223. print(f'video:{count}/{video_count}, img shape:{image.shape}, using time:{time_arr[-1]:.4f}s, mean time:{np.mean(time_arr):.4f}s, save frame in {folder_save_path + str(count + 1)}.jpg')
  224. if os.path.exists(video_temp_save_path):
  225. shutil.rmtree(video_temp_save_path)
  226. out.release()
  227. print('done.')
  228. else:
  229. # 判断是否全模型推理
  230. if type(model) is list:
  231. image = np.concatenate([m(opt.path) for m in model], axis=1)
  232. # 单模型推理
  233. else:
  234. image = model(opt.path)
  235. cv2.imwrite(infer_img, image)
  236. print(f'inference success. save image in {infer_img}')
  237. # 文件夹推理
  238. else:
  239. folder_save_path = f'{PROJECT_NAME}/inference_img_save/'
  240. if os.path.exists(folder_save_path):
  241. shutil.rmtree(folder_save_path)
  242. os.makedirs(folder_save_path, exist_ok=True)
  243. image_listdir = glob.glob(f'{opt.path}/*')
  244. for idx, img_path in enumerate(image_listdir):
  245. if type(model) is list:
  246. image = np.concatenate([m(img_path) for m in model], axis=1)
  247. else:
  248. image = model(img_path)
  249. cv2.imwrite(f'{folder_save_path + str(idx + 1)}.jpg', image)
  250. print('-'*20 + f'{idx + 1}/{len(image_listdir)} inference success. save image in {folder_save_path + str(idx + 1)}.jpg' + '-'*20)
  251. if type(model) is list:
  252. for model_ in model:
  253. if isinstance(model_, (Caffe_Detector, Caffe_Batch_Detector, Qkqb_Detector, Qkqb_Batch_Detector)):
  254. model_.net.release()
  255. else:
  256. if isinstance(model, (Caffe_Detector, Caffe_Batch_Detector, Qkqb_Detector, Qkqb_Batch_Detector)):
  257. model.net.release()