Notice
Recent Posts
Recent Comments
Link
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
Tags
- leetcode
- SQL코테
- t분포
- airflow
- CASE
- inner join
- update
- LSTM
- sigmoid
- 카이제곱분포
- 코딩테스트
- nlp논문
- MySQL
- NLP
- 자연어 논문
- 자연어처리
- 표준편차
- 설명의무
- 그룹바이
- Window Function
- torch
- GRU
- 서브쿼리
- SQL 날짜 데이터
- HackerRank
- sql
- Statistics
- 자연어 논문 리뷰
- 짝수
- 논문리뷰
Archives
- Today
- Total
HAZEL
[ PyTorch 02. ] Data Preprocess , DataLoader, 데이터 시각화 본문
DATA ANALYSIS/ML & DL
[ PyTorch 02. ] Data Preprocess , DataLoader, 데이터 시각화
Rmsid01 2021. 4. 30. 18:28
import os
from glob import glob
import torch
from torchvision import datasets, transforms # dataset 예제 변환, transform 예제 변환을 줌
1. Data Loader 부르기
# batch 사이즈를 데이터 로드에 직접 넣어줌
batch_size = 32
test_batch_size =32
# train 용도 이므로 True , 로컬에 데이터가 없으면 download 받을 것이므로 True
# 데이터를 변경시켜줄것이므로, 아래처럼 처리
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('dataset/', train = True , download= True,
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean = (0.5,) , std = (0.5,))
])),
batch_size = batch_size,
shuffle = True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('dataset', train = False,
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5))
])),
batch_size = batch_size,
shuffle = True)
* 파이토치 같은 경우, 텐서플로우와는 다르게, DataLoader 를 이용해서 불러주게된다.
- DataLoader 을 이용하면, transform 을 처리해주어서 나오게 됨. ( 굉장히 유용한 아이이다 )
* 위에서 불러온 데이터를 iteration 을 이용해서 한개만 불러오기
images, labels = next(iter(train_loader))
images.shape # torch.Size([32, 1, 28, 28])
labels.shape # torch.Size([32])
데이터를 보면, shape의 모양이 tf 와 다름을 알 수 있다.
텐서플로우는 [ 배치사이즈, height , width , 채널 ] , 토치 [ 배치사이즈 , 채널(channel) , height, width ] 로 둘의 순서가 다르다.
2. 데이터 시각화
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
images[0].shape # torch.Size([1, 28, 28])
torch_image = torch.squeeze(images[0]) # 0 번째를 없애준다.
torch_image.shape # torch.Size([28, 28])
# 토치를 넘파이화 해줌
image = torch_image.numpy()
image.shape
label # array (9)
plt.title(label)
plt.imshow(image, 'gray')
plt.show()
label 을 찍어보면, 9인 것을 확인 할 수있으며, 직접 시각화 해보면 이미지 역시 '9' 임을 볼 수 있다.