Fix bugs
This commit is contained in:
parent
57017a9e0e
commit
ab3bdb1591
27
main.py
27
main.py
@ -168,7 +168,9 @@ def main_worker(gpu, ngpus_per_node, args):
|
|||||||
|
|
||||||
pre_train_dataset = pcl.loader.PreImager(pre_train_dir, eval_augmentation)
|
pre_train_dataset = pcl.loader.PreImager(pre_train_dir, eval_augmentation)
|
||||||
|
|
||||||
train_dataset = pcl.loader.ImageFolderInstance(train_dir, eval_augmentation)
|
train_dataset = pcl.loader.ImageFolderInstance(
|
||||||
|
train_dir,
|
||||||
|
pcl.loader.TwoCropsTransform(eval_augmentation))
|
||||||
eval_dataset = pcl.loader.ImageFolderInstance(
|
eval_dataset = pcl.loader.ImageFolderInstance(
|
||||||
train_dir,
|
train_dir,
|
||||||
eval_augmentation)
|
eval_augmentation)
|
||||||
@ -195,7 +197,7 @@ def main_worker(gpu, ngpus_per_node, args):
|
|||||||
drop_last=True)
|
drop_last=True)
|
||||||
|
|
||||||
train_loader = torch.utils.data.DataLoader(
|
train_loader = torch.utils.data.DataLoader(
|
||||||
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
|
train_dataset, batch_size=args.batch_size//2, shuffle=(train_sampler is None),
|
||||||
num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
|
num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
|
||||||
|
|
||||||
# dataloader for center-cropped images, use larger batch size to increase speed
|
# dataloader for center-cropped images, use larger batch size to increase speed
|
||||||
@ -275,15 +277,18 @@ def train(train_loader, model, criterion, optimizer, epoch, args, cluster_result
|
|||||||
|
|
||||||
im_q = []
|
im_q = []
|
||||||
im_k = []
|
im_k = []
|
||||||
class_number = len(images)
|
if cluster_result is None:
|
||||||
class_len = len(images[0])
|
class_number = len(images)
|
||||||
for _i in range(0, class_len, 2):
|
class_len = len(images[0])
|
||||||
for c in range(class_number):
|
for _i in range(0, class_len, 2):
|
||||||
im_q.append(images[c][_i])
|
for c in range(class_number):
|
||||||
im_k.append(images[c][_i+1])
|
im_q.append(images[c][_i])
|
||||||
|
im_k.append(images[c][_i+1])
|
||||||
im_q = torch.stack(im_q)
|
im_q = torch.stack(im_q)
|
||||||
im_k = torch.stack(im_k)
|
im_k = torch.stack(im_k)
|
||||||
|
else:
|
||||||
|
im_q = images[0]
|
||||||
|
im_k = images[1]
|
||||||
|
|
||||||
if args.gpu is not None:
|
if args.gpu is not None:
|
||||||
im_q = im_q.cuda(args.gpu, non_blocking=True)
|
im_q = im_q.cuda(args.gpu, non_blocking=True)
|
||||||
|
@ -22,11 +22,12 @@ class PreImager(tud.Dataset):
|
|||||||
|
|
||||||
self.aug = aug
|
self.aug = aug
|
||||||
self.images = img_class
|
self.images = img_class
|
||||||
|
self.loader = data_meta.loader
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
imgs = []
|
imgs = []
|
||||||
for i in range(self.class_number):
|
for i in range(self.class_number):
|
||||||
img = image.read_image(self.images[i][index]).float()
|
img = self.loader(self.images[i][index])
|
||||||
out = self.aug(img)
|
out = self.aug(img)
|
||||||
imgs.append(out)
|
imgs.append(out)
|
||||||
return imgs, index
|
return imgs, index
|
||||||
|
Loading…
x
Reference in New Issue
Block a user