2023-05-17 09:21:06 +08:00
|
|
|
import torch
|
|
|
|
import numpy as np
|
|
|
|
import math
|
2023-05-17 11:02:03 +08:00
|
|
|
import random
|
|
|
|
from skimage.metrics import structural_similarity as SSIM
|
|
|
|
from skimage.metrics import peak_signal_noise_ratio as PSNR
|
2023-05-17 09:21:06 +08:00
|
|
|
def proto_with_quality(output, target, output_proto, target_proto, criterion, acc_proto, images, num_cluster):
|
|
|
|
# InfoNCE loss
|
|
|
|
loss = criterion(output, target)
|
|
|
|
|
|
|
|
# ProtoNCE loss
|
|
|
|
if output_proto is not None:
|
|
|
|
loss_proto = 0
|
|
|
|
for proto_out, proto_target in zip(output_proto, target_proto):
|
|
|
|
loss_proto += criterion(proto_out, proto_target)
|
|
|
|
accp = accuracy(proto_out, proto_target)[0]
|
|
|
|
acc_proto.update(accp[0], images[0].size(0))
|
|
|
|
|
|
|
|
# average loss across all sets of prototypes
|
|
|
|
loss_proto /= len(num_cluster)
|
|
|
|
loss += loss_proto
|
|
|
|
|
|
|
|
# Quality loss
|
2023-05-17 11:02:03 +08:00
|
|
|
im_q = torch.split(images[0], split_size_or_sections=1, dim=0)
|
|
|
|
im_q = [torch.squeeze(im, dim=0) for im in im_q]
|
|
|
|
im_k = torch.split(images[1], split_size_or_sections=1, dim=0)
|
|
|
|
im_k = [torch.squeeze(im, dim=0) for im in im_k]
|
|
|
|
|
|
|
|
l_psnr = []
|
|
|
|
l_ssim = []
|
|
|
|
for i in range(min(len(im_q), len(im_k))):
|
|
|
|
k = im_k[i]
|
|
|
|
q_index = random.randint(0,i-2)
|
|
|
|
if q_index >= i:
|
|
|
|
q_index += 1
|
|
|
|
q = im_q[q_index]
|
|
|
|
psnr_temp = PSNR(k,q)
|
|
|
|
if psnr_temp >= 50:
|
|
|
|
psnr_temp = 0
|
|
|
|
elif psnr_temp <= 30:
|
|
|
|
psnr_temp = 1
|
|
|
|
else:
|
|
|
|
psnr_temp = (50-psnr_temp)/20
|
|
|
|
l_psnr.append(psnr_temp)
|
|
|
|
l_ssim.append(1-SSIM(k,q))
|
|
|
|
|
|
|
|
loss += np.mean(l_psnr)+np.mean(l_ssim)
|
2023-05-17 09:21:06 +08:00
|
|
|
|
|
|
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
|
|
def accuracy(output, target, topk=(1,)):
|
|
|
|
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
|
|
|
with torch.no_grad():
|
|
|
|
maxk = max(topk)
|
|
|
|
batch_size = target.size(0)
|
|
|
|
|
|
|
|
_, pred = output.topk(maxk, 1, True, True)
|
|
|
|
pred = pred.t()
|
|
|
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
|
|
|
|
|
|
res = []
|
|
|
|
for k in topk:
|
|
|
|
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
|
|
|
res.append(correct_k.mul_(100.0 / batch_size))
|
|
|
|
return res
|