复现PCL,重构结构,修改数据输入方式,取消数据增强,验证对比学习对图像质量的评估
This commit is contained in:
parent
d1b868bbd4
commit
57017a9e0e
365
main.py
Normal file
365
main.py
Normal file
@ -0,0 +1,365 @@
|
||||
import builtins
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.distributed as dist
|
||||
import torch.optim
|
||||
import torch.multiprocessing as mp
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.datasets as datasets
|
||||
import torchvision.models as models
|
||||
|
||||
from scripts.parser import parser
|
||||
from scripts.meter import AverageMeter, ProgressMeter
|
||||
import scripts.augmentation as aug
|
||||
import scripts.momentum as momentum
|
||||
import scripts.clustering as clustering
|
||||
import pcl.builder
|
||||
import pcl.loader
|
||||
|
||||
|
||||
def main_loader():
|
||||
args = parser().parse_args()
|
||||
|
||||
if args.seed is not None:
|
||||
random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
warnings.warn('You have chosen to seed training. '
|
||||
'This will turn on the CUDNN deterministic setting, '
|
||||
'which can slow down your training considerably! '
|
||||
'You may see unexpected behavior when restarting '
|
||||
'from checkpoints.')
|
||||
|
||||
if args.gpu is not None:
|
||||
warnings.warn('You have chosen a specific GPU. This will completely '
|
||||
'disable data parallelism.')
|
||||
|
||||
if args.dist_url == "env://" and args.world_size == -1:
|
||||
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||
|
||||
args.distributed = args.world_size > 1 or args.multiprocessing_distributed
|
||||
|
||||
args.num_cluster = args.num_cluster.split(',')
|
||||
|
||||
if not os.path.exists(args.exp_dir):
|
||||
os.mkdir(args.exp_dir)
|
||||
|
||||
ngpus_per_node = torch.cuda.device_count()
|
||||
if args.multiprocessing_distributed:
|
||||
# Since we have ngpus_per_node processes per node, the total world_size
|
||||
# needs to be adjusted accordingly
|
||||
args.world_size = ngpus_per_node * args.world_size
|
||||
# Use torch.multiprocessing.spawn to launch distributed processes: the
|
||||
# main_worker process function
|
||||
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
||||
else:
|
||||
# Simply call main_worker function
|
||||
main_worker(args.gpu, ngpus_per_node, args)
|
||||
|
||||
|
||||
def main_worker(gpu, ngpus_per_node, args):
|
||||
args.gpu = gpu
|
||||
|
||||
if args.gpu is not None:
|
||||
print("Use GPU: {} for training".format(args.gpu))
|
||||
|
||||
# suppress printing if not master
|
||||
if args.multiprocessing_distributed and args.gpu != 0:
|
||||
def print_pass():
|
||||
pass
|
||||
builtins.print = print_pass
|
||||
|
||||
if args.distributed:
|
||||
if args.dist_url == "env://" and args.rank == -1:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
if args.multiprocessing_distributed:
|
||||
# For multiprocessing distributed training, rank needs to be the
|
||||
# global rank among all the processes
|
||||
args.rank = args.rank * ngpus_per_node + gpu
|
||||
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
|
||||
# create model
|
||||
print("=> create model '{}'".format(args.arch))
|
||||
model = pcl.builder.MoCo(
|
||||
models.__dict__[args.arch],
|
||||
args.low_dim, args.pcl_r, args.moco_m, args.temperature, args.mlp)
|
||||
# print(model)
|
||||
|
||||
if args.distributed:
|
||||
# For multiprocessing distributed, DistributedDataParallel constructor
|
||||
# should always set the single device scope, otherwise,
|
||||
# DistributedDataParallel will use all available devices.
|
||||
if args.gpu is not None:
|
||||
torch.cuda.set_device(args.gpu)
|
||||
model.cuda(args.gpu)
|
||||
# When using a single GPU per process and per
|
||||
# DistributedDataParallel, we need to divide the batch size
|
||||
# ourselves based on the total number of GPUs we have
|
||||
args.batch_size = int(args.batch_size / ngpus_per_node)
|
||||
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
||||
else:
|
||||
model.cuda()
|
||||
# DistributedDataParallel will divide and allocate batch_size to all
|
||||
# available GPUs if device_ids are not set
|
||||
model = torch.nn.parallel.DistributedDataParallel(model)
|
||||
elif args.gpu is not None:
|
||||
torch.cuda.set_device(args.gpu)
|
||||
model.cuda(args.gpu)
|
||||
# comment out the following line for debugging
|
||||
raise NotImplementedError("Only DistributedDataParallel is supported.")
|
||||
else:
|
||||
# AllGather implementation (batch shuffle, queue update, etc.) in
|
||||
# this code only supports DistributedDataParallel.
|
||||
raise NotImplementedError("Only DistributedDataParallel is supported.")
|
||||
|
||||
# define loss function (criterion) and optimizer
|
||||
criterion = nn.CrossEntropyLoss().cuda(args.gpu)
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), args.lr,
|
||||
momentum=args.momentum,
|
||||
weight_decay=args.weight_decay)
|
||||
|
||||
# optionally resume from a checkpoint
|
||||
if args.resume:
|
||||
if os.path.isfile(args.resume):
|
||||
print("=> loading checkpoint '{}'".format(args.resume))
|
||||
if args.gpu is None:
|
||||
checkpoint = torch.load(args.resume)
|
||||
else:
|
||||
# Map model to be loaded to specified single gpu.
|
||||
loc = 'cuda:{}'.format(args.gpu)
|
||||
checkpoint = torch.load(args.resume, map_location=loc)
|
||||
args.start_epoch = checkpoint['epoch']
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
print("=> loaded checkpoint '{}' (epoch {})"
|
||||
.format(args.resume, checkpoint['epoch']))
|
||||
else:
|
||||
print("=> no checkpoint found at '{}'".format(args.resume))
|
||||
|
||||
cudnn.benchmark = True
|
||||
cudnn.deterministic = True
|
||||
|
||||
# Data loading code
|
||||
pre_train_dir = os.path.join(args.data, 'train')
|
||||
train_dir = os.path.join(args.data, 'train')
|
||||
|
||||
if args.aug_plus:
|
||||
# MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709
|
||||
augmentation = aug.moco_v2()
|
||||
else:
|
||||
# MoCo v1's aug: same as InstDisc https://arxiv.org/abs/1805.01978
|
||||
augmentation = aug.moco_v1()
|
||||
|
||||
# center-crop augmentation
|
||||
eval_augmentation = aug.moco_eval()
|
||||
|
||||
pre_train_dataset = pcl.loader.PreImager(pre_train_dir, eval_augmentation)
|
||||
|
||||
train_dataset = pcl.loader.ImageFolderInstance(train_dir, eval_augmentation)
|
||||
eval_dataset = pcl.loader.ImageFolderInstance(
|
||||
train_dir,
|
||||
eval_augmentation)
|
||||
|
||||
if args.distributed:
|
||||
pre_train_sampler = torch.utils.data.distributed.DistributedSampler(pre_train_dataset)
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
||||
eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset, shuffle=False)
|
||||
else:
|
||||
pre_train_sampler = None
|
||||
train_sampler = None
|
||||
eval_sampler = None
|
||||
|
||||
if args.batch_size//pre_train_dataset.class_number < 2:
|
||||
raise NotImplementedError("Batch size must above double number of classes.")
|
||||
|
||||
pre_train_loader = torch.utils.data.DataLoader(
|
||||
pre_train_dataset,
|
||||
batch_size=args.batch_size//pre_train_dataset.class_number,
|
||||
shuffle=(pre_train_sampler is None),
|
||||
num_workers=args.workers,
|
||||
pin_memory=True,
|
||||
sampler=pre_train_sampler,
|
||||
drop_last=True)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
|
||||
num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
|
||||
|
||||
# dataloader for center-cropped images, use larger batch size to increase speed
|
||||
eval_loader = torch.utils.data.DataLoader(
|
||||
eval_dataset, batch_size=args.batch_size * 5, shuffle=False,
|
||||
sampler=eval_sampler, num_workers=args.workers, pin_memory=True)
|
||||
|
||||
print("=> Pre-train")
|
||||
# main loop
|
||||
for epoch in range(args.start_epoch, args.epochs):
|
||||
|
||||
cluster_result = None
|
||||
if epoch >= args.warmup_epoch:
|
||||
# compute momentum features for center-cropped images
|
||||
features = momentum.compute_features(eval_loader, model, args)
|
||||
|
||||
# placeholder for clustering result
|
||||
cluster_result = {'im2cluster': [], 'centroids': [], 'density': []}
|
||||
for num_cluster in args.num_cluster:
|
||||
cluster_result['im2cluster'].append(torch.zeros(len(eval_dataset), dtype=torch.long).cuda())
|
||||
cluster_result['centroids'].append(torch.zeros(int(num_cluster), args.low_dim).cuda())
|
||||
cluster_result['density'].append(torch.zeros(int(num_cluster)).cuda())
|
||||
|
||||
if args.gpu == 0:
|
||||
features[
|
||||
torch.norm(features, dim=1) > 1.5] /= 2 # account for the few samples that are computed twice
|
||||
features = features.numpy()
|
||||
cluster_result = clustering.run_kmeans(features, args) # run kmeans clustering on master node
|
||||
# save the clustering result
|
||||
# torch.save(cluster_result,os.path.join(args.exp_dir, 'clusters_%d'%epoch))
|
||||
|
||||
dist.barrier()
|
||||
# broadcast clustering result
|
||||
for _k, data_list in cluster_result.items():
|
||||
for data_tensor in data_list:
|
||||
dist.broadcast(data_tensor, 0, async_op=False)
|
||||
|
||||
if args.distributed:
|
||||
train_sampler.set_epoch(epoch)
|
||||
adjust_learning_rate(optimizer, epoch, args)
|
||||
|
||||
# train for one epoch
|
||||
if epoch >= args.warmup_epoch:
|
||||
train(train_loader, model, criterion, optimizer, epoch, args, cluster_result)
|
||||
else:
|
||||
train(pre_train_loader, model, criterion, optimizer, epoch, args, cluster_result)
|
||||
|
||||
if (epoch + 1) % 10 == 0 and (not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
||||
and args.rank % ngpus_per_node == 0)):
|
||||
save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'arch': args.arch,
|
||||
'state_dict': model.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
}, is_best=False, filename='{}/checkpoint_{:04d}.pth.tar'.format(args.exp_dir, epoch))
|
||||
|
||||
|
||||
def train(train_loader, model, criterion, optimizer, epoch, args, cluster_result=None):
|
||||
batch_time = AverageMeter('Time', ':6.3f')
|
||||
data_time = AverageMeter('Data', ':6.3f')
|
||||
losses = AverageMeter('Loss', ':.4e')
|
||||
acc_inst = AverageMeter('Acc@Inst', ':6.2f')
|
||||
acc_proto = AverageMeter('Acc@Proto', ':6.2f')
|
||||
|
||||
progress = ProgressMeter(
|
||||
len(train_loader),
|
||||
[batch_time, data_time, losses, acc_inst, acc_proto],
|
||||
prefix="Epoch: [{}]".format(epoch))
|
||||
|
||||
# switch to train mode
|
||||
model.train()
|
||||
|
||||
end = time.time()
|
||||
for i, (images, index) in enumerate(train_loader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
im_q = []
|
||||
im_k = []
|
||||
class_number = len(images)
|
||||
class_len = len(images[0])
|
||||
for _i in range(0, class_len, 2):
|
||||
for c in range(class_number):
|
||||
im_q.append(images[c][_i])
|
||||
im_k.append(images[c][_i+1])
|
||||
|
||||
im_q = torch.stack(im_q)
|
||||
im_k = torch.stack(im_k)
|
||||
|
||||
if args.gpu is not None:
|
||||
im_q = im_q.cuda(args.gpu, non_blocking=True)
|
||||
im_k = im_k.cuda(args.gpu, non_blocking=True)
|
||||
|
||||
# compute output
|
||||
output, target, output_proto, target_proto = model(im_q=im_q, im_k=im_k,
|
||||
cluster_result=cluster_result, index=index)
|
||||
|
||||
# InfoNCE loss
|
||||
loss = criterion(output, target)
|
||||
|
||||
# ProtoNCE loss
|
||||
if output_proto is not None:
|
||||
loss_proto = 0
|
||||
for proto_out, proto_target in zip(output_proto, target_proto):
|
||||
loss_proto += criterion(proto_out, proto_target)
|
||||
accp = accuracy(proto_out, proto_target)[0]
|
||||
acc_proto.update(accp[0], images[0].size(0))
|
||||
|
||||
# average loss across all sets of prototypes
|
||||
loss_proto /= len(args.num_cluster)
|
||||
loss += loss_proto
|
||||
|
||||
losses.update(loss.item(), images[0].size(0))
|
||||
acc = accuracy(output, target)[0]
|
||||
acc_inst.update(acc[0], images[0].size(0))
|
||||
|
||||
# compute gradient and do SGD step
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
progress.display(i)
|
||||
|
||||
|
||||
def adjust_learning_rate(optimizer, epoch, args):
|
||||
"""Decay the learning rate based on schedule"""
|
||||
lr = args.lr
|
||||
if args.cos: # cosine lr schedule
|
||||
lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
|
||||
else: # stepwise lr schedule
|
||||
for milestone in args.schedule:
|
||||
lr *= 0.1 if epoch >= milestone else 1.
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||||
with torch.no_grad():
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
|
||||
torch.save(state, filename)
|
||||
print("Model saved as:"+filename)
|
||||
if is_best:
|
||||
shutil.copyfile(filename, 'model_best.pth.tar')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main_loader()
|
1
pcl/__init__.py
Normal file
1
pcl/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
217
pcl/builder.py
Normal file
217
pcl/builder.py
Normal file
@ -0,0 +1,217 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from random import sample
|
||||
|
||||
|
||||
class MoCo(nn.Module):
|
||||
"""
|
||||
Build a MoCo model with: a query encoder, a key encoder, and a queue
|
||||
https://arxiv.org/abs/1911.05722
|
||||
"""
|
||||
def __init__(self, base_encoder, dim=128, r=16384, m=0.999, T=0.1, mlp=False):
|
||||
"""
|
||||
dim: feature dimension (default: 128)
|
||||
r: queue size; number of negative samples/prototypes (default: 16384)
|
||||
m: momentum for updating key encoder (default: 0.999)
|
||||
T: softmax temperature
|
||||
mlp: whether to use mlp projection
|
||||
"""
|
||||
super(MoCo, self).__init__()
|
||||
|
||||
self.r = r
|
||||
self.m = m
|
||||
self.T = T
|
||||
|
||||
# create the encoders
|
||||
# num_classes is the output fc dimension
|
||||
self.encoder_q = base_encoder(num_classes=dim)
|
||||
self.encoder_k = base_encoder(num_classes=dim)
|
||||
|
||||
if mlp: # hack: brute-force replacement
|
||||
dim_mlp = self.encoder_q.fc.weight.shape[1]
|
||||
self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc)
|
||||
self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc)
|
||||
|
||||
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
|
||||
param_k.data.copy_(param_q.data) # initialize
|
||||
param_k.requires_grad = False # not update by gradient
|
||||
|
||||
# create the queue
|
||||
self.register_buffer("queue", torch.randn(dim, r))
|
||||
self.queue = nn.functional.normalize(self.queue, dim=0)
|
||||
|
||||
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
|
||||
|
||||
@torch.no_grad()
|
||||
def _momentum_update_key_encoder(self):
|
||||
"""
|
||||
Momentum update of the key encoder
|
||||
"""
|
||||
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
|
||||
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
|
||||
|
||||
@torch.no_grad()
|
||||
def _dequeue_and_enqueue(self, keys):
|
||||
# gather keys before updating queue
|
||||
keys = concat_all_gather(keys)
|
||||
|
||||
batch_size = keys.shape[0]
|
||||
|
||||
ptr = int(self.queue_ptr)
|
||||
assert self.r % batch_size == 0 # for simplicity
|
||||
|
||||
# replace the keys at ptr (dequeue and enqueue)
|
||||
self.queue[:, ptr:ptr + batch_size] = keys.T
|
||||
ptr = (ptr + batch_size) % self.r # move pointer
|
||||
|
||||
self.queue_ptr[0] = ptr
|
||||
|
||||
@torch.no_grad()
|
||||
def _batch_shuffle_ddp(self, x):
|
||||
"""
|
||||
Batch shuffle, for making use of BatchNorm.
|
||||
*** Only support DistributedDataParallel (DDP) model. ***
|
||||
"""
|
||||
# gather from all gpus
|
||||
batch_size_this = x.shape[0]
|
||||
x_gather = concat_all_gather(x)
|
||||
batch_size_all = x_gather.shape[0]
|
||||
|
||||
num_gpus = batch_size_all // batch_size_this
|
||||
|
||||
# random shuffle index
|
||||
idx_shuffle = torch.randperm(batch_size_all).cuda()
|
||||
|
||||
# broadcast to all gpus
|
||||
torch.distributed.broadcast(idx_shuffle, src=0)
|
||||
|
||||
# index for restoring
|
||||
idx_unshuffle = torch.argsort(idx_shuffle)
|
||||
|
||||
# shuffled index for this gpu
|
||||
gpu_idx = torch.distributed.get_rank()
|
||||
idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
|
||||
|
||||
return x_gather[idx_this], idx_unshuffle
|
||||
|
||||
@torch.no_grad()
|
||||
def _batch_unshuffle_ddp(self, x, idx_unshuffle):
|
||||
"""
|
||||
Undo batch shuffle.
|
||||
*** Only support DistributedDataParallel (DDP) model. ***
|
||||
"""
|
||||
# gather from all gpus
|
||||
batch_size_this = x.shape[0]
|
||||
x_gather = concat_all_gather(x)
|
||||
batch_size_all = x_gather.shape[0]
|
||||
|
||||
num_gpus = batch_size_all // batch_size_this
|
||||
|
||||
# restored index for this gpu
|
||||
gpu_idx = torch.distributed.get_rank()
|
||||
idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
|
||||
|
||||
return x_gather[idx_this]
|
||||
|
||||
def forward(self, im_q, im_k=None, is_eval=False, cluster_result=None, index=None):
|
||||
"""
|
||||
Input:
|
||||
im_q: a batch of query images
|
||||
im_k: a batch of key images
|
||||
is_eval: return momentum embeddings (used for clustering)
|
||||
cluster_result: cluster assignments, centroids, and density
|
||||
index: indices for training samples
|
||||
Output:
|
||||
logits, targets, proto_logits, proto_targets
|
||||
"""
|
||||
|
||||
if is_eval:
|
||||
k = self.encoder_k(im_q)
|
||||
k = nn.functional.normalize(k, dim=1)
|
||||
return k
|
||||
|
||||
# compute key features
|
||||
with torch.no_grad(): # no gradient to keys
|
||||
self._momentum_update_key_encoder() # update the key encoder
|
||||
|
||||
# shuffle for making use of BN
|
||||
im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)
|
||||
|
||||
k = self.encoder_k(im_k) # keys: NxC
|
||||
k = nn.functional.normalize(k, dim=1)
|
||||
|
||||
# undo shuffle
|
||||
k = self._batch_unshuffle_ddp(k, idx_unshuffle)
|
||||
|
||||
# compute query features
|
||||
q = self.encoder_q(im_q) # queries: NxC
|
||||
q = nn.functional.normalize(q, dim=1)
|
||||
|
||||
# compute logits
|
||||
# Einstein sum is more intuitive
|
||||
# positive logits: Nx1
|
||||
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
|
||||
# negative logits: Nxr
|
||||
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
|
||||
|
||||
# logits: Nx(1+r)
|
||||
logits = torch.cat([l_pos, l_neg], dim=1)
|
||||
|
||||
# apply temperature
|
||||
logits /= self.T
|
||||
|
||||
# labels: positive key indicators
|
||||
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
|
||||
|
||||
# dequeue and enqueue
|
||||
self._dequeue_and_enqueue(k)
|
||||
|
||||
# prototypical contrast
|
||||
if cluster_result is not None:
|
||||
proto_labels = []
|
||||
proto_logits = []
|
||||
for n, (im2cluster, prototypes, density) in enumerate(zip(cluster_result['im2cluster'],
|
||||
cluster_result['centroids'],
|
||||
cluster_result['density'])):
|
||||
# get positive prototypes
|
||||
pos_proto_id = im2cluster[index]
|
||||
pos_prototypes = prototypes[pos_proto_id]
|
||||
|
||||
# sample negative prototypes
|
||||
all_proto_id = [i for i in range(im2cluster.max()+1)]
|
||||
neg_proto_id = set(all_proto_id)-set(pos_proto_id.tolist())
|
||||
neg_proto_id = sample(neg_proto_id,self.r) # sample r negative prototypes
|
||||
neg_prototypes = prototypes[neg_proto_id]
|
||||
|
||||
proto_selected = torch.cat([pos_prototypes,neg_prototypes],dim=0)
|
||||
|
||||
# compute prototypical logits
|
||||
logits_proto = torch.mm(q,proto_selected.t())
|
||||
|
||||
# targets for prototype assignment
|
||||
labels_proto = torch.linspace(0, q.size(0)-1, steps=q.size(0)).long().cuda()
|
||||
|
||||
# scaling temperatures for the selected prototypes
|
||||
temp_proto = density[torch.cat([pos_proto_id,torch.LongTensor(neg_proto_id).cuda()],dim=0)]
|
||||
logits_proto /= temp_proto
|
||||
|
||||
proto_labels.append(labels_proto)
|
||||
proto_logits.append(logits_proto)
|
||||
return logits, labels, proto_logits, proto_labels
|
||||
else:
|
||||
return logits, labels, None, None
|
||||
|
||||
|
||||
# utils
|
||||
@torch.no_grad()
|
||||
def concat_all_gather(tensor):
|
||||
"""
|
||||
Performs all_gather operation on the provided tensors.
|
||||
*** Warning ***: torch.distributed.all_gather has no gradient.
|
||||
"""
|
||||
tensors_gather = [torch.ones_like(tensor)
|
||||
for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
||||
|
||||
output = torch.cat(tensors_gather, dim=0)
|
||||
return output
|
68
pcl/loader.py
Normal file
68
pcl/loader.py
Normal file
@ -0,0 +1,68 @@
|
||||
from PIL import ImageFilter
|
||||
import random
|
||||
import torch.utils.data as tud
|
||||
import torchvision.datasets as datasets
|
||||
from torchvision.io import image
|
||||
|
||||
|
||||
class PreImager(tud.Dataset):
|
||||
def __init__(self, samples_dir, aug):
|
||||
data_meta = datasets.ImageFolder(samples_dir)
|
||||
images = data_meta.imgs
|
||||
self.classes = data_meta.classes
|
||||
self.class_number = len(self.classes)
|
||||
self.class_to_index = data_meta.class_to_idx
|
||||
img_class = [[] for i in range(self.class_number)]
|
||||
|
||||
for img in images:
|
||||
img_class[img[1]].append(img[0])
|
||||
lens = [len(c) for c in img_class]
|
||||
self.length = min(lens)
|
||||
self.samples_dir = samples_dir
|
||||
|
||||
self.aug = aug
|
||||
self.images = img_class
|
||||
|
||||
def __getitem__(self, index):
|
||||
imgs = []
|
||||
for i in range(self.class_number):
|
||||
img = image.read_image(self.images[i][index]).float()
|
||||
out = self.aug(img)
|
||||
imgs.append(out)
|
||||
return imgs, index
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
|
||||
class TwoCropsTransform:
|
||||
"""Take two random crops of one image as the query and key."""
|
||||
|
||||
def __init__(self, base_transform):
|
||||
self.base_transform = base_transform
|
||||
|
||||
def __call__(self, x):
|
||||
q = self.base_transform(x)
|
||||
k = self.base_transform(x)
|
||||
return [q, k]
|
||||
|
||||
|
||||
class GaussianBlur(object):
|
||||
"""Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
|
||||
|
||||
def __init__(self, sigma=[.1, 2.]):
|
||||
self.sigma = sigma
|
||||
|
||||
def __call__(self, x):
|
||||
sigma = random.uniform(self.sigma[0], self.sigma[1])
|
||||
x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
|
||||
return x
|
||||
|
||||
|
||||
class ImageFolderInstance(datasets.ImageFolder):
|
||||
def __getitem__(self, index):
|
||||
path, target = self.samples[index]
|
||||
sample = self.loader(path)
|
||||
if self.transform is not None:
|
||||
sample = self.transform(sample)
|
||||
return sample, index
|
39
scripts/augmentation.py
Normal file
39
scripts/augmentation.py
Normal file
@ -0,0 +1,39 @@
|
||||
import torchvision.transforms as transforms
|
||||
import pcl.loader
|
||||
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
|
||||
|
||||
def moco_v2():
|
||||
return [
|
||||
transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
|
||||
transforms.RandomApply([
|
||||
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened
|
||||
], p=0.8),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
transforms.RandomApply([pcl.loader.GaussianBlur([.1, 2.])], p=0.5),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
normalize
|
||||
]
|
||||
|
||||
|
||||
def moco_v1():
|
||||
return [
|
||||
transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
|
||||
transforms.RandomGrayscale(p=0.2),
|
||||
transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
normalize
|
||||
]
|
||||
|
||||
|
||||
def moco_eval():
|
||||
return transforms.Compose([
|
||||
transforms.Resize([512, 512]),
|
||||
# transforms.CenterCrop(512),
|
||||
transforms.ToTensor(),
|
||||
normalize
|
||||
])
|
74
scripts/clustering.py
Normal file
74
scripts/clustering.py
Normal file
@ -0,0 +1,74 @@
|
||||
import faiss
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
|
||||
|
||||
def run_kmeans(x, args):
|
||||
"""
|
||||
Args:
|
||||
x: data to be clustered
|
||||
args:
|
||||
"""
|
||||
|
||||
print('-> Performing kmeans clustering')
|
||||
results = {'im2cluster': [], 'centroids': [], 'density': []}
|
||||
|
||||
for seed, num_cluster in enumerate(args.num_cluster):
|
||||
# intialize faiss clustering parameters
|
||||
print("\tnum_cluster:" + str(num_cluster) + "...", end="")
|
||||
d = x.shape[1]
|
||||
k = int(num_cluster)
|
||||
|
||||
clus = faiss.Kmeans(d, k, gpu=True)
|
||||
clus.verbose = True
|
||||
clus.niter = 20
|
||||
clus.nredo = 5
|
||||
clus.seed = seed
|
||||
clus.max_points_per_centroid = 100
|
||||
clus.min_points_per_centroid = 10
|
||||
|
||||
clus.train(x)
|
||||
|
||||
D, I = clus.index.search(x, 1) # for each sample, find cluster distance and assignments
|
||||
im2cluster = [int(n[0]) for n in I]
|
||||
|
||||
# get cluster centroids
|
||||
# print(type(clus.centroids))
|
||||
centroids = clus.centroids.reshape(k, d)
|
||||
|
||||
# sample-to-centroid distances for each cluster
|
||||
Dcluster = [[] for c in range(k)]
|
||||
for im, i in enumerate(im2cluster):
|
||||
Dcluster[i].append(D[im][0])
|
||||
|
||||
# concentration estimation (phi)
|
||||
density = np.zeros(k)
|
||||
for i, dist in enumerate(Dcluster):
|
||||
if len(dist) > 1:
|
||||
d = (np.asarray(dist) ** 0.5).mean() / np.log(len(dist) + 10)
|
||||
density[i] = d
|
||||
|
||||
# if cluster only has one point, use the max to estimate its concentration
|
||||
dmax = density.max(axis=None)
|
||||
for i, dist in enumerate(Dcluster):
|
||||
if len(dist) <= 1:
|
||||
density[i] = dmax
|
||||
|
||||
density = density.clip(np.percentile(density, 10),
|
||||
np.percentile(density, 90)) # clamp extreme values for stability
|
||||
density = args.temperature * density / density.mean() # scale the mean to temperature
|
||||
|
||||
# convert to cuda Tensors for broadcast
|
||||
centroids = torch.Tensor(centroids).cuda()
|
||||
centroids = nn.functional.normalize(centroids, p=2, dim=1)
|
||||
|
||||
im2cluster = torch.LongTensor(im2cluster).cuda()
|
||||
density = torch.Tensor(density).cuda()
|
||||
|
||||
results['centroids'].append(centroids)
|
||||
results['density'].append(density)
|
||||
results['im2cluster'].append(im2cluster)
|
||||
print("ok")
|
||||
|
||||
return results
|
40
scripts/meter.py
Normal file
40
scripts/meter.py
Normal file
@ -0,0 +1,40 @@
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self, name, fmt=':f'):
|
||||
self.name = name
|
||||
self.fmt = fmt
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
def __str__(self):
|
||||
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
||||
return fmtstr.format(**self.__dict__)
|
||||
|
||||
|
||||
class ProgressMeter(object):
|
||||
def __init__(self, num_batches, meters, prefix=""):
|
||||
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
||||
self.meters = meters
|
||||
self.prefix = prefix
|
||||
|
||||
def display(self, batch):
|
||||
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
||||
entries += [str(meter) for meter in self.meters]
|
||||
print('\t'.join(entries))
|
||||
|
||||
def _get_batch_fmtstr(self, num_batches):
|
||||
num_digits = len(str(num_batches // 1))
|
||||
fmt = '{:' + str(num_digits) + 'd}'
|
||||
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
17
scripts/momentum.py
Normal file
17
scripts/momentum.py
Normal file
@ -0,0 +1,17 @@
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
||||
def compute_features(eval_loader, model, args):
|
||||
print('-> Computing features')
|
||||
model.eval()
|
||||
features = torch.zeros(len(eval_loader.dataset), args.low_dim).cuda()
|
||||
for i, (images, index) in enumerate(tqdm(eval_loader)):
|
||||
with torch.no_grad():
|
||||
images = images.cuda(non_blocking=True)
|
||||
feat = model(images, is_eval=True)
|
||||
features[index] = feat
|
||||
torch.distributed.barrier()
|
||||
torch.distributed.all_reduce(features, op=torch.distributed.ReduceOp.SUM)
|
||||
return features.cpu()
|
83
scripts/parser.py
Normal file
83
scripts/parser.py
Normal file
@ -0,0 +1,83 @@
|
||||
import argparse
|
||||
import torchvision.models as models
|
||||
|
||||
|
||||
def parser():
|
||||
model_names = sorted(name for name in models.__dict__
|
||||
if name.islower() and not name.startswith("__")
|
||||
and callable(models.__dict__[name]))
|
||||
_parser = argparse.ArgumentParser(description='PyTorch ImageNet Training PCL')
|
||||
_parser.add_argument('data', metavar='DIR',
|
||||
help='path to dataset')
|
||||
_parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50',
|
||||
choices=model_names,
|
||||
help='model architecture: ' +
|
||||
' | '.join(model_names) +
|
||||
' (default: resnet50)')
|
||||
_parser.add_argument('-j', '--workers', default=32, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 32)')
|
||||
_parser.add_argument('--epochs', default=200, type=int, metavar='N',
|
||||
help='number of total epochs to run')
|
||||
_parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
||||
help='manual epoch number (useful on restarts)')
|
||||
_parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||
metavar='N',
|
||||
help='mini-batch size (default: 256), this is the total '
|
||||
'batch size of all GPUs on the current node when '
|
||||
'using Data Parallel or Distributed Data Parallel')
|
||||
_parser.add_argument('--lr', '--learning-rate', default=0.03, type=float,
|
||||
metavar='LR', help='initial learning rate', dest='lr')
|
||||
_parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int,
|
||||
help='learning rate schedule (when to drop lr by 10x)')
|
||||
_parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||
help='momentum of SGD solver')
|
||||
_parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
||||
metavar='W', help='weight decay (default: 1e-4)',
|
||||
dest='weight_decay')
|
||||
_parser.add_argument('-p', '--print-freq', default=100, type=int,
|
||||
metavar='N', help='print frequency (default: 10)')
|
||||
_parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
_parser.add_argument('--world-size', default=-1, type=int,
|
||||
help='number of nodes for distributed training')
|
||||
_parser.add_argument('--rank', default=-1, type=int,
|
||||
help='node rank for distributed training')
|
||||
_parser.add_argument('--dist-url', default='tcp://172.0.0.1:23456', type=str,
|
||||
help='url used to set up distributed training')
|
||||
_parser.add_argument('--dist-backend', default='nccl', type=str,
|
||||
help='distributed backend')
|
||||
_parser.add_argument('--seed', default=None, type=int,
|
||||
help='seed for initializing training. ')
|
||||
_parser.add_argument('--gpu', default=None, type=int,
|
||||
help='GPU id to use.')
|
||||
_parser.add_argument('--multiprocessing-distributed', action='store_true',
|
||||
help='Use multi-processing distributed training to launch '
|
||||
'N processes per node, which has N GPUs. This is the '
|
||||
'fastest way to use PyTorch for either single node or '
|
||||
'multi node data parallel training')
|
||||
|
||||
_parser.add_argument('--low-dim', default=128, type=int,
|
||||
help='feature dimension (default: 128)')
|
||||
_parser.add_argument('--pcl-r', default=16384, type=int,
|
||||
help='queue size; number of negative pairs; needs to be smaller than num_cluster (default: '
|
||||
'16384)')
|
||||
_parser.add_argument('--moco-m', default=0.999, type=float,
|
||||
help='moco momentum of updating key encoder (default: 0.999)')
|
||||
_parser.add_argument('--temperature', default=0.2, type=float,
|
||||
help='softmax temperature')
|
||||
|
||||
_parser.add_argument('--mlp', action='store_true',
|
||||
help='use mlp head')
|
||||
_parser.add_argument('--aug-plus', action='store_true',
|
||||
help='use moco-v2/SimCLR data augmentation')
|
||||
_parser.add_argument('--cos', action='store_true',
|
||||
help='use cosine lr schedule')
|
||||
|
||||
_parser.add_argument('--num-cluster', default='25000,50000,100000', type=str,
|
||||
help='number of clusters')
|
||||
_parser.add_argument('--warmup-epoch', default=20, type=int,
|
||||
help='number of warm-up epochs to only train with InfoNCE loss')
|
||||
_parser.add_argument('--exp-dir', default='experiment_pcl', type=str,
|
||||
help='experiment directory')
|
||||
|
||||
return _parser
|
Loading…
x
Reference in New Issue
Block a user