You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
217 lines
7.7 KiB
217 lines
7.7 KiB
import torch
|
|
import torch.nn as nn
|
|
from random import sample
|
|
|
|
|
|
class MoCo(nn.Module):
|
|
"""
|
|
Build a MoCo model with: a query encoder, a key encoder, and a queue
|
|
https://arxiv.org/abs/1911.05722
|
|
"""
|
|
def __init__(self, base_encoder, dim=128, r=16384, m=0.999, T=0.1, mlp=False):
|
|
"""
|
|
dim: feature dimension (default: 128)
|
|
r: queue size; number of negative samples/prototypes (default: 16384)
|
|
m: momentum for updating key encoder (default: 0.999)
|
|
T: softmax temperature
|
|
mlp: whether to use mlp projection
|
|
"""
|
|
super(MoCo, self).__init__()
|
|
|
|
self.r = r
|
|
self.m = m
|
|
self.T = T
|
|
|
|
# create the encoders
|
|
# num_classes is the output fc dimension
|
|
self.encoder_q = base_encoder(num_classes=dim)
|
|
self.encoder_k = base_encoder(num_classes=dim)
|
|
|
|
if mlp: # hack: brute-force replacement
|
|
dim_mlp = self.encoder_q.fc.weight.shape[1]
|
|
self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc)
|
|
self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc)
|
|
|
|
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
|
|
param_k.data.copy_(param_q.data) # initialize
|
|
param_k.requires_grad = False # not update by gradient
|
|
|
|
# create the queue
|
|
self.register_buffer("queue", torch.randn(dim, r))
|
|
self.queue = nn.functional.normalize(self.queue, dim=0)
|
|
|
|
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
|
|
|
|
@torch.no_grad()
|
|
def _momentum_update_key_encoder(self):
|
|
"""
|
|
Momentum update of the key encoder
|
|
"""
|
|
for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
|
|
param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)
|
|
|
|
@torch.no_grad()
|
|
def _dequeue_and_enqueue(self, keys):
|
|
# gather keys before updating queue
|
|
keys = concat_all_gather(keys)
|
|
|
|
batch_size = keys.shape[0]
|
|
|
|
ptr = int(self.queue_ptr)
|
|
assert self.r % batch_size == 0 # for simplicity
|
|
|
|
# replace the keys at ptr (dequeue and enqueue)
|
|
self.queue[:, ptr:ptr + batch_size] = keys.T
|
|
ptr = (ptr + batch_size) % self.r # move pointer
|
|
|
|
self.queue_ptr[0] = ptr
|
|
|
|
@torch.no_grad()
|
|
def _batch_shuffle_ddp(self, x):
|
|
"""
|
|
Batch shuffle, for making use of BatchNorm.
|
|
*** Only support DistributedDataParallel (DDP) model. ***
|
|
"""
|
|
# gather from all gpus
|
|
batch_size_this = x.shape[0]
|
|
x_gather = concat_all_gather(x)
|
|
batch_size_all = x_gather.shape[0]
|
|
|
|
num_gpus = batch_size_all // batch_size_this
|
|
|
|
# random shuffle index
|
|
idx_shuffle = torch.randperm(batch_size_all).cuda()
|
|
|
|
# broadcast to all gpus
|
|
torch.distributed.broadcast(idx_shuffle, src=0)
|
|
|
|
# index for restoring
|
|
idx_unshuffle = torch.argsort(idx_shuffle)
|
|
|
|
# shuffled index for this gpu
|
|
gpu_idx = torch.distributed.get_rank()
|
|
idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
|
|
|
|
return x_gather[idx_this], idx_unshuffle
|
|
|
|
@torch.no_grad()
|
|
def _batch_unshuffle_ddp(self, x, idx_unshuffle):
|
|
"""
|
|
Undo batch shuffle.
|
|
*** Only support DistributedDataParallel (DDP) model. ***
|
|
"""
|
|
# gather from all gpus
|
|
batch_size_this = x.shape[0]
|
|
x_gather = concat_all_gather(x)
|
|
batch_size_all = x_gather.shape[0]
|
|
|
|
num_gpus = batch_size_all // batch_size_this
|
|
|
|
# restored index for this gpu
|
|
gpu_idx = torch.distributed.get_rank()
|
|
idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
|
|
|
|
return x_gather[idx_this]
|
|
|
|
def forward(self, im_q, im_k=None, is_eval=False, cluster_result=None, index=None):
|
|
"""
|
|
Input:
|
|
im_q: a batch of query images
|
|
im_k: a batch of key images
|
|
is_eval: return momentum embeddings (used for clustering)
|
|
cluster_result: cluster assignments, centroids, and density
|
|
index: indices for training samples
|
|
Output:
|
|
logits, targets, proto_logits, proto_targets
|
|
"""
|
|
|
|
if is_eval:
|
|
k = self.encoder_k(im_q)
|
|
k = nn.functional.normalize(k, dim=1)
|
|
return k
|
|
|
|
# compute key features
|
|
with torch.no_grad(): # no gradient to keys
|
|
self._momentum_update_key_encoder() # update the key encoder
|
|
|
|
# shuffle for making use of BN
|
|
im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)
|
|
|
|
k = self.encoder_k(im_k) # keys: NxC
|
|
k = nn.functional.normalize(k, dim=1)
|
|
|
|
# undo shuffle
|
|
k = self._batch_unshuffle_ddp(k, idx_unshuffle)
|
|
|
|
# compute query features
|
|
q = self.encoder_q(im_q) # queries: NxC
|
|
q = nn.functional.normalize(q, dim=1)
|
|
|
|
# compute logits
|
|
# Einstein sum is more intuitive
|
|
# positive logits: Nx1
|
|
l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
|
|
# negative logits: Nxr
|
|
l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])
|
|
|
|
# logits: Nx(1+r)
|
|
logits = torch.cat([l_pos, l_neg], dim=1)
|
|
|
|
# apply temperature
|
|
logits /= self.T
|
|
|
|
# labels: positive key indicators
|
|
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
|
|
|
|
# dequeue and enqueue
|
|
self._dequeue_and_enqueue(k)
|
|
|
|
# prototypical contrast
|
|
if cluster_result is not None:
|
|
proto_labels = []
|
|
proto_logits = []
|
|
for n, (im2cluster, prototypes, density) in enumerate(zip(cluster_result['im2cluster'],
|
|
cluster_result['centroids'],
|
|
cluster_result['density'])):
|
|
# get positive prototypes
|
|
pos_proto_id = im2cluster[index]
|
|
pos_prototypes = prototypes[pos_proto_id]
|
|
|
|
# sample negative prototypes
|
|
all_proto_id = [i for i in range(im2cluster.max()+1)]
|
|
neg_proto_id = set(all_proto_id)-set(pos_proto_id.tolist())
|
|
neg_proto_id = sample(neg_proto_id,self.r) # sample r negative prototypes
|
|
neg_prototypes = prototypes[neg_proto_id]
|
|
|
|
proto_selected = torch.cat([pos_prototypes,neg_prototypes],dim=0)
|
|
|
|
# compute prototypical logits
|
|
logits_proto = torch.mm(q,proto_selected.t())
|
|
|
|
# targets for prototype assignment
|
|
labels_proto = torch.linspace(0, q.size(0)-1, steps=q.size(0)).long().cuda()
|
|
|
|
# scaling temperatures for the selected prototypes
|
|
temp_proto = density[torch.cat([pos_proto_id,torch.LongTensor(neg_proto_id).cuda()],dim=0)]
|
|
logits_proto /= temp_proto
|
|
|
|
proto_labels.append(labels_proto)
|
|
proto_logits.append(logits_proto)
|
|
return logits, labels, proto_logits, proto_labels
|
|
else:
|
|
return logits, labels, None, None
|
|
|
|
|
|
# utils
|
|
@torch.no_grad()
|
|
def concat_all_gather(tensor):
|
|
"""
|
|
Performs all_gather operation on the provided tensors.
|
|
*** Warning ***: torch.distributed.all_gather has no gradient.
|
|
"""
|
|
tensors_gather = [torch.ones_like(tensor)
|
|
for _ in range(torch.distributed.get_world_size())]
|
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
|
|
|
|
output = torch.cat(tensors_gather, dim=0)
|
|
return output
|
|
|