import argparse import torchvision.models as models def parser(): model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__") and callable(models.__dict__[name])) _parser = argparse.ArgumentParser(description='PyTorch ImageNet Training PCL') _parser.add_argument('data', metavar='DIR', help='path to dataset') _parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', choices=model_names, help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet50)') _parser.add_argument('-j', '--workers', default=8, type=int, metavar='N', help='number of data loading workers (default: 8)') _parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs to run') _parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)') _parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256), this is the total ' 'batch size of all GPUs on the current node when ' 'using Data Parallel or Distributed Data Parallel') _parser.add_argument('--lr', '--learning-rate', default=0.03, type=float, metavar='LR', help='initial learning rate', dest='lr') _parser.add_argument('--schedule', default=[120, 160], nargs='*', type=int, help='learning rate schedule (when to drop lr by 10x)') _parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum of SGD solver') _parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') _parser.add_argument('-p', '--print-freq', default=10, type=int, metavar='N', help='print frequency (default: 10)') _parser.add_argument('--save-freq', default=10, type=int, metavar='N', help='save frequency (default: 10)') _parser.add_argument('--resume', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') _parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training') _parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training') _parser.add_argument('--dist-url', default='tcp://172.0.0.1:23456', type=str, help='url used to set up distributed training') _parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend') _parser.add_argument('--seed', default=None, type=int, help='seed for initializing training. ') _parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.') _parser.add_argument('--multiprocessing-distributed', action='store_true', help='Use multi-processing distributed training to launch ' 'N processes per node, which has N GPUs. This is the ' 'fastest way to use PyTorch for either single node or ' 'multi node data parallel training') _parser.add_argument('--low-dim', default=128, type=int, help='feature dimension (default: 128)') _parser.add_argument('--pcl-r', default=16, type=int, help='queue size; number of negative pairs; needs to be smaller than num_cluster (default: ' '16384)') _parser.add_argument('--moco-m', default=0.999, type=float, help='moco momentum of updating key encoder (default: 0.999)') _parser.add_argument('--temperature', default=0.2, type=float, help='softmax temperature') _parser.add_argument('--mlp', action='store_true', help='use mlp head') _parser.add_argument('--aug-plus', action='store_true', help='use moco-v2/SimCLR data augmentation') _parser.add_argument('--cos', action='store_true', help='use cosine lr schedule') _parser.add_argument('--num-cluster', default='20,25,30', type=str, help='number of clusters') _parser.add_argument('--warmup-epoch', default=20, type=int, help='number of warm-up epochs to only train with InfoNCE loss') _parser.add_argument('--exp-dir', default='experiment', type=str, help='experiment directory') _parser.add_argument('--cost', type=str, default='0.5') _parser.add_argument('--num-class', type=int, default=20) _parser.add_argument('--pretrained', default='', type=str, help='path to pretrained checkpoint') return _parser