paint-brush
वेरिएशनल ऑटोएनकोडर के साथ अव्यक्त स्थान से नमूना कैसे लेंद्वारा@owlgrey
8,205 रीडिंग
8,205 रीडिंग

वेरिएशनल ऑटोएनकोडर के साथ अव्यक्त स्थान से नमूना कैसे लें

द्वारा Dmitrii Matveichev 17m2024/02/29
Read on Terminal Reader

बहुत लंबा; पढ़ने के लिए

पारंपरिक एई मॉडल के विपरीत, वेरिएशनल ऑटोएन्कोडर्स (वीएई) एक बहुभिन्नरूपी सामान्य वितरण के लिए इनपुट को मैप करते हैं, जिससे विभिन्न नमूनाकरण विधियों के माध्यम से उपन्यास डेटा पीढ़ी की अनुमति मिलती है। इस आलेख में शामिल नमूनाकरण विधियां पश्च नमूनाकरण, पूर्व नमूनाकरण, दो वैक्टरों के बीच इंटरपोलेशन और अव्यक्त आयाम ट्रैवर्सल हैं।
featured image - वेरिएशनल ऑटोएनकोडर के साथ अव्यक्त स्थान से नमूना कैसे लें
Dmitrii Matveichev  HackerNoon profile picture

पारंपरिक ऑटोएनकोडर के समान, VAE आर्किटेक्चर के दो भाग होते हैं: एक एनकोडर और एक डिकोडर। पारंपरिक एई मॉडल इनपुट को एक अव्यक्त-स्पेस वेक्टर में मैप करते हैं और इस वेक्टर से आउटपुट का पुनर्निर्माण करते हैं।


VAE एक बहुभिन्नरूपी सामान्य वितरण में इनपुट को मैप करता है (एनकोडर प्रत्येक अव्यक्त आयाम के माध्य और विचरण को आउटपुट करता है)।


चूंकि वीएई एनकोडर एक वितरण उत्पन्न करता है, इसलिए इस वितरण से नमूना लेकर और नमूना किए गए अव्यक्त वेक्टर को डिकोडर में पास करके नया डेटा उत्पन्न किया जा सकता है। आउटपुट छवियों को उत्पन्न करने के लिए उत्पादित वितरण से नमूना लेने का मतलब है कि वीएई नए डेटा को उत्पन्न करने की अनुमति देता है जो इनपुट डेटा के समान है, लेकिन समान है।


यह आलेख वीएई वास्तुकला के घटकों की खोज करता है और वीएई मॉडल के साथ नई छवियां (नमूना) उत्पन्न करने के कई तरीके प्रदान करता है। सभी कोड Google Colab पर उपलब्ध हैं।

1 वीएई मॉडल कार्यान्वयन


एई मॉडल को पुनर्निर्माण हानि को कम करके प्रशिक्षित किया जाता है (उदाहरण के लिए बीसीई या एमएसई)


ऑटोएनकोडर और वेरिएशनल ऑटोएनकोडर दोनों के दो भाग होते हैं: एनकोडर और डिकोडर। एई का एनकोडर न्यूरल नेटवर्क प्रत्येक छवि को अव्यक्त स्थान में एक वेक्टर में मैप करना सीखता है और डिकोडर एन्कोडेड अव्यक्त वेक्टर से मूल छवि को फिर से बनाना सीखता है।


वीएई मॉडल को पुनर्निर्माण हानि और केएल-विचलन हानि को कम करके प्रशिक्षित किया जाता है


वीएई का एनकोडर तंत्रिका नेटवर्क पैरामीटर आउटपुट करता है जो अव्यक्त स्थान (बहुभिन्नरूपी वितरण) के प्रत्येक आयाम के लिए संभाव्यता वितरण को परिभाषित करता है। प्रत्येक इनपुट के लिए, एनकोडर अव्यक्त स्थान के प्रत्येक आयाम के लिए एक माध्य और एक भिन्नता उत्पन्न करता है।


