Changes
This commit is contained in:
parent
ce7e190d5d
commit
298a928632
8
main.py
8
main.py
@ -2,7 +2,6 @@ import builtins
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import shutil
|
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
@ -159,16 +158,15 @@ def worker(gpu, ngpus_per_node, args):
|
|||||||
eval_dir = os.path.join(args.data, 'train')
|
eval_dir = os.path.join(args.data, 'train')
|
||||||
|
|
||||||
# center-crop augmentation
|
# center-crop augmentation
|
||||||
eval_augmentation = aug.moco_eval()
|
|
||||||
|
|
||||||
pre_train_dataset = pcl.loader.PreImager(pre_train_dir, eval_augmentation)
|
pre_train_dataset = pcl.loader.PreImager(pre_train_dir, aug.moco_eval())
|
||||||
|
|
||||||
train_dataset = pcl.loader.ImageFolderInstance(
|
train_dataset = pcl.loader.ImageFolderInstance(
|
||||||
train_dir,
|
train_dir,
|
||||||
pcl.loader.TwoCropsTransform(eval_augmentation))
|
pcl.loader.TwoCropsTransform(aug.only_rotate()))
|
||||||
eval_dataset = pcl.loader.ImageFolderInstance(
|
eval_dataset = pcl.loader.ImageFolderInstance(
|
||||||
eval_dir,
|
eval_dir,
|
||||||
eval_augmentation)
|
aug.moco_eval())
|
||||||
|
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
pre_train_sampler = torch.utils.data.distributed.DistributedSampler(pre_train_dataset)
|
pre_train_sampler = torch.utils.data.distributed.DistributedSampler(pre_train_dataset)
|
||||||
|
@ -2,7 +2,6 @@ from PIL import ImageFilter
|
|||||||
import random
|
import random
|
||||||
import torch.utils.data as tud
|
import torch.utils.data as tud
|
||||||
import torchvision.datasets as datasets
|
import torchvision.datasets as datasets
|
||||||
from torchvision.io import image
|
|
||||||
|
|
||||||
|
|
||||||
class PreImager(tud.Dataset):
|
class PreImager(tud.Dataset):
|
||||||
|
8
requirements.txt
Normal file
8
requirements.txt
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
torch
|
||||||
|
torchvision
|
||||||
|
Pillow
|
||||||
|
numpy
|
||||||
|
scikit-image
|
||||||
|
tqdm
|
||||||
|
matplotlib
|
||||||
|
scikit-learn
|
@ -1,6 +1,7 @@
|
|||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
import pcl.loader
|
import pcl.loader
|
||||||
|
|
||||||
|
|
||||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||||
std=[0.229, 0.224, 0.225])
|
std=[0.229, 0.224, 0.225])
|
||||||
|
|
||||||
@ -32,8 +33,16 @@ def moco_v1():
|
|||||||
|
|
||||||
def moco_eval():
|
def moco_eval():
|
||||||
return transforms.Compose([
|
return transforms.Compose([
|
||||||
transforms.Resize([512, 512]),
|
transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
|
||||||
# transforms.CenterCrop(512),
|
# transforms.CenterCrop(512),
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(),
|
||||||
normalize
|
normalize
|
||||||
])
|
])
|
||||||
|
|
||||||
|
def only_rotate():
|
||||||
|
return transforms.Compose([
|
||||||
|
transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
|
||||||
|
transforms.RandomRotation(90),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
normalize
|
||||||
|
])
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def data_process(cluster_result, images, gpu):
|
def data_process(cluster_result, images, gpu):
|
||||||
im_q = []
|
im_q = []
|
||||||
im_k = []
|
im_k = []
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import math
|
import math
|
||||||
|
import random
|
||||||
|
from skimage.metrics import structural_similarity as SSIM
|
||||||
|
from skimage.metrics import peak_signal_noise_ratio as PSNR
|
||||||
def proto_with_quality(output, target, output_proto, target_proto, criterion, acc_proto, images, num_cluster):
|
def proto_with_quality(output, target, output_proto, target_proto, criterion, acc_proto, images, num_cluster):
|
||||||
# InfoNCE loss
|
# InfoNCE loss
|
||||||
loss = criterion(output, target)
|
loss = criterion(output, target)
|
||||||
@ -18,8 +21,30 @@ def proto_with_quality(output, target, output_proto, target_proto, criterion, ac
|
|||||||
loss += loss_proto
|
loss += loss_proto
|
||||||
|
|
||||||
# Quality loss
|
# Quality loss
|
||||||
mse = np.mean((images[1]/255.0-images[2]/255.0)**2)
|
im_q = torch.split(images[0], split_size_or_sections=1, dim=0)
|
||||||
psnr = 20 * math.log10(1 / math.sqrt(mse))
|
im_q = [torch.squeeze(im, dim=0) for im in im_q]
|
||||||
|
im_k = torch.split(images[1], split_size_or_sections=1, dim=0)
|
||||||
|
im_k = [torch.squeeze(im, dim=0) for im in im_k]
|
||||||
|
|
||||||
|
l_psnr = []
|
||||||
|
l_ssim = []
|
||||||
|
for i in range(min(len(im_q), len(im_k))):
|
||||||
|
k = im_k[i]
|
||||||
|
q_index = random.randint(0,i-2)
|
||||||
|
if q_index >= i:
|
||||||
|
q_index += 1
|
||||||
|
q = im_q[q_index]
|
||||||
|
psnr_temp = PSNR(k,q)
|
||||||
|
if psnr_temp >= 50:
|
||||||
|
psnr_temp = 0
|
||||||
|
elif psnr_temp <= 30:
|
||||||
|
psnr_temp = 1
|
||||||
|
else:
|
||||||
|
psnr_temp = (50-psnr_temp)/20
|
||||||
|
l_psnr.append(psnr_temp)
|
||||||
|
l_ssim.append(1-SSIM(k,q))
|
||||||
|
|
||||||
|
loss += np.mean(l_psnr)+np.mean(l_ssim)
|
||||||
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
Loading…
x
Reference in New Issue
Block a user