pytorch
Pytorch Multi-Gpu 학습했을 때 weight file 저장 및 로딩 방법
륵기
2023. 1. 30. 09:36
728x90
반응형
isinstance로 모델 타입 확인하기
-save
if isinstance(model, (DataParallel, DistributedDataParallel)):
torch.save(model.module.state_dict(), model_save_name)
else:
torch.save(model.state_dict(), model_save_name)
-load
state_dict = torch.load(model_name, map_location=current_gpu_device)
if isinstance(model, (DataParallel, DistributedDataParallel)):
model.module.load_state_dict(state_dict)
else:
model.load_state_dict(state_dict)
multi-gpu에서 학습한 것을 single-gpu에서 돌릴 때는 weight file에서 module 삭제하면 될 듯하다.
추론 시에 똑같이 모든 gpu를 사용할 가능성은 적으니..
checkpoints = torch.load("/kaggle/working/resnet50_cifar10_best_acc.ckp")
if isinstance(model, nn.DataParallel): # GPU 병렬사용 적용
model.module.load_state_dict(torch.load(checkpoints))
else: # GPU 병렬사용을 안할 경우
state_dict = torch.load(checkpoints)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.` ## module 키 제거
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
그런데 이것도 안 되는 경우가 있어서 그냥 가중치 key값에 module을 다 대체해버렸다.
checkpoint = torch.load('가중치 파일 경로')
for key in list(checkpoint.keys()):
if 'module.' in key:
checkpoint[key.replace('module.', '')] = checkpoint[key]
del checkpoint[key]
model.load_state_dict(checkpoint)
728x90
반응형