आउटपुट माध्य और विचरण का उपयोग बहुभिन्नरूपी गाऊसी वितरण को परिभाषित करने के लिए किया जाता है। डिकोडर न्यूरल नेटवर्क AE मॉडल के समान है।

1.1 वीएई हानियाँ

वीएई मॉडल को प्रशिक्षित करने का लक्ष्य प्रदान किए गए अव्यक्त वैक्टर से वास्तविक छवियां उत्पन्न करने की संभावना को अधिकतम करना है। प्रशिक्षण के दौरान, VAE मॉडल दो नुकसानों को कम करता है:


  • पुनर्निर्माण हानि - इनपुट छवियों और डिकोडर के आउटपुट के बीच का अंतर।


  • कुल्बैक-लीबलर विचलन हानि (केएल विचलन दो संभाव्यता वितरणों के बीच एक सांख्यिकीय दूरी) - एनकोडर के आउटपुट की संभाव्यता वितरण और पूर्व वितरण (एक मानक सामान्य वितरण) के बीच की दूरी, अव्यक्त स्थान को नियमित करने में मदद करती है।

1.2 पुनर्निर्माण हानि

सामान्य पुनर्निर्माण हानियाँ बाइनरी क्रॉस-एन्ट्रॉपी (बीसीई) और माध्य वर्ग त्रुटि (एमएसई) हैं। इस लेख में, मैं डेमो के लिए एमएनआईएसटी डेटासेट का उपयोग करूंगा। एमएनआईएसटी छवियों में केवल एक चैनल होता है, और पिक्सेल 0 और 1 के बीच मान लेते हैं।


इस मामले में, बीसीई हानि का उपयोग एमएनआईएसटी छवियों के पिक्सल को एक द्विआधारी यादृच्छिक चर के रूप में इलाज करने के लिए पुनर्निर्माण हानि के रूप में किया जा सकता है जो बर्नौली वितरण का अनुसरण करता है।

 reconstruction_loss = nn.BCELoss(reduction='sum')

1.3 कुल्बैक-लीब्लर डाइवर्जेंस

जैसा कि ऊपर बताया गया है - केएल विचलन दो वितरणों के बीच अंतर का मूल्यांकन करता है। ध्यान दें कि इसमें दूरी का सममित गुण नहीं है: KL(P‖Q)!=KL(Q‖P)।


जिन दो वितरणों की तुलना करने की आवश्यकता है वे हैं:

  • दिए गए इनपुट चित्र x: q(z|x) एनकोडर आउटपुट का गुप्त स्थान


  • p(z) से पहले अव्यक्त स्थान जिसे शून्य के माध्य और प्रत्येक अव्यक्त स्थान आयाम N(0, I ) में एक के मानक विचलन के साथ एक सामान्य वितरण माना जाता है।


    ऐसी धारणा केएल विचलन गणना को सरल बनाती है और अव्यक्त स्थान को ज्ञात, प्रबंधनीय वितरण का पालन करने के लिए प्रोत्साहित करती है।

 from torch.distributions.kl import kl_divergence def kl_divergence_loss(z_dist): return kl_divergence(z_dist, Normal(torch.zeros_like(z_dist.mean), torch.ones_like(z_dist.stddev)) ).sum(-1).sum()

