You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
256 lines
8.9 KiB
256 lines
8.9 KiB
from __future__ import print_function
|
|
|
|
import os
|
|
import torch
|
|
from copy import deepcopy
|
|
import numpy as np
|
|
|
|
from torchvision import transforms, datasets
|
|
import torchvision.models as models
|
|
|
|
from sklearn.svm import LinearSVC, SVC
|
|
from sklearn.multiclass import OneVsRestClassifier
|
|
from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score, precision_score, f1_score
|
|
|
|
from matplotlib import pyplot as plt
|
|
|
|
from scripts.parser import parser
|
|
import scripts.augmentation as aug
|
|
import pcl.loader
|
|
|
|
import matplotlib.font_manager
|
|
# 通过字体文件添加字体
|
|
matplotlib.font_manager.fontManager.addfont('simsun.ttc')
|
|
|
|
def calculate_ap(rec, prec):
|
|
"""
|
|
Computes the AP under the precision recall curve.
|
|
"""
|
|
rec, prec = rec.reshape(rec.size, 1), prec.reshape(prec.size, 1)
|
|
z, o = np.zeros((1, 1)), np.ones((1, 1))
|
|
mrec, mpre = np.vstack((z, rec, o)), np.vstack((z, prec, z))
|
|
for i in range(len(mpre) - 2, -1, -1):
|
|
mpre[i] = max(mpre[i], mpre[i + 1])
|
|
|
|
indices = np.where(mrec[1:] != mrec[0:-1])[0] + 1
|
|
ap = 0
|
|
for i in indices:
|
|
ap = ap + (mrec[i] - mrec[i - 1]) * mpre[i]
|
|
return ap
|
|
|
|
|
|
def get_precision_recall(targets, preds):
|
|
"""
|
|
[P, R, score, ap] = get_precision_recall(targets, preds)
|
|
Input :
|
|
targets : number of occurrences of this class in the ith image
|
|
preds : score for this image
|
|
Output :
|
|
P, R : precision and recall
|
|
score : score which corresponds to the particular precision and recall
|
|
ap : average precision
|
|
"""
|
|
# binarize targets
|
|
targets = np.array(targets > 0, dtype=np.float32)
|
|
tog = np.hstack((
|
|
targets[:, np.newaxis].astype(np.float64),
|
|
preds[:, np.newaxis].astype(np.float64)
|
|
))
|
|
ind = np.argsort(preds)
|
|
ind = ind[::-1]
|
|
score = np.array([tog[i, 1] for i in ind])
|
|
sortcounts = np.array([tog[i, 0] for i in ind])
|
|
|
|
tp = sortcounts
|
|
fp = sortcounts.copy()
|
|
for i in range(sortcounts.shape[0]):
|
|
if sortcounts[i] >= 1:
|
|
fp[i] = 0.
|
|
elif sortcounts[i] < 1:
|
|
fp[i] = 1.
|
|
P = np.cumsum(tp) / (np.cumsum(tp) + np.cumsum(fp))
|
|
numinst = np.sum(targets)
|
|
R = np.cumsum(tp) / numinst
|
|
ap = calculate_ap(R, P)
|
|
return P, R, score, ap
|
|
|
|
|
|
def main():
|
|
args = parser().parse_args()
|
|
|
|
# if not args.seed is None:
|
|
# random.seed(args.seed)
|
|
# np.random.seed(args.seed)
|
|
|
|
mean = [0.485, 0.456, 0.406]
|
|
std = [0.229, 0.224, 0.225]
|
|
normalize = transforms.Normalize(mean=mean, std=std)
|
|
transform = transforms.Compose([
|
|
transforms.Resize(256),
|
|
transforms.CenterCrop(224),
|
|
transforms.ToTensor(),
|
|
normalize,
|
|
])
|
|
|
|
eval_augmentation = aug.moco_eval()
|
|
|
|
pre_train_dir = os.path.join(args.data, 'pre_train')
|
|
eval_dir = os.path.join(args.data, 'test')
|
|
|
|
train_dataset = pcl.loader.TenserImager(pre_train_dir, eval_augmentation)
|
|
val_dataset = pcl.loader.TenserImager(
|
|
eval_dir,
|
|
eval_augmentation)
|
|
|
|
val_loader = torch.utils.data.DataLoader(
|
|
val_dataset, batch_size=args.batch_size, shuffle=False,
|
|
num_workers=args.workers, pin_memory=True)
|
|
|
|
# create model
|
|
print("=> creating model '{}'".format(args.arch))
|
|
model = models.__dict__[args.arch](num_classes=128)
|
|
|
|
# load from pre-trained
|
|
if args.pretrained:
|
|
if os.path.isfile(args.pretrained):
|
|
print("=> loading checkpoint '{}'".format(args.pretrained))
|
|
checkpoint = torch.load(args.pretrained, map_location="cpu")
|
|
state_dict = checkpoint['state_dict']
|
|
# rename pre-trained keys
|
|
for k in list(state_dict.keys()):
|
|
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
|
|
# remove prefix
|
|
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
|
|
# delete renamed or unused k
|
|
del state_dict[k]
|
|
model.load_state_dict(state_dict, strict=False)
|
|
model.fc = torch.nn.Identity()
|
|
print("=> loaded pre-trained model '{}'".format(args.pretrained))
|
|
else:
|
|
print("=> no checkpoint found at '{}'".format(args.pretrained))
|
|
|
|
model.cuda()
|
|
model.eval()
|
|
|
|
test_feats = []
|
|
test_labels = []
|
|
print('==> calculate test features')
|
|
for idx, (images, target) in enumerate(val_loader):
|
|
images = images.cuda(non_blocking=True)
|
|
feat = model(images)
|
|
feat = feat.detach().cpu()
|
|
test_feats.append(feat)
|
|
test_labels.append(target)
|
|
|
|
test_feats = torch.cat(test_feats, 0).numpy()
|
|
test_labels = torch.cat(test_labels, 0).numpy()
|
|
|
|
test_feats_norm = np.linalg.norm(test_feats, axis=1)
|
|
test_feats = test_feats / (test_feats_norm + 1e-5)[:, np.newaxis]
|
|
|
|
result = {}
|
|
|
|
k_list = ['full']
|
|
|
|
for k in k_list:
|
|
cost_list = args.cost.split(',')
|
|
result_k = np.zeros(len(cost_list))
|
|
for i, cost in enumerate(cost_list):
|
|
cost = float(cost)
|
|
avg_map = []
|
|
for run in range(args.n_run):
|
|
print(len(train_dataset))
|
|
|
|
train_loader = torch.utils.data.DataLoader(
|
|
train_dataset, batch_size=args.batch_size, shuffle=False,
|
|
num_workers=args.workers, pin_memory=True)
|
|
classes = len(train_dataset.classes)
|
|
classes_name = train_dataset.classes
|
|
train_feats = []
|
|
train_labels = []
|
|
print('==> calculate train features')
|
|
for idx, (images, target) in enumerate(train_loader):
|
|
images = images.cuda(non_blocking=True)
|
|
feat = model(images)
|
|
feat = feat.detach()
|
|
|
|
train_feats.append(feat)
|
|
train_labels.append(target)
|
|
|
|
train_feats = torch.cat(train_feats, 0).cpu().numpy()
|
|
train_labels = torch.cat(train_labels, 0).cpu().numpy()
|
|
|
|
train_feats_norm = np.linalg.norm(train_feats, axis=1)
|
|
train_feats = train_feats / (train_feats_norm + 1e-5)[:, np.newaxis]
|
|
|
|
print('==> training SVM Classifier')
|
|
#test_labels[test_labels == 0] = -1
|
|
#train_labels[train_labels == 0] = -1
|
|
clf = OneVsRestClassifier(LinearSVC(
|
|
C=cost, # class_weight={1: 2, -1: 1},
|
|
intercept_scaling=1.0,
|
|
penalty='l2', loss='squared_hinge', tol=1e-4,
|
|
dual=True, max_iter=2000, random_state=0))
|
|
clf.fit(train_feats, train_labels)
|
|
|
|
prediction = clf.decision_function(test_feats)
|
|
predict = clf.predict(test_feats)
|
|
|
|
|
|
plt.figure(1)
|
|
plt.rcParams['font.sans-serif'] = ['simsun']
|
|
plt.figure(2)
|
|
plt.rcParams['font.sans-serif'] = ['simsun']
|
|
|
|
list_ap = []
|
|
list_auc = []
|
|
for cl in range(classes):
|
|
t_labels,t_pre = deepcopy(test_labels),deepcopy(predict)
|
|
t_labels[t_labels != cl] = -1
|
|
t_labels[t_labels == cl] = 1
|
|
clf = LinearSVC()
|
|
clf.fit(test_feats, t_labels)
|
|
t_pre=clf.predict(test_feats)
|
|
|
|
P, R, score, ap = get_precision_recall(t_labels, t_pre)
|
|
fpr, tpr, thres =roc_curve(t_labels, t_pre)
|
|
auc = roc_auc_score(t_labels, t_pre)
|
|
list_ap.append(ap)
|
|
list_auc.append(auc)
|
|
plt.figure(1)
|
|
plt.plot(R, P)
|
|
plt.figure(2)
|
|
plt.plot(fpr, tpr)
|
|
|
|
plt.figure(1)
|
|
plt.xlabel('召回率', fontsize=14)
|
|
plt.ylabel('精准率', fontsize=14)
|
|
plt.legend(classes_name)
|
|
plt.savefig("PR.png")
|
|
|
|
plt.figure(2)
|
|
plt.xlabel('假正率', fontsize=14)
|
|
plt.ylabel('真正率', fontsize=14)
|
|
plt.legend(classes_name)
|
|
plt.savefig("ROC.png")
|
|
print(classes_name)
|
|
confusion = confusion_matrix(test_labels, predict)
|
|
print(confusion)
|
|
mean_ap = np.mean(list_ap) * 100
|
|
print('==> Run%d\nmAP is %f ' % (run, mean_ap))
|
|
print("AP: " + str(precision_score(test_labels, predict, average="weighted")))
|
|
print("mAUC: "+ str(np.mean(list_auc)))
|
|
print("F1_score: " + str(f1_score(test_labels, predict, average="weighted")))
|
|
avg_map.append(mean_ap)
|
|
|
|
avg_map = np.asarray(avg_map)
|
|
#print('Cost:%.2f - Average ap is: %.2f' % (cost, avg_map.mean()))
|
|
#print('Cost:%.2f - Std is: %.2f' % (cost, avg_map.std()))
|
|
result_k[i] = avg_map.mean()
|
|
result[k] = result_k.max()
|
|
print(result)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|
|
|
|
|