개요
배치 정규화(Batch Normalization)은 2015년에 공개된 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift 논문에서 제안된 개념입니다. 배치 정규화는 학습 속도 개선, 오버피팅 억제 등 여러 장점이 있어 현재 대부분의 딥러닝 논문을 보면 항상 사용되는 알고리즘 입니다.
Internal Covariate Shift(ICS)
모델 학습은 각 층에서는 입력값을 받아 가중치와 편향값을 더하고 활성화 함수를 거쳐 출력값을 만들어 다음 층으로 전달합니다. 그러나 이 과정에서 입력값과 출력값의 분포가 계속해서 달라지는 문제가 발생합니다. 이를 Internal Covariate Shift(ICS) 라고 합니다.
원인
ICS는 다양한 원인으로 발생합니다.
첫 번째 원인으로 비선형적인 활성화 함수가 있습니다. 활성화 함수는 비선형적 특징을 주어 복잡한 문제를 학습할 수 있도록 도와준다는 장점이 있지만, 그 과정에서 값을 특정 범위로 압축시키는 문제가 있습니다. 이 과정에서 기울기 소실/폭발(Gradient Vanishing/Exploding) 문제가 발생하여 학습 안정성이 크게 떨어지고 정보가 소실될 수 있습니다.
두 번째 원인으로는 가중치 초기값, 학습률 문제입니다. 초기 가중치 값이나 학습률이 지나치게 크면 급격한 변화로 학습의 안정성이 크게 떨어지게 됩니다.
배치 정규화(Batch Normalization)
논문에서는 위 문제를 해결하기 위한 방법으로 평균과 분산을 조정하는 배치 정규화를 제안했습니다. 기존에도 최초 데이터 입력시 Standardization, Normalization와 같은 방법을 적용하여 데이터의 크기를 일정하게 조정하는 방법이 있었지만, 배치 정규화는 학습 과정에서 신경망 내부에 포함되어 정규화를 진행한다는 차이가 있습니다. 이는 학습 과정에서 발생하는 ICS문제를 해결하는데 도움을 줍니다.
계산 과정
위 수식은 배치 정규화 계산 과정입니다.
\[\mu_{\mathcal{B}} \leftarrow \frac{1}{m}\sum_{i=1}^{m} x_i\] \[\sigma^2_{\mathcal{B}} \leftarrow \frac{1}{m}\sum_{i=1}^{m} (x_i-\mu_{\mathcal{B}})^2\] \[\hat{x}i \leftarrow \frac{x_i-\mu{\mathcal{B}}}{\sqrt{\sigma^2_{\mathcal{B}}+\epsilon}}\]미니배치의 평균과 분산을 구해 정규화를 진행합니다. 정규화 결과로 배치는 평균이 0, 분산이 1로 정규화 됩니다. 정규화 과정에서 분모에 있는 \(\epsilon\) 값은 0으로 나눠지는걸 방지하기 위한 매우 작은 값 입니다. 파이토치의 BN은 1e-5가 디폴트 값 입니다.
\[y_i \leftarrow \gamma \hat{x}_i + \beta\]정규화 된 배치 데이터에 \(\gamma\)를 곱하고 \(\beta\)를 더해 최종 출력 값을 만듭니다. 이떄 \(\gamma\)와 \(\beta\)는 학습 가능한 변수입니다. \(\gamma\)와 \(\beta\)값을 조정하여 네트워크가 필요한 분포로 조정을 진행합니다.
추론
배치 정규화는 학습 모드와 추론 모드가 나눠져 있습니다. 배치 정규화는 학습 단계에서는 들어오는 미니배치의 평균과 분산을 구해 정규화를 진행합니다. 그러나 추론 단계에서는 들어오는 배치의 크기가 1 또는 매우 작을 수 있어서 학습때와 마찬가지로 정규화를 진행하면 매우 불안정해지는 문제가 있습니다.
\[\text{running_mean} \leftarrow (1-\alpha)\cdot \text{running_mean} + \alpha\cdot \mu\mathcal{B}\] \[\text{running_var} \leftarrow (1-\alpha)\cdot \text{running_var} + \alpha\cdot \sigma^2\mathcal{B}\]그래서 학습 모드에서 미니배치에서 계산된 평균과 분산을 지속적으로 갱신하여 running_mean과 running_var을 저장하고, 추론 모드에서 활용합니다. 수식의 \(\alpha\)는 최근 배치를 어느정도 반영할지 정하는 모멘텀 계수 입니다.
\[\hat x=\frac{x-\text{running_mean}}{\sqrt{\text{running_var}+\epsilon}},\quad y=\gamma\hat x+\beta\]이렇게 계산된 값을 추론 모드에서 다음과 같이 활용하여 항상 고정된 값으로 정규화를 안정적으로 진행하게 됩니다.