1.4 एनकोडर

 class Encoder(nn.Module): def __init__(self, im_chan=1, output_chan=32, hidden_dim=16): super(Encoder, self).__init__() self.z_dim = output_chan self.encoder = nn.Sequential( self.init_conv_block(im_chan, hidden_dim), self.init_conv_block(hidden_dim, hidden_dim * 2), # double output_chan for mean and std with [output_chan] size self.init_conv_block(hidden_dim * 2, output_chan * 2, final_layer=True), ) def init_conv_block(self, input_channels, output_channels, kernel_size=4, stride=2, padding=0, final_layer=False): layers = [ nn.Conv2d(input_channels, output_channels, kernel_size=kernel_size, padding=padding, stride=stride) ] if not final_layer: layers += [ nn.BatchNorm2d(output_channels), nn.ReLU(inplace=True) ] return nn.Sequential(*layers) def forward(self, image): encoder_pred = self.encoder(image) encoding = encoder_pred.view(len(encoder_pred), -1) mean = encoding[:, :self.z_dim] logvar = encoding[:, self.z_dim:] # encoding output representing standard deviation is interpreted as # the logarithm of the variance associated with the normal distribution # take the exponent to convert it to standard deviation return mean, torch.exp(logvar*0.5)

1.5 डिकोडर

 class Decoder(nn.Module): def __init__(self, z_dim=32, im_chan=1, hidden_dim=64): super(Decoder, self).__init__() self.z_dim = z_dim self.decoder = nn.Sequential( self.init_conv_block(z_dim, hidden_dim * 4), self.init_conv_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1), self.init_conv_block(hidden_dim * 2, hidden_dim), self.init_conv_block(hidden_dim, im_chan, kernel_size=4, final_layer=True), ) def init_conv_block(self, input_channels, output_channels, kernel_size=3, stride=2, padding=0, final_layer=False): layers = [ nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding) ] if not final_layer: layers += [ nn.BatchNorm2d(output_channels), nn.ReLU(inplace=True) ] else: layers += [nn.Sigmoid()] return nn.Sequential(*layers) def forward(self, z): # Ensure the input latent vector z is correctly reshaped for the decoder x = z.view(-1, self.z_dim, 1, 1) # Pass the reshaped input through the decoder network return self.decoder(x)

1.6 वीएई मॉडल

एक यादृच्छिक नमूने के माध्यम से बैक-प्रोपेगेट करने के लिए आपको मापदंडों के माध्यम से ग्रेडिएंट गणना की अनुमति देने के लिए यादृच्छिक नमूने ( μ और 𝝈) के मापदंडों को फ़ंक्शन के बाहर ले जाना होगा। इस चरण को "पुनरावर्तन चाल" भी कहा जाता है।


PyTorch में, आप एनकोडर के आउटपुट μ और 𝝈 के साथ एक सामान्य वितरण बना सकते हैं और उसमें से rsample() विधि का नमूना ले सकते हैं जो रिपैरामीटराइजेशन ट्रिक को लागू करता है: यह torch.randn(z_dim) * stddev + mean)

 class VAE(nn.Module): def __init__(self, z_dim=32, im_chan=1): super(VAE, self).__init__() self.z_dim = z_dim self.encoder = Encoder(im_chan, z_dim) self.decoder = Decoder(z_dim, im_chan) def forward(self, images): z_dist = Normal(self.encoder(images)) # sample from distribution with reparametarazation trick z = z_dist.rsample() decoding = self.decoder(z) return decoding, z_dist

1.7 वीएई का प्रशिक्षण

एमएनआईएसटी ट्रेन और परीक्षण डेटा लोड करें।

 transform = transforms.Compose([transforms.ToTensor()]) # Download and load the MNIST training data trainset = datasets.MNIST('.', download=True, train=True, transform=transform) train_loader = DataLoader(trainset, batch_size=64, shuffle=True) # Download and load the MNIST test data testset = datasets.MNIST('.', download=True, train=False, transform=transform) test_loader = DataLoader(testset, batch_size=64, shuffle=True) 

वीएई प्रशिक्षण चरण

एक प्रशिक्षण लूप बनाएं जो ऊपर चित्र में दिखाए गए वीएई प्रशिक्षण चरणों का पालन करता हो।

 def train_model(epochs=10, z_dim = 16): model = VAE(z_dim=z_dim).to(device) model_opt = torch.optim.Adam(model.parameters()) for epoch in range(epochs): print(f"Epoch {epoch}") for images, step in tqdm(train_loader): images = images.to(device) model_opt.zero_grad() recon_images, encoding = model(images) loss = reconstruction_loss(recon_images, images)+ kl_divergence_loss(encoding) loss.backward() model_opt.step() show_images_grid(images.cpu(), title=f'Input images') show_images_grid(recon_images.cpu(), title=f'Reconstructed images') return model
 z_dim = 8 vae = train_model(epochs=20, z_dim=z_dim) 

