Pytorch_ImageFolder

ImageFolder를 통해 데이터 가져오기

모두를 위한 딥러닝 - 파이토치 강의 참고

  • 분류된 이미지 데이터셋이 준비되어있다면 ImageFolder를 통해서 데이터를 가져올 수 있다.

  • 예를들어, 3개의 클래스를 가지는 이미지셋을 준비했다면 다음과 같은 폴더 형태로 담아내면 된다.

1
2
3
4
5
6
7
8
9
10
11
12
/project
ㄴdata
ㄴdataset_name
ㄴtrain_data
ㄴclass1
ㄴclass2
ㄴclass3
ㄴtest_data
ㄴclass1
ㄴclass2
ㄴclass3
ㄴImageFolder_EX.py
  • 위와 같이 각 클래스별로 데이터를 준비했다면 ImageFolder 를 이용해 데이터를 불러온다.

  • MNIST나 CIFAR10의 데이터를 불러올 때처럼, tansform을 통해 텐서로 변환해 가져오게 된다.

1
2
3
4
5
6
7
8
9
10
11
trans = transforms.Compose([
transforms.ToTensor()
])

train_data = torchvision.datasets.ImageFolder(root='data/custom_data/train_data', transform=trans)

data_loader = DataLoader(dataset=train_data, batch_size=8, shuffle=True)

test_data = torchvision.datasets.ImageFolder(root='data/custom_data/test_data', transform=trans)

test_loader = DataLoader(dataset=test_data, batch_size=len(test_data))
  • 이후 간단한 CNN 모델을 만들어서 자신이 가진 데이터셋을 학습시킬 수 있다.
1
2
3
4
5
6
7
8
9
10
11
12
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(3, 6, 5),
nn.ReLU(),
nn.MaxPool2d(2),
)
.
.
.
# Full Code 참조
  • 이전까지는 Pytorch에서 바로 다운로드하고 불러올 수 있는 MNIST, CIFAR10과 같은 데이터셋을 이용했다.

  • 하지만, 이번에는 자신이 가진 고유의 데이터셋을 ImageFolder를 통해 불러오고 학습시키는 과정을 알 수 있었다.

Full Code

Full Code

Comments

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×