diff --git a/main.py b/main.py index f9a0a93..553ab5c 100644 --- a/main.py +++ b/main.py @@ -168,7 +168,9 @@ def main_worker(gpu, ngpus_per_node, args): pre_train_dataset = pcl.loader.PreImager(pre_train_dir, eval_augmentation) - train_dataset = pcl.loader.ImageFolderInstance(train_dir, eval_augmentation) + train_dataset = pcl.loader.ImageFolderInstance( + train_dir, + pcl.loader.TwoCropsTransform(eval_augmentation)) eval_dataset = pcl.loader.ImageFolderInstance( train_dir, eval_augmentation) @@ -195,7 +197,7 @@ def main_worker(gpu, ngpus_per_node, args): drop_last=True) train_loader = torch.utils.data.DataLoader( - train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), + train_dataset, batch_size=args.batch_size//2, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) # dataloader for center-cropped images, use larger batch size to increase speed @@ -275,15 +277,18 @@ def train(train_loader, model, criterion, optimizer, epoch, args, cluster_result im_q = [] im_k = [] - class_number = len(images) - class_len = len(images[0]) - for _i in range(0, class_len, 2): - for c in range(class_number): - im_q.append(images[c][_i]) - im_k.append(images[c][_i+1]) - - im_q = torch.stack(im_q) - im_k = torch.stack(im_k) + if cluster_result is None: + class_number = len(images) + class_len = len(images[0]) + for _i in range(0, class_len, 2): + for c in range(class_number): + im_q.append(images[c][_i]) + im_k.append(images[c][_i+1]) + im_q = torch.stack(im_q) + im_k = torch.stack(im_k) + else: + im_q = images[0] + im_k = images[1] if args.gpu is not None: im_q = im_q.cuda(args.gpu, non_blocking=True) diff --git a/pcl/loader.py b/pcl/loader.py index b834d67..35b4638 100644 --- a/pcl/loader.py +++ b/pcl/loader.py @@ -22,11 +22,12 @@ class PreImager(tud.Dataset): self.aug = aug self.images = img_class + self.loader = data_meta.loader def __getitem__(self, index): imgs = [] for i in range(self.class_number): - img = image.read_image(self.images[i][index]).float() + img = self.loader(self.images[i][index]) out = self.aug(img) imgs.append(out) return imgs, index