rueki

Pytorch Multi-Gpu 학습했을 때 weight file 저장 및 로딩 방법 본문

pytorch

Pytorch Multi-Gpu 학습했을 때 weight file 저장 및 로딩 방법

륵기 2023. 1. 30. 09:36
728x90
반응형

 

isinstance로 모델 타입 확인하기

-save

if isinstance(model, (DataParallel, DistributedDataParallel)):
    torch.save(model.module.state_dict(), model_save_name)
else:
    torch.save(model.state_dict(), model_save_name)

-load

state_dict = torch.load(model_name, map_location=current_gpu_device)
if isinstance(model, (DataParallel, DistributedDataParallel)):
    model.module.load_state_dict(state_dict)
else:
    model.load_state_dict(state_dict)

 

multi-gpu에서 학습한 것을 single-gpu에서 돌릴 때는 weight file에서 module 삭제하면 될 듯하다.

추론 시에 똑같이 모든 gpu를 사용할 가능성은 적으니..

checkpoints = torch.load("/kaggle/working/resnet50_cifar10_best_acc.ckp")

if isinstance(model, nn.DataParallel):  # GPU 병렬사용 적용
    model.module.load_state_dict(torch.load(checkpoints))
else: # GPU 병렬사용을 안할 경우
    state_dict = torch.load(checkpoints) 
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:] # remove `module.`  ## module 키 제거
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)

그런데 이것도 안 되는 경우가 있어서 그냥 가중치 key값에 module을 다 대체해버렸다.

checkpoint = torch.load('가중치 파일 경로')
for key in list(checkpoint.keys()):
    if 'module.' in key:
        checkpoint[key.replace('module.', '')] = checkpoint[key]
        del checkpoint[key]
model.load_state_dict(checkpoint)
728x90
반응형
Comments