rueki
Pytorch Multi-Gpu 학습했을 때 weight file 저장 및 로딩 방법 본문
728x90
반응형
isinstance로 모델 타입 확인하기
-save
if isinstance(model, (DataParallel, DistributedDataParallel)):
torch.save(model.module.state_dict(), model_save_name)
else:
torch.save(model.state_dict(), model_save_name)
-load
state_dict = torch.load(model_name, map_location=current_gpu_device)
if isinstance(model, (DataParallel, DistributedDataParallel)):
model.module.load_state_dict(state_dict)
else:
model.load_state_dict(state_dict)
multi-gpu에서 학습한 것을 single-gpu에서 돌릴 때는 weight file에서 module 삭제하면 될 듯하다.
추론 시에 똑같이 모든 gpu를 사용할 가능성은 적으니..
checkpoints = torch.load("/kaggle/working/resnet50_cifar10_best_acc.ckp")
if isinstance(model, nn.DataParallel): # GPU 병렬사용 적용
model.module.load_state_dict(torch.load(checkpoints))
else: # GPU 병렬사용을 안할 경우
state_dict = torch.load(checkpoints)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.` ## module 키 제거
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
그런데 이것도 안 되는 경우가 있어서 그냥 가중치 key값에 module을 다 대체해버렸다.
checkpoint = torch.load('가중치 파일 경로')
for key in list(checkpoint.keys()):
if 'module.' in key:
checkpoint[key.replace('module.', '')] = checkpoint[key]
del checkpoint[key]
model.load_state_dict(checkpoint)
728x90
반응형
'pytorch' 카테고리의 다른 글
tensorboardx ssh 연결 및 모니터링 방법 (0) | 2023.07.31 |
---|---|
mmdetection 설치 (0) | 2023.02.22 |
Pytorch로 FLOPs 계산하는 방법 (0) | 2022.09.27 |
torch custom dataset code 참고용 (0) | 2021.04.08 |
MNIST 데이터 CNN으로 분류하기 - review (0) | 2020.08.22 |
Comments