본문 바로가기
AI

CutMix

by Ladun 2022. 1. 27.

 이번 글에서는 Data augmentation 기법 중 하나인 CutMix에 대해서 살펴볼 것입니다. 기존에 이미지 기반 테스크에서 성능을 높이기 위해서 이미지의 일부분을 잘라서 0으로 채우거나(Cutout) 다른 이미지와 겹치는(Mixup)과 같은 여러 기법이 사용되었습니다. 이러한 방법들을 통해서 이미지의 덜 중요한 부분까지 포커싱하게 만드는 regional dropout 전략을 사용해왔습니다. 하지만 Table 1에서 보이는 것과 같이 이미지의 정보가 손실되거나 왜곡되는 현상 때문에 오히려 성능이 감소하는 문제가 발생했습니다. CutMix는 기존 방법에서 더 나아가 cut-and-paste 방법을 취해서 현 이미지의 패치를 다른 이미지의 패치로 채우는 기법을 사용하여 높은 성능을 가져왔습니다.

 


Algorithm

이미지 $ x \in R^{W\times H\times C} $, 라벨 $y$에 대하여 CutMix는 두 개의 데이터 $(x_A, y_A), (x_B, y_B)$로 부터 새로운 데이터 $(\widetilde x, \widetilde y) $를 아래 식에 따라 만들어냅니다.

$$ \widetilde x = M\odot x_A + (1-M)\odot x_b\\\widetilde y = \lambda y_A + (1-\lambda)y_B$$

 

위 식의 $M \in \{0, 1\}^{W\times H}$는 $x_A$의 어느 부분을 지우고 $x_B$로부터 그 부분을 채울 것인지를 정하는 마스크이며, $\lambda$는 베타 분포 $~Beta(\alpha, \alpha)$로부터 추출된 0과 1사이로 정의된 두 데이터의 결합 비율을 의미합니다.

 

즉, 샘플링한 $\lambda$의 비율에 따라서 마스크 $M$을 생성하고 마스크 $M$을 이용하여 $x_A$의 일정부분을 $x_B$로 교체하는 알고리즘이다. 이에 해당하는 코드는 다음과 같습니다.

def cutmix(batch, alpha):
    data, targets = batch

    indices = torch.randperm(data.size(0))
    shuffled_data = data[indices]
    shuffled_targets = targets[indices]

    lam = np.random.beta(alpha, alpha)

    image_h, image_w = data.shape[2:]
    cx = np.random.uniform(0, image_w)
    cy = np.random.uniform(0, image_h)
    w = image_w * np.sqrt(1 - lam)
    h = image_h * np.sqrt(1 - lam)
    x0 = int(np.round(max(cx - w / 2, 0)))
    x1 = int(np.round(min(cx + w / 2, image_w)))
    y0 = int(np.round(max(cy - h / 2, 0)))
    y1 = int(np.round(min(cy + h / 2, image_h)))

    data[:, :, y0:y1, x0:x1] = shuffled_data[:, :, y0:y1, x0:x1]
    targets = (targets, shuffled_targets, lam)

    return data, targets

출처: https://github.com/hysts/pytorch_cutmix/blob/master/cutmix.py

위 코드를 보면 Label $\widetilde y$에 대한 수식이 안 보이는 것을 확인할 수 있습니다. 이는 실질적으로 $\lambda$를 loss의 가중치로 두어서 학습을 진행합니다. 즉, 아래 수식과 같이 loss를 가지게 됩니다.

$$ loss(\widetilde y) = \lambda* loss(y_A, f(\widetilde x)) + (1 - \lambda) * loss(y_B, f(\widetilde x))$$

이를 코드로 보면 다음과 같습니다.

class CutMixCriterion:
    def __init__(self, reduction):
        self.criterion = nn.CrossEntropyLoss(reduction=reduction)

    def __call__(self, preds, targets):
        targets1, targets2, lam = targets
        return lam * self.criterion(
            preds, targets1) + (1 - lam) * self.criterion(preds, targets2)

출처: https://github.com/hysts/pytorch_cutmix/blob/master/cutmix.py

loss와 똑같이 accuracy나 f1 스코어를 구할 때도 가중치를 둬서 구하면 됩니다.

$$ acc = \lambda* acc(y_A, f(\widetilde x)) + (1 - \lambda) * acc(y_B, f(\widetilde x))$$

 


Reference

'AI' 카테고리의 다른 글

Model의 Bias-Variance  (0) 2022.02.03
Macro-average, Micro-average  (0) 2022.02.03
Cross Entropy를 사용하는 이유  (0) 2021.11.12
Backpropagation  (0) 2021.11.03
[Metric] Recall과 Precision  (0) 2021.10.03