1.8 अव्यक्त स्थान की कल्पना करें

 def visualize_latent_space(model, data_loader, device, method='TSNE', num_samples=10000): model.eval() latents = [] labels = [] with torch.no_grad(): for i, (data, label) in enumerate(data_loader): if len(latents) > num_samples: break mu, _ = model.encoder(data.to(device)) latents.append(mu.cpu()) labels.append(label.cpu()) latents = torch.cat(latents, dim=0).numpy() labels = torch.cat(labels, dim=0).numpy() assert method in ['TSNE', 'UMAP'], 'method should be TSNE or UMAP' if method == 'TSNE': tsne = TSNE(n_components=2, verbose=1) tsne_results = tsne.fit_transform(latents) fig = px.scatter(tsne_results, x=0, y=1, color=labels, labels={'color': 'label'}) fig.update_layout(title='VAE Latent Space with TSNE', width=600, height=600) elif method == 'UMAP': reducer = umap.UMAP() embedding = reducer.fit_transform(latents) fig = px.scatter(embedding, x=0, y=1, color=labels, labels={'color': 'label'}) fig.update_layout(title='VAE Latent Space with UMAP', width=600, height=600 ) fig.show()
 visualize_latent_space(vae, train_loader, device='cuda' if torch.cuda.is_available() else 'cpu', method='UMAP', num_samples=10000) 

वीएई मॉडल अव्यक्त स्थान को यूएमएपी के साथ देखा गया

2 वीएई के साथ नमूनाकरण

वेरिएशनल ऑटोएन्कोडर (वीएई) से नमूनाकरण नए डेटा की पीढ़ी को सक्षम बनाता है जो प्रशिक्षण के दौरान देखे गए डेटा के समान है और यह एक अनूठा पहलू है जो वीएई को पारंपरिक एई आर्किटेक्चर से अलग करता है।


VAE से नमूना लेने के कई तरीके हैं:

  • पश्च नमूनाकरण: दिए गए इनपुट को देखते हुए पश्च वितरण से नमूना लेना।


  • पूर्व नमूनाकरण: एक मानक सामान्य बहुभिन्नरूपी वितरण मानकर अव्यक्त स्थान से नमूनाकरण। यह इस धारणा (वीएई प्रशिक्षण के दौरान प्रयुक्त) के कारण संभव है कि अव्यक्त चर सामान्य रूप से वितरित होते हैं। यह विधि विशिष्ट गुणों के साथ डेटा उत्पन्न करने की अनुमति नहीं देती है (उदाहरण के लिए, किसी विशिष्ट वर्ग से डेटा उत्पन्न करना)।


  • प्रक्षेप : अव्यक्त स्थान में दो बिंदुओं के बीच प्रक्षेप से पता चल सकता है कि अव्यक्त स्थान चर में परिवर्तन उत्पन्न डेटा में परिवर्तन से कैसे मेल खाते हैं।


  • अव्यक्त आयामों का ट्रैवर्सल : वीएई के अव्यक्त आयामों का ट्रैवर्सिंग डेटा का अव्यक्त स्थान विचरण प्रत्येक आयाम पर निर्भर करता है। ट्रैवर्सल एक चुने हुए आयाम को छोड़कर अव्यक्त वेक्टर के सभी आयामों और उसकी सीमा में चुने गए आयाम के अलग-अलग मूल्यों को तय करके किया जाता है। अव्यक्त स्थान के कुछ आयाम डेटा की विशिष्ट विशेषताओं के अनुरूप हो सकते हैं (वीएई के पास उस व्यवहार को मजबूर करने के लिए विशिष्ट तंत्र नहीं हैं लेकिन ऐसा हो सकता है)।


    उदाहरण के लिए, अव्यक्त स्थान में एक आयाम किसी चेहरे की भावनात्मक अभिव्यक्ति या किसी वस्तु के अभिविन्यास को नियंत्रित कर सकता है।


प्रत्येक नमूनाकरण विधि वीएई के अव्यक्त स्थान द्वारा कैप्चर किए गए डेटा गुणों की खोज और समझने का एक अलग तरीका प्रदान करती है।

