rueki

imbalance data일때 weightedRandomSampler 적용하기 본문

code tip

imbalance data일때 weightedRandomSampler 적용하기

륵기 2021. 8. 30. 22:14

이미지 분류에서 클래스마다 존재하는 데이터의 수가 항상 비슷하게 존재하지 않은다. 불균형적인 데이터일때 적용할 수 있는 방법 중 하나를 소개한다.

WeightedRandomSampler라는 방법인데 Dataloader를 불러올때 같이 선언할 수 있다고 한다.

import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

###############################################
#위에 데이터셋 class, 및 등등 선언되어있다 가정#
###############################################

#class 0 : 43200개, class 1 : 4800개
class_counts = y_train.value_counts().to_list() #43200, 4800
num_samples = sum(class_counts) # 48000 - 전체 데이터 갯수
labels = y_train.to_list()

#클래스별 가중치 부여 [48000/43200, 48000/4800] => class 1에 가중치 높게 부여하게 됨
class_weights = [num_samples / class_counts[i] for i in range(len(class_counts))] 

# 해당 데이터의 label에 해당되는 가중치
weights = [class_weights[labels[i]] for i in range(int(num_samples))] #해당 레이블마다의 가중치 비율
sampler = WeightedRandomSampler(torch.DoubleTensor(weights), int(num_samples))


#dataloader
train_loader = DataLoader(
    train_dataset, batch_size=params['batch_size'], sampler = sampler,
    num_workers=params['num_workers'], pin_memory=True)

val_loader = DataLoader(
    valid_dataset, batch_size=params['batch_size'], shuffle=False,
    num_workers=params['num_workers'], pin_memory=True)

중요한 부분은 데이터 갯수에 따른 각 클래스에 해당되는 데이터 존재 비율이다. 이를 하나의 가중치로 나타내서 imbalance 데이터를 다루는 문제를 해결할 수 있다고 한다.

Comments