Browse Source

Fix bugs

main
Fiber 2 years ago
parent
commit
ab3bdb1591
  1. 11
      main.py
  2. 3
      pcl/loader.py

11
main.py

@ -168,7 +168,9 @@ def main_worker(gpu, ngpus_per_node, args):
pre_train_dataset = pcl.loader.PreImager(pre_train_dir, eval_augmentation) pre_train_dataset = pcl.loader.PreImager(pre_train_dir, eval_augmentation)
train_dataset = pcl.loader.ImageFolderInstance(train_dir, eval_augmentation) train_dataset = pcl.loader.ImageFolderInstance(
train_dir,
pcl.loader.TwoCropsTransform(eval_augmentation))
eval_dataset = pcl.loader.ImageFolderInstance( eval_dataset = pcl.loader.ImageFolderInstance(
train_dir, train_dir,
eval_augmentation) eval_augmentation)
@ -195,7 +197,7 @@ def main_worker(gpu, ngpus_per_node, args):
drop_last=True) drop_last=True)
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), train_dataset, batch_size=args.batch_size//2, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True) num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
# dataloader for center-cropped images, use larger batch size to increase speed # dataloader for center-cropped images, use larger batch size to increase speed
@ -275,15 +277,18 @@ def train(train_loader, model, criterion, optimizer, epoch, args, cluster_result
im_q = [] im_q = []
im_k = [] im_k = []
if cluster_result is None:
class_number = len(images) class_number = len(images)
class_len = len(images[0]) class_len = len(images[0])
for _i in range(0, class_len, 2): for _i in range(0, class_len, 2):
for c in range(class_number): for c in range(class_number):
im_q.append(images[c][_i]) im_q.append(images[c][_i])
im_k.append(images[c][_i+1]) im_k.append(images[c][_i+1])
im_q = torch.stack(im_q) im_q = torch.stack(im_q)
im_k = torch.stack(im_k) im_k = torch.stack(im_k)
else:
im_q = images[0]
im_k = images[1]
if args.gpu is not None: if args.gpu is not None:
im_q = im_q.cuda(args.gpu, non_blocking=True) im_q = im_q.cuda(args.gpu, non_blocking=True)

3
pcl/loader.py

@ -22,11 +22,12 @@ class PreImager(tud.Dataset):
self.aug = aug self.aug = aug
self.images = img_class self.images = img_class
self.loader = data_meta.loader
def __getitem__(self, index): def __getitem__(self, index):
imgs = [] imgs = []
for i in range(self.class_number): for i in range(self.class_number):
img = image.read_image(self.images[i][index]).float() img = self.loader(self.images[i][index])
out = self.aug(img) out = self.aug(img)
imgs.append(out) imgs.append(out)
return imgs, index return imgs, index

Loading…
Cancel
Save