yolo.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. import torch
  2. import torchvision
  3. import time
  4. import cv2
  5. import random
  6. random.seed(0)
  7. import tqdm
  8. import os
  9. import numpy as np
  10. from config import *
  11. from typing import Callable
  12. from multiprocessing import Pool
  13. def parallelise(function: Callable, data, chunksize=100, verbose=True, num_workers=os.cpu_count()):
  14. num_workers = 1 if num_workers < 1 else num_workers # Pool needs to have at least 1 worker.
  15. pool = Pool(processes=num_workers)
  16. results = list(
  17. tqdm.tqdm(pool.imap(function, data, chunksize), total=len(data), disable=not verbose)
  18. )
  19. pool.close()
  20. pool.join()
  21. return results
  22. def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
  23. # Resize and pad image while meeting stride-multiple constraints
  24. shape = im.shape[:2] # current shape [height, width]
  25. if isinstance(new_shape, int):
  26. new_shape = (new_shape, new_shape)
  27. # Scale ratio (new / old)
  28. r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
  29. if not scaleup: # only scale down, do not scale up (for better val mAP)
  30. r = min(r, 1.0)
  31. # Compute padding
  32. ratio = r, r # width, height ratios
  33. new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
  34. dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
  35. if auto: # minimum rectangle
  36. dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
  37. elif scaleFill: # stretch
  38. dw, dh = 0.0, 0.0
  39. new_unpad = (new_shape[1], new_shape[0])
  40. ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
  41. dw /= 2 # divide padding into 2 sides
  42. dh /= 2
  43. if shape[::-1] != new_unpad: # resize
  44. im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
  45. top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
  46. left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
  47. im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
  48. return im, ratio, (dw, dh)
  49. def xywh2xyxy(x):
  50. # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
  51. y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
  52. y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
  53. y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
  54. y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
  55. y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
  56. return y
  57. def box_iou(box1, box2, eps=1e-7):
  58. """
  59. Calculate intersection-over-union (IoU) of boxes.
  60. Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
  61. Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
  62. Args:
  63. box1 (torch.Tensor): A tensor of shape (N, 4) representing N bounding boxes.
  64. box2 (torch.Tensor): A tensor of shape (M, 4) representing M bounding boxes.
  65. eps (float, optional): A small value to avoid division by zero. Defaults to 1e-7.
  66. Returns:
  67. (torch.Tensor): An NxM tensor containing the pairwise IoU values for every element in box1 and box2.
  68. """
  69. # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
  70. (a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)
  71. inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp_(0).prod(2)
  72. # IoU = inter / (area1 + area2 - inter)
  73. return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)
  74. def non_max_suppression(
  75. prediction,
  76. conf_thres=0.25,
  77. iou_thres=0.45,
  78. classes=None,
  79. agnostic=False,
  80. multi_label=False,
  81. labels=(),
  82. max_det=300,
  83. nc=0, # number of classes (optional)
  84. max_time_img=0.05,
  85. max_nms=30000,
  86. max_wh=7680,
  87. ):
  88. """
  89. Perform non-maximum suppression (NMS) on a set of boxes, with support for masks and multiple labels per box.
  90. Arguments:
  91. prediction (torch.Tensor): A tensor of shape (batch_size, num_classes + 4 + num_masks, num_boxes)
  92. containing the predicted boxes, classes, and masks. The tensor should be in the format
  93. output by a model, such as YOLO.
  94. conf_thres (float): The confidence threshold below which boxes will be filtered out.
  95. Valid values are between 0.0 and 1.0.
  96. iou_thres (float): The IoU threshold below which boxes will be filtered out during NMS.
  97. Valid values are between 0.0 and 1.0.
  98. classes (List[int]): A list of class indices to consider. If None, all classes will be considered.
  99. agnostic (bool): If True, the model is agnostic to the number of classes, and all
  100. classes will be considered as one.
  101. multi_label (bool): If True, each box may have multiple labels.
  102. labels (List[List[Union[int, float, torch.Tensor]]]): A list of lists, where each inner
  103. list contains the apriori labels for a given image. The list should be in the format
  104. output by a dataloader, with each label being a tuple of (class_index, x1, y1, x2, y2).
  105. max_det (int): The maximum number of boxes to keep after NMS.
  106. nc (int, optional): The number of classes output by the model. Any indices after this will be considered masks.
  107. max_time_img (float): The maximum time (seconds) for processing one image.
  108. max_nms (int): The maximum number of boxes into torchvision.ops.nms().
  109. max_wh (int): The maximum box width and height in pixels
  110. Returns:
  111. (List[torch.Tensor]): A list of length batch_size, where each element is a tensor of
  112. shape (num_boxes, 6 + num_masks) containing the kept boxes, with columns
  113. (x1, y1, x2, y2, confidence, class, mask1, mask2, ...).
  114. """
  115. # Checks
  116. assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'
  117. assert 0 <= iou_thres <= 1, f'Invalid IoU {iou_thres}, valid values are between 0.0 and 1.0'
  118. if isinstance(prediction, (list, tuple)): # YOLOv8 model in validation model, output = (inference_out, loss_out)
  119. prediction = prediction[0] # select only inference output
  120. device = prediction.device
  121. mps = 'mps' in device.type # Apple MPS
  122. if mps: # MPS not fully supported yet, convert tensors to CPU before NMS
  123. prediction = prediction.cpu()
  124. bs = prediction.shape[0] # batch size
  125. nc = nc or (prediction.shape[1] - 4) # number of classes
  126. nm = prediction.shape[1] - nc - 4
  127. mi = 4 + nc # mask start index
  128. xc = prediction[:, 4:mi].amax(1) > conf_thres # candidates
  129. # Settings
  130. # min_wh = 2 # (pixels) minimum box width and height
  131. time_limit = 0.5 + max_time_img * bs # seconds to quit after
  132. redundant = True # require redundant detections
  133. multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
  134. merge = False # use merge-NMS
  135. t = time.time()
  136. output = [torch.zeros((0, 6 + nm), device=prediction.device)] * bs
  137. for xi, x in enumerate(prediction): # image index, image inference
  138. # Apply constraints
  139. # x[((x[:, 2:4] < min_wh) | (x[:, 2:4] > max_wh)).any(1), 4] = 0 # width-height
  140. x = x.transpose(0, -1)[xc[xi]] # confidence
  141. # Cat apriori labels if autolabelling
  142. if labels and len(labels[xi]):
  143. lb = labels[xi]
  144. v = torch.zeros((len(lb), nc + nm + 5), device=x.device)
  145. v[:, :4] = lb[:, 1:5] # box
  146. v[range(len(lb)), lb[:, 0].long() + 4] = 1.0 # cls
  147. x = torch.cat((x, v), 0)
  148. # If none remain process next image
  149. if not x.shape[0]:
  150. continue
  151. # Detections matrix nx6 (xyxy, conf, cls)
  152. box, cls, mask = x.split((4, nc, nm), 1)
  153. box = xywh2xyxy(box) # center_x, center_y, width, height) to (x1, y1, x2, y2)
  154. if multi_label:
  155. i, j = (cls > conf_thres).nonzero(as_tuple=False).T
  156. x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float(), mask[i]), 1)
  157. else: # best class only
  158. conf, j = cls.max(1, keepdim=True)
  159. x = torch.cat((box, conf, j.float(), mask), 1)[conf.view(-1) > conf_thres]
  160. # Filter by class
  161. if classes is not None:
  162. x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
  163. # Apply finite constraint
  164. # if not torch.isfinite(x).all():
  165. # x = x[torch.isfinite(x).all(1)]
  166. # Check shape
  167. n = x.shape[0] # number of boxes
  168. if not n: # no boxes
  169. continue
  170. x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
  171. # Batched NMS
  172. c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
  173. boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
  174. i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
  175. i = i[:max_det] # limit detections
  176. if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
  177. # Update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
  178. iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
  179. weights = iou * scores[None] # box weights
  180. x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
  181. if redundant:
  182. i = i[iou.sum(1) > 1] # require redundancy
  183. output[xi] = x[i]
  184. if mps:
  185. output[xi] = output[xi].to(device)
  186. return output
  187. def plot_one_box(x, img, color=None, label=None, line_thickness=3):
  188. # Plots one bounding box on image img
  189. tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness
  190. color = color or [random.randint(0, 255) for _ in range(3)]
  191. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  192. cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  193. if label:
  194. tf = max(tl - 1, 1) # font thickness
  195. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  196. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  197. cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA) # filled
  198. cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  199. def softmax(z):
  200. e_z = np.exp(z - np.max(z))
  201. return e_z / e_z.sum(axis=0)
  202. def dfl_conv(z):
  203. weights = np.expand_dims(np.arange(z.shape[-1]), axis=0)
  204. z = z * weights
  205. return np.sum(z, axis=-1)
  206. def dfl(res):
  207. res = np.expand_dims(res, axis=0)
  208. b, c, a = res.shape
  209. c1 = c // 4
  210. res = res.reshape((b, 4, c1, a))
  211. res = np.transpose(res, axes=[0, 3, 1, 2])
  212. # res = parallelise(softmax, res.reshape((-1, c1)))
  213. # res = parallelise(dfl_conv, res)
  214. res = np.stack([softmax(i) for i in res.reshape((-1, c1))])
  215. res = dfl_conv(res)
  216. return np.transpose(res.reshape((b, a, 4)), axes=[0, 2, 1])
  217. def make_anchor(input_shape=net_shape, grid_cell_offset=0.5):
  218. anchor_points, stride = [], []
  219. for i in strides:
  220. h, w = input_shape[0] // i, input_shape[1] // i
  221. sx = np.arange(w) + grid_cell_offset
  222. sy = np.arange(h) + grid_cell_offset
  223. sx, sy = np.meshgrid(sx, sy)
  224. anchor_points.append(np.stack((sx, sy), -1).reshape((-1, 2)))
  225. stride.append(np.full((h * w, 1), i))
  226. return np.transpose(np.concatenate(anchor_points), axes=[1, 0]), np.transpose(np.concatenate(stride), axes=[1, 0])
  227. def yolov8layer(res, c1=16):
  228. n, c = res[0].shape[:2]
  229. res = np.concatenate([xi.reshape((n, c, -1)) for xi in res], 2)
  230. res_dfl = parallelise(dfl, res[:, :4 * c1])
  231. res = np.concatenate([np.concatenate(res_dfl, axis=0), res[:, 4 * c1:]], axis=1)
  232. anchor_points, stride = make_anchor()
  233. x1y1 = anchor_points - res[:, :2]
  234. x2y2 = anchor_points + res[:, 2:4]
  235. # x1y1x2y2 -> xywh
  236. c_xy = (x1y1 + x2y2) / 2
  237. wh = x2y2 - x1y1
  238. res[:, :4] = np.concatenate((c_xy, wh), axis=1) * stride
  239. return res