-
Notifications
You must be signed in to change notification settings - Fork 0
/
preprocess.py
197 lines (155 loc) · 7.78 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.utils.data as data
import clip.embed as clip
# from torch.utils.data import Subset
def get_data(device='cuda' if torch.cuda.is_available() else 'cpu',
mode='standard', dataset='mnist', val_split=0.2, batch_size=64,
save_embedding=True, shuffle=True):
train_dataset, test_dataset = None, None
if dataset == 'mnist':
# Download MNIST dataset
train_dataset = datasets.MNIST(
root='dataset/',
train=True,
transform=None,
download=True
)
# Create a subset of the training set that contains only the first 100 data points
# This is for debugging purposes
'''
indices_range = 1000
subset_indices = list(range(indices_range))
train_dataset = Subset(train_dataset, subset_indices)
'''
test_dataset = datasets.MNIST(
root='dataset/',
train=False,
transform=None,
download=True
)
elif dataset == 'cifar10':
# Download CIFAR10 dataset
train_dataset = datasets.CIFAR10(
root='dataset/',
train=True,
transform=None,
download=True
)
test_dataset = datasets.CIFAR10(
root='dataset/',
train=False,
transform=None,
download=True
)
if mode == 'standard':
to_tensor = transforms.ToTensor()
train_dataset = [(to_tensor(img), label) for img, label in train_dataset]
test_dataset = [(to_tensor(img), label) for img, label in test_dataset]
elif mode == 'clip':
class PreprocessedImagesDataset(data.Dataset):
def __init__(self, images, labels, preprocess):
self.images = images
self.labels = labels
self.preprocess = preprocess
def __getitem__(self, index):
preprocessed_image = self.preprocess(self.images[index])
label = self.labels[index]
return preprocessed_image, label
def __len__(self):
return len(self.images)
with torch.no_grad():
# Embed the images using CLIP with pretrained weights
model, preprocess = clip.create_model_and_transforms()
# Get the images and labels from the train and test datasets
train_images, train_labels = zip(*train_dataset)
test_images, test_labels = zip(*test_dataset)
# Turn the images and labels into datasets
train_dataset = PreprocessedImagesDataset(train_images, train_labels, preprocess)
test_dataset = PreprocessedImagesDataset(test_images, test_labels, preprocess)
train_image_embeds = []
train_text_embeds = []
train_labels = []
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
print("Embedding training images...")
# TODO: Need to implement a separate loader for different modalities
# e.g. Obtain the text through LLM/template mining and added to a text dataset
# then load the text and image datasets separately
for batch_images, batch_labels in train_dataloader:
batch_images = batch_images.float().to(device)
batch_image_embeds = model.encode_image(batch_images).float().to(device)
batch_image_embeds /= batch_image_embeds.norm(dim=-1, keepdim=True)
texts = batch_labels.numpy().astype(str)
# Fixed template for text embedding
text_tokens = clip.tokenizer.tokenize(["This is a photo of " + text for text in texts]).to(device)
batch_text_embeds = model.encode_text(text_tokens).float().to(device)
batch_text_embeds /= batch_text_embeds.norm(dim=-1, keepdim=True)
train_image_embeds.append(batch_image_embeds)
train_text_embeds.append(batch_text_embeds)
train_labels.append(batch_labels)
train_image_embeds = torch.cat(train_image_embeds, dim=0).to(device)
train_text_embeds = torch.cat(train_text_embeds, dim=0).to(device)
train_labels = torch.cat(train_labels, dim=0).to(device)
# TODO: I suspect that the concatenation of the image and text embeddings
# causes accuracy to be reported in an unintended and weird way
# Concatenate the image and text embeddings
train_vectors = torch.cat((train_image_embeds, train_text_embeds), dim=0).to(device)
train_labels_vector = torch.cat((train_labels, train_labels), dim=0).to(device)
train_dataset = data.TensorDataset(train_vectors, train_labels_vector)
test_image_embeds = []
test_labels = []
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
print("Embedding test images...")
for batch_images, batch_labels in test_dataloader:
batch_images = batch_images.float().to(device)
batch_image_embeds = model.encode_image(batch_images).float().to(device)
batch_image_embeds /= batch_image_embeds.norm(dim=-1, keepdim=True)
test_image_embeds.append(batch_image_embeds)
test_labels.append(batch_labels)
test_image_embeds = torch.cat(test_image_embeds, dim=0).to(device)
test_labels = torch.cat(test_labels, dim=0).to(device)
test_dataset = data.TensorDataset(test_image_embeds, test_labels)
# Save the embedded datasets to disk
save_dir = 'embeddings/'
if save_dir is not None and save_embedding:
torch.save(train_dataset, f"{save_dir}/{dataset}_train_dataset_embedded.pt")
torch.save(test_dataset, f"{save_dir}/{dataset}_test_dataset_embedded.pt")
else:
raise ValueError(f"Invalid mode: {mode}")
train_loader, val_loader = split_train_val(train_dataset, val_split)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=batch_size,
shuffle=True
)
return train_dataset, test_dataset, train_loader, val_loader, test_loader
def load_data(dataset='mnist', val_split=0.2, batch_size=64, shuffle=True):
print(f"Files found, loading {dataset} dataset...")
train_dataset = torch.load(f'embeddings/{dataset}_train_dataset_embedded.pt')
test_dataset = torch.load(f'embeddings/{dataset}_test_dataset_embedded.pt')
train_loader, val_loader = split_train_val(train_dataset, val_split)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=batch_size,
shuffle=True
)
return train_dataset, test_dataset, train_loader, val_loader, test_loader
def split_train_val(dataset, val_split):
train_size = int((1 - val_split) * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = data.random_split(dataset, [train_size, val_size])
train_loader = data.DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = data.DataLoader(val_dataset, batch_size=64, shuffle=True)
return train_loader, val_loader
# TODO: This function is a very ugly way to get the labels for the datasets
# I should find a better way to do it and make it works for any dataset
def get_labels(dataset):
if dataset == 'mnist':
labels = [str(i) for i in range(10)] # MNIST has 10 classes representing digits 0-9
elif dataset == 'cifar10':
labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
else:
raise ValueError(f"Unsupported dataset: {dataset}")
return labels