世界讯息:yolotv5和resnet152模型预测
发布时间:2023-05-31 22:20:13
来源:博客园
(资料图片仅供参考)
我已经训练完成了yolov5检测和resnet152分类的模型,下面开始对一张图片进行检测分类。
首先用yolo算法对猫和狗进行检测,然后将检测到的目标进行裁剪,然后用resnet152对裁剪的图片进行分类。
首先我有以下这些训练好的模型
猫狗检测的,猫的分类,狗的分类
我的预测文件my_detect.py
import osimport sysfrom pathlib import Pathfrom tools_detect import draw_box_and_save_img, dataLoad, predict_classify, detect_img_2_classify_img, get_time_uuidFILE = Path(__file__).resolve()ROOT = FILE.parents[0] # YOLOv5 root directoryif str(ROOT) not in sys.path: sys.path.append(str(ROOT)) # add ROOT to PATHROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relativefrom models.common import DetectMultiBackendfrom utils.general import (non_max_suppression)from utils.plots import save_one_boximport config as cfgconf_thres = cfg.conf_thresiou_thres = cfg.iou_thresdetect_size = cfg.detect_img_sizeclassify_size = cfg.classify_img_sizedef detect_img(img, device, detect_weights="", detect_class=[], save_dir=""): # 选择计算设备 # device = select_device(device) # 加载数据 imgsz = (detect_size, detect_size) im0s, im = dataLoad(img, imgsz, device) # print(im0) # print(im) # 加载模型 model = DetectMultiBackend(detect_weights, device=device) stride, names, pt = model.stride, model.names, model.pt # print((1, 3, *imgsz)) model.warmup(imgsz=(1, 3, *imgsz)) # warmup pred = model(im, augment=False, visualize=False) # print(pred) pred = non_max_suppression(pred, conf_thres, iou_thres, None, False, max_det=1000) # print(pred) im0 = im0s.copy() # 画框,保存图片 # ret_bytes= None ret_bytes = draw_box_and_save_img(pred, names, detect_class, save_dir, im0, im) ret_li = list() # print(pred) im0_arc = int(im0.shape[0]) * int(im0.shape[1]) count = 1 for det in reversed(pred[0]): # print(det) # print(det) # 目标太小跳过 xyxy_arc = (int(det[2]) - int(det[0])) * (int(det[3]) - int(det[1])) # print(xyxy_arc) if xyxy_arc / im0_arc < 0.01: continue # 裁剪图片 xyxy = det[:4] im_crop = save_one_box(xyxy, im0, file=Path("im.jpg"), gain=1.1, pad=10, square=False, BGR=False, save=False) # 将裁剪的图片转为分类的大小及tensor类型 im_crop = detect_img_2_classify_img(im_crop, classify_size, device) d = dict() # print(det) c = int(det[-1]) label = detect_class[c] # 开始做具体分类 if label == detect_class[0]: classify_predict = predict_classify(cfg.cat_weight, im_crop, device) classify_label = cfg.cat_class[int(classify_predict)] else: classify_predict = predict_classify(cfg.dog_weight, im_crop, device) classify_label = cfg.dog_class[int(classify_predict)] # print(classify_label) d["details"] = classify_label conf = round(float(det[-2]), 2) d["label"] = label+str(count) d["conf"] = conf ret_li.append(d) count += 1 return ret_li, ret_bytesdef start_predict(img, save_dir=""): weights = cfg.detect_weight detect_class = cfg.detect_class device = cfg.device ret_li, ret_bytes = detect_img(img, device, weights, detect_class, save_dir) # print(ret_li) return ret_li, ret_bytesif __name__ == "__main__": name = get_time_uuid() save_dir = f"./save/{name}.jpg" # path = r"./test_img/hashiqi20230312_00010.jpg" path = r"./test_img/hashiqi20230312_00116.jpg" # path = r"./test_img/kejiquan20230312_00046.jpg" f = open(path, "rb") img = f.read() f.close() # print(img) # print(type(img)) img_ret_li, img_bytes = start_predict(img, save_dir=save_dir) print(img_ret_li)
我的tools_detect.py文件
import datetimeimport osimport randomimport sysimport timefrom pathlib import Pathimport torchfrom PIL import Imagefrom torch import nnfrom utils.augmentations import letterboxFILE = Path(__file__).resolve()ROOT = FILE.parents[0] # YOLOv5 root directoryif str(ROOT) not in sys.path: sys.path.append(str(ROOT)) # add ROOT to PATHROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relativefrom utils.general import (cv2, scale_boxes, xyxy2xywh)from utils.plots import Annotator, colorsimport numpy as npdef bytes_to_ndarray(byte_img): """ 图片二进制转numpy格式 """ image = np.asarray(bytearray(byte_img), dtype="uint8") image = cv2.imdecode(image, cv2.IMREAD_COLOR) return imagedef ndarray_to_bytes(ndarray_img): """ 图片numpy格式转二进制 """ ret, buf = cv2.imencode(".jpg", ndarray_img) img_bin = Image.fromarray(np.uint8(buf)).tobytes() # print(type(img_bin)) return img_bindef get_time_uuid(): """ :return: 20220525140635467912 :PS :并发较高时尾部随机数增加 """ uid = str(datetime.datetime.fromtimestamp(time.time())).replace("-", "").replace(" ", "").replace(":","").replace(".", "") + str(random.randint(100, 999)) return uiddef dataLoad(img, img_size, device, half=False): image = bytes_to_ndarray(img) # print(image.shape) im = letterbox(image, img_size)[0] # padded resize im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB im = np.ascontiguousarray(im) # contiguous im = torch.from_numpy(im).to(device) im = im.half() if half else im.float() # uint8 to fp16/32 im /= 255 # 0 - 255 to 0.0 - 1.0 if len(im.shape) == 3: im = im[None] # expand for batch dim return image, imdef draw_box_and_save_img(pred, names, class_names, save_dir, im0, im): save_path = save_dir fontpath = "./simsun.ttc" for i, det in enumerate(pred): annotator = Annotator(im0, line_width=3, example=str(names), font=fontpath, pil=True) if len(det): det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() count = 1 im0_arc = int(im0.shape[0]) * int(im0.shape[1]) gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] base_path = os.path.split(save_path)[0] file_name = os.path.split(save_path)[1].split(".")[0] txt_path = os.path.join(base_path, "labels") if not os.path.exists(txt_path): os.mkdir(txt_path) txt_path = os.path.join(txt_path, file_name) for *xyxy, conf, cls in reversed(det): # 目标太小跳过 xyxy_arc = (int(xyxy[2]) - int(xyxy[0])) * (int(xyxy[3]) - int(xyxy[1])) # print(im0.shape, xyxy, xyxy_arc, im0_arc, xyxy_arc / im0_arc) if xyxy_arc / im0_arc < 0.01: continue # print(im0.shape, xyxy) c = int(cls) # integer class label = f"{class_names[c]}{count} {round(float(conf), 2)}" # .encode("utf-8") # print(xyxy) annotator.box_label(xyxy, label, color=colors(c, True)) im0 = annotator.result() count += 1 # print(im0) # print(type(im0)) # im0 为 numpy.ndarray类型 # Write to file # print("+++++++++++") xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh # print(xywh) line = (cls, *xywh) # label format with open(f"{txt_path}.txt", "a") as f: f.write(("%g " * len(line)).rstrip() % line + "\n") cv2.imwrite(save_path, im0) ret_bytes = ndarray_to_bytes(im0) return ret_bytesdef predict_classify(model_path, img, device): # im = torch.nn.functional.interpolate(img, (160, 160), mode="bilinear", align_corners=True) # print(device) if torch.cuda.is_available(): model = torch.load(model_path) else: model = torch.load(model_path, map_location="cpu") # print(help(model)) model.to(device) model.eval() predicts = model(img) _, preds = torch.max(predicts, 1) pred = torch.squeeze(preds) # print(pred) return preddef detect_img_2_classify_img(img, classify_size, device): im_crop1 = img.copy() im_crop1 = np.float32(im_crop1) image = cv2.resize(im_crop1, (classify_size, classify_size)) image = image.transpose((2, 0, 1)) im = torch.from_numpy(image).unsqueeze(0) im_crop = im.to(device) return im_crop
我的config.py文件
import torchimport osbase_path = r".\weights"detect_weight = os.path.join(base_path, r"cat_dog_detect/best.pt")detect_class = ["猫", "狗"]cat_weight = os.path.join(base_path, r"cat_predict/best.pt")cat_class = ["东方短毛猫", "亚洲豹猫", "加菲猫", "安哥拉猫", "布偶猫", "德文卷毛猫", "折耳猫", "无毛猫", "暹罗猫", "森林猫", "橘猫", "奶牛猫", "狞猫", "狮子猫", "狸花猫", "玳瑁猫", "白猫", "蓝猫", "蓝白猫", "薮猫", "金渐层猫", "阿比西尼亚猫", "黑猫"]dog_weight = os.path.join(base_path, r"dog_predict/best.pt")dog_class = ["中华田园犬", "博美犬", "吉娃娃", "哈士奇", "喜乐蒂", "巴哥犬", "德牧", "拉布拉多犬", "杜宾犬", "松狮犬", "柯基犬", "柴犬", "比格犬", "比熊", "法国斗牛犬", "秋田犬", "约克夏", "罗威纳犬", "腊肠犬", "萨摩耶", "西高地白梗犬", "贵宾犬", "边境牧羊犬", "金毛犬", "阿拉斯加犬", "雪纳瑞", "马尔济斯犬"]# device = 0# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")device = torch.device("cpu")conf_thres = 0.5iou_thres = 0.45detect_img_size = 416classify_img_size = 160
整体文件结构
其中models和utils文件夹都是yolov5源码的文件
运行my_detect.py的结果
标签:
AD
更多相关文章
- 世界讯息:yolotv5和resnet152模型预测
- 至尊天神声望怎么刷_至尊天神声望如
- 法尔克:拜仁加入弗拉霍维奇争夺战
- 浙数文化:数字科技储备充足 社交
- 百名老区孩子开启大湾区研学体验之
- 明星餐饮为何逃不出短命魔咒?专家
- 全球快报:云从科技现6笔大宗交易
- 《抖音》抖音回应10亿元收购支付公
- 天天快资讯丨德尔菲娜·尚内亚克种
- 武林风一龙一虎一豹(武林风一龙)
- 天天热议:明天起 婚姻登记可“跨省
- 通化东宝:痛风双靶点抑制剂I期临床
- “六一”趁热上新70余款 乐高提速
- 6月1日起,部分城市可以使用交管121
- 全球速看:佳能携陶瓷3D打印服务亮
- 冠石科技(605588.SH):拟定增募资不
- 全球新动态:和评理 |排他性 “印
- 全球观焦点:《死侍3》受编剧罢工影
- 褥疮特效药粉末_褥疮特效药
- 天天报道:济宁中考实验操作考试什么
- 济南高新区在济南市首推 “三医联
- 动态:「小白」2023 上半年什么手
- 杭州成人吸烟率已降至18.8%,“无烟
- 开山股份(300257.SZ):控股股东减持
- 为什么中外合作办学项目不能转专业
- 贵定:“数字乡村” 数字化赋能乡
- 每日聚焦:云南普者黑:发展旅游产
- 华龙区残联开展“世界无烟日”签名
- 干枣和鲜枣哪个好?干枣和鲜枣的功
- 仙佑医药科技有限公司怎么样? 仙