You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

257 lines
8.9 KiB

2 years ago
from __future__ import print_function
import os
import torch
2 years ago
from copy import deepcopy
2 years ago
import numpy as np
from torchvision import transforms, datasets
import torchvision.models as models
2 years ago
from sklearn.svm import LinearSVC, SVC
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score, precision_score, f1_score
from matplotlib import pyplot as plt
2 years ago
from scripts.parser import parser
import scripts.augmentation as aug
import pcl.loader
2 years ago
import matplotlib.font_manager
# 通过字体文件添加字体
matplotlib.font_manager.fontManager.addfont('simsun.ttc')
2 years ago
def calculate_ap(rec, prec):
"""
Computes the AP under the precision recall curve.
"""
rec, prec = rec.reshape(rec.size, 1), prec.reshape(prec.size, 1)
z, o = np.zeros((1, 1)), np.ones((1, 1))
mrec, mpre = np.vstack((z, rec, o)), np.vstack((z, prec, z))
for i in range(len(mpre) - 2, -1, -1):
mpre[i] = max(mpre[i], mpre[i + 1])
indices = np.where(mrec[1:] != mrec[0:-1])[0] + 1
ap = 0
for i in indices:
ap = ap + (mrec[i] - mrec[i - 1]) * mpre[i]
return ap
def get_precision_recall(targets, preds):
"""
[P, R, score, ap] = get_precision_recall(targets, preds)
Input :
targets : number of occurrences of this class in the ith image
preds : score for this image
Output :
P, R : precision and recall
score : score which corresponds to the particular precision and recall
ap : average precision
"""
# binarize targets
targets = np.array(targets > 0, dtype=np.float32)
tog = np.hstack((
targets[:, np.newaxis].astype(np.float64),
preds[:, np.newaxis].astype(np.float64)
))
ind = np.argsort(preds)
ind = ind[::-1]
score = np.array([tog[i, 1] for i in ind])
sortcounts = np.array([tog[i, 0] for i in ind])
tp = sortcounts
fp = sortcounts.copy()
for i in range(sortcounts.shape[0]):
if sortcounts[i] >= 1:
fp[i] = 0.
elif sortcounts[i] < 1:
fp[i] = 1.
P = np.cumsum(tp) / (np.cumsum(tp) + np.cumsum(fp))
numinst = np.sum(targets)
R = np.cumsum(tp) / numinst
ap = calculate_ap(R, P)
return P, R, score, ap
def main():
args = parser().parse_args()
2 years ago
# if not args.seed is None:
# random.seed(args.seed)
# np.random.seed(args.seed)
2 years ago
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
normalize = transforms.Normalize(mean=mean, std=std)
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])
eval_augmentation = aug.moco_eval()
pre_train_dir = os.path.join(args.data, 'pre_train')
2 years ago
eval_dir = os.path.join(args.data, 'test')
2 years ago
2 years ago
train_dataset = pcl.loader.TenserImager(pre_train_dir, eval_augmentation)
val_dataset = pcl.loader.TenserImager(
2 years ago
eval_dir,
eval_augmentation)
val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
# create model
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch](num_classes=128)
# load from pre-trained
if args.pretrained:
if os.path.isfile(args.pretrained):
print("=> loading checkpoint '{}'".format(args.pretrained))
checkpoint = torch.load(args.pretrained, map_location="cpu")
state_dict = checkpoint['state_dict']
# rename pre-trained keys
for k in list(state_dict.keys()):
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
# remove prefix
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
# delete renamed or unused k
del state_dict[k]
model.load_state_dict(state_dict, strict=False)
model.fc = torch.nn.Identity()
print("=> loaded pre-trained model '{}'".format(args.pretrained))
else:
print("=> no checkpoint found at '{}'".format(args.pretrained))
model.cuda()
model.eval()
test_feats = []
test_labels = []
print('==> calculate test features')
for idx, (images, target) in enumerate(val_loader):
images = images.cuda(non_blocking=True)
feat = model(images)
feat = feat.detach().cpu()
test_feats.append(feat)
test_labels.append(target)
test_feats = torch.cat(test_feats, 0).numpy()
test_labels = torch.cat(test_labels, 0).numpy()
test_feats_norm = np.linalg.norm(test_feats, axis=1)
test_feats = test_feats / (test_feats_norm + 1e-5)[:, np.newaxis]
result = {}
k_list = ['full']
for k in k_list:
cost_list = args.cost.split(',')
result_k = np.zeros(len(cost_list))
for i, cost in enumerate(cost_list):
cost = float(cost)
avg_map = []
for run in range(args.n_run):
print(len(train_dataset))
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
2 years ago
classes = len(train_dataset.classes)
classes_name = train_dataset.classes
2 years ago
train_feats = []
train_labels = []
print('==> calculate train features')
for idx, (images, target) in enumerate(train_loader):
images = images.cuda(non_blocking=True)
feat = model(images)
feat = feat.detach()
train_feats.append(feat)
train_labels.append(target)
train_feats = torch.cat(train_feats, 0).cpu().numpy()
train_labels = torch.cat(train_labels, 0).cpu().numpy()
train_feats_norm = np.linalg.norm(train_feats, axis=1)
train_feats = train_feats / (train_feats_norm + 1e-5)[:, np.newaxis]
print('==> training SVM Classifier')
2 years ago
#test_labels[test_labels == 0] = -1
#train_labels[train_labels == 0] = -1
clf = OneVsRestClassifier(LinearSVC(
C=cost, # class_weight={1: 2, -1: 1},
intercept_scaling=1.0,
penalty='l2', loss='squared_hinge', tol=1e-4,
dual=True, max_iter=2000, random_state=0))
clf.fit(train_feats, train_labels)
prediction = clf.decision_function(test_feats)
predict = clf.predict(test_feats)
plt.figure(1)
plt.rcParams['font.sans-serif'] = ['simsun']
plt.figure(2)
plt.rcParams['font.sans-serif'] = ['simsun']
list_ap = []
list_auc = []
for cl in range(classes):
t_labels,t_pre = deepcopy(test_labels),deepcopy(predict)
t_labels[t_labels != cl] = -1
t_labels[t_labels == cl] = 1
clf = LinearSVC()
clf.fit(test_feats, t_labels)
t_pre=clf.predict(test_feats)
P, R, score, ap = get_precision_recall(t_labels, t_pre)
fpr, tpr, thres =roc_curve(t_labels, t_pre)
auc = roc_auc_score(t_labels, t_pre)
list_ap.append(ap)
list_auc.append(auc)
plt.figure(1)
plt.plot(R, P)
plt.figure(2)
plt.plot(fpr, tpr)
plt.figure(1)
plt.xlabel('召回率', fontsize=14)
plt.ylabel('精准率', fontsize=14)
plt.legend(classes_name)
plt.savefig("PR.png")
plt.figure(2)
plt.xlabel('假正率', fontsize=14)
plt.ylabel('真正率', fontsize=14)
plt.legend(classes_name)
plt.savefig("ROC.png")
print(classes_name)
confusion = confusion_matrix(test_labels, predict)
print(confusion)
mean_ap = np.mean(list_ap) * 100
print('==> Run%d\nmAP is %f ' % (run, mean_ap))
print("AP: " + str(precision_score(test_labels, predict, average="weighted")))
print("mAUC: "+ str(np.mean(list_auc)))
print("F1_score: " + str(f1_score(test_labels, predict, average="weighted")))
2 years ago
avg_map.append(mean_ap)
avg_map = np.asarray(avg_map)
2 years ago
#print('Cost:%.2f - Average ap is: %.2f' % (cost, avg_map.mean()))
#print('Cost:%.2f - Std is: %.2f' % (cost, avg_map.std()))
2 years ago
result_k[i] = avg_map.mean()
result[k] = result_k.max()
print(result)
if __name__ == '__main__':
main()