Changes
This commit is contained in:
parent
2ec28bbeba
commit
ce7e190d5d
54
README.md
54
README.md
@ -1 +1,53 @@
|
|||||||
# myPCL
|
# myPCL
|
||||||
|
|
||||||
|
## Training:
|
||||||
|
#### 参数解析
|
||||||
|
| 简写参数 | 全称参数 | 描述 |
|
||||||
|
|:----:|:---------------:|:---------------------------------------|
|
||||||
|
| -a | --arch | 指定主干网络类型,如:resnet-18,resnet-50 |
|
||||||
|
| -j | --workers | 指定线程数,默认为4 |
|
||||||
|
| | --epochs | 总训练循环次数,默认为200 |
|
||||||
|
| | --warmup-epoch | 有监督epoch次数,默认为100 |
|
||||||
|
| | --exp-dir | 输出路径,默认为experiment |
|
||||||
|
| -b | --batch-size | 一批的数量,默认为8,必须为标签数的倍数 |
|
||||||
|
| -lr | --learning-rate | 学习率,默认为0.03 |
|
||||||
|
| | --cos | 使用cosine学习率 |
|
||||||
|
| | --schedule | 指定学习率下降的epoch,默认为[120,160],只在cos未指定时生效 |
|
||||||
|
| | --momentum | 优化器的动量餐宿,默认为0.9 |
|
||||||
|
| --wd | --weight-decay | SSPCL模型的权重衰减,默认为1e-4 |
|
||||||
|
| | --low-dim | 输出维度,默认为128 |
|
||||||
|
| | --num-cluster | 聚类个数,默认为'20,25,30' |
|
||||||
|
| | --pcl-r | 负例对,需要小于聚类个数,默认为16 |
|
||||||
|
| | --moco-m | SSPCL中ME更新参数使用的动量,默认为0.999 |
|
||||||
|
| | --mlp | 设置即为使用mlp,无参数,参考PCL模型 |
|
||||||
|
| | --temperature | softmax层温度参数,默认为0.2 |
|
||||||
|
| -p | --print-freq | 显示频率,默认为每10个数据 |
|
||||||
|
| | --save-freq | 保存模型的频率,默认为每10个epoch |
|
||||||
|
| | --world-size | 总训练程序数量,默认为1 |
|
||||||
|
| | --rank | 此训练程序编号,默认编号0 |
|
||||||
|
| | --dist-url | 多程序训练连接地址,此参数参照pytorch分布式训练解释 |
|
||||||
|
| | --dist-backend | 默认为nccl |
|
||||||
|
| | --gpu | 用于训练的gpu编号 |
|
||||||
|
| | --seed | 随机数种子,默认为自动生成 |
|
||||||
|
| | --resume | 需要载入的模型位置 |
|
||||||
|
| | --start-epoch | 训练起始的epoch,与resume配合使用 |
|
||||||
|
| | | |
|
||||||
|
|
||||||
|
#### 用例
|
||||||
|
<pre>
|
||||||
|
python main.py -a resnet18 --lr 0.03 --batch-size 8 --workers 4 --temperature 0.2 --mlp --aug-plus --cos --dist-url "tcp://localhost:10001" --world-size 1 --rank 0 --warmup-epoch 100 --epochs 100 --exp-dir exp images
|
||||||
|
</pre>
|
||||||
|
|
||||||
|
## Testing:
|
||||||
|
#### 参数解析
|
||||||
|
**如训练时修改了以上默认的参数,在测试时也需要指定**
|
||||||
|
以下是必须要设置的参数
|
||||||
|
|
||||||
|
| 简写参数 | 全称参数 | 描述 |
|
||||||
|
|:----:|:------------:|:----------|
|
||||||
|
| | --pretrained | 需要载入模型的路径 |
|
||||||
|
|
||||||
|
#### 用例
|
||||||
|
<pre>
|
||||||
|
python test_svm.py --pretrained exp/checkpoint_0199.pth.tar
|
||||||
|
</pre>
|
||||||
|
86
main.py
86
main.py
@ -15,8 +15,6 @@ import torch.optim
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.utils.data
|
import torch.utils.data
|
||||||
import torch.utils.data.distributed
|
import torch.utils.data.distributed
|
||||||
import torchvision.transforms as transforms
|
|
||||||
import torchvision.datasets as datasets
|
|
||||||
import torchvision.models as models
|
import torchvision.models as models
|
||||||
|
|
||||||
from scripts.parser import parser
|
from scripts.parser import parser
|
||||||
@ -24,11 +22,13 @@ from scripts.meter import AverageMeter, ProgressMeter
|
|||||||
import scripts.augmentation as aug
|
import scripts.augmentation as aug
|
||||||
import scripts.momentum as momentum
|
import scripts.momentum as momentum
|
||||||
import scripts.clustering as clustering
|
import scripts.clustering as clustering
|
||||||
|
import scripts.loss as loss_script
|
||||||
|
from scripts.data_process import data_process
|
||||||
import pcl.builder
|
import pcl.builder
|
||||||
import pcl.loader
|
import pcl.loader
|
||||||
|
|
||||||
|
|
||||||
def main_loader():
|
def worker_loader():
|
||||||
args = parser().parse_args()
|
args = parser().parse_args()
|
||||||
|
|
||||||
if args.seed is not None:
|
if args.seed is not None:
|
||||||
@ -61,14 +61,15 @@ def main_loader():
|
|||||||
args.world_size = ngpus_per_node * args.world_size
|
args.world_size = ngpus_per_node * args.world_size
|
||||||
# Use torch.multiprocessing.spawn to launch distributed processes: the
|
# Use torch.multiprocessing.spawn to launch distributed processes: the
|
||||||
# main_worker process function
|
# main_worker process function
|
||||||
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
mp.spawn(worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
|
||||||
else:
|
else:
|
||||||
# Simply call main_worker function
|
# Simply call main_worker function
|
||||||
main_worker(args.gpu, ngpus_per_node, args)
|
worker(args.gpu, ngpus_per_node, args)
|
||||||
|
|
||||||
|
|
||||||
def main_worker(gpu, ngpus_per_node, args):
|
def worker(gpu, ngpus_per_node, args):
|
||||||
args.gpu = gpu
|
args.gpu = gpu
|
||||||
|
args.multiprocessing_distributed = True
|
||||||
|
|
||||||
if args.gpu is not None:
|
if args.gpu is not None:
|
||||||
print("Use GPU: {} for training".format(args.gpu))
|
print("Use GPU: {} for training".format(args.gpu))
|
||||||
@ -157,13 +158,6 @@ def main_worker(gpu, ngpus_per_node, args):
|
|||||||
train_dir = os.path.join(args.data, 'train')
|
train_dir = os.path.join(args.data, 'train')
|
||||||
eval_dir = os.path.join(args.data, 'train')
|
eval_dir = os.path.join(args.data, 'train')
|
||||||
|
|
||||||
if args.aug_plus:
|
|
||||||
# MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709
|
|
||||||
augmentation = aug.moco_v2()
|
|
||||||
else:
|
|
||||||
# MoCo v1's aug: same as InstDisc https://arxiv.org/abs/1805.01978
|
|
||||||
augmentation = aug.moco_v1()
|
|
||||||
|
|
||||||
# center-crop augmentation
|
# center-crop augmentation
|
||||||
eval_augmentation = aug.moco_eval()
|
eval_augmentation = aug.moco_eval()
|
||||||
|
|
||||||
@ -248,12 +242,12 @@ def main_worker(gpu, ngpus_per_node, args):
|
|||||||
|
|
||||||
if (epoch + 1) % args.save_freq == 0 and (not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
if (epoch + 1) % args.save_freq == 0 and (not args.multiprocessing_distributed or (args.multiprocessing_distributed
|
||||||
and args.rank % ngpus_per_node == 0)):
|
and args.rank % ngpus_per_node == 0)):
|
||||||
save_checkpoint({
|
torch.save({
|
||||||
'epoch': epoch + 1,
|
'epoch': epoch + 1,
|
||||||
'arch': args.arch,
|
'arch': args.arch,
|
||||||
'state_dict': model.state_dict(),
|
'state_dict': model.state_dict(),
|
||||||
'optimizer': optimizer.state_dict(),
|
'optimizer': optimizer.state_dict(),
|
||||||
}, is_best=False, filename='{}/checkpoint_{:04d}.pth.tar'.format(args.exp_dir, epoch))
|
}, '{}/checkpoint_{:04d}.pth.tar'.format(args.exp_dir, epoch))
|
||||||
|
|
||||||
|
|
||||||
def train(train_loader, model, criterion, optimizer, epoch, args, cluster_result=None):
|
def train(train_loader, model, criterion, optimizer, epoch, args, cluster_result=None):
|
||||||
@ -276,46 +270,18 @@ def train(train_loader, model, criterion, optimizer, epoch, args, cluster_result
|
|||||||
# measure data loading time
|
# measure data loading time
|
||||||
data_time.update(time.time() - end)
|
data_time.update(time.time() - end)
|
||||||
|
|
||||||
im_q = []
|
im_q, im_k = data_process(cluster_result, images, args.gpu)
|
||||||
im_k = []
|
|
||||||
if cluster_result is None:
|
|
||||||
class_number = len(images)
|
|
||||||
class_len = len(images[0])
|
|
||||||
for _i in range(0, class_len, 2):
|
|
||||||
for c in range(class_number):
|
|
||||||
im_q.append(images[c][_i])
|
|
||||||
im_k.append(images[c][_i+1])
|
|
||||||
im_q = torch.stack(im_q)
|
|
||||||
im_k = torch.stack(im_k)
|
|
||||||
else:
|
|
||||||
im_q = images[0]
|
|
||||||
im_k = images[1]
|
|
||||||
|
|
||||||
if args.gpu is not None:
|
|
||||||
im_q = im_q.cuda(args.gpu, non_blocking=True)
|
|
||||||
im_k = im_k.cuda(args.gpu, non_blocking=True)
|
|
||||||
|
|
||||||
# compute output
|
# compute output
|
||||||
output, target, output_proto, target_proto = model(im_q=im_q, im_k=im_k,
|
output, target, output_proto, target_proto = model(im_q=im_q, im_k=im_k,
|
||||||
cluster_result=cluster_result, index=index)
|
cluster_result=cluster_result, index=index)
|
||||||
|
|
||||||
# InfoNCE loss
|
loss = loss_script.proto_with_quality(output, target, output_proto, target_proto, criterion, acc_proto, images,
|
||||||
loss = criterion(output, target)
|
args.num_cluster)
|
||||||
|
|
||||||
# ProtoNCE loss
|
|
||||||
if output_proto is not None:
|
|
||||||
loss_proto = 0
|
|
||||||
for proto_out, proto_target in zip(output_proto, target_proto):
|
|
||||||
loss_proto += criterion(proto_out, proto_target)
|
|
||||||
accp = accuracy(proto_out, proto_target)[0]
|
|
||||||
acc_proto.update(accp[0], images[0].size(0))
|
|
||||||
|
|
||||||
# average loss across all sets of prototypes
|
|
||||||
loss_proto /= len(args.num_cluster)
|
|
||||||
loss += loss_proto
|
|
||||||
|
|
||||||
losses.update(loss.item(), images[0].size(0))
|
losses.update(loss.item(), images[0].size(0))
|
||||||
acc = accuracy(output, target)[0]
|
acc = loss_script.accuracy(output, target)[0]
|
||||||
acc_inst.update(acc[0], images[0].size(0))
|
acc_inst.update(acc[0], images[0].size(0))
|
||||||
|
|
||||||
# compute gradient and do SGD step
|
# compute gradient and do SGD step
|
||||||
@ -343,29 +309,5 @@ def adjust_learning_rate(optimizer, epoch, args):
|
|||||||
param_group['lr'] = lr
|
param_group['lr'] = lr
|
||||||
|
|
||||||
|
|
||||||
def accuracy(output, target, topk=(1,)):
|
|
||||||
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
|
||||||
with torch.no_grad():
|
|
||||||
maxk = max(topk)
|
|
||||||
batch_size = target.size(0)
|
|
||||||
|
|
||||||
_, pred = output.topk(maxk, 1, True, True)
|
|
||||||
pred = pred.t()
|
|
||||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
|
||||||
|
|
||||||
res = []
|
|
||||||
for k in topk:
|
|
||||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
|
||||||
res.append(correct_k.mul_(100.0 / batch_size))
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
|
|
||||||
torch.save(state, filename)
|
|
||||||
print("Model saved as:"+filename)
|
|
||||||
if is_best:
|
|
||||||
shutil.copyfile(filename, 'model_best.pth.tar')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main_loader()
|
worker_loader()
|
||||||
|
@ -35,6 +35,13 @@ class PreImager(tud.Dataset):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.length
|
return self.length
|
||||||
|
|
||||||
|
class TenserImager(datasets.ImageFolder):
|
||||||
|
def __getitem__(self, index):
|
||||||
|
path, target = self.samples[index]
|
||||||
|
sample = self.loader(path)
|
||||||
|
sample = self.transform(sample)
|
||||||
|
return sample, target
|
||||||
|
|
||||||
|
|
||||||
class TwoCropsTransform:
|
class TwoCropsTransform:
|
||||||
"""Take two random crops of one image as the query and key."""
|
"""Take two random crops of one image as the query and key."""
|
||||||
|
22
scripts/data_process.py
Normal file
22
scripts/data_process.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
import torch
|
||||||
|
def data_process(cluster_result, images, gpu):
|
||||||
|
im_q = []
|
||||||
|
im_k = []
|
||||||
|
if cluster_result is None:
|
||||||
|
class_number = len(images)
|
||||||
|
class_len = len(images[0])
|
||||||
|
for _i in range(0, class_len, 2):
|
||||||
|
for c in range(class_number):
|
||||||
|
im_q.append(images[c][_i])
|
||||||
|
im_k.append(images[c][_i + 1])
|
||||||
|
im_q = torch.stack(im_q)
|
||||||
|
im_k = torch.stack(im_k)
|
||||||
|
else:
|
||||||
|
im_q = images[0]
|
||||||
|
im_k = images[1]
|
||||||
|
|
||||||
|
if gpu is not None:
|
||||||
|
im_q = im_q.cuda(gpu, non_blocking=True)
|
||||||
|
im_k = im_k.cuda(gpu, non_blocking=True)
|
||||||
|
|
||||||
|
return im_q, im_k
|
42
scripts/loss.py
Normal file
42
scripts/loss.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
def proto_with_quality(output, target, output_proto, target_proto, criterion, acc_proto, images, num_cluster):
|
||||||
|
# InfoNCE loss
|
||||||
|
loss = criterion(output, target)
|
||||||
|
|
||||||
|
# ProtoNCE loss
|
||||||
|
if output_proto is not None:
|
||||||
|
loss_proto = 0
|
||||||
|
for proto_out, proto_target in zip(output_proto, target_proto):
|
||||||
|
loss_proto += criterion(proto_out, proto_target)
|
||||||
|
accp = accuracy(proto_out, proto_target)[0]
|
||||||
|
acc_proto.update(accp[0], images[0].size(0))
|
||||||
|
|
||||||
|
# average loss across all sets of prototypes
|
||||||
|
loss_proto /= len(num_cluster)
|
||||||
|
loss += loss_proto
|
||||||
|
|
||||||
|
# Quality loss
|
||||||
|
mse = np.mean((images[1]/255.0-images[2]/255.0)**2)
|
||||||
|
psnr = 20 * math.log10(1 / math.sqrt(mse))
|
||||||
|
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def accuracy(output, target, topk=(1,)):
|
||||||
|
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
||||||
|
with torch.no_grad():
|
||||||
|
maxk = max(topk)
|
||||||
|
batch_size = target.size(0)
|
||||||
|
|
||||||
|
_, pred = output.topk(maxk, 1, True, True)
|
||||||
|
pred = pred.t()
|
||||||
|
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||||
|
|
||||||
|
res = []
|
||||||
|
for k in topk:
|
||||||
|
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||||
|
res.append(correct_k.mul_(100.0 / batch_size))
|
||||||
|
return res
|
@ -14,35 +14,52 @@ def parser():
|
|||||||
help='model architecture: ' +
|
help='model architecture: ' +
|
||||||
' | '.join(model_names) +
|
' | '.join(model_names) +
|
||||||
' (default: resnet50)')
|
' (default: resnet50)')
|
||||||
_parser.add_argument('-j', '--workers', default=8, type=int, metavar='N',
|
_parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
||||||
help='number of data loading workers (default: 8)')
|
help='number of data loading workers (default: 4)')
|
||||||
_parser.add_argument('--epochs', default=200, type=int, metavar='N',
|
_parser.add_argument('--epochs', default=200, type=int, metavar='N',
|
||||||
help='number of total epochs to run')
|
help='number of total epochs to run')
|
||||||
_parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
_parser.add_argument('--warmup-epoch', default=100, type=int,
|
||||||
help='manual epoch number (useful on restarts)')
|
help='number of warm-up epochs to only train with InfoNCE loss')
|
||||||
_parser.add_argument('-b', '--batch-size', default=256, type=int,
|
_parser.add_argument('--exp-dir', default='experiment', type=str,
|
||||||
|
help='experiment directory')
|
||||||
|
|
||||||
|
_parser.add_argument('-b', '--batch-size', default=8, type=int,
|
||||||
metavar='N',
|
metavar='N',
|
||||||
help='mini-batch size (default: 256), this is the total '
|
help='mini-batch size (default: 8), this is the total '
|
||||||
'batch size of all GPUs on the current node when '
|
'batch size of all GPUs on the current node when '
|
||||||
'using Data Parallel or Distributed Data Parallel')
|
'using Data Parallel or Distributed Data Parallel')
|
||||||
_parser.add_argument('--lr', '--learning-rate', default=0.03, type=float,
|
_parser.add_argument('-lr', '--learning-rate', default=0.03, type=float,
|
||||||
metavar='LR', help='initial learning rate', dest='lr')
|
metavar='LR', help='initial learning rate', dest='lr')
|
||||||
|
_parser.add_argument('--cos', action='store_true',
|
||||||
|
help='use cosine lr schedule')
|
||||||
_parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int,
|
_parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int,
|
||||||
help='learning rate schedule (when to drop lr by 10x)')
|
help='learning rate schedule (when to drop lr by 10x)')
|
||||||
|
|
||||||
_parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
_parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||||
help='momentum of SGD solver')
|
help='momentum of SGD solver')
|
||||||
_parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
_parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
||||||
metavar='W', help='weight decay (default: 1e-4)',
|
metavar='W', help='weight decay (default: 1e-4)',
|
||||||
dest='weight_decay')
|
dest='weight_decay')
|
||||||
|
_parser.add_argument('--low-dim', default=128, type=int,
|
||||||
|
help='feature dimension (default: 128)')
|
||||||
|
_parser.add_argument('--num-cluster', default='20,25,30', type=str,
|
||||||
|
help='number of clusters')
|
||||||
|
_parser.add_argument('--pcl-r', default=16, type=int,
|
||||||
|
help='queue size; number of negative pairs; needs to be smaller than num_cluster (default: '
|
||||||
|
'16)')
|
||||||
|
_parser.add_argument('--moco-m', default=0.999, type=float,
|
||||||
|
help='moco momentum of updating key encoder (default: 0.999)')
|
||||||
|
_parser.add_argument('--temperature', default=0.2, type=float,
|
||||||
|
help='softmax temperature')
|
||||||
|
|
||||||
_parser.add_argument('-p', '--print-freq', default=10, type=int,
|
_parser.add_argument('-p', '--print-freq', default=10, type=int,
|
||||||
metavar='N', help='print frequency (default: 10)')
|
metavar='N', help='print frequency (default: 10)')
|
||||||
_parser.add_argument('--save-freq', default=10, type=int,
|
_parser.add_argument('--save-freq', default=10, type=int,
|
||||||
metavar='N', help='save frequency (default: 10)')
|
metavar='N', help='save frequency (default: 10)')
|
||||||
_parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
|
||||||
help='path to latest checkpoint (default: none)')
|
_parser.add_argument('--world-size', default=1, type=int,
|
||||||
_parser.add_argument('--world-size', default=-1, type=int,
|
|
||||||
help='number of nodes for distributed training')
|
help='number of nodes for distributed training')
|
||||||
_parser.add_argument('--rank', default=-1, type=int,
|
_parser.add_argument('--rank', default=0, type=int,
|
||||||
help='node rank for distributed training')
|
help='node rank for distributed training')
|
||||||
_parser.add_argument('--dist-url', default='tcp://172.0.0.1:23456', type=str,
|
_parser.add_argument('--dist-url', default='tcp://172.0.0.1:23456', type=str,
|
||||||
help='url used to set up distributed training')
|
help='url used to set up distributed training')
|
||||||
@ -58,33 +75,16 @@ def parser():
|
|||||||
'fastest way to use PyTorch for either single node or '
|
'fastest way to use PyTorch for either single node or '
|
||||||
'multi node data parallel training')
|
'multi node data parallel training')
|
||||||
|
|
||||||
_parser.add_argument('--low-dim', default=128, type=int,
|
_parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
|
||||||
help='feature dimension (default: 128)')
|
help='manual epoch number (useful on restarts)')
|
||||||
_parser.add_argument('--pcl-r', default=16, type=int,
|
_parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||||
help='queue size; number of negative pairs; needs to be smaller than num_cluster (default: '
|
help='path to latest checkpoint (default: none)')
|
||||||
'16384)')
|
|
||||||
_parser.add_argument('--moco-m', default=0.999, type=float,
|
|
||||||
help='moco momentum of updating key encoder (default: 0.999)')
|
|
||||||
_parser.add_argument('--temperature', default=0.2, type=float,
|
|
||||||
help='softmax temperature')
|
|
||||||
|
|
||||||
_parser.add_argument('--mlp', action='store_true',
|
_parser.add_argument('--mlp', action='store_true',
|
||||||
help='use mlp head')
|
help='use mlp head')
|
||||||
_parser.add_argument('--aug-plus', action='store_true',
|
|
||||||
help='use moco-v2/SimCLR data augmentation')
|
|
||||||
_parser.add_argument('--cos', action='store_true',
|
|
||||||
help='use cosine lr schedule')
|
|
||||||
|
|
||||||
_parser.add_argument('--num-cluster', default='20,25,30', type=str,
|
|
||||||
help='number of clusters')
|
|
||||||
_parser.add_argument('--warmup-epoch', default=20, type=int,
|
|
||||||
help='number of warm-up epochs to only train with InfoNCE loss')
|
|
||||||
_parser.add_argument('--exp-dir', default='experiment', type=str,
|
|
||||||
help='experiment directory')
|
|
||||||
|
|
||||||
_parser.add_argument('--cost', type=str, default='0.5')
|
_parser.add_argument('--cost', type=str, default='0.5')
|
||||||
|
_parser.add_argument('--n-run', type=int, default=1)
|
||||||
_parser.add_argument('--num-class', type=int, default=20)
|
|
||||||
_parser.add_argument('--pretrained', default='', type=str,
|
_parser.add_argument('--pretrained', default='', type=str,
|
||||||
help='path to pretrained checkpoint')
|
help='path to pretrained checkpoint')
|
||||||
|
|
||||||
|
106
test_svm.py
106
test_svm.py
@ -1,25 +1,27 @@
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import torch
|
import torch
|
||||||
import torch.optim as optim
|
from copy import deepcopy
|
||||||
import torch.backends.cudnn as cudnn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import argparse
|
|
||||||
import random
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from torchvision import transforms, datasets
|
from torchvision import transforms, datasets
|
||||||
import torchvision.models as models
|
import torchvision.models as models
|
||||||
|
|
||||||
from sklearn.svm import LinearSVC
|
from sklearn.svm import LinearSVC, SVC
|
||||||
|
from sklearn.multiclass import OneVsRestClassifier
|
||||||
|
from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score, precision_score, f1_score
|
||||||
|
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
from scripts.parser import parser
|
from scripts.parser import parser
|
||||||
import scripts.augmentation as aug
|
import scripts.augmentation as aug
|
||||||
import pcl.loader
|
import pcl.loader
|
||||||
|
|
||||||
|
import matplotlib.font_manager
|
||||||
|
# 通过字体文件添加字体
|
||||||
|
matplotlib.font_manager.fontManager.addfont('simsun.ttc')
|
||||||
|
|
||||||
def calculate_ap(rec, prec):
|
def calculate_ap(rec, prec):
|
||||||
"""
|
"""
|
||||||
Computes the AP under the precision recall curve.
|
Computes the AP under the precision recall curve.
|
||||||
@ -76,9 +78,9 @@ def get_precision_recall(targets, preds):
|
|||||||
def main():
|
def main():
|
||||||
args = parser().parse_args()
|
args = parser().parse_args()
|
||||||
|
|
||||||
if not args.seed is None:
|
# if not args.seed is None:
|
||||||
random.seed(args.seed)
|
# random.seed(args.seed)
|
||||||
np.random.seed(args.seed)
|
# np.random.seed(args.seed)
|
||||||
|
|
||||||
mean = [0.485, 0.456, 0.406]
|
mean = [0.485, 0.456, 0.406]
|
||||||
std = [0.229, 0.224, 0.225]
|
std = [0.229, 0.224, 0.225]
|
||||||
@ -93,10 +95,10 @@ def main():
|
|||||||
eval_augmentation = aug.moco_eval()
|
eval_augmentation = aug.moco_eval()
|
||||||
|
|
||||||
pre_train_dir = os.path.join(args.data, 'pre_train')
|
pre_train_dir = os.path.join(args.data, 'pre_train')
|
||||||
eval_dir = os.path.join(args.data, 'eval')
|
eval_dir = os.path.join(args.data, 'test')
|
||||||
|
|
||||||
train_dataset = pcl.loader.PreImager(pre_train_dir, eval_augmentation)
|
train_dataset = pcl.loader.TenserImager(pre_train_dir, eval_augmentation)
|
||||||
val_dataset = pcl.loader.ImageFolderInstance(
|
val_dataset = pcl.loader.TenserImager(
|
||||||
eval_dir,
|
eval_dir,
|
||||||
eval_augmentation)
|
eval_augmentation)
|
||||||
|
|
||||||
@ -162,7 +164,8 @@ def main():
|
|||||||
train_loader = torch.utils.data.DataLoader(
|
train_loader = torch.utils.data.DataLoader(
|
||||||
train_dataset, batch_size=args.batch_size, shuffle=False,
|
train_dataset, batch_size=args.batch_size, shuffle=False,
|
||||||
num_workers=args.workers, pin_memory=True)
|
num_workers=args.workers, pin_memory=True)
|
||||||
|
classes = len(train_dataset.classes)
|
||||||
|
classes_name = train_dataset.classes
|
||||||
train_feats = []
|
train_feats = []
|
||||||
train_labels = []
|
train_labels = []
|
||||||
print('==> calculate train features')
|
print('==> calculate train features')
|
||||||
@ -181,27 +184,68 @@ def main():
|
|||||||
train_feats = train_feats / (train_feats_norm + 1e-5)[:, np.newaxis]
|
train_feats = train_feats / (train_feats_norm + 1e-5)[:, np.newaxis]
|
||||||
|
|
||||||
print('==> training SVM Classifier')
|
print('==> training SVM Classifier')
|
||||||
cls_ap = np.zeros((args.num_class, 1))
|
#test_labels[test_labels == 0] = -1
|
||||||
test_labels[test_labels == 0] = -1
|
#train_labels[train_labels == 0] = -1
|
||||||
train_labels[train_labels == 0] = -1
|
clf = OneVsRestClassifier(LinearSVC(
|
||||||
for cls in range(args.num_class):
|
C=cost, # class_weight={1: 2, -1: 1},
|
||||||
clf = LinearSVC(
|
intercept_scaling=1.0,
|
||||||
C=cost, class_weight={1: 2, -1: 1}, intercept_scaling=1.0,
|
penalty='l2', loss='squared_hinge', tol=1e-4,
|
||||||
penalty='l2', loss='squared_hinge', tol=1e-4,
|
dual=True, max_iter=2000, random_state=0))
|
||||||
dual=True, max_iter=2000, random_state=0)
|
clf.fit(train_feats, train_labels)
|
||||||
clf.fit(train_feats, train_labels[:, cls])
|
|
||||||
|
|
||||||
prediction = clf.decision_function(test_feats)
|
prediction = clf.decision_function(test_feats)
|
||||||
P, R, score, ap = get_precision_recall(test_labels[:, cls], prediction)
|
predict = clf.predict(test_feats)
|
||||||
cls_ap[cls][0] = ap * 100
|
|
||||||
mean_ap = np.mean(cls_ap, axis=0)
|
|
||||||
|
|
||||||
print('==> Run%d mAP is %.2f: ' % (run, mean_ap))
|
|
||||||
|
plt.figure(1)
|
||||||
|
plt.rcParams['font.sans-serif'] = ['simsun']
|
||||||
|
plt.figure(2)
|
||||||
|
plt.rcParams['font.sans-serif'] = ['simsun']
|
||||||
|
|
||||||
|
list_ap = []
|
||||||
|
list_auc = []
|
||||||
|
for cl in range(classes):
|
||||||
|
t_labels,t_pre = deepcopy(test_labels),deepcopy(predict)
|
||||||
|
t_labels[t_labels != cl] = -1
|
||||||
|
t_labels[t_labels == cl] = 1
|
||||||
|
clf = LinearSVC()
|
||||||
|
clf.fit(test_feats, t_labels)
|
||||||
|
t_pre=clf.predict(test_feats)
|
||||||
|
|
||||||
|
P, R, score, ap = get_precision_recall(t_labels, t_pre)
|
||||||
|
fpr, tpr, thres =roc_curve(t_labels, t_pre)
|
||||||
|
auc = roc_auc_score(t_labels, t_pre)
|
||||||
|
list_ap.append(ap)
|
||||||
|
list_auc.append(auc)
|
||||||
|
plt.figure(1)
|
||||||
|
plt.plot(R, P)
|
||||||
|
plt.figure(2)
|
||||||
|
plt.plot(fpr, tpr)
|
||||||
|
|
||||||
|
plt.figure(1)
|
||||||
|
plt.xlabel('召回率', fontsize=14)
|
||||||
|
plt.ylabel('精准率', fontsize=14)
|
||||||
|
plt.legend(classes_name)
|
||||||
|
plt.savefig("PR.png")
|
||||||
|
|
||||||
|
plt.figure(2)
|
||||||
|
plt.xlabel('假正率', fontsize=14)
|
||||||
|
plt.ylabel('真正率', fontsize=14)
|
||||||
|
plt.legend(classes_name)
|
||||||
|
plt.savefig("ROC.png")
|
||||||
|
print(classes_name)
|
||||||
|
confusion = confusion_matrix(test_labels, predict)
|
||||||
|
print(confusion)
|
||||||
|
mean_ap = np.mean(list_ap) * 100
|
||||||
|
print('==> Run%d\nmAP is %f ' % (run, mean_ap))
|
||||||
|
print("AP: " + str(precision_score(test_labels, predict, average="weighted")))
|
||||||
|
print("mAUC: "+ str(np.mean(list_auc)))
|
||||||
|
print("F1_score: " + str(f1_score(test_labels, predict, average="weighted")))
|
||||||
avg_map.append(mean_ap)
|
avg_map.append(mean_ap)
|
||||||
|
|
||||||
avg_map = np.asarray(avg_map)
|
avg_map = np.asarray(avg_map)
|
||||||
print('Cost:%.2f - Average ap is: %.2f' % (cost, avg_map.mean()))
|
#print('Cost:%.2f - Average ap is: %.2f' % (cost, avg_map.mean()))
|
||||||
print('Cost:%.2f - Std is: %.2f' % (cost, avg_map.std()))
|
#print('Cost:%.2f - Std is: %.2f' % (cost, avg_map.std()))
|
||||||
result_k[i] = avg_map.mean()
|
result_k[i] = avg_map.mean()
|
||||||
result[k] = result_k.max()
|
result[k] = result_k.max()
|
||||||
print(result)
|
print(result)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user