Pytorch_MNIST_CNN

Pytorch MNIST에 CNN모델 만들기

모두를 위한 딥러닝 - 파이토치 강의 참고

  • Convolution layer를 활용해서 MNIST이미지를 학습시킬 CNN 모델을 만들어보자.

  • 아래 CNN 구조와 비슷한 모델을 설계해본다.

  • CNN 모델

  • torch.nn.Sequential을 통해 위의 그림에서 나타내는 각 layer와 Fully Connected layer를 표현할 수 있다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class CNN(torch.nn.Module):

def __init__(self):
super(CNN, self).__init__()
self.drop_prob = 0.5

self.layer1 = torch.nn.Sequential(
torch.nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
self.layer2 = torch.nn.Sequential(
torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
self.layer3 = torch.nn.Sequential(
torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
self.fc1 = torch.nn.Linear(3*3*128, 625, bias=True)
torch.nn.init.xavier_uniform_(self.fc1.weight)

self.layer4 = torch.nn.Sequential(
self.fc1,
torch.nn.ReLU(),
torch.nn.Dropout(p=self.drop_prob)
)

self.fc2 = torch.nn.Linear(625, 10, bias=True)
torch.nn.init.xavier_uniform_(self.fc2.weight)

def forward(self, x):
output = self.layer1(x)
# print(output.shape)
output = self.layer2(output)
# print(output.shape)
output = self.layer3(output)
# print(output.shape)
output = output.view(output.size(0), -1)
output = self.layer4(output)
output = self.fc2(output)
return output
  • Conv2dMaxPool2d을 지닌 layer를 통과하면서 output의 shape은 게속 변하게 된다.

  • 이러한 과정에서 Fully Connected layer를 지날때는 이전 layer에서 나온 output의 shape를 알아야한다.

  • 이때 이전 포스트에서 알아본 공식을 통해 output shape을 계산할 수 있다.

  • 또한, forward를 진행하면서 각각의 layer를 통과한 output의 shape를 출력하면서 확인할 수도 있다.

  • CNN 모델을 배우기 전에는 선형함수만을 통해 MNIST이미지를 학습했다면 이번 과정에서는 지금까지 배운 Dropout, weight initialization, Convolution layer 등을 사용하는 모델을 만들 수 있었다.

  • 학습과 평가는 기존 MNIST 학습 코드와 유사하고, Full Code 링크를 통해 확인할 수 있다.

Full Code

Full Code

Comments

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×