diff --git a/main.py b/main.py index abbeedf..7a2c878 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,6 @@ import builtins import math import os import random -import shutil import time import warnings @@ -159,16 +158,15 @@ def worker(gpu, ngpus_per_node, args): eval_dir = os.path.join(args.data, 'train') # center-crop augmentation - eval_augmentation = aug.moco_eval() - pre_train_dataset = pcl.loader.PreImager(pre_train_dir, eval_augmentation) + pre_train_dataset = pcl.loader.PreImager(pre_train_dir, aug.moco_eval()) train_dataset = pcl.loader.ImageFolderInstance( train_dir, - pcl.loader.TwoCropsTransform(eval_augmentation)) + pcl.loader.TwoCropsTransform(aug.only_rotate())) eval_dataset = pcl.loader.ImageFolderInstance( eval_dir, - eval_augmentation) + aug.moco_eval()) if args.distributed: pre_train_sampler = torch.utils.data.distributed.DistributedSampler(pre_train_dataset) diff --git a/pcl/loader.py b/pcl/loader.py index 28a6bd5..f9f3cda 100644 --- a/pcl/loader.py +++ b/pcl/loader.py @@ -2,7 +2,6 @@ from PIL import ImageFilter import random import torch.utils.data as tud import torchvision.datasets as datasets -from torchvision.io import image class PreImager(tud.Dataset): diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..d9166e9 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +torch +torchvision +Pillow +numpy +scikit-image +tqdm +matplotlib +scikit-learn \ No newline at end of file diff --git a/scripts/augmentation.py b/scripts/augmentation.py index 9dc0974..6da5f2a 100644 --- a/scripts/augmentation.py +++ b/scripts/augmentation.py @@ -1,6 +1,7 @@ import torchvision.transforms as transforms import pcl.loader + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) @@ -32,8 +33,16 @@ def moco_v1(): def moco_eval(): return transforms.Compose([ - transforms.Resize([512, 512]), + transforms.RandomResizedCrop(224, scale=(0.2, 1.)), # transforms.CenterCrop(512), transforms.ToTensor(), normalize ]) + +def only_rotate(): + return transforms.Compose([ + transforms.RandomResizedCrop(224, scale=(0.2, 1.)), + transforms.RandomRotation(90), + transforms.ToTensor(), + normalize + ]) diff --git a/scripts/data_process.py b/scripts/data_process.py index 6852f7b..d43b453 100644 --- a/scripts/data_process.py +++ b/scripts/data_process.py @@ -1,4 +1,6 @@ import torch + + def data_process(cluster_result, images, gpu): im_q = [] im_k = [] diff --git a/scripts/loss.py b/scripts/loss.py index 473f4bd..d112d8c 100644 --- a/scripts/loss.py +++ b/scripts/loss.py @@ -1,6 +1,9 @@ import torch import numpy as np import math +import random +from skimage.metrics import structural_similarity as SSIM +from skimage.metrics import peak_signal_noise_ratio as PSNR def proto_with_quality(output, target, output_proto, target_proto, criterion, acc_proto, images, num_cluster): # InfoNCE loss loss = criterion(output, target) @@ -18,8 +21,30 @@ def proto_with_quality(output, target, output_proto, target_proto, criterion, ac loss += loss_proto # Quality loss - mse = np.mean((images[1]/255.0-images[2]/255.0)**2) - psnr = 20 * math.log10(1 / math.sqrt(mse)) + im_q = torch.split(images[0], split_size_or_sections=1, dim=0) + im_q = [torch.squeeze(im, dim=0) for im in im_q] + im_k = torch.split(images[1], split_size_or_sections=1, dim=0) + im_k = [torch.squeeze(im, dim=0) for im in im_k] + + l_psnr = [] + l_ssim = [] + for i in range(min(len(im_q), len(im_k))): + k = im_k[i] + q_index = random.randint(0,i-2) + if q_index >= i: + q_index += 1 + q = im_q[q_index] + psnr_temp = PSNR(k,q) + if psnr_temp >= 50: + psnr_temp = 0 + elif psnr_temp <= 30: + psnr_temp = 1 + else: + psnr_temp = (50-psnr_temp)/20 + l_psnr.append(psnr_temp) + l_ssim.append(1-SSIM(k,q)) + + loss += np.mean(l_psnr)+np.mean(l_ssim) return loss