You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

74 lines
2.3 KiB

import faiss
import torch
import torch.nn as nn
import numpy as np
def run_kmeans(x, args):
"""
Args:
x: data to be clustered
args:
"""
print('-> Performing kmeans clustering')
results = {'im2cluster': [], 'centroids': [], 'density': []}
for seed, num_cluster in enumerate(args.num_cluster):
# intialize faiss clustering parameters
print("\tnum_cluster:" + str(num_cluster) + "...", end="")
d = x.shape[1]
k = int(num_cluster)
clus = faiss.Kmeans(d, k, gpu=True)
clus.verbose = True
clus.niter = 20
clus.nredo = 5
clus.seed = seed
clus.max_points_per_centroid = 100
clus.min_points_per_centroid = 10
clus.train(x)
D, I = clus.index.search(x, 1) # for each sample, find cluster distance and assignments
im2cluster = [int(n[0]) for n in I]
# get cluster centroids
# print(type(clus.centroids))
centroids = clus.centroids.reshape(k, d)
# sample-to-centroid distances for each cluster
Dcluster = [[] for c in range(k)]
for im, i in enumerate(im2cluster):
Dcluster[i].append(D[im][0])
# concentration estimation (phi)
density = np.zeros(k)
for i, dist in enumerate(Dcluster):
if len(dist) > 1:
d = (np.asarray(dist) ** 0.5).mean() / np.log(len(dist) + 10)
density[i] = d
# if cluster only has one point, use the max to estimate its concentration
dmax = density.max(axis=None)
for i, dist in enumerate(Dcluster):
if len(dist) <= 1:
density[i] = dmax
density = density.clip(np.percentile(density, 10),
np.percentile(density, 90)) # clamp extreme values for stability
density = args.temperature * density / density.mean() # scale the mean to temperature
# convert to cuda Tensors for broadcast
centroids = torch.Tensor(centroids).cuda()
centroids = nn.functional.normalize(centroids, p=2, dim=1)
im2cluster = torch.LongTensor(im2cluster).cuda()
density = torch.Tensor(density).cuda()
results['centroids'].append(centroids)
results['density'].append(density)
results['im2cluster'].append(im2cluster)
print("ok")
return results