728x90
아직 모르는 게 많지만, 작성한 코드입니다.
train 폴더 내에 label 이름별 폴더가 들어 있고, 그 내부에 이미지가 들어있는 구조입니다.
test 폴더 같은 계층에 존재한다는 가정하에 향후 이용을 위해
train 폴더 상위 폴더에서 os.path.join으로 train을 더하고 os.walk로 내부 폴더와 파일명을 가져옵니다.
for문을 돌려서 위치에 맞게 이미지 파일 경로와 라벨을 저장하고,
라벨은 학습을 위해 딕셔너리를 활용해 숫자로 변경합니다.
적절한 변형을 가한 후 window창을 띄워 확인합니다.
import cv2
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
class CustomDataset(Dataset):
def __init__(self, root_path, mode='train', transform=None):
self.file_paths = list()
self.labels = list()
self.mode = mode
self.transform = transform
root_path = os.path.join(root_path, mode)
for a, b, c in os.walk(root_path):
if c:
sub_dir_name = a.split(os.sep)[-1]
for file_name in c:
self.labels.append(self.label_dict[sub_dir_name])
self.file_paths.append(os.path.join(a,file_name))
else:
self.label_dict = dict(zip(b,range(len(b))))
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
label = self.labels[idx]
file_path = self.file_paths[idx]
# 경로에 한글이 있다면, cv2.imread가 에러가 납니다.
# 해결 방법은 복잡하니 임시 방편으로 아래와 같이
# replace로 상대경로를 이용해서 한글이 들어간 부분을 제외할 수 있습니다.
# file_path = file_path.replace('python 파일이 들어있는 경로', ".")
# image = cv2.imread(file_path)
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Image.open은 잘 작동합니다.
image = Image.open(file_path).convert("RGB")
image = np.array(image)
if self.transform:
# albumentation 일때만 해당
augmented = self.transform(image=image)
image = augmented['image']
return image, label
album_transform = A.Compose([
A.Resize(256, 256),
A.RandomCrop(224, 224),
A.HorizontalFlip(),
A.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
ToTensorV2()
])
root_path = '파일 경로'
dataset = CustomDataset(root_path=root_path, mode='train', transform=album_transform)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)
for images, labels in data_loader:
# tensor --> numpy
np_imgs = images.cpu().numpy()
np_labels = labels.cpu().numpy()
np_imgs = np_imgs[0].transpose(1,2,0)
cv2.imshow(str(np_labels[0]), np_imgs)
key = cv2.waitKey() # 누른 키 저장
cv2.destroyAllWindows()
if key == ord('q'): # 누른 키가 'q'이면 for문 종료
break
'IT > AI' 카테고리의 다른 글
머신러닝 기초 (0) | 2024.03.07 |
---|---|
인공 신경망 기초 (0) | 2024.03.06 |
[이미지] 번호판 생성 [Python] - cv2, PIL (2) | 2023.09.01 |
[이미지] Numpy로 이미지 처리 기초 [Python]-cv2 (2) | 2023.08.29 |