nn

 

2023-03-08

1
import torch
2
import torch.distributed as dist
3
import torch.nn as nn
4
import torch.nn.functional as F
5
6
# TODO 未验证
7
# 如果不是主进程,需要初始化进程组
8
if not dist.is_initialized():
9
dist.init_process_group(backend='nccl', init_method='tcp://localhost:23456', rank=rank, world_size=world_size)
10
11
# 加载模型并将其转换为 `DistributedDataParallel` 模型:
12
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True)
13
model = nn.parallel.DistributedDataParallel(model, device_ids=[rank])

参考