Torch

hub

2023-04-23

[[torch]] 中用于加载模型的函数,可以加载 [[github]] 仓库中的模型,也可以加载本地模型。

方法

list

查看可用模型

import torch
torch.hub.list("ultralytics/yolov5")
# 返回: ['custom', 'yolov5l', 'yolov5l6', 'yolov5m', 'yolov5m6', 'yolov5n', 'yolov5n6', 'yolov5s', 'yolov5s6', 'yolov5x', 'yolov5x6']

help

查看模型帮助

import torch
torch.hub.help("pytorch/vision", "resnet18")

load

加载模型

import torch
# 加载
# github 仓库和模型
model = torch.hub.load("ultralytics/yolov5", "yolov5s")
# 本地模型
model = torch.hub.load(
"ultralytics/yolov5",
"custom",
path="path/best.pt"
)
# 本地版本和模型
model = torch.hub.load(
"path/yolov5",
"custom",
path="path/best.pt",
source="local"
)

参数

  • device: GPU
  • _verbose: 静默加载

设置

# 设置
model.conf = 0.25 # 置信度
model.classes = [0] # 分类

使用

res = model("file.jpg")
res.print()
res.save()
res.show() # 预览图片
res.names: 所有 classess
res.xyxy[0]
res.files: 文件名 # 列表形式
# pandas 输出
res.pandas().xyxy[0]
res.pandas().xyxy[0].sort_values("xmin") # 排序从左到右
# json 输出
res.pandas().xyxy[0].to_json(orient="records")

pandas 包含:

  • names: 所有 classess #字典
  • files: 文件名 #列表
  • xyxy
  • xyxyn: 包含归一化
  • xywhn: 归一化
  • xywhn
  • n
  • t
  • s

加载到设备

model.cpu() # CPU
model.cuda() # GPU
model.to(device) # i.e. device=torch.device(0)

download_url_to_file

import torch
# 下载文件到本地
torch.hub.download_url_to_file('http://url/img.jpg', '/save_path/file_name')

参考