From 57017a9e0edbc8e4a650991b3428e29b48ca9974 Mon Sep 17 00:00:00 2001 From: Fiber Date: Sun, 19 Mar 2023 20:05:53 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A4=8D=E7=8E=B0PCL=EF=BC=8C=E9=87=8D?= =?UTF-8?q?=E6=9E=84=E7=BB=93=E6=9E=84=EF=BC=8C=E4=BF=AE=E6=94=B9=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E8=BE=93=E5=85=A5=E6=96=B9=E5=BC=8F=EF=BC=8C=E5=8F=96?= =?UTF-8?q?=E6=B6=88=E6=95=B0=E6=8D=AE=E5=A2=9E=E5=BC=BA=EF=BC=8C=E9=AA=8C?= =?UTF-8?q?=E8=AF=81=E5=AF=B9=E6=AF=94=E5=AD=A6=E4=B9=A0=E5=AF=B9=E5=9B=BE?= =?UTF-8?q?=E5=83=8F=E8=B4=A8=E9=87=8F=E7=9A=84=E8=AF=84=E4=BC=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 365 ++++++++++++++++++++++++++++++++++++++++ pcl/__init__.py | 1 + pcl/builder.py | 217 ++++++++++++++++++++++++ pcl/loader.py | 68 ++++++++ scripts/augmentation.py | 39 +++++ scripts/clustering.py | 74 ++++++++ scripts/meter.py | 40 +++++ scripts/momentum.py | 17 ++ scripts/parser.py | 83 +++++++++ 9 files changed, 904 insertions(+) create mode 100644 main.py create mode 100644 pcl/__init__.py create mode 100644 pcl/builder.py create mode 100644 pcl/loader.py create mode 100644 scripts/augmentation.py create mode 100644 scripts/clustering.py create mode 100644 scripts/meter.py create mode 100644 scripts/momentum.py create mode 100644 scripts/parser.py diff --git a/main.py b/main.py new file mode 100644 index 0000000..f9a0a93 --- /dev/null +++ b/main.py @@ -0,0 +1,365 @@ +import builtins +import math +import os +import random +import shutil +import time +import warnings + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.distributed as dist +import torch.optim +import torch.multiprocessing as mp +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +import torchvision.models as models + +from scripts.parser import parser +from scripts.meter import AverageMeter, ProgressMeter +import scripts.augmentation as aug +import scripts.momentum as momentum +import scripts.clustering as clustering +import pcl.builder +import pcl.loader + + +def main_loader(): + args = parser().parse_args() + + if args.seed is not None: + random.seed(args.seed) + torch.manual_seed(args.seed) + warnings.warn('You have chosen to seed training. ' + 'This will turn on the CUDNN deterministic setting, ' + 'which can slow down your training considerably! ' + 'You may see unexpected behavior when restarting ' + 'from checkpoints.') + + if args.gpu is not None: + warnings.warn('You have chosen a specific GPU. This will completely ' + 'disable data parallelism.') + + if args.dist_url == "env://" and args.world_size == -1: + args.world_size = int(os.environ["WORLD_SIZE"]) + + args.distributed = args.world_size > 1 or args.multiprocessing_distributed + + args.num_cluster = args.num_cluster.split(',') + + if not os.path.exists(args.exp_dir): + os.mkdir(args.exp_dir) + + ngpus_per_node = torch.cuda.device_count() + if args.multiprocessing_distributed: + # Since we have ngpus_per_node processes per node, the total world_size + # needs to be adjusted accordingly + args.world_size = ngpus_per_node * args.world_size + # Use torch.multiprocessing.spawn to launch distributed processes: the + # main_worker process function + mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) + else: + # Simply call main_worker function + main_worker(args.gpu, ngpus_per_node, args) + + +def main_worker(gpu, ngpus_per_node, args): + args.gpu = gpu + + if args.gpu is not None: + print("Use GPU: {} for training".format(args.gpu)) + + # suppress printing if not master + if args.multiprocessing_distributed and args.gpu != 0: + def print_pass(): + pass + builtins.print = print_pass + + if args.distributed: + if args.dist_url == "env://" and args.rank == -1: + args.rank = int(os.environ["RANK"]) + if args.multiprocessing_distributed: + # For multiprocessing distributed training, rank needs to be the + # global rank among all the processes + args.rank = args.rank * ngpus_per_node + gpu + dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + + # create model + print("=> create model '{}'".format(args.arch)) + model = pcl.builder.MoCo( + models.__dict__[args.arch], + args.low_dim, args.pcl_r, args.moco_m, args.temperature, args.mlp) + # print(model) + + if args.distributed: + # For multiprocessing distributed, DistributedDataParallel constructor + # should always set the single device scope, otherwise, + # DistributedDataParallel will use all available devices. + if args.gpu is not None: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + # When using a single GPU per process and per + # DistributedDataParallel, we need to divide the batch size + # ourselves based on the total number of GPUs we have + args.batch_size = int(args.batch_size / ngpus_per_node) + args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) + else: + model.cuda() + # DistributedDataParallel will divide and allocate batch_size to all + # available GPUs if device_ids are not set + model = torch.nn.parallel.DistributedDataParallel(model) + elif args.gpu is not None: + torch.cuda.set_device(args.gpu) + model.cuda(args.gpu) + # comment out the following line for debugging + raise NotImplementedError("Only DistributedDataParallel is supported.") + else: + # AllGather implementation (batch shuffle, queue update, etc.) in + # this code only supports DistributedDataParallel. + raise NotImplementedError("Only DistributedDataParallel is supported.") + + # define loss function (criterion) and optimizer + criterion = nn.CrossEntropyLoss().cuda(args.gpu) + + optimizer = torch.optim.SGD(model.parameters(), args.lr, + momentum=args.momentum, + weight_decay=args.weight_decay) + + # optionally resume from a checkpoint + if args.resume: + if os.path.isfile(args.resume): + print("=> loading checkpoint '{}'".format(args.resume)) + if args.gpu is None: + checkpoint = torch.load(args.resume) + else: + # Map model to be loaded to specified single gpu. + loc = 'cuda:{}'.format(args.gpu) + checkpoint = torch.load(args.resume, map_location=loc) + args.start_epoch = checkpoint['epoch'] + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + print("=> loaded checkpoint '{}' (epoch {})" + .format(args.resume, checkpoint['epoch'])) + else: + print("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + cudnn.deterministic = True + + # Data loading code + pre_train_dir = os.path.join(args.data, 'train') + train_dir = os.path.join(args.data, 'train') + + if args.aug_plus: + # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709 + augmentation = aug.moco_v2() + else: + # MoCo v1's aug: same as InstDisc https://arxiv.org/abs/1805.01978 + augmentation = aug.moco_v1() + + # center-crop augmentation + eval_augmentation = aug.moco_eval() + + pre_train_dataset = pcl.loader.PreImager(pre_train_dir, eval_augmentation) + + train_dataset = pcl.loader.ImageFolderInstance(train_dir, eval_augmentation) + eval_dataset = pcl.loader.ImageFolderInstance( + train_dir, + eval_augmentation) + + if args.distributed: + pre_train_sampler = torch.utils.data.distributed.DistributedSampler(pre_train_dataset) + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) + eval_sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset, shuffle=False) + else: + pre_train_sampler = None + train_sampler = None + eval_sampler = None + + if args.batch_size//pre_train_dataset.class_number < 2: + raise NotImplementedError("Batch size must above double number of classes.") + + pre_train_loader = torch.utils.data.DataLoader( + pre_train_dataset, + batch_size=args.batch_size//pre_train_dataset.class_number, + shuffle=(pre_train_sampler is None), + num_workers=args.workers, + pin_memory=True, + sampler=pre_train_sampler, + drop_last=True) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) + + # dataloader for center-cropped images, use larger batch size to increase speed + eval_loader = torch.utils.data.DataLoader( + eval_dataset, batch_size=args.batch_size * 5, shuffle=False, + sampler=eval_sampler, num_workers=args.workers, pin_memory=True) + + print("=> Pre-train") + # main loop + for epoch in range(args.start_epoch, args.epochs): + + cluster_result = None + if epoch >= args.warmup_epoch: + # compute momentum features for center-cropped images + features = momentum.compute_features(eval_loader, model, args) + + # placeholder for clustering result + cluster_result = {'im2cluster': [], 'centroids': [], 'density': []} + for num_cluster in args.num_cluster: + cluster_result['im2cluster'].append(torch.zeros(len(eval_dataset), dtype=torch.long).cuda()) + cluster_result['centroids'].append(torch.zeros(int(num_cluster), args.low_dim).cuda()) + cluster_result['density'].append(torch.zeros(int(num_cluster)).cuda()) + + if args.gpu == 0: + features[ + torch.norm(features, dim=1) > 1.5] /= 2 # account for the few samples that are computed twice + features = features.numpy() + cluster_result = clustering.run_kmeans(features, args) # run kmeans clustering on master node + # save the clustering result + # torch.save(cluster_result,os.path.join(args.exp_dir, 'clusters_%d'%epoch)) + + dist.barrier() + # broadcast clustering result + for _k, data_list in cluster_result.items(): + for data_tensor in data_list: + dist.broadcast(data_tensor, 0, async_op=False) + + if args.distributed: + train_sampler.set_epoch(epoch) + adjust_learning_rate(optimizer, epoch, args) + + # train for one epoch + if epoch >= args.warmup_epoch: + train(train_loader, model, criterion, optimizer, epoch, args, cluster_result) + else: + train(pre_train_loader, model, criterion, optimizer, epoch, args, cluster_result) + + if (epoch + 1) % 10 == 0 and (not args.multiprocessing_distributed or (args.multiprocessing_distributed + and args.rank % ngpus_per_node == 0)): + save_checkpoint({ + 'epoch': epoch + 1, + 'arch': args.arch, + 'state_dict': model.state_dict(), + 'optimizer': optimizer.state_dict(), + }, is_best=False, filename='{}/checkpoint_{:04d}.pth.tar'.format(args.exp_dir, epoch)) + + +def train(train_loader, model, criterion, optimizer, epoch, args, cluster_result=None): + batch_time = AverageMeter('Time', ':6.3f') + data_time = AverageMeter('Data', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + acc_inst = AverageMeter('Acc@Inst', ':6.2f') + acc_proto = AverageMeter('Acc@Proto', ':6.2f') + + progress = ProgressMeter( + len(train_loader), + [batch_time, data_time, losses, acc_inst, acc_proto], + prefix="Epoch: [{}]".format(epoch)) + + # switch to train mode + model.train() + + end = time.time() + for i, (images, index) in enumerate(train_loader): + # measure data loading time + data_time.update(time.time() - end) + + im_q = [] + im_k = [] + class_number = len(images) + class_len = len(images[0]) + for _i in range(0, class_len, 2): + for c in range(class_number): + im_q.append(images[c][_i]) + im_k.append(images[c][_i+1]) + + im_q = torch.stack(im_q) + im_k = torch.stack(im_k) + + if args.gpu is not None: + im_q = im_q.cuda(args.gpu, non_blocking=True) + im_k = im_k.cuda(args.gpu, non_blocking=True) + + # compute output + output, target, output_proto, target_proto = model(im_q=im_q, im_k=im_k, + cluster_result=cluster_result, index=index) + + # InfoNCE loss + loss = criterion(output, target) + + # ProtoNCE loss + if output_proto is not None: + loss_proto = 0 + for proto_out, proto_target in zip(output_proto, target_proto): + loss_proto += criterion(proto_out, proto_target) + accp = accuracy(proto_out, proto_target)[0] + acc_proto.update(accp[0], images[0].size(0)) + + # average loss across all sets of prototypes + loss_proto /= len(args.num_cluster) + loss += loss_proto + + losses.update(loss.item(), images[0].size(0)) + acc = accuracy(output, target)[0] + acc_inst.update(acc[0], images[0].size(0)) + + # compute gradient and do SGD step + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + progress.display(i) + + +def adjust_learning_rate(optimizer, epoch, args): + """Decay the learning rate based on schedule""" + lr = args.lr + if args.cos: # cosine lr schedule + lr *= 0.5 * (1. + math.cos(math.pi * epoch / args.epochs)) + else: # stepwise lr schedule + for milestone in args.schedule: + lr *= 0.1 if epoch >= milestone else 1. + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): + torch.save(state, filename) + print("Model saved as:"+filename) + if is_best: + shutil.copyfile(filename, 'model_best.pth.tar') + + +if __name__ == '__main__': + main_loader() diff --git a/pcl/__init__.py b/pcl/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/pcl/__init__.py @@ -0,0 +1 @@ + diff --git a/pcl/builder.py b/pcl/builder.py new file mode 100644 index 0000000..c76590f --- /dev/null +++ b/pcl/builder.py @@ -0,0 +1,217 @@ +import torch +import torch.nn as nn +from random import sample + + +class MoCo(nn.Module): + """ + Build a MoCo model with: a query encoder, a key encoder, and a queue + https://arxiv.org/abs/1911.05722 + """ + def __init__(self, base_encoder, dim=128, r=16384, m=0.999, T=0.1, mlp=False): + """ + dim: feature dimension (default: 128) + r: queue size; number of negative samples/prototypes (default: 16384) + m: momentum for updating key encoder (default: 0.999) + T: softmax temperature + mlp: whether to use mlp projection + """ + super(MoCo, self).__init__() + + self.r = r + self.m = m + self.T = T + + # create the encoders + # num_classes is the output fc dimension + self.encoder_q = base_encoder(num_classes=dim) + self.encoder_k = base_encoder(num_classes=dim) + + if mlp: # hack: brute-force replacement + dim_mlp = self.encoder_q.fc.weight.shape[1] + self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc) + self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc) + + for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): + param_k.data.copy_(param_q.data) # initialize + param_k.requires_grad = False # not update by gradient + + # create the queue + self.register_buffer("queue", torch.randn(dim, r)) + self.queue = nn.functional.normalize(self.queue, dim=0) + + self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) + + @torch.no_grad() + def _momentum_update_key_encoder(self): + """ + Momentum update of the key encoder + """ + for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): + param_k.data = param_k.data * self.m + param_q.data * (1. - self.m) + + @torch.no_grad() + def _dequeue_and_enqueue(self, keys): + # gather keys before updating queue + keys = concat_all_gather(keys) + + batch_size = keys.shape[0] + + ptr = int(self.queue_ptr) + assert self.r % batch_size == 0 # for simplicity + + # replace the keys at ptr (dequeue and enqueue) + self.queue[:, ptr:ptr + batch_size] = keys.T + ptr = (ptr + batch_size) % self.r # move pointer + + self.queue_ptr[0] = ptr + + @torch.no_grad() + def _batch_shuffle_ddp(self, x): + """ + Batch shuffle, for making use of BatchNorm. + *** Only support DistributedDataParallel (DDP) model. *** + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = concat_all_gather(x) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # random shuffle index + idx_shuffle = torch.randperm(batch_size_all).cuda() + + # broadcast to all gpus + torch.distributed.broadcast(idx_shuffle, src=0) + + # index for restoring + idx_unshuffle = torch.argsort(idx_shuffle) + + # shuffled index for this gpu + gpu_idx = torch.distributed.get_rank() + idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this], idx_unshuffle + + @torch.no_grad() + def _batch_unshuffle_ddp(self, x, idx_unshuffle): + """ + Undo batch shuffle. + *** Only support DistributedDataParallel (DDP) model. *** + """ + # gather from all gpus + batch_size_this = x.shape[0] + x_gather = concat_all_gather(x) + batch_size_all = x_gather.shape[0] + + num_gpus = batch_size_all // batch_size_this + + # restored index for this gpu + gpu_idx = torch.distributed.get_rank() + idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] + + return x_gather[idx_this] + + def forward(self, im_q, im_k=None, is_eval=False, cluster_result=None, index=None): + """ + Input: + im_q: a batch of query images + im_k: a batch of key images + is_eval: return momentum embeddings (used for clustering) + cluster_result: cluster assignments, centroids, and density + index: indices for training samples + Output: + logits, targets, proto_logits, proto_targets + """ + + if is_eval: + k = self.encoder_k(im_q) + k = nn.functional.normalize(k, dim=1) + return k + + # compute key features + with torch.no_grad(): # no gradient to keys + self._momentum_update_key_encoder() # update the key encoder + + # shuffle for making use of BN + im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k) + + k = self.encoder_k(im_k) # keys: NxC + k = nn.functional.normalize(k, dim=1) + + # undo shuffle + k = self._batch_unshuffle_ddp(k, idx_unshuffle) + + # compute query features + q = self.encoder_q(im_q) # queries: NxC + q = nn.functional.normalize(q, dim=1) + + # compute logits + # Einstein sum is more intuitive + # positive logits: Nx1 + l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1) + # negative logits: Nxr + l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()]) + + # logits: Nx(1+r) + logits = torch.cat([l_pos, l_neg], dim=1) + + # apply temperature + logits /= self.T + + # labels: positive key indicators + labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda() + + # dequeue and enqueue + self._dequeue_and_enqueue(k) + + # prototypical contrast + if cluster_result is not None: + proto_labels = [] + proto_logits = [] + for n, (im2cluster, prototypes, density) in enumerate(zip(cluster_result['im2cluster'], + cluster_result['centroids'], + cluster_result['density'])): + # get positive prototypes + pos_proto_id = im2cluster[index] + pos_prototypes = prototypes[pos_proto_id] + + # sample negative prototypes + all_proto_id = [i for i in range(im2cluster.max()+1)] + neg_proto_id = set(all_proto_id)-set(pos_proto_id.tolist()) + neg_proto_id = sample(neg_proto_id,self.r) # sample r negative prototypes + neg_prototypes = prototypes[neg_proto_id] + + proto_selected = torch.cat([pos_prototypes,neg_prototypes],dim=0) + + # compute prototypical logits + logits_proto = torch.mm(q,proto_selected.t()) + + # targets for prototype assignment + labels_proto = torch.linspace(0, q.size(0)-1, steps=q.size(0)).long().cuda() + + # scaling temperatures for the selected prototypes + temp_proto = density[torch.cat([pos_proto_id,torch.LongTensor(neg_proto_id).cuda()],dim=0)] + logits_proto /= temp_proto + + proto_labels.append(labels_proto) + proto_logits.append(logits_proto) + return logits, labels, proto_logits, proto_labels + else: + return logits, labels, None, None + + +# utils +@torch.no_grad() +def concat_all_gather(tensor): + """ + Performs all_gather operation on the provided tensors. + *** Warning ***: torch.distributed.all_gather has no gradient. + """ + tensors_gather = [torch.ones_like(tensor) + for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather(tensors_gather, tensor, async_op=False) + + output = torch.cat(tensors_gather, dim=0) + return output diff --git a/pcl/loader.py b/pcl/loader.py new file mode 100644 index 0000000..b834d67 --- /dev/null +++ b/pcl/loader.py @@ -0,0 +1,68 @@ +from PIL import ImageFilter +import random +import torch.utils.data as tud +import torchvision.datasets as datasets +from torchvision.io import image + + +class PreImager(tud.Dataset): + def __init__(self, samples_dir, aug): + data_meta = datasets.ImageFolder(samples_dir) + images = data_meta.imgs + self.classes = data_meta.classes + self.class_number = len(self.classes) + self.class_to_index = data_meta.class_to_idx + img_class = [[] for i in range(self.class_number)] + + for img in images: + img_class[img[1]].append(img[0]) + lens = [len(c) for c in img_class] + self.length = min(lens) + self.samples_dir = samples_dir + + self.aug = aug + self.images = img_class + + def __getitem__(self, index): + imgs = [] + for i in range(self.class_number): + img = image.read_image(self.images[i][index]).float() + out = self.aug(img) + imgs.append(out) + return imgs, index + + def __len__(self): + return self.length + + +class TwoCropsTransform: + """Take two random crops of one image as the query and key.""" + + def __init__(self, base_transform): + self.base_transform = base_transform + + def __call__(self, x): + q = self.base_transform(x) + k = self.base_transform(x) + return [q, k] + + +class GaussianBlur(object): + """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" + + def __init__(self, sigma=[.1, 2.]): + self.sigma = sigma + + def __call__(self, x): + sigma = random.uniform(self.sigma[0], self.sigma[1]) + x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) + return x + + +class ImageFolderInstance(datasets.ImageFolder): + def __getitem__(self, index): + path, target = self.samples[index] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + return sample, index \ No newline at end of file diff --git a/scripts/augmentation.py b/scripts/augmentation.py new file mode 100644 index 0000000..9dc0974 --- /dev/null +++ b/scripts/augmentation.py @@ -0,0 +1,39 @@ +import torchvision.transforms as transforms +import pcl.loader + +normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + + +def moco_v2(): + return [ + transforms.RandomResizedCrop(224, scale=(0.2, 1.)), + transforms.RandomApply([ + transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened + ], p=0.8), + transforms.RandomGrayscale(p=0.2), + transforms.RandomApply([pcl.loader.GaussianBlur([.1, 2.])], p=0.5), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize + ] + + +def moco_v1(): + return [ + transforms.RandomResizedCrop(224, scale=(0.2, 1.)), + transforms.RandomGrayscale(p=0.2), + transforms.ColorJitter(0.4, 0.4, 0.4, 0.4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize + ] + + +def moco_eval(): + return transforms.Compose([ + transforms.Resize([512, 512]), + # transforms.CenterCrop(512), + transforms.ToTensor(), + normalize + ]) diff --git a/scripts/clustering.py b/scripts/clustering.py new file mode 100644 index 0000000..2675695 --- /dev/null +++ b/scripts/clustering.py @@ -0,0 +1,74 @@ +import faiss +import torch +import torch.nn as nn +import numpy as np + + +def run_kmeans(x, args): + """ + Args: + x: data to be clustered + args: + """ + + print('-> Performing kmeans clustering') + results = {'im2cluster': [], 'centroids': [], 'density': []} + + for seed, num_cluster in enumerate(args.num_cluster): + # intialize faiss clustering parameters + print("\tnum_cluster:" + str(num_cluster) + "...", end="") + d = x.shape[1] + k = int(num_cluster) + + clus = faiss.Kmeans(d, k, gpu=True) + clus.verbose = True + clus.niter = 20 + clus.nredo = 5 + clus.seed = seed + clus.max_points_per_centroid = 100 + clus.min_points_per_centroid = 10 + + clus.train(x) + + D, I = clus.index.search(x, 1) # for each sample, find cluster distance and assignments + im2cluster = [int(n[0]) for n in I] + + # get cluster centroids + # print(type(clus.centroids)) + centroids = clus.centroids.reshape(k, d) + + # sample-to-centroid distances for each cluster + Dcluster = [[] for c in range(k)] + for im, i in enumerate(im2cluster): + Dcluster[i].append(D[im][0]) + + # concentration estimation (phi) + density = np.zeros(k) + for i, dist in enumerate(Dcluster): + if len(dist) > 1: + d = (np.asarray(dist) ** 0.5).mean() / np.log(len(dist) + 10) + density[i] = d + + # if cluster only has one point, use the max to estimate its concentration + dmax = density.max(axis=None) + for i, dist in enumerate(Dcluster): + if len(dist) <= 1: + density[i] = dmax + + density = density.clip(np.percentile(density, 10), + np.percentile(density, 90)) # clamp extreme values for stability + density = args.temperature * density / density.mean() # scale the mean to temperature + + # convert to cuda Tensors for broadcast + centroids = torch.Tensor(centroids).cuda() + centroids = nn.functional.normalize(centroids, p=2, dim=1) + + im2cluster = torch.LongTensor(im2cluster).cuda() + density = torch.Tensor(density).cuda() + + results['centroids'].append(centroids) + results['density'].append(density) + results['im2cluster'].append(im2cluster) + print("ok") + + return results diff --git a/scripts/meter.py b/scripts/meter.py new file mode 100644 index 0000000..73ad146 --- /dev/null +++ b/scripts/meter.py @@ -0,0 +1,40 @@ +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' \ No newline at end of file diff --git a/scripts/momentum.py b/scripts/momentum.py new file mode 100644 index 0000000..5a4c91c --- /dev/null +++ b/scripts/momentum.py @@ -0,0 +1,17 @@ +from tqdm import tqdm +import torch +import torch.distributed + + +def compute_features(eval_loader, model, args): + print('-> Computing features') + model.eval() + features = torch.zeros(len(eval_loader.dataset), args.low_dim).cuda() + for i, (images, index) in enumerate(tqdm(eval_loader)): + with torch.no_grad(): + images = images.cuda(non_blocking=True) + feat = model(images, is_eval=True) + features[index] = feat + torch.distributed.barrier() + torch.distributed.all_reduce(features, op=torch.distributed.ReduceOp.SUM) + return features.cpu() diff --git a/scripts/parser.py b/scripts/parser.py new file mode 100644 index 0000000..d65726c --- /dev/null +++ b/scripts/parser.py @@ -0,0 +1,83 @@ +import argparse +import torchvision.models as models + + +def parser(): + model_names = sorted(name for name in models.__dict__ + if name.islower() and not name.startswith("__") + and callable(models.__dict__[name])) + _parser = argparse.ArgumentParser(description='PyTorch ImageNet Training PCL') + _parser.add_argument('data', metavar='DIR', + help='path to dataset') + _parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet50', + choices=model_names, + help='model architecture: ' + + ' | '.join(model_names) + + ' (default: resnet50)') + _parser.add_argument('-j', '--workers', default=32, type=int, metavar='N', + help='number of data loading workers (default: 32)') + _parser.add_argument('--epochs', default=200, type=int, metavar='N', + help='number of total epochs to run') + _parser.add_argument('--start-epoch', default=0, type=int, metavar='N', + help='manual epoch number (useful on restarts)') + _parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', + help='mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel') + _parser.add_argument('--lr', '--learning-rate', default=0.03, type=float, + metavar='LR', help='initial learning rate', dest='lr') + _parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int, + help='learning rate schedule (when to drop lr by 10x)') + _parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum of SGD solver') + _parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') + _parser.add_argument('-p', '--print-freq', default=100, type=int, + metavar='N', help='print frequency (default: 10)') + _parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') + _parser.add_argument('--world-size', default=-1, type=int, + help='number of nodes for distributed training') + _parser.add_argument('--rank', default=-1, type=int, + help='node rank for distributed training') + _parser.add_argument('--dist-url', default='tcp://172.0.0.1:23456', type=str, + help='url used to set up distributed training') + _parser.add_argument('--dist-backend', default='nccl', type=str, + help='distributed backend') + _parser.add_argument('--seed', default=None, type=int, + help='seed for initializing training. ') + _parser.add_argument('--gpu', default=None, type=int, + help='GPU id to use.') + _parser.add_argument('--multiprocessing-distributed', action='store_true', + help='Use multi-processing distributed training to launch ' + 'N processes per node, which has N GPUs. This is the ' + 'fastest way to use PyTorch for either single node or ' + 'multi node data parallel training') + + _parser.add_argument('--low-dim', default=128, type=int, + help='feature dimension (default: 128)') + _parser.add_argument('--pcl-r', default=16384, type=int, + help='queue size; number of negative pairs; needs to be smaller than num_cluster (default: ' + '16384)') + _parser.add_argument('--moco-m', default=0.999, type=float, + help='moco momentum of updating key encoder (default: 0.999)') + _parser.add_argument('--temperature', default=0.2, type=float, + help='softmax temperature') + + _parser.add_argument('--mlp', action='store_true', + help='use mlp head') + _parser.add_argument('--aug-plus', action='store_true', + help='use moco-v2/SimCLR data augmentation') + _parser.add_argument('--cos', action='store_true', + help='use cosine lr schedule') + + _parser.add_argument('--num-cluster', default='25000,50000,100000', type=str, + help='number of clusters') + _parser.add_argument('--warmup-epoch', default=20, type=int, + help='number of warm-up epochs to only train with InfoNCE loss') + _parser.add_argument('--exp-dir', default='experiment_pcl', type=str, + help='experiment directory') + + return _parser