Browse Source

Fix bugs

main
Fiber 2 years ago
parent
commit
ab3bdb1591
  1. 27
      main.py
  2. 3
      pcl/loader.py

27
main.py

@ -168,7 +168,9 @@ def main_worker(gpu, ngpus_per_node, args):
pre_train_dataset = pcl.loader.PreImager(pre_train_dir, eval_augmentation)
train_dataset = pcl.loader.ImageFolderInstance(train_dir, eval_augmentation)
train_dataset = pcl.loader.ImageFolderInstance(
train_dir,
pcl.loader.TwoCropsTransform(eval_augmentation))
eval_dataset = pcl.loader.ImageFolderInstance(
train_dir,
eval_augmentation)
@ -195,7 +197,7 @@ def main_worker(gpu, ngpus_per_node, args):
drop_last=True)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
train_dataset, batch_size=args.batch_size//2, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
# dataloader for center-cropped images, use larger batch size to increase speed
@ -275,15 +277,18 @@ def train(train_loader, model, criterion, optimizer, epoch, args, cluster_result
im_q = []
im_k = []
class_number = len(images)
class_len = len(images[0])
for _i in range(0, class_len, 2):
for c in range(class_number):
im_q.append(images[c][_i])
im_k.append(images[c][_i+1])
im_q = torch.stack(im_q)
im_k = torch.stack(im_k)
if cluster_result is None:
class_number = len(images)
class_len = len(images[0])
for _i in range(0, class_len, 2):
for c in range(class_number):
im_q.append(images[c][_i])
im_k.append(images[c][_i+1])
im_q = torch.stack(im_q)
im_k = torch.stack(im_k)
else:
im_q = images[0]
im_k = images[1]
if args.gpu is not None:
im_q = im_q.cuda(args.gpu, non_blocking=True)

3
pcl/loader.py

@ -22,11 +22,12 @@ class PreImager(tud.Dataset):
self.aug = aug
self.images = img_class
self.loader = data_meta.loader
def __getitem__(self, index):
imgs = []
for i in range(self.class_number):
img = image.read_image(self.images[i][index]).float()
img = self.loader(self.images[i][index])
out = self.aug(img)
imgs.append(out)
return imgs, index

Loading…
Cancel
Save