2.1 पिछला नमूना (दी गई इनपुट छवि से)

एनकोडर अव्यक्त स्थान में एक वितरण (सामान्य वितरण का μ_x और 𝝈_x) आउटपुट करता है। सामान्य वितरण N(μ_x, 𝝈_x) से नमूना लेने और नमूना वेक्टर को डिकोडर में पास करने से दी गई इनपुट छवि के समान एक छवि उत्पन्न होती है।

 def posterior_sampling(model, data_loader, n_samples=25): model.eval() images, _ = next(iter(data_loader)) images = images[:n_samples] with torch.no_grad(): _, encoding_dist = model(images.to(device)) input_sample=encoding_dist.sample() recon_images = model.decoder(input_sample) show_images_grid(images, title=f'input samples') show_images_grid(recon_images, title=f'generated posterior samples')
 posterior_sampling(vae, train_loader, n_samples=30)

पश्च नमूनाकरण यथार्थवादी डेटा नमूने उत्पन्न करने की अनुमति देता है लेकिन कम परिवर्तनशीलता के साथ: आउटपुट डेटा इनपुट डेटा के समान होता है।

2.2 पूर्व नमूनाकरण (एक यादृच्छिक अव्यक्त अंतरिक्ष वेक्टर से)

वितरण से नमूना लेना और नमूना वेक्टर को डिकोडर में पास करने से नए डेटा उत्पन्न करने की अनुमति मिलती है

 def prior_sampling(model, z_dim=32, n_samples = 25): model.eval() input_sample=torch.randn(n_samples, z_dim).to(device) with torch.no_grad(): sampled_images = model.decoder(input_sample) show_images_grid(sampled_images, title=f'generated prior samples')
 prior_sampling(vae, z_dim, n_samples=40)

N(0, I ) के साथ पूर्व नमूनाकरण हमेशा विश्वसनीय डेटा उत्पन्न नहीं करता है लेकिन इसमें उच्च परिवर्तनशीलता होती है।

2.3 कक्षा केन्द्रों से नमूनाकरण

प्रत्येक वर्ग के माध्य एन्कोडिंग को संपूर्ण डेटासेट से संचित किया जा सकता है और बाद में नियंत्रित (सशर्त पीढ़ी) के लिए उपयोग किया जा सकता है।

 def get_data_predictions(model, data_loader): model.eval() latents_mean = [] latents_std = [] labels = [] with torch.no_grad(): for i, (data, label) in enumerate(data_loader): mu, std = model.encoder(data.to(device)) latents_mean.append(mu.cpu()) latents_std.append(std.cpu()) labels.append(label.cpu()) latents_mean = torch.cat(latents_mean, dim=0) latents_std = torch.cat(latents_std, dim=0) labels = torch.cat(labels, dim=0) return latents_mean, latents_std, labels
 def get_classes_mean(class_to_idx, labels, latents_mean, latents_std): classes_mean = {} for class_name in train_loader.dataset.class_to_idx: class_id = train_loader.dataset.class_to_idx[class_name] labels_class = labels[labels==class_id] latents_mean_class = latents_mean[labels==class_id] latents_mean_class = latents_mean_class.mean(dim=0, keepdims=True) latents_std_class = latents_std[labels==class_id] latents_std_class = latents_std_class.mean(dim=0, keepdims=True) classes_mean[class_id] = [latents_mean_class, latents_std_class] return classes_mean
 latents_mean, latents_stdvar, labels = get_data_predictions(vae, train_loader) classes_mean = get_classes_mean(train_loader.dataset.class_to_idx, labels, latents_mean, latents_stdvar) n_samples = 20 for class_id in classes_mean.keys(): latents_mean_class, latents_stddev_class = classes_mean[class_id] # create normal distribution of the current class class_dist = Normal(latents_mean_class, latents_stddev_class) percentiles = torch.linspace(0.05, 0.95, n_samples) # get samples from different parts of the distribution using icdf # https://pytorch.org/docs/stable/distributions.html#torch.distributions.distribution.Distribution.icdf class_z_sample = class_dist.icdf(percentiles[:, None].repeat(1, z_dim)) with torch.no_grad(): # generate image directly from mean class_image_prototype = vae.decoder(latents_mean_class.to(device)) # generate images sampled from Normal(class mean, class std) class_images = vae.decoder(class_z_sample.to(device)) show_image(class_image_prototype[0].cpu(), title=f'Class {class_id} prototype image') show_images_grid(class_images.cpu(), title=f'Class {class_id} images')

