copy_img_select.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import os, random, shutil, tqdm, argparse, cv2
  2. import numpy as np
  3. import xml.etree.ElementTree as ET
  4. from config import names
  5. PROJECT_NAME = os.environ['PROJECT_NAME']
  6. def convert_annotation(xmlpath, xmlname, opt):
  7. with open(xmlpath, "r", encoding='utf-8') as in_file:
  8. tree = ET.parse(in_file)
  9. root = tree.getroot()
  10. img = cv2.imdecode(np.fromfile('{}/{}.{}'.format(opt.path, xmlname[:-4], opt.postfix), np.uint8), cv2.IMREAD_COLOR)
  11. h, w = img.shape[:2]
  12. for obj in root.iter('object'):
  13. cls = obj.find('name').text
  14. if cls not in names:
  15. continue
  16. xmlbox = obj.find('bndbox')
  17. b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
  18. float(xmlbox.find('ymax').text))
  19. if ((b[1] - b[0]) * (b[3] - b[2])) / (h * w) > opt.min_area_ratio and ((b[1] - b[0]) * (b[3] - b[2])) / (h * w) < opt.max_area_ratio:
  20. return True
  21. return False
  22. def select_pic(opt):
  23. postfix = opt.postfix
  24. xmlpath = opt.xml_path
  25. need_copy = []
  26. list = os.listdir(xmlpath)
  27. random.shuffle(list)
  28. error_file_list = []
  29. for i in tqdm.tqdm(range(0, len(list)), desc=f'selecting... {len(need_copy)}/{opt.num}'):
  30. try:
  31. path = os.path.join(xmlpath, list[i])
  32. if ('.xml' in path) or ('.XML' in path):
  33. if convert_annotation(path, list[i], opt):
  34. need_copy.append(f'{list[i][:-3]}{postfix}')
  35. except Exception as e:
  36. error_file_list.append(list[i])
  37. if len(need_copy) >= opt.num:
  38. break
  39. print(f'this file convert failure\n{error_file_list}')
  40. return need_copy
  41. def parse_opt():
  42. parser = argparse.ArgumentParser()
  43. parser.add_argument('--path', type=str, required=True, help='base_image_path')
  44. parser.add_argument('--xml_path', type=str, required=True, help='base_xml_path')
  45. parser.add_argument('--num', type=int, required=True, help='image num')
  46. parser.add_argument('--postfix', type=str, default='jpg', help='image postfix')
  47. parser.add_argument('--min_area_ratio', type=float, default=0.0, help='min_area_ratio')
  48. parser.add_argument('--max_area_ratio', type=float, default=0.1, help='max_area_ratio')
  49. opt = parser.parse_known_args()[0]
  50. return opt
  51. if __name__ == '__main__':
  52. opt = parse_opt()
  53. pwd = os.getcwd()
  54. dest_image_path = f'{pwd}/{PROJECT_NAME}/img_train'
  55. if os.path.exists(dest_image_path):
  56. shutil.rmtree(dest_image_path)
  57. os.makedirs(dest_image_path, exist_ok=True)
  58. need_copy = select_pic(opt)
  59. for path in tqdm.tqdm(need_copy, desc=f'from {opt.path} copy to {dest_image_path}'):
  60. shutil.copy(f'{opt.path}/{path}', f'{dest_image_path}/{path}')