You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

256 lines
8.9 KiB

from __future__ import print_function
import os
import torch
from copy import deepcopy
import numpy as np
from torchvision import transforms, datasets
import torchvision.models as models
from sklearn.svm import LinearSVC, SVC
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score, precision_score, f1_score
from matplotlib import pyplot as plt
from scripts.parser import parser
import scripts.augmentation as aug
import pcl.loader
import matplotlib.font_manager
# 通过字体文件添加字体
matplotlib.font_manager.fontManager.addfont('simsun.ttc')
def calculate_ap(rec, prec):
"""
Computes the AP under the precision recall curve.
"""
rec, prec = rec.reshape(rec.size, 1), prec.reshape(prec.size, 1)
z, o = np.zeros((1, 1)), np.ones((1, 1))
mrec, mpre = np.vstack((z, rec, o)), np.vstack((z, prec, z))
for i in range(len(mpre) - 2, -1, -1):
mpre[i] = max(mpre[i], mpre[i + 1])
indices = np.where(mrec[1:] != mrec[0:-1])[0] + 1
ap = 0
for i in indices:
ap = ap + (mrec[i] - mrec[i - 1]) * mpre[i]
return ap
def get_precision_recall(targets, preds):
"""
[P, R, score, ap] = get_precision_recall(targets, preds)
Input :
targets : number of occurrences of this class in the ith image
preds : score for this image
Output :
P, R : precision and recall
score : score which corresponds to the particular precision and recall
ap : average precision
"""
# binarize targets
targets = np.array(targets > 0, dtype=np.float32)
tog = np.hstack((
targets[:, np.newaxis].astype(np.float64),
preds[:, np.newaxis].astype(np.float64)
))
ind = np.argsort(preds)
ind = ind[::-1]
score = np.array([tog[i, 1] for i in ind])
sortcounts = np.array([tog[i, 0] for i in ind])
tp = sortcounts
fp = sortcounts.copy()
for i in range(sortcounts.shape[0]):
if sortcounts[i] >= 1:
fp[i] = 0.
elif sortcounts[i] < 1:
fp[i] = 1.
P = np.cumsum(tp) / (np.cumsum(tp) + np.cumsum(fp))
numinst = np.sum(targets)
R = np.cumsum(tp) / numinst
ap = calculate_ap(R, P)
return P, R, score, ap
def main():
args = parser().parse_args()
# if not args.seed is None:
# random.seed(args.seed)
# np.random.seed(args.seed)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
normalize = transforms.Normalize(mean=mean, std=std)
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
eval_augmentation = aug.moco_eval()
pre_train_dir = os.path.join(args.data, 'pre_train')
eval_dir = os.path.join(args.data, 'test')
train_dataset = pcl.loader.TenserImager(pre_train_dir, eval_augmentation)
val_dataset = pcl.loader.TenserImager(
eval_dir,
eval_augmentation)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
# create model
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch](num_classes=128)
# load from pre-trained
if args.pretrained:
if os.path.isfile(args.pretrained):
print("=> loading checkpoint '{}'".format(args.pretrained))
checkpoint = torch.load(args.pretrained, map_location="cpu")
state_dict = checkpoint['state_dict']
# rename pre-trained keys
for k in list(state_dict.keys()):
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
# remove prefix
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
# delete renamed or unused k
del state_dict[k]
model.load_state_dict(state_dict, strict=False)
model.fc = torch.nn.Identity()
print("=> loaded pre-trained model '{}'".format(args.pretrained))
else:
print("=> no checkpoint found at '{}'".format(args.pretrained))
model.cuda()
model.eval()
test_feats = []
test_labels = []
print('==> calculate test features')
for idx, (images, target) in enumerate(val_loader):
images = images.cuda(non_blocking=True)
feat = model(images)
feat = feat.detach().cpu()
test_feats.append(feat)
test_labels.append(target)
test_feats = torch.cat(test_feats, 0).numpy()
test_labels = torch.cat(test_labels, 0).numpy()
test_feats_norm = np.linalg.norm(test_feats, axis=1)
test_feats = test_feats / (test_feats_norm + 1e-5)[:, np.newaxis]
result = {}
k_list = ['full']
for k in k_list:
cost_list = args.cost.split(',')
result_k = np.zeros(len(cost_list))
for i, cost in enumerate(cost_list):
cost = float(cost)
avg_map = []
for run in range(args.n_run):
print(len(train_dataset))
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
classes = len(train_dataset.classes)
classes_name = train_dataset.classes
train_feats = []
train_labels = []
print('==> calculate train features')
for idx, (images, target) in enumerate(train_loader):
images = images.cuda(non_blocking=True)
feat = model(images)
feat = feat.detach()
train_feats.append(feat)
train_labels.append(target)
train_feats = torch.cat(train_feats, 0).cpu().numpy()
train_labels = torch.cat(train_labels, 0).cpu().numpy()
train_feats_norm = np.linalg.norm(train_feats, axis=1)
train_feats = train_feats / (train_feats_norm + 1e-5)[:, np.newaxis]
print('==> training SVM Classifier')
#test_labels[test_labels == 0] = -1
#train_labels[train_labels == 0] = -1
clf = OneVsRestClassifier(LinearSVC(
C=cost, # class_weight={1: 2, -1: 1},
intercept_scaling=1.0,
penalty='l2', loss='squared_hinge', tol=1e-4,
dual=True, max_iter=2000, random_state=0))
clf.fit(train_feats, train_labels)
prediction = clf.decision_function(test_feats)
predict = clf.predict(test_feats)
plt.figure(1)
plt.rcParams['font.sans-serif'] = ['simsun']
plt.figure(2)
plt.rcParams['font.sans-serif'] = ['simsun']
list_ap = []
list_auc = []
for cl in range(classes):
t_labels,t_pre = deepcopy(test_labels),deepcopy(predict)
t_labels[t_labels != cl] = -1
t_labels[t_labels == cl] = 1
clf = LinearSVC()
clf.fit(test_feats, t_labels)
t_pre=clf.predict(test_feats)
P, R, score, ap = get_precision_recall(t_labels, t_pre)
fpr, tpr, thres =roc_curve(t_labels, t_pre)
auc = roc_auc_score(t_labels, t_pre)
list_ap.append(ap)
list_auc.append(auc)
plt.figure(1)
plt.plot(R, P)
plt.figure(2)
plt.plot(fpr, tpr)
plt.figure(1)
plt.xlabel('召回率', fontsize=14)
plt.ylabel('精准率', fontsize=14)
plt.legend(classes_name)
plt.savefig("PR.png")
plt.figure(2)
plt.xlabel('假正率', fontsize=14)
plt.ylabel('真正率', fontsize=14)
plt.legend(classes_name)
plt.savefig("ROC.png")
print(classes_name)
confusion = confusion_matrix(test_labels, predict)
print(confusion)
mean_ap = np.mean(list_ap) * 100
print('==> Run%d\nmAP is %f ' % (run, mean_ap))
print("AP: " + str(precision_score(test_labels, predict, average="weighted")))
print("mAUC: "+ str(np.mean(list_auc)))
print("F1_score: " + str(f1_score(test_labels, predict, average="weighted")))
avg_map.append(mean_ap)
avg_map = np.asarray(avg_map)
#print('Cost:%.2f - Average ap is: %.2f' % (cost, avg_map.mean()))
#print('Cost:%.2f - Std is: %.2f' % (cost, avg_map.std()))
result_k[i] = avg_map.mean()
result[k] = result_k.max()
print(result)
if __name__ == '__main__':
main()