औसत वर्ग μ के साथ सामान्य वितरण से नमूनाकरण उसी वर्ग से नए डेटा की पीढ़ी की गारंटी देता है।

कक्षा 3 केंद्र से उत्पन्न छवि

कक्षा 4 केंद्र से उत्पन्न छवि

आईसीडीएफ में उपयोग किए जाने वाले निम्न और उच्च प्रतिशतक मानों के परिणामस्वरूप उच्च डेटा भिन्नता होती है

2.4 अंतर्वेशन

 def linear_interpolation(start, end, steps): # Create a linear path from start to end z = torch.linspace(0, 1, steps)[:, None].to(device) * (end - start) + start # Decode the samples along the path vae.eval() with torch.no_grad(): samples = vae.decoder(z) return samples

2.4.1 दो यादृच्छिक अव्यक्त वेक्टरों के बीच अंतर्वेशन

 start = torch.randn(1, z_dim).to(device) end = torch.randn(1, z_dim).to(device) interpolated_samples = linear_interpolation(start, end, steps = 24) show_images_grid(interpolated_samples, title=f'Linear interpolation between two random latent vectors') 

2.4.2 दो वर्ग केन्द्रों के बीच अंतर्वेशन

 for start_class_id in range(1,10): start = classes_mean[start_class_id][0].to(device) for end_class_id in range(1, 10): if end_class_id == start_class_id: continue end = classes_mean[end_class_id][0].to(device) interpolated_samples = linear_interpolation(start, end, steps = 20) show_images_grid(interpolated_samples, title=f'Linear interpolation between classes {start_class_id} and {end_class_id}') 

2.5 अव्यक्त अंतरिक्ष ट्रैवर्सल

अव्यक्त वेक्टर का प्रत्येक आयाम एक सामान्य वितरण का प्रतिनिधित्व करता है; आयाम के मानों की सीमा को आयाम के माध्य और विचरण द्वारा नियंत्रित किया जाता है। मूल्यों की सीमा को पार करने का एक सरल तरीका सामान्य वितरण के व्युत्क्रम सीडीएफ (संचयी वितरण फ़ंक्शन) का उपयोग करना होगा।


ICDF 0 और 1 के बीच मान लेता है (संभावना का प्रतिनिधित्व करता है) और वितरण से एक मान लौटाता है। किसी दी गई प्रायिकता p के लिए ICDF एक p_icdf मान इस प्रकार आउटपुट करता है कि एक यादृच्छिक चर के <= p_icdf होने की प्रायिकता दी गई प्रायिकता p के बराबर होती है?


यदि आपके पास सामान्य वितरण है, तो icdf(0.5) को वितरण का माध्य लौटाना चाहिए। icdf(0.95) को वितरण से 95% डेटा से बड़ा मान लौटाना चाहिए।

सीडीएफ का विज़ुअलाइज़ेशन और आईसीडीएफ द्वारा दिए गए मान दिए गए संभावनाएं 0.025, 0.5, 0.975

