You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
74 lines
2.3 KiB
74 lines
2.3 KiB
import faiss
|
|
import torch
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
|
|
def run_kmeans(x, args):
|
|
"""
|
|
Args:
|
|
x: data to be clustered
|
|
args:
|
|
"""
|
|
|
|
print('-> Performing kmeans clustering')
|
|
results = {'im2cluster': [], 'centroids': [], 'density': []}
|
|
|
|
for seed, num_cluster in enumerate(args.num_cluster):
|
|
# intialize faiss clustering parameters
|
|
print("\tnum_cluster:" + str(num_cluster) + "...", end="")
|
|
d = x.shape[1]
|
|
k = int(num_cluster)
|
|
|
|
clus = faiss.Kmeans(d, k, gpu=True)
|
|
clus.verbose = True
|
|
clus.niter = 20
|
|
clus.nredo = 5
|
|
clus.seed = seed
|
|
clus.max_points_per_centroid = 100
|
|
clus.min_points_per_centroid = 10
|
|
|
|
clus.train(x)
|
|
|
|
D, I = clus.index.search(x, 1) # for each sample, find cluster distance and assignments
|
|
im2cluster = [int(n[0]) for n in I]
|
|
|
|
# get cluster centroids
|
|
# print(type(clus.centroids))
|
|
centroids = clus.centroids.reshape(k, d)
|
|
|
|
# sample-to-centroid distances for each cluster
|
|
Dcluster = [[] for c in range(k)]
|
|
for im, i in enumerate(im2cluster):
|
|
Dcluster[i].append(D[im][0])
|
|
|
|
# concentration estimation (phi)
|
|
density = np.zeros(k)
|
|
for i, dist in enumerate(Dcluster):
|
|
if len(dist) > 1:
|
|
d = (np.asarray(dist) ** 0.5).mean() / np.log(len(dist) + 10)
|
|
density[i] = d
|
|
|
|
# if cluster only has one point, use the max to estimate its concentration
|
|
dmax = density.max(axis=None)
|
|
for i, dist in enumerate(Dcluster):
|
|
if len(dist) <= 1:
|
|
density[i] = dmax
|
|
|
|
density = density.clip(np.percentile(density, 10),
|
|
np.percentile(density, 90)) # clamp extreme values for stability
|
|
density = args.temperature * density / density.mean() # scale the mean to temperature
|
|
|
|
# convert to cuda Tensors for broadcast
|
|
centroids = torch.Tensor(centroids).cuda()
|
|
centroids = nn.functional.normalize(centroids, p=2, dim=1)
|
|
|
|
im2cluster = torch.LongTensor(im2cluster).cuda()
|
|
density = torch.Tensor(density).cuda()
|
|
|
|
results['centroids'].append(centroids)
|
|
results['density'].append(density)
|
|
results['im2cluster'].append(im2cluster)
|
|
print("ok")
|
|
|
|
return results
|
|
|