Fiber
2 years ago
9 changed files with 904 additions and 0 deletions
@ -0,0 +1,365 @@ |
|||||
|
import builtins |
||||
|
import math |
||||
|
import os |
||||
|
import random |
||||
|
import shutil |
||||
|
import time |
||||
|
import warnings |
||||
|
|
||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
import torch.nn.parallel |
||||
|
import torch.backends.cudnn as cudnn |
||||
|
import torch.distributed as dist |
||||
|
import torch.optim |
||||
|
import torch.multiprocessing as mp |
||||
|
import torch.utils.data |
||||
|
import torch.utils.data.distributed |
||||
|
import torchvision.transforms as transforms |
||||
|
import torchvision.datasets as datasets |
||||
|
import torchvision.models as models |
||||
|
|
||||
|
from scripts.parser import parser |
||||
|
from scripts.meter import AverageMeter, ProgressMeter |
||||
|
import scripts.augmentation as aug |
||||
|
import scripts.momentum as momentum |
||||
|
import scripts.clustering as clustering |
||||
|
import pcl.builder |
||||
|
import pcl.loader |
||||
|
|
||||
|
|
||||
|
def main_loader(): |
||||
|
args = parser().parse_args() |
||||
|
|
||||
|
if args.seed is not None: |
||||
|
random.seed(args.seed) |
||||
|
torch.manual_seed(args.seed) |
||||
|
warnings.warn('You have chosen to seed training. ' |
||||
|
'This will turn on the CUDNN deterministic setting, ' |
||||
|
'which can slow down your training considerably! ' |
||||
|
'You may see unexpected behavior when restarting ' |
||||
|
'from checkpoints.') |
||||
|
|
||||
|
if args.gpu is not None: |
||||
|
warnings.warn('You have chosen a specific GPU. This will completely ' |
||||
|
'disable data parallelism.') |
||||
|
|
||||
|
if args.dist_url == "env://" and args.world_size == -1: |
||||
|
args.world_size = int(os.environ["WORLD_SIZE"]) |
||||
|
|
||||
|
args.distributed = args.world_size > 1 or args.multiprocessing_distributed |
||||
|
|
||||
|
args.num_cluster = args.num_cluster.split(',') |
||||
|
|
||||
|
if not os.path.exists(args.exp_dir): |
||||
|
os.mkdir(args.exp_dir) |
||||
|
|
||||
|
ngpus_per_node = torch.cuda.device_count() |
||||
|
if args.multiprocessing_distributed: |
||||
|
# Since we have ngpus_per_node processes per node, the total world_size |
||||
|
# needs to be adjusted accordingly |
||||
|
args.world_size = ngpus_per_node * args.world_size |
||||
|
# Use torch.multiprocessing.spawn to launch distributed processes: the |
||||
|
# main_worker process function |
||||
|
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) |
||||
|
else: |
||||
|
# Simply call main_worker function |
||||
|
main_worker(args.gpu, ngpus_per_node, args) |
||||
|
|
||||
|
|
||||
|
def main_worker(gpu, ngpus_per_node, args): |
||||
|
args.gpu = gpu |
||||
|
|
||||
|
if args.gpu is not None: |
||||
|
print("Use GPU: {} for training".format(args.gpu)) |
||||
|
|
||||
|
# suppress printing if not master |
||||
|
if args.multiprocessing_distributed and args.gpu != 0: |
||||
|
def print_pass(): |
||||
|
pass |
||||
|
builtins.print = print_pass |
||||
|
|
||||
|
if args.distributed: |
||||
|
if args.dist_url == "env://" and args.rank == -1: |
||||
|
args.rank = int(os.environ["RANK"]) |
||||
|
if args.multiprocessing_distributed: |
||||
|
# For multiprocessing distributed training, rank needs to be the |
||||
|
# global rank among all the processes |
||||
|
args.rank = args.rank * ngpus_per_node + gpu |
||||
|
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, |
||||
|
world_size=args.world_size, rank=args.rank) |
||||
|
|
||||
|
# create model |
||||
|
print("=> create model '{}'".format(args.arch)) |
||||
|
model = pcl.builder.MoCo( |
||||
|
models.__dict__[args.arch], |
||||
|
args.low_dim, args.pcl_r, args.moco_m, args.temperature, args.mlp) |
||||
|
# print(model) |
||||
|
|
||||
|
if args.distributed: |
||||
|
# For multiprocessing distributed, DistributedDataParallel constructor |
||||
|
# should always set the single device scope, otherwise, |
||||
|
# DistributedDataParallel will use all available devices. |
||||
|
if args.gpu is not None: |
||||
|
torch.cuda.set_device(args.gpu) |
||||
|
model.cuda(args.gpu) |
||||
|
# When using a single GPU per process and per |
||||
|
# DistributedDataParallel, we need to divide the batch size |
||||
|
# ourselves based on the total number of GPUs we have |
||||
|
args.batch_size = int(args.batch_size / ngpus_per_node) |
||||
|
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) |
||||
|
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) |
||||
|
else: |
||||
|
model.cuda() |
||||
|
# DistributedDataParallel will divide and allocate batch_size to all |
||||
|
# available GPUs if device_ids are not set |
||||
|
model = torch.nn.parallel.DistributedDataParallel(model) |
||||
|
elif args.gpu is not None: |
||||
|
torch.cuda.set_device(args.gpu) |
||||
|
model.cuda(args.gpu) |
||||
|
# comment out the following line for debugging |
||||
|
raise NotImplementedError("Only DistributedDataParallel is supported.") |
||||
|
else: |
||||
|
# AllGather implementation (batch shuffle, queue update, etc.) in |
||||
|
# this code only supports DistributedDataParallel. |
||||
|
raise NotImplementedError("Only DistributedDataParallel is supported.") |
||||
|
|
||||
|
# define loss function (criterion) and optimizer |
||||
|
criterion = nn.CrossEntropyLoss().cuda(args.gpu) |
||||
|
|
||||
|
optimizer = torch.optim.SGD(model.parameters(), args.lr, |
||||
|
momentum=args.momentum, |
||||
|
weight_decay=args.weight_decay) |
||||
|
|
||||
|
# optionally resume from a checkpoint |
||||
|
if args.resume: |
||||
|
if os.path.isfile(args.resume): |
||||
|
print("=> loading checkpoint '{}'".format(args.resume)) |
||||
|
if args.gpu is None: |
||||
|
checkpoint = torch.load(args.resume) |
||||
|
else: |
||||
|
# Map model to be loaded to specified single gpu. |
||||
|
loc = 'cuda:{}'.format(args.gpu) |
||||
|
checkpoint = torch.load(args.resume, map_location=loc) |
||||
|
args.start_epoch = checkpoint['epoch'] |
||||
|
model.load_state_dict(checkpoint['state_dict']) |
||||
|
optimizer.load_state_dict(checkpoint['optimizer']) |
||||
|
print("=> loaded checkpoint '{}' (epoch {})" |
||||
|
.format(args.resume, checkpoint['epoch'])) |
||||
|
else: |
||||
|
print("=> no checkpoint found at '{}'".format(args.resume)) |
||||
|
|
||||
|
cudnn.benchmark = True |
||||
|
cudnn.deterministic = True |
||||
|
|
||||
|
# Data loading code |
||||
|
pre_train_dir = os.path.join(args.data, 'train') |
||||
|
train_dir = os.path.join(args.data, 'train') |
||||
|
|
||||
|
if args.aug_plus: |
||||
|
# MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709 |
||||
|
augmentation = aug.moco_v2() |
||||
|
else: |
||||
|
# MoCo v1's aug: same as InstDisc https://arxiv.org/abs/1805.01978 |
||||
|
augmentation = aug.moco_v1() |
||||
|
|
||||
|
# center-crop augmentation |
||||
|
eval_augmentation = aug.moco_eval() |
||||
|
|
||||
|
pre_train_dataset = pcl.loader.PreImager(pre_train_dir, eval_augmentation) |
||||
|
|
||||
|
train_dataset = pcl.loader.ImageFolderInstance(train_dir, eval_augmentation) |
||||
|
eval_dataset = pcl.loader.ImageFolderInstance( |
||||
|
train_dir, |
||||
|
eval_augmentation) |
||||
|
|
||||
|
if args.distributed: |
||||
|
pre_train_sampler = torch.utils.data.distributed.DistributedSampler(pre_train_dataset) |
||||
|
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) |
||||
|
eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset, shuffle=False) |
||||
|
else: |
||||
|
pre_train_sampler = None |
||||
|
train_sampler = None |
||||
|
eval_sampler = None |
||||
|
|
||||
|
if args.batch_size//pre_train_dataset.class_number < 2: |
||||
|
raise NotImplementedError("Batch size must above double number of classes.") |
||||
|
|
||||
|
pre_train_loader = torch.utils.data.DataLoader( |
||||
|
pre_train_dataset, |
||||
|
batch_size=args.batch_size//pre_train_dataset.class_number, |
||||
|
shuffle=(pre_train_sampler is None), |
||||
|
num_workers=args.workers, |
||||
|
pin_memory=True, |
||||
|
sampler=pre_train_sampler, |
||||
|
drop_last=True) |
||||
|
|
||||
|
train_loader = torch.utils.data.DataLoader( |
||||
|
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), |
||||
|
num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) |
||||
|
|
||||
|
# dataloader for center-cropped images, use larger batch size to increase speed |
||||
|
eval_loader = torch.utils.data.DataLoader( |
||||
|
eval_dataset, batch_size=args.batch_size * 5, shuffle=False, |
||||
|
sampler=eval_sampler, num_workers=args.workers, pin_memory=True) |
||||
|
|
||||
|
print("=> Pre-train") |
||||
|
# main loop |
||||
|
for epoch in range(args.start_epoch, args.epochs): |
||||
|
|
||||
|
cluster_result = None |
||||
|
if epoch >= args.warmup_epoch: |
||||
|
# compute momentum features for center-cropped images |
||||
|
features = momentum.compute_features(eval_loader, model, args) |
||||
|
|
||||
|
# placeholder for clustering result |
||||
|
cluster_result = {'im2cluster': [], 'centroids': [], 'density': []} |
||||
|
for num_cluster in args.num_cluster: |
||||
|
cluster_result['im2cluster'].append(torch.zeros(len(eval_dataset), dtype=torch.long).cuda()) |
||||
|
cluster_result['centroids'].append(torch.zeros(int(num_cluster), args.low_dim).cuda()) |
||||
|
cluster_result['density'].append(torch.zeros(int(num_cluster)).cuda()) |
||||
|
|
||||
|
if args.gpu == 0: |
||||
|
features[ |
||||
|
torch.norm(features, dim=1) > 1.5] /= 2 # account for the few samples that are computed twice |
||||
|
features = features.numpy() |
||||
|
cluster_result = clustering.run_kmeans(features, args) # run kmeans clustering on master node |
||||
|
# save the clustering result |
||||
|
# torch.save(cluster_result,os.path.join(args.exp_dir, 'clusters_%d'%epoch)) |
||||
|
|
||||
|
dist.barrier() |
||||
|
# broadcast clustering result |
||||
|
for _k, data_list in cluster_result.items(): |
||||
|
for data_tensor in data_list: |
||||
|
dist.broadcast(data_tensor, 0, async_op=False) |
||||
|
|
||||
|
if args.distributed: |
||||
|
train_sampler.set_epoch(epoch) |
||||
|
adjust_learning_rate(optimizer, epoch, args) |
||||
|
|
||||
|
# train for one epoch |
||||
|
if epoch >= args.warmup_epoch: |
||||
|
train(train_loader, model, criterion, optimizer, epoch, args, cluster_result) |
||||
|
else: |
||||
|
train(pre_train_loader, model, criterion, optimizer, epoch, args, cluster_result) |
||||
|
|
||||
|
if (epoch + 1) % 10 == 0 and (not args.multiprocessing_distributed or (args.multiprocessing_distributed |
||||
|
and args.rank % ngpus_per_node == 0)): |
||||
|
save_checkpoint({ |
||||
|
'epoch': epoch + 1, |
||||
|
'arch': args.arch, |
||||
|
'state_dict': model.state_dict(), |
||||
|
'optimizer': optimizer.state_dict(), |
||||
|
}, is_best=False, filename='{}/checkpoint_{:04d}.pth.tar'.format(args.exp_dir, epoch)) |
||||
|
|
||||
|
|
||||
|
def train(train_loader, model, criterion, optimizer, epoch, args, cluster_result=None): |
||||
|
batch_time = AverageMeter('Time', ':6.3f') |
||||
|
data_time = AverageMeter('Data', ':6.3f') |
||||
|
losses = AverageMeter('Loss', ':.4e') |
||||
|
acc_inst = AverageMeter('Acc@Inst', ':6.2f') |
||||
|
acc_proto = AverageMeter('Acc@Proto', ':6.2f') |
||||
|
|
||||
|
progress = ProgressMeter( |
||||
|
len(train_loader), |
||||
|
[batch_time, data_time, losses, acc_inst, acc_proto], |
||||
|
prefix="Epoch: [{}]".format(epoch)) |
||||
|
|
||||
|
# switch to train mode |
||||
|
model.train() |
||||
|
|
||||
|
end = time.time() |
||||
|
for i, (images, index) in enumerate(train_loader): |
||||
|
# measure data loading time |
||||
|
data_time.update(time.time() - end) |
||||
|
|
||||
|
im_q = [] |
||||
|
im_k = [] |
||||
|
class_number = len(images) |
||||
|
class_len = len(images[0]) |
||||
|
for _i in range(0, class_len, 2): |
||||
|
for c in range(class_number): |
||||
|
im_q.append(images[c][_i]) |
||||
|
im_k.append(images[c][_i+1]) |
||||
|
|
||||
|
im_q = torch.stack(im_q) |
||||
|
im_k = torch.stack(im_k) |
||||
|
|
||||
|
if args.gpu is not None: |
||||
|
im_q = im_q.cuda(args.gpu, non_blocking=True) |
||||
|
im_k = im_k.cuda(args.gpu, non_blocking=True) |
||||
|
|
||||
|
# compute output |
||||
|
output, target, output_proto, target_proto = model(im_q=im_q, im_k=im_k, |
||||
|
cluster_result=cluster_result, index=index) |
||||
|
|
||||
|
# InfoNCE loss |
||||
|
loss = criterion(output, target) |
||||
|
|
||||
|
# ProtoNCE loss |
||||
|
if output_proto is not None: |
||||
|
loss_proto = 0 |
||||
|
for proto_out, proto_target in zip(output_proto, target_proto): |
||||
|
loss_proto += criterion(proto_out, proto_target) |
||||
|
accp = accuracy(proto_out, proto_target)[0] |
||||
|
acc_proto.update(accp[0], images[0].size(0)) |
||||
|
|
||||
|
# average loss across all sets of prototypes |
||||
|
loss_proto /= len(args.num_cluster) |
||||
|
loss += loss_proto |
||||
|
|
||||
|
losses.update(loss.item(), images[0].size(0)) |
||||
|
acc = accuracy(output, target)[0] |
||||
|
acc_inst.update(acc[0], images[0].size(0)) |
||||
|
|
||||
|
# compute gradient and do SGD step |
||||
|
optimizer.zero_grad() |
||||
|
loss.backward() |
||||
|
optimizer.step() |
||||
|
|
||||
|
# measure elapsed time |
||||
|
batch_time.update(time.time() - end) |
||||
|
end = time.time() |
||||
|
|
||||
|
if i % args.print_freq == 0: |
||||
|
progress.display(i) |
||||
|
|
||||
|
|
||||
|
def adjust_learning_rate(optimizer, epoch, args): |
||||
|
"""Decay the learning rate based on schedule""" |
||||
|
lr = args.lr |
||||
|
if args.cos: # cosine lr schedule |
||||
|
lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) |
||||
|
else: # stepwise lr schedule |
||||
|
for milestone in args.schedule: |
||||
|
lr *= 0.1 if epoch >= milestone else 1. |
||||
|
for param_group in optimizer.param_groups: |
||||
|
param_group['lr'] = lr |
||||
|
|
||||
|
|
||||
|
def accuracy(output, target, topk=(1,)): |
||||
|
"""Computes the accuracy over the k top predictions for the specified values of k""" |
||||
|
with torch.no_grad(): |
||||
|
maxk = max(topk) |
||||
|
batch_size = target.size(0) |
||||
|
|
||||
|
_, pred = output.topk(maxk, 1, True, True) |
||||
|
pred = pred.t() |
||||
|
correct = pred.eq(target.view(1, -1).expand_as(pred)) |
||||
|
|
||||
|
res = [] |
||||
|
for k in topk: |
||||
|
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) |
||||
|
res.append(correct_k.mul_(100.0 / batch_size)) |
||||
|
return res |
||||
|
|
||||
|
|
||||
|
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): |
||||
|
torch.save(state, filename) |
||||
|
print("Model saved as:"+filename) |
||||
|
if is_best: |
||||
|
shutil.copyfile(filename, 'model_best.pth.tar') |
||||
|
|
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
main_loader() |
@ -0,0 +1 @@ |
|||||
|
|
@ -0,0 +1,217 @@ |
|||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
from random import sample |
||||
|
|
||||
|
|
||||
|
class MoCo(nn.Module): |
||||
|
""" |
||||
|
Build a MoCo model with: a query encoder, a key encoder, and a queue |
||||
|
https://arxiv.org/abs/1911.05722 |
||||
|
""" |
||||
|
def __init__(self, base_encoder, dim=128, r=16384, m=0.999, T=0.1, mlp=False): |
||||
|
""" |
||||
|
dim: feature dimension (default: 128) |
||||
|
r: queue size; number of negative samples/prototypes (default: 16384) |
||||
|
m: momentum for updating key encoder (default: 0.999) |
||||
|
T: softmax temperature |
||||
|
mlp: whether to use mlp projection |
||||
|
""" |
||||
|
super(MoCo, self).__init__() |
||||
|
|
||||
|
self.r = r |
||||
|
self.m = m |
||||
|
self.T = T |
||||
|
|
||||
|
# create the encoders |
||||
|
# num_classes is the output fc dimension |
||||
|
self.encoder_q = base_encoder(num_classes=dim) |
||||
|
self.encoder_k = base_encoder(num_classes=dim) |
||||
|
|
||||
|
if mlp: # hack: brute-force replacement |
||||
|
dim_mlp = self.encoder_q.fc.weight.shape[1] |
||||
|
self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) |
||||
|
self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) |
||||
|
|
||||
|
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): |
||||
|
param_k.data.copy_(param_q.data) # initialize |
||||
|
param_k.requires_grad = False # not update by gradient |
||||
|
|
||||
|
# create the queue |
||||
|
self.register_buffer("queue", torch.randn(dim, r)) |
||||
|
self.queue = nn.functional.normalize(self.queue, dim=0) |
||||
|
|
||||
|
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) |
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def _momentum_update_key_encoder(self): |
||||
|
""" |
||||
|
Momentum update of the key encoder |
||||
|
""" |
||||
|
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): |
||||
|
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) |
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def _dequeue_and_enqueue(self, keys): |
||||
|
# gather keys before updating queue |
||||
|
keys = concat_all_gather(keys) |
||||
|
|
||||
|
batch_size = keys.shape[0] |
||||
|
|
||||
|
ptr = int(self.queue_ptr) |
||||
|
assert self.r % batch_size == 0 # for simplicity |
||||
|
|
||||
|
# replace the keys at ptr (dequeue and enqueue) |
||||
|
self.queue[:, ptr:ptr + batch_size] = keys.T |
||||
|
ptr = (ptr + batch_size) % self.r # move pointer |
||||
|
|
||||
|
self.queue_ptr[0] = ptr |
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def _batch_shuffle_ddp(self, x): |
||||
|
""" |
||||
|
Batch shuffle, for making use of BatchNorm. |
||||
|
*** Only support DistributedDataParallel (DDP) model. *** |
||||
|
""" |
||||
|
# gather from all gpus |
||||
|
batch_size_this = x.shape[0] |
||||
|
x_gather = concat_all_gather(x) |
||||
|
batch_size_all = x_gather.shape[0] |
||||
|
|
||||
|
num_gpus = batch_size_all // batch_size_this |
||||
|
|
||||
|
# random shuffle index |
||||
|
idx_shuffle = torch.randperm(batch_size_all).cuda() |
||||
|
|
||||
|
# broadcast to all gpus |
||||
|
torch.distributed.broadcast(idx_shuffle, src=0) |
||||
|
|
||||
|
# index for restoring |
||||
|
idx_unshuffle = torch.argsort(idx_shuffle) |
||||
|
|
||||
|
# shuffled index for this gpu |
||||
|
gpu_idx = torch.distributed.get_rank() |
||||
|
idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] |
||||
|
|
||||
|
return x_gather[idx_this], idx_unshuffle |
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def _batch_unshuffle_ddp(self, x, idx_unshuffle): |
||||
|
""" |
||||
|
Undo batch shuffle. |
||||
|
*** Only support DistributedDataParallel (DDP) model. *** |
||||
|
""" |
||||
|
# gather from all gpus |
||||
|
batch_size_this = x.shape[0] |
||||
|
x_gather = concat_all_gather(x) |
||||
|
batch_size_all = x_gather.shape[0] |
||||
|
|
||||
|
num_gpus = batch_size_all // batch_size_this |
||||
|
|
||||
|
# restored index for this gpu |
||||
|
gpu_idx = torch.distributed.get_rank() |
||||
|
idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] |
||||
|
|
||||
|
return x_gather[idx_this] |
||||
|
|
||||
|
def forward(self, im_q, im_k=None, is_eval=False, cluster_result=None, index=None): |
||||
|
""" |
||||
|
Input: |
||||
|
im_q: a batch of query images |
||||
|
im_k: a batch of key images |
||||
|
is_eval: return momentum embeddings (used for clustering) |
||||
|
cluster_result: cluster assignments, centroids, and density |
||||
|
index: indices for training samples |
||||
|
Output: |
||||
|
logits, targets, proto_logits, proto_targets |
||||
|
""" |
||||
|
|
||||
|
if is_eval: |
||||
|
k = self.encoder_k(im_q) |
||||
|
k = nn.functional.normalize(k, dim=1) |
||||
|
return k |
||||
|
|
||||
|
# compute key features |
||||
|
with torch.no_grad(): # no gradient to keys |
||||
|
self._momentum_update_key_encoder() # update the key encoder |
||||
|
|
||||
|
# shuffle for making use of BN |
||||
|
im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k) |
||||
|
|
||||
|
k = self.encoder_k(im_k) # keys: NxC |
||||
|
k = nn.functional.normalize(k, dim=1) |
||||
|
|
||||
|
# undo shuffle |
||||
|
k = self._batch_unshuffle_ddp(k, idx_unshuffle) |
||||
|
|
||||
|
# compute query features |
||||
|
q = self.encoder_q(im_q) # queries: NxC |
||||
|
q = nn.functional.normalize(q, dim=1) |
||||
|
|
||||
|
# compute logits |
||||
|
# Einstein sum is more intuitive |
||||
|
# positive logits: Nx1 |
||||
|
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) |
||||
|
# negative logits: Nxr |
||||
|
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) |
||||
|
|
||||
|
# logits: Nx(1+r) |
||||
|
logits = torch.cat([l_pos, l_neg], dim=1) |
||||
|
|
||||
|
# apply temperature |
||||
|
logits /= self.T |
||||
|
|
||||
|
# labels: positive key indicators |
||||
|
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() |
||||
|
|
||||
|
# dequeue and enqueue |
||||
|
self._dequeue_and_enqueue(k) |
||||
|
|
||||
|
# prototypical contrast |
||||
|
if cluster_result is not None: |
||||
|
proto_labels = [] |
||||
|
proto_logits = [] |
||||
|
for n, (im2cluster, prototypes, density) in enumerate(zip(cluster_result['im2cluster'], |
||||
|
cluster_result['centroids'], |
||||
|
cluster_result['density'])): |
||||
|
# get positive prototypes |
||||
|
pos_proto_id = im2cluster[index] |
||||
|
pos_prototypes = prototypes[pos_proto_id] |
||||
|
|
||||
|
# sample negative prototypes |
||||
|
all_proto_id = [i for i in range(im2cluster.max()+1)] |
||||
|
neg_proto_id = set(all_proto_id)-set(pos_proto_id.tolist()) |
||||
|
neg_proto_id = sample(neg_proto_id,self.r) # sample r negative prototypes |
||||
|
neg_prototypes = prototypes[neg_proto_id] |
||||
|
|
||||
|
proto_selected = torch.cat([pos_prototypes,neg_prototypes],dim=0) |
||||
|
|
||||
|
# compute prototypical logits |
||||
|
logits_proto = torch.mm(q,proto_selected.t()) |
||||
|
|
||||
|
# targets for prototype assignment |
||||
|
labels_proto = torch.linspace(0, q.size(0)-1, steps=q.size(0)).long().cuda() |
||||
|
|
||||
|
# scaling temperatures for the selected prototypes |
||||
|
temp_proto = density[torch.cat([pos_proto_id,torch.LongTensor(neg_proto_id).cuda()],dim=0)] |
||||
|
logits_proto /= temp_proto |
||||
|
|
||||
|
proto_labels.append(labels_proto) |
||||
|
proto_logits.append(logits_proto) |
||||
|
return logits, labels, proto_logits, proto_labels |
||||
|
else: |
||||
|
return logits, labels, None, None |
||||
|
|
||||
|
|
||||
|
# utils |
||||
|
@torch.no_grad() |
||||
|
def concat_all_gather(tensor): |
||||
|
""" |
||||
|
Performs all_gather operation on the provided tensors. |
||||
|
*** Warning ***: torch.distributed.all_gather has no gradient. |
||||
|
""" |
||||
|
tensors_gather = [torch.ones_like(tensor) |
||||
|
for _ in range(torch.distributed.get_world_size())] |
||||
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False) |
||||
|
|
||||
|
output = torch.cat(tensors_gather, dim=0) |
||||
|
return output |
@ -0,0 +1,68 @@ |
|||||
|
from PIL import ImageFilter |
||||
|
import random |
||||
|
import torch.utils.data as tud |
||||
|
import torchvision.datasets as datasets |
||||
|
from torchvision.io import image |
||||
|
|
||||
|
|
||||
|
class PreImager(tud.Dataset): |
||||
|
def __init__(self, samples_dir, aug): |
||||
|
data_meta = datasets.ImageFolder(samples_dir) |
||||
|
images = data_meta.imgs |
||||
|
self.classes = data_meta.classes |
||||
|
self.class_number = len(self.classes) |
||||
|
self.class_to_index = data_meta.class_to_idx |
||||
|
img_class = [[] for i in range(self.class_number)] |
||||
|
|
||||
|
for img in images: |
||||
|
img_class[img[1]].append(img[0]) |
||||
|
lens = [len(c) for c in img_class] |
||||
|
self.length = min(lens) |
||||
|
self.samples_dir = samples_dir |
||||
|
|
||||
|
self.aug = aug |
||||
|
self.images = img_class |
||||
|
|
||||
|
def __getitem__(self, index): |
||||
|
imgs = [] |
||||
|
for i in range(self.class_number): |
||||
|
img = image.read_image(self.images[i][index]).float() |
||||
|
out = self.aug(img) |
||||
|
imgs.append(out) |
||||
|
return imgs, index |
||||
|
|
||||
|
def __len__(self): |
||||
|
return self.length |
||||
|
|
||||
|
|
||||
|
class TwoCropsTransform: |
||||
|
"""Take two random crops of one image as the query and key.""" |
||||
|
|
||||
|
def __init__(self, base_transform): |
||||
|
self.base_transform = base_transform |
||||
|
|
||||
|
def __call__(self, x): |
||||
|
q = self.base_transform(x) |
||||
|
k = self.base_transform(x) |
||||
|
return [q, k] |
||||
|
|
||||
|
|
||||
|
class GaussianBlur(object): |
||||
|
"""Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" |
||||
|
|
||||
|
def __init__(self, sigma=[.1, 2.]): |
||||
|
self.sigma = sigma |
||||
|
|
||||
|
def __call__(self, x): |
||||
|
sigma = random.uniform(self.sigma[0], self.sigma[1]) |
||||
|
x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) |
||||
|
return x |
||||
|
|
||||
|
|
||||
|
class ImageFolderInstance(datasets.ImageFolder): |
||||
|
def __getitem__(self, index): |
||||
|
path, target = self.samples[index] |
||||
|
sample = self.loader(path) |
||||
|
if self.transform is not None: |
||||
|
sample = self.transform(sample) |
||||
|
return sample, index |
@ -0,0 +1,39 @@ |
|||||
|
import torchvision.transforms as transforms |
||||
|
import pcl.loader |
||||
|
|
||||
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], |
||||
|
std=[0.229, 0.224, 0.225]) |
||||
|
|
||||
|
|
||||
|
def moco_v2(): |
||||
|
return [ |
||||
|
transforms.RandomResizedCrop(224, scale=(0.2, 1.)), |
||||
|
transforms.RandomApply([ |
||||
|
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened |
||||
|
], p=0.8), |
||||
|
transforms.RandomGrayscale(p=0.2), |
||||
|
transforms.RandomApply([pcl.loader.GaussianBlur([.1, 2.])], p=0.5), |
||||
|
transforms.RandomHorizontalFlip(), |
||||
|
transforms.ToTensor(), |
||||
|
normalize |
||||
|
] |
||||
|
|
||||
|
|
||||
|
def moco_v1(): |
||||
|
return [ |
||||
|
transforms.RandomResizedCrop(224, scale=(0.2, 1.)), |
||||
|
transforms.RandomGrayscale(p=0.2), |
||||
|
transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), |
||||
|
transforms.RandomHorizontalFlip(), |
||||
|
transforms.ToTensor(), |
||||
|
normalize |
||||
|
] |
||||
|
|
||||
|
|
||||
|
def moco_eval(): |
||||
|
return transforms.Compose([ |
||||
|
transforms.Resize([512, 512]), |
||||
|
# transforms.CenterCrop(512), |
||||
|
transforms.ToTensor(), |
||||
|
normalize |
||||
|
]) |
@ -0,0 +1,74 @@ |
|||||
|
import faiss |
||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
import numpy as np |
||||
|
|
||||
|
|
||||
|
def run_kmeans(x, args): |
||||
|
""" |
||||
|
Args: |
||||
|
x: data to be clustered |
||||
|
args: |
||||
|
""" |
||||
|
|
||||
|
print('-> Performing kmeans clustering') |
||||
|
results = {'im2cluster': [], 'centroids': [], 'density': []} |
||||
|
|
||||
|
for seed, num_cluster in enumerate(args.num_cluster): |
||||
|
# intialize faiss clustering parameters |
||||
|
print("\tnum_cluster:" + str(num_cluster) + "...", end="") |
||||
|
d = x.shape[1] |
||||
|
k = int(num_cluster) |
||||
|
|
||||
|
clus = faiss.Kmeans(d, k, gpu=True) |
||||
|
clus.verbose = True |
||||
|
clus.niter = 20 |
||||
|
clus.nredo = 5 |
||||
|
clus.seed = seed |
||||
|
clus.max_points_per_centroid = 100 |
||||
|
clus.min_points_per_centroid = 10 |
||||
|
|
||||
|
clus.train(x) |
||||
|
|
||||
|
D, I = clus.index.search(x, 1) # for each sample, find cluster distance and assignments |
||||
|
im2cluster = [int(n[0]) for n in I] |
||||
|
|
||||
|
# get cluster centroids |
||||
|
# print(type(clus.centroids)) |
||||
|
centroids = clus.centroids.reshape(k, d) |
||||
|
|
||||
|
# sample-to-centroid distances for each cluster |
||||
|
Dcluster = [[] for c in range(k)] |
||||
|
for im, i in enumerate(im2cluster): |
||||
|
Dcluster[i].append(D[im][0]) |
||||
|
|
||||
|
# concentration estimation (phi) |
||||
|
density = np.zeros(k) |
||||
|
for i, dist in enumerate(Dcluster): |
||||
|
if len(dist) > 1: |
||||
|
d = (np.asarray(dist) ** 0.5).mean() / np.log(len(dist) + 10) |
||||
|
density[i] = d |
||||
|
|
||||
|
# if cluster only has one point, use the max to estimate its concentration |
||||
|
dmax = density.max(axis=None) |
||||
|
for i, dist in enumerate(Dcluster): |
||||
|
if len(dist) <= 1: |
||||
|
density[i] = dmax |
||||
|
|
||||
|
density = density.clip(np.percentile(density, 10), |
||||
|
np.percentile(density, 90)) # clamp extreme values for stability |
||||
|
density = args.temperature * density / density.mean() # scale the mean to temperature |
||||
|
|
||||
|
# convert to cuda Tensors for broadcast |
||||
|
centroids = torch.Tensor(centroids).cuda() |
||||
|
centroids = nn.functional.normalize(centroids, p=2, dim=1) |
||||
|
|
||||
|
im2cluster = torch.LongTensor(im2cluster).cuda() |
||||
|
density = torch.Tensor(density).cuda() |
||||
|
|
||||
|
results['centroids'].append(centroids) |
||||
|
results['density'].append(density) |
||||
|
results['im2cluster'].append(im2cluster) |
||||
|
print("ok") |
||||
|
|
||||
|
return results |
@ -0,0 +1,40 @@ |
|||||
|
class AverageMeter(object): |
||||
|
"""Computes and stores the average and current value""" |
||||
|
|
||||
|
def __init__(self, name, fmt=':f'): |
||||
|
self.name = name |
||||
|
self.fmt = fmt |
||||
|
self.reset() |
||||
|
|
||||
|
def reset(self): |
||||
|
self.val = 0 |
||||
|
self.avg = 0 |
||||
|
self.sum = 0 |
||||
|
self.count = 0 |
||||
|
|
||||
|
def update(self, val, n=1): |
||||
|
self.val = val |
||||
|
self.sum += val * n |
||||
|
self.count += n |
||||
|
self.avg = self.sum / self.count |
||||
|
|
||||
|
def __str__(self): |
||||
|
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' |
||||
|
return fmtstr.format(**self.__dict__) |
||||
|
|
||||
|
|
||||
|
class ProgressMeter(object): |
||||
|
def __init__(self, num_batches, meters, prefix=""): |
||||
|
self.batch_fmtstr = self._get_batch_fmtstr(num_batches) |
||||
|
self.meters = meters |
||||
|
self.prefix = prefix |
||||
|
|
||||
|
def display(self, batch): |
||||
|
entries = [self.prefix + self.batch_fmtstr.format(batch)] |
||||
|
entries += [str(meter) for meter in self.meters] |
||||
|
print('\t'.join(entries)) |
||||
|
|
||||
|
def _get_batch_fmtstr(self, num_batches): |
||||
|
num_digits = len(str(num_batches // 1)) |
||||
|
fmt = '{:' + str(num_digits) + 'd}' |
||||
|
return '[' + fmt + '/' + fmt.format(num_batches) + ']' |
@ -0,0 +1,17 @@ |
|||||
|
from tqdm import tqdm |
||||
|
import torch |
||||
|
import torch.distributed |
||||
|
|
||||
|
|
||||
|
def compute_features(eval_loader, model, args): |
||||
|
print('-> Computing features') |
||||
|
model.eval() |
||||
|
features = torch.zeros(len(eval_loader.dataset), args.low_dim).cuda() |
||||
|
for i, (images, index) in enumerate(tqdm(eval_loader)): |
||||
|
with torch.no_grad(): |
||||
|
images = images.cuda(non_blocking=True) |
||||
|
feat = model(images, is_eval=True) |
||||
|
features[index] = feat |
||||
|
torch.distributed.barrier() |
||||
|
torch.distributed.all_reduce(features, op=torch.distributed.ReduceOp.SUM) |
||||
|
return features.cpu() |
@ -0,0 +1,83 @@ |
|||||
|
import argparse |
||||
|
import torchvision.models as models |
||||
|
|
||||
|
|
||||
|
def parser(): |
||||
|
model_names = sorted(name for name in models.__dict__ |
||||
|
if name.islower() and not name.startswith("__") |
||||
|
and callable(models.__dict__[name])) |
||||
|
_parser = argparse.ArgumentParser(description='PyTorch ImageNet Training PCL') |
||||
|
_parser.add_argument('data', metavar='DIR', |
||||
|
help='path to dataset') |
||||
|
_parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', |
||||
|
choices=model_names, |
||||
|
help='model architecture: ' + |
||||
|
' | '.join(model_names) + |
||||
|
' (default: resnet50)') |
||||
|
_parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', |
||||
|
help='number of data loading workers (default: 32)') |
||||
|
_parser.add_argument('--epochs', default=200, type=int, metavar='N', |
||||
|
help='number of total epochs to run') |
||||
|
_parser.add_argument('--start-epoch', default=0, type=int, metavar='N', |
||||
|
help='manual epoch number (useful on restarts)') |
||||
|
_parser.add_argument('-b', '--batch-size', default=256, type=int, |
||||
|
metavar='N', |
||||
|
help='mini-batch size (default: 256), this is the total ' |
||||
|
'batch size of all GPUs on the current node when ' |
||||
|
'using Data Parallel or Distributed Data Parallel') |
||||
|
_parser.add_argument('--lr', '--learning-rate', default=0.03, type=float, |
||||
|
metavar='LR', help='initial learning rate', dest='lr') |
||||
|
_parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int, |
||||
|
help='learning rate schedule (when to drop lr by 10x)') |
||||
|
_parser.add_argument('--momentum', default=0.9, type=float, metavar='M', |
||||
|
help='momentum of SGD solver') |
||||
|
_parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, |
||||
|
metavar='W', help='weight decay (default: 1e-4)', |
||||
|
dest='weight_decay') |
||||
|
_parser.add_argument('-p', '--print-freq', default=100, type=int, |
||||
|
metavar='N', help='print frequency (default: 10)') |
||||
|
_parser.add_argument('--resume', default='', type=str, metavar='PATH', |
||||
|
help='path to latest checkpoint (default: none)') |
||||
|
_parser.add_argument('--world-size', default=-1, type=int, |
||||
|
help='number of nodes for distributed training') |
||||
|
_parser.add_argument('--rank', default=-1, type=int, |
||||
|
help='node rank for distributed training') |
||||
|
_parser.add_argument('--dist-url', default='tcp://172.0.0.1:23456', type=str, |
||||
|
help='url used to set up distributed training') |
||||
|
_parser.add_argument('--dist-backend', default='nccl', type=str, |
||||
|
help='distributed backend') |
||||
|
_parser.add_argument('--seed', default=None, type=int, |
||||
|
help='seed for initializing training. ') |
||||
|
_parser.add_argument('--gpu', default=None, type=int, |
||||
|
help='GPU id to use.') |
||||
|
_parser.add_argument('--multiprocessing-distributed', action='store_true', |
||||
|
help='Use multi-processing distributed training to launch ' |
||||
|
'N processes per node, which has N GPUs. This is the ' |
||||
|
'fastest way to use PyTorch for either single node or ' |
||||
|
'multi node data parallel training') |
||||
|
|
||||
|
_parser.add_argument('--low-dim', default=128, type=int, |
||||
|
help='feature dimension (default: 128)') |
||||
|
_parser.add_argument('--pcl-r', default=16384, type=int, |
||||
|
help='queue size; number of negative pairs; needs to be smaller than num_cluster (default: ' |
||||
|
'16384)') |
||||
|
_parser.add_argument('--moco-m', default=0.999, type=float, |
||||
|
help='moco momentum of updating key encoder (default: 0.999)') |
||||
|
_parser.add_argument('--temperature', default=0.2, type=float, |
||||
|
help='softmax temperature') |
||||
|
|
||||
|
_parser.add_argument('--mlp', action='store_true', |
||||
|
help='use mlp head') |
||||
|
_parser.add_argument('--aug-plus', action='store_true', |
||||
|
help='use moco-v2/SimCLR data augmentation') |
||||
|
_parser.add_argument('--cos', action='store_true', |
||||
|
help='use cosine lr schedule') |
||||
|
|
||||
|
_parser.add_argument('--num-cluster', default='25000,50000,100000', type=str, |
||||
|
help='number of clusters') |
||||
|
_parser.add_argument('--warmup-epoch', default=20, type=int, |
||||
|
help='number of warm-up epochs to only train with InfoNCE loss') |
||||
|
_parser.add_argument('--exp-dir', default='experiment_pcl', type=str, |
||||
|
help='experiment directory') |
||||
|
|
||||
|
return _parser |
Loading…
Reference in new issue