2.5.1 एकल आयाम अव्यक्त अंतरिक्ष ट्रैवर्सल

 def latent_space_traversal(model, input_sample, norm_dist, dim_to_traverse, n_samples, latent_dim, device): # Create a range of values to traverse assert input_sample.shape[0] == 1, 'input sample shape should be [1, latent_dim]' # Generate linearly spaced percentiles between 0.05 and 0.95 percentiles = torch.linspace(0.1, 0.9, n_samples) # Get the quantile values corresponding to the percentiles traversed_values = norm_dist.icdf(percentiles[:, None].repeat(1, z_dim)) # Initialize a latent space vector with zeros z = input_sample.repeat(n_samples, 1) # Assign the traversed values to the specified dimension z[:, dim_to_traverse] = traversed_values[:, dim_to_traverse] # Decode the latent vectors with torch.no_grad(): samples = model.decoder(z.to(device)) return samples
 for class_id in range(0,10): mu, std = classes_mean[class_id] with torch.no_grad(): recon_images = vae.decoder(mu.to(device)) show_image(recon_images[0], title=f'class {class_id} mean sample') for i in range(z_dim): interpolated_samples = latent_space_traversal(vae, mu, norm_dist=Normal(mu, torch.ones_like(mu)), dim_to_traverse=i, n_samples=20, latent_dim=z_dim, device=device) show_images_grid(interpolated_samples, title=f'Class {class_id} dim={i} traversal')

किसी एकल आयाम को पार करने से अंक शैली या नियंत्रण अंक अभिविन्यास में परिवर्तन हो सकता है।

2.5.3 दो आयाम अव्यक्त अंतरिक्ष ट्रैवर्सल

 def traverse_two_latent_dimensions(model, input_sample, z_dist, n_samples=25, z_dim=16, dim_1=0, dim_2=1, title='plot'): digit_size=28 percentiles = torch.linspace(0.10, 0.9, n_samples) grid_x = z_dist.icdf(percentiles[:, None].repeat(1, z_dim)) grid_y = z_dist.icdf(percentiles[:, None].repeat(1, z_dim)) figure = np.zeros((digit_size * n_samples, digit_size * n_samples)) z_sample_def = input_sample.clone().detach() # select two dimensions to vary (dim_1 and dim_2) and keep the rest fixed for yi in range(n_samples): for xi in range(n_samples): with torch.no_grad(): z_sample = z_sample_def.clone().detach() z_sample[:, dim_1] = grid_x[xi, dim_1] z_sample[:, dim_2] = grid_y[yi, dim_2] x_decoded = model.decoder(z_sample.to(device)).cpu() digit = x_decoded[0].reshape(digit_size, digit_size) figure[yi * digit_size: (yi + 1) * digit_size, xi * digit_size: (xi + 1) * digit_size] = digit.numpy() plt.figure(figsize=(6, 6)) plt.imshow(figure, cmap='Greys_r') plt.title(title) plt.show()
 for class_id in range(10): mu, std = classes_mean[class_id] with torch.no_grad(): recon_images = vae.decoder(mu.to(device)) show_image(recon_images[0], title=f'class {class_id} mean sample') traverse_two_latent_dimensions(vae, mu, z_dist=Normal(mu, torch.ones_like(mu)), n_samples=8, z_dim=z_dim, dim_1=3, dim_2=6, title=f'Class {class_id} traversing dimensions {(3, 6)}')

एक साथ कई आयामों को पार करना उच्च परिवर्तनशीलता के साथ डेटा उत्पन्न करने का एक नियंत्रणीय तरीका प्रदान करता है।

2.6 बोनस - अव्यक्त स्थान से अंकों का 2डी मैनिफोल्ड

यदि VAE मॉडल को z_dim =2 के साथ प्रशिक्षित किया जाता है, तो इसके अव्यक्त स्थान से अंकों का 2D मैनिफोल्ड प्रदर्शित करना संभव है। ऐसा करने के लिए, मैं dim_1 =0 और dim_2 =2 के साथ travers_two_latent_dimensions फ़ंक्शन का उपयोग करूंगा।

 vae_2d = train_model(epochs=10, z_dim=2)
 z_dist = Normal(torch.zeros(1, 2), torch.ones(1, 2)) input_sample = torch.zeros(1, 2) with torch.no_grad(): decoding = vae_2d.decoder(input_sample.to(device)) traverse_two_latent_dimensions(vae_2d, input_sample, z_dist, n_samples=20, dim_1=0, dim_2=1, z_dim=2, title=f'traversing 2D latent space') 

2D अव्यक्त स्थान