Fix bugs
This commit is contained in:
parent
57017a9e0e
commit
ab3bdb1591
27
main.py
27
main.py
@ -168,7 +168,9 @@ def main_worker(gpu, ngpus_per_node, args):
|
||||
|
||||
pre_train_dataset = pcl.loader.PreImager(pre_train_dir, eval_augmentation)
|
||||
|
||||
train_dataset = pcl.loader.ImageFolderInstance(train_dir, eval_augmentation)
|
||||
train_dataset = pcl.loader.ImageFolderInstance(
|
||||
train_dir,
|
||||
pcl.loader.TwoCropsTransform(eval_augmentation))
|
||||
eval_dataset = pcl.loader.ImageFolderInstance(
|
||||
train_dir,
|
||||
eval_augmentation)
|
||||
@ -195,7 +197,7 @@ def main_worker(gpu, ngpus_per_node, args):
|
||||
drop_last=True)
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
|
||||
train_dataset, batch_size=args.batch_size//2, shuffle=(train_sampler is None),
|
||||
num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
|
||||
|
||||
# dataloader for center-cropped images, use larger batch size to increase speed
|
||||
@ -275,15 +277,18 @@ def train(train_loader, model, criterion, optimizer, epoch, args, cluster_result
|
||||
|
||||
im_q = []
|
||||
im_k = []
|
||||
class_number = len(images)
|
||||
class_len = len(images[0])
|
||||
for _i in range(0, class_len, 2):
|
||||
for c in range(class_number):
|
||||
im_q.append(images[c][_i])
|
||||
im_k.append(images[c][_i+1])
|
||||
|
||||
im_q = torch.stack(im_q)
|
||||
im_k = torch.stack(im_k)
|
||||
if cluster_result is None:
|
||||
class_number = len(images)
|
||||
class_len = len(images[0])
|
||||
for _i in range(0, class_len, 2):
|
||||
for c in range(class_number):
|
||||
im_q.append(images[c][_i])
|
||||
im_k.append(images[c][_i+1])
|
||||
im_q = torch.stack(im_q)
|
||||
im_k = torch.stack(im_k)
|
||||
else:
|
||||
im_q = images[0]
|
||||
im_k = images[1]
|
||||
|
||||
if args.gpu is not None:
|
||||
im_q = im_q.cuda(args.gpu, non_blocking=True)
|
||||
|
@ -22,11 +22,12 @@ class PreImager(tud.Dataset):
|
||||
|
||||
self.aug = aug
|
||||
self.images = img_class
|
||||
self.loader = data_meta.loader
|
||||
|
||||
def __getitem__(self, index):
|
||||
imgs = []
|
||||
for i in range(self.class_number):
|
||||
img = image.read_image(self.images[i][index]).float()
|
||||
img = self.loader(self.images[i][index])
|
||||
out = self.aug(img)
|
||||
imgs.append(out)
|
||||
return imgs, index
|
||||
|
Loading…
x
Reference in New Issue
Block a user