You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

69 lines
2.0 KiB

from PIL import ImageFilter
import random
import torch.utils.data as tud
import torchvision.datasets as datasets
from torchvision.io import image
class PreImager(tud.Dataset):
def __init__(self, samples_dir, aug):
data_meta = datasets.ImageFolder(samples_dir)
images = data_meta.imgs
self.classes = data_meta.classes
self.class_number = len(self.classes)
self.class_to_index = data_meta.class_to_idx
img_class = [[] for i in range(self.class_number)]
for img in images:
img_class[img[1]].append(img[0])
lens = [len(c) for c in img_class]
self.length = min(lens)
self.samples_dir = samples_dir
self.aug = aug
self.images = img_class
self.loader = data_meta.loader
def __getitem__(self, index):
imgs = []
for i in range(self.class_number):
img = self.loader(self.images[i][index])
out = self.aug(img)
imgs.append(out)
return imgs, index
def __len__(self):
return self.length
class TwoCropsTransform:
"""Take two random crops of one image as the query and key."""
def __init__(self, base_transform):
self.base_transform = base_transform
def __call__(self, x):
q = self.base_transform(x)
k = self.base_transform(x)
return [q, k]
class GaussianBlur(object):
"""Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
def __init__(self, sigma=[.1, 2.]):
self.sigma = sigma
def __call__(self, x):
sigma = random.uniform(self.sigma[0], self.sigma[1])
x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
return x
class ImageFolderInstance(datasets.ImageFolder):
def __getitem__(self, index):
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
return sample, index