Cross Pseudo Supervision: Consistency Regularization을 위한 간단하지만 강력한 방법론

안녕하세요, 오늘 소개할 논문은 중국 Peking University에서 연구한 Semi-Supervised Learning 관련 논문으로 논문의 제목은 Semi-Supervised Semantic Segmentation with Cross Pseudo Supervision 입니다. 본 논문은 21년 6월에 CVPR 2021에 Accept된 논문으로 등재 8개월 만에 24개의 논문에서 인용을 되었습니다.

Semi-supervised Learning

인공지능 모델을 학습하기 위해서는 학습을 위한 데이터 가공이 필요합니다. 현재까지도 수 많은 인공지능 학습 모델이 연구되어 오고는 있지만 그 중에서도 제일 중요한 것은 여전히 데이터입니다. Kaggle[1]에서 개최되는 인공지능 Challenge만 보더라도 2014년에 발표된 VGGNet[2]모델이 많은 우승을 하고있습니다. 따라서 데이터 가공이 인공지능 모델의 핵심이다라고 보더라도 과언이 아닙니다.

그림 1. Task에 따른 데이터 Annotation

하지만 데이터를 가공하는 작업은 높은 작업 비용을 요구하기 때문에 연구자들의 고민거리가 되어왔습니다. 위 그림 1의 Segmentation mask를 보더라도 수천 장의 이미지를 픽셀 수준의 라벨링을 해야하며, 더 나아가 “데이터 정제 > 데이터 라벨링 > 데이터 검수” 작업을 수행하여야 하기에 노동 집약적이며 비효율적입니다. 이를 극복하기 위하여, 이러한 문제의식은 아래와 같은 연구 분야로 확장되고 현재도 활발하게 연구가 수행중에 있습니다.

  • Semi-supervised Learning(SSL)
  • Unsupervised Learning (UL)
  • Active Learning(AL)
  • Domain Adaptation(DA)

이중에서 오늘 다루어 볼 연구는 Semi-supervised Learning 입니다. Semi-supervised Learning은 적은양의 가공 데이터를 활용하여 가공되지 않은 데이터를 효율적으로 학습에 활용하는 연구입니다.

Consistency regularization

Consistency regularization은 SSL에서 널리 쓰이는 방법론 중에 하나로써, 데이터에 perturbation을 주고 모델에 입력을 주어도 해당 perturbation이 데이터의 Critical 한 영향을 주는 것이 아니라면 일관성있는 결과를 보여줘야 한다는 전제로 SSL에 적용합니다. 이에 대한 대표적인 연구로 Π Model[3]과 Temporal ensemble[4]이 있습니다.

그림 2. Π Model과 Temporal ensemble

위 모델들과 같이 unlabeled x에 대하여 서로 다른 perturbation을 주고 이에 대한 Consistency를 맞추어 가는 방향으로 SSL은 학습이 이루어집니다. 오늘 소개드릴 CPS도 Consistency regularization 방법론을 사용합니다.

Cross Pseudo Supervision(CPS)

그림 3. (a) CPS (b) cross confidence consistency (c) mean teacher (d) pseudoSeg

본 논문에서 제안하는 Cross Pseudo Supervision(CPS, 위 그림3의 (a))는 아래 파이프라인을 통해 Consistency regularization을 수행합니다.

  1. 서로 다른 값으로 초기화된 두개의 세그멘테이션 모델을 생성

세그멘테이션 모델 f 에 서로 다르게 초기화된 가중치 1, θ2로 초기화를 하고 동일한 Augmentation 을 적용한 입력 X를 입력하여 두 개의 confidence map P1, P2 를 생성합니다.

  1. 동일한 아규멘테이션을 수행한 이미지 X를 두 모델에 입력하여 2개의 One-hot vector를 추론

추론된 P1, P2 에 Augmax를 주어 One-hot vector Y1, Y2를 생성합니다.

2. 두 모델의 결과를 교차(cross)하여 CELoss를 계산하여 역전파 수행

labeled data X에 대한 Loss Ls은 일반적인 CELoss를 통해 P1, P2의 Loss를 계산하며, CPS는 CELoss(P1, Y2), CELoss(P2, Y1)으로 Cross 하여 Lcps를 계산합니다.

결과적으로 최종 Loss L은 위의 식을 통해 산출됩니다. 여기서의 λ는 Cityscapes에서 6, PASCAL VOC에서 1.5를 사용하였을 때 가장 높은 성능을 보였다고 합니다.