HAZEL

[ PyTorch 02. ] Data Preprocess , DataLoader, 데이터 시각화 본문

DATA ANALYSIS/ML & DL

[ PyTorch 02. ] Data Preprocess , DataLoader, 데이터 시각화

Rmsid01 2021. 4. 30. 18:28

 

import os
from glob import glob

import torch
from torchvision import datasets, transforms # dataset 예제 변환, transform 예제 변환을 줌

 

1. Data Loader 부르기

 

# batch 사이즈를 데이터 로드에 직접 넣어줌

batch_size = 32
test_batch_size =32

# train 용도 이므로 True , 로컬에 데이터가 없으면 download 받을 것이므로 True
# 데이터를 변경시켜줄것이므로, 아래처럼 처리 
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('dataset/', train = True , download= True,
                  transform = transforms.Compose([
                      transforms.ToTensor(),
                      transforms.Normalize(mean = (0.5,) , std = (0.5,))
                  ])),

    batch_size = batch_size,
    shuffle  = True) 

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('dataset', train = False,
                    transform = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5))
                    ])),
    batch_size = batch_size,
    shuffle = True)
    
    

 

* 파이토치 같은 경우, 텐서플로우와는 다르게, DataLoader 를 이용해서 불러주게된다. 

- DataLoader 을 이용하면, transform 을 처리해주어서 나오게 됨.  ( 굉장히 유용한 아이이다 )

 

* 위에서 불러온 데이터를 iteration 을 이용해서 한개만 불러오기

images, labels = next(iter(train_loader))

images.shape  # torch.Size([32, 1, 28, 28])

labels.shape # torch.Size([32])

 

데이터를 보면, shape의 모양이 tf 와 다름을 알 수 있다.

텐서플로우는 [ 배치사이즈, height , width , 채널 ] , 토치 [ 배치사이즈 , 채널(channel) , height, width ] 로 둘의 순서가 다르다.

 

2. 데이터 시각화

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

images[0].shape  #  torch.Size([1, 28, 28])

torch_image = torch.squeeze(images[0]) # 0 번째를 없애준다. 
torch_image.shape   #  torch.Size([28, 28])

# 토치를 넘파이화 해줌 
image = torch_image.numpy()
image.shape

label   # array (9)

plt.title(label)
plt.imshow(image, 'gray')
plt.show()

 

label 을 찍어보면, 9인 것을 확인 할 수있으며, 직접 시각화 해보면 이미지 역시 '9' 임을 볼 수 있다.