This commit is contained in:
Fiber 2023-05-17 11:02:03 +08:00
parent ce7e190d5d
commit 298a928632
6 changed files with 50 additions and 9 deletions

View File

@ -2,7 +2,6 @@ import builtins
import math import math
import os import os
import random import random
import shutil
import time import time
import warnings import warnings
@ -159,16 +158,15 @@ def worker(gpu, ngpus_per_node, args):
eval_dir = os.path.join(args.data, 'train') eval_dir = os.path.join(args.data, 'train')
# center-crop augmentation # center-crop augmentation
eval_augmentation = aug.moco_eval()
pre_train_dataset = pcl.loader.PreImager(pre_train_dir, eval_augmentation) pre_train_dataset = pcl.loader.PreImager(pre_train_dir, aug.moco_eval())
train_dataset = pcl.loader.ImageFolderInstance( train_dataset = pcl.loader.ImageFolderInstance(
train_dir, train_dir,
pcl.loader.TwoCropsTransform(eval_augmentation)) pcl.loader.TwoCropsTransform(aug.only_rotate()))
eval_dataset = pcl.loader.ImageFolderInstance( eval_dataset = pcl.loader.ImageFolderInstance(
eval_dir, eval_dir,
eval_augmentation) aug.moco_eval())
if args.distributed: if args.distributed:
pre_train_sampler = torch.utils.data.distributed.DistributedSampler(pre_train_dataset) pre_train_sampler = torch.utils.data.distributed.DistributedSampler(pre_train_dataset)

View File

@ -2,7 +2,6 @@ from PIL import ImageFilter
import random import random
import torch.utils.data as tud import torch.utils.data as tud
import torchvision.datasets as datasets import torchvision.datasets as datasets
from torchvision.io import image
class PreImager(tud.Dataset): class PreImager(tud.Dataset):

8
requirements.txt Normal file
View File

@ -0,0 +1,8 @@
torch
torchvision
Pillow
numpy
scikit-image
tqdm
matplotlib
scikit-learn

View File

@ -1,6 +1,7 @@
import torchvision.transforms as transforms import torchvision.transforms as transforms
import pcl.loader import pcl.loader
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]) std=[0.229, 0.224, 0.225])
@ -32,8 +33,16 @@ def moco_v1():
def moco_eval(): def moco_eval():
return transforms.Compose([ return transforms.Compose([
transforms.Resize([512, 512]), transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
# transforms.CenterCrop(512), # transforms.CenterCrop(512),
transforms.ToTensor(), transforms.ToTensor(),
normalize normalize
]) ])
def only_rotate():
return transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
transforms.RandomRotation(90),
transforms.ToTensor(),
normalize
])

View File

@ -1,4 +1,6 @@
import torch import torch
def data_process(cluster_result, images, gpu): def data_process(cluster_result, images, gpu):
im_q = [] im_q = []
im_k = [] im_k = []

View File

@ -1,6 +1,9 @@
import torch import torch
import numpy as np import numpy as np
import math import math
import random
from skimage.metrics import structural_similarity as SSIM
from skimage.metrics import peak_signal_noise_ratio as PSNR
def proto_with_quality(output, target, output_proto, target_proto, criterion, acc_proto, images, num_cluster): def proto_with_quality(output, target, output_proto, target_proto, criterion, acc_proto, images, num_cluster):
# InfoNCE loss # InfoNCE loss
loss = criterion(output, target) loss = criterion(output, target)
@ -18,8 +21,30 @@ def proto_with_quality(output, target, output_proto, target_proto, criterion, ac
loss += loss_proto loss += loss_proto
# Quality loss # Quality loss
mse = np.mean((images[1]/255.0-images[2]/255.0)**2) im_q = torch.split(images[0], split_size_or_sections=1, dim=0)
psnr = 20 * math.log10(1 / math.sqrt(mse)) im_q = [torch.squeeze(im, dim=0) for im in im_q]
im_k = torch.split(images[1], split_size_or_sections=1, dim=0)
im_k = [torch.squeeze(im, dim=0) for im in im_k]
l_psnr = []
l_ssim = []
for i in range(min(len(im_q), len(im_k))):
k = im_k[i]
q_index = random.randint(0,i-2)
if q_index >= i:
q_index += 1
q = im_q[q_index]
psnr_temp = PSNR(k,q)
if psnr_temp >= 50:
psnr_temp = 0
elif psnr_temp <= 30:
psnr_temp = 1
else:
psnr_temp = (50-psnr_temp)/20
l_psnr.append(psnr_temp)
l_ssim.append(1-SSIM(k,q))
loss += np.mean(l_psnr)+np.mean(l_ssim)
return loss return loss