Pytorch_VGG_with_CIFAR10

CIFAR10 데이터셋에 VGG 모델 적용해보기

모두를 위한 딥러닝 - 파이토치 강의 참고

  • VGG는 3x224x224의 input을 기본으로 만들어져 있다. 따라서, 이미지의 크기가 다를경우 이미지 크기를 조정하거나 모델을 수정해서 사용할 수 있다.

  • 이번에는 이미지 사이즈는 그대로 사용하며 VGG모델을 조금 수정해서 적용해보자.

  • VGG의 모델이 어떻게 생성되는지는 이전 포스트 - VGG 모델 생성 살펴보기에서 확인할 수 있다.

  • 이번에는 custom convolution layer 을 만들고 이를 maye_layers함수를 통해 생성한 후, 이를 통한 모델을 만들고자 한다.

  • convolution layer 13개와 fully connected layer3개를 가지는 VGG13 configuration을 생성한다.

1
cfg = [32,32,'M', 64,64,128,128,128,'M',256,256,256,512,512,512,'M']
  • 그리고, VGG Source Code에서 가져온 VGG class를 다음과 같이 일부 수정한다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class VGG(nn.Module):
def __init__(self, features, num_classes=1000, init_weights=True):
super(VGG, self).__init__()
self.features = features
self.classifier = nn.Sequential(
nn.Linear(512*4*4, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes)
)
if init_weights:
self._initialize_weights()

def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x

def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
  • 위의 코드는 Source Code에서 아래와 같은 두 가지 사항을 변경했다

    • a. nn.AdaptiveAvgPool2d를 삭제했다.

      • 왜냐하면, features layer를 통과하고 우리의 이미지 사이즈는 4x4이므로 nn.AdaptiveAvgPool2d((7, 7))을 통해 사이즈를 키워줄 필요가 없었다.
    • b. classifier layer의 fully connected layer의 input size가 (batch size x 4 x 4) 로 수정되었다.

      • 왜냐하면, 위와 마찬가지로 features를 통과한 이미지의 사이즈가 4x4이기 때문이다.
  • 이처럼, input size가 다를 경우 기존의 모델을 수정해서 사용할 수 있다.

  • 중요한점은, layer를 통과하면서 image의 size가 어떻게 변하는지를 알고 fully connected layer까지 수정해야 size에러가 발생하지 않는다는 점이다.

  • 이를 위해서, 직접 공식을 통해 계산할 수도 있으며 아래와 같이 forward함수에서 shape를 프린트하며 확인할 수도 있다.

1
2
3
4
5
6
7
8
9
10
11
12
13
class VGG(nn.Module):
.
.
.
def forward(self, x):
x = self.features(x)
print(x.shape) # features layer를 통과하고 shape을 확인해보자.
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
.
.
.
  • 이처럼, 사이즈가 다른 CIFAR10 이미지를 수정한 VGG모델에 넣어서 학습시켜볼 수 있었다.

Full Code

Full Code

Comments

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×