You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
17 lines
596 B
17 lines
596 B
from tqdm import tqdm
|
|
import torch
|
|
import torch.distributed
|
|
|
|
|
|
def compute_features(eval_loader, model, args):
|
|
print('-> Computing features')
|
|
model.eval()
|
|
features = torch.zeros(len(eval_loader.dataset), args.low_dim).cuda()
|
|
for i, (images, index) in enumerate(tqdm(eval_loader)):
|
|
with torch.no_grad():
|
|
images = images.cuda(non_blocking=True)
|
|
feat = model(images, is_eval=True)
|
|
features[index] = feat
|
|
torch.distributed.barrier()
|
|
torch.distributed.all_reduce(features, op=torch.distributed.ReduceOp.SUM)
|
|
return features.cpu()
|
|
|