# aaaabbbb
import torch.nn as nn
import torch
import matplotlib.pyplot as plt
from tqdm import trange
Procedurally-generated environment consists of a small set of \(K\) input symbols and a set of \(K\) output symbols. The meta learner remembers input-output symbol mappings and is expected to generalize over the sets of symbols.
= 16, 32
I, O = 8
K = 24
T = 31
Nbatch = 10000, 100
Nalphabets_train, Nalphabets_valid
= [nn.Embedding(K, I) for _ in range(Nalphabets_train)]
train_inputs = [nn.Embedding(K, O) for _ in range(Nalphabets_train)]
train_outputs
= [nn.Embedding(K, I) for _ in range(Nalphabets_valid)]
valid_inputs = [nn.Embedding(K, O) for _ in range(Nalphabets_valid)]
valid_outputs
= torch.zeros(Nbatch, O) pad
def make_sequences(inputs, generator=None):
= torch.randint(0, len(inputs), (Nbatch,), generator=generator)
alphabets = torch.randint(0, K, (Nbatch, T), generator=generator)
sequences return alphabets, sequences
@torch.inference_mode()
def make_batch(inputs, outputs, generator=None):
= make_sequences(inputs, generator)
alphabets, sequences
= torch.stack([
all_inputs for s in seq])
torch.stack([inputs[i](s) for i, seq in zip(alphabets, sequences)
0,1)
]).transpose(= torch.stack([
all_targets 0]] + [outputs[i](s) for s in seq])
torch.stack([pad[for i, seq in zip(alphabets, sequences)
0,1)
]).transpose(
= torch.cat([all_inputs, torch.zeros_like(all_targets[:-1])], dim=-1)
nar_x = torch.cat([all_inputs, all_targets[:-1]], dim=-1)
x return x, nar_x, all_targets[1:]
# magic sequence
=torch.Generator().manual_seed(0))[1][0] make_sequences(valid_inputs, generator
tensor([4, 3, 0, 3, 5, 6, 7, 7, 0, 2, 3, 0, 1, 3, 5, 3, 3, 6, 7, 0, 1, 1, 1, 7])
class Model(nn.Module):
def __init__(self, hidden=512):
super().__init__()
self.readin = nn.Linear(I+O, hidden)
self.readout = nn.Linear(hidden, O)
self.hidden = hidden
self.rnn = nn.LSTM(hidden, hidden, num_layers=1)
def forward(self, x):
return self.readout(self.rnn(self.readin(x))[0])
= 'cuda:1'
device = Model().to(device)
model = Model().to(device)
nar_model
= torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = torch.optim.Adam(nar_model.parameters(), lr=0.001)
nar_optimizer
= 10000
Nsteps = torch.zeros(Nsteps)
train_losses = torch.zeros(Nsteps)
valid_losses = torch.zeros(Nsteps, T, Nbatch)
valid_stepwise_losses
= torch.zeros(Nsteps)
nar_train_losses = torch.zeros(Nsteps)
nar_valid_losses = torch.zeros(Nsteps, T, Nbatch)
nar_valid_stepwise_losses
with trange(Nsteps) as t:
for step in t:
optimizer.zero_grad()
nar_optimizer.zero_grad()
= make_batch(train_inputs, train_outputs)
x, nar_x, y_targets
= y_targets.to(device)
y_targets = model(x.to(device))
y
= (y - y_targets.detach()).pow(2).mean()
loss
loss.backward()
optimizer.step()= loss.item()
train_losses[step]
= nar_model(nar_x.to(device))
nar_y
= (nar_y - y_targets.detach()).pow(2).mean()
nar_loss
nar_loss.backward()
nar_optimizer.step()= nar_loss.item()
nar_train_losses[step]
=loss.item(), nar_loss=nar_loss.item())
t.set_postfix(loss
with torch.inference_mode():
= make_sequences(valid_inputs, generator=torch.Generator().manual_seed(0))[1]
sequences = make_batch(valid_inputs, valid_outputs, generator=torch.Generator().manual_seed(0))
x_valid, nar_x_valid, y_targets_valid = y_targets_valid.to(device)
y_targets_valid = model(x_valid.to(device))
y_valid = nar_model(nar_x_valid.to(device))
nar_y_valid
= (y_valid - y_targets_valid).pow(2).mean()
loss_valid = (nar_y_valid - y_targets_valid).pow(2).mean()
nar_loss_valid
= (y_valid - y_targets_valid).pow(2).mean(dim=-1)
loss_valid_stepwise = loss_valid.item()
valid_losses[step] = loss_valid_stepwise.cpu()
valid_stepwise_losses[step, :, :]
= (nar_y_valid - y_targets_valid).pow(2).mean(dim=-1)
nar_loss_valid_stepwise = nar_loss_valid.item()
nar_valid_losses[step] = nar_loss_valid_stepwise.cpu()
nar_valid_stepwise_losses[step, :, :]
if step and step % 100 == 0:
= plt.subplots(1, 3, figsize=(16, 4))
fig, (axl, axc, axr)
='blue')
axl.plot(train_losses[:step], color='orange')
axl.plot(valid_losses[:step], color='--', color='blue')
axl.plot(nar_train_losses[:step], linestyle='--', color='orange')
axl.plot(nar_valid_losses[:step], linestyle
0], color='blue', alpha=0.5, marker='x')
axc.plot(torch.arange(T), valid_stepwise_losses[step, :, for (i, j, text) in zip(torch.arange(T), valid_stepwise_losses[step, :, 0], sequences[0]):
f'{text}', color='blue')
axc.text(i, j,
0], color='blue', alpha=0.2, marker='x', linestyle='--')
axc.plot(torch.arange(T), nar_valid_stepwise_losses[step, :,
''.join([str(x) for x in make_sequences(valid_inputs, generator=torch.Generator().manual_seed(0))[1][0].tolist()]))
axc.set_title(
1], color='orange', alpha=0.5, marker='x')
axr.plot(torch.arange(T), valid_stepwise_losses[step, :, for (i, j, text) in zip(torch.arange(T), valid_stepwise_losses[step, :, 1], sequences[1]):
f'{text}', color='orange')
axr.text(i, j,
1], color='orange', alpha=0.2, marker='x', linestyle='--')
axr.plot(torch.arange(T), nar_valid_stepwise_losses[step, :,
''.join([str(x) for x in make_sequences(valid_inputs, generator=torch.Generator().manual_seed(0))[1][1].tolist()]))
axr.set_title( plt.show()
1%| | 100/10000 [00:03<04:56, 33.39it/s, loss=0.95, nar_loss=1.01] 2%|▏ | 200/10000 [00:06<05:01, 32.48it/s, loss=0.937, nar_loss=1] 3%|▎ | 300/10000 [00:09<04:59, 32.37it/s, loss=0.942, nar_loss=0.999] 4%|▍ | 400/10000 [00:13<04:53, 32.70it/s, loss=0.917, nar_loss=0.982] 5%|▌ | 500/10000 [00:16<04:46, 33.15it/s, loss=0.937, nar_loss=0.997] 6%|▌ | 600/10000 [00:19<04:43, 33.19it/s, loss=0.948, nar_loss=1.01] 7%|▋ | 700/10000 [00:23<04:43, 32.79it/s, loss=0.923, nar_loss=0.992] 8%|▊ | 800/10000 [00:26<04:44, 32.34it/s, loss=0.931, nar_loss=1.01] 9%|▉ | 900/10000 [00:30<04:40, 32.39it/s, loss=0.926, nar_loss=0.998] 10%|█ | 1000/10000 [00:33<04:34, 32.75it/s, loss=0.898, nar_loss=1.01] 11%|█ | 1100/10000 [00:37<04:32, 32.66it/s, loss=0.849, nar_loss=0.996] 12%|█▏ | 1200/10000 [00:40<04:32, 32.33it/s, loss=0.819, nar_loss=0.982] 13%|█▎ | 1300/10000 [00:44<04:28, 32.37it/s, loss=0.766, nar_loss=0.958] 14%|█▍ | 1400/10000 [00:47<04:26, 32.22it/s, loss=0.771, nar_loss=0.981] 15%|█▌ | 1500/10000 [00:51<04:19, 32.74it/s, loss=0.768, nar_loss=1] 16%|█▌ | 1600/10000 [00:54<04:17, 32.67it/s, loss=0.744, nar_loss=0.959] 17%|█▋ | 1700/10000 [00:57<04:15, 32.48it/s, loss=0.744, nar_loss=1.02] 18%|█▊ | 1800/10000 [01:01<04:13, 32.35it/s, loss=0.729, nar_loss=0.997] 19%|█▉ | 1900/10000 [01:05<04:08, 32.59it/s, loss=0.719, nar_loss=0.995] 20%|██ | 2000/10000 [01:08<04:09, 32.02it/s, loss=0.722, nar_loss=1] 21%|██ | 2100/10000 [01:12<04:06, 32.01it/s, loss=0.698, nar_loss=0.981] 22%|██▏ | 2200/10000 [01:15<04:01, 32.25it/s, loss=0.702, nar_loss=0.985] 23%|██▎ | 2300/10000 [01:19<03:55, 32.66it/s, loss=0.67, nar_loss=0.995] 24%|██▍ | 2400/10000 [01:22<03:50, 32.98it/s, loss=0.677, nar_loss=1.01] 25%|██▌ | 2500/10000 [01:26<03:52, 32.28it/s, loss=0.665, nar_loss=0.979] 26%|██▌ | 2600/10000 [01:29<03:47, 32.47it/s, loss=0.676, nar_loss=1.01] 27%|██▋ | 2700/10000 [01:33<03:45, 32.35it/s, loss=0.667, nar_loss=0.999] 28%|██▊ | 2800/10000 [01:36<03:38, 32.91it/s, loss=0.663, nar_loss=0.986] 29%|██▉ | 2900/10000 [01:40<03:37, 32.69it/s, loss=0.655, nar_loss=0.985] 30%|███ | 3000/10000 [01:43<03:37, 32.19it/s, loss=0.669, nar_loss=0.992] 31%|███ | 3100/10000 [01:47<03:33, 32.31it/s, loss=0.668, nar_loss=1] 32%|███▏ | 3200/10000 [01:50<03:25, 33.16it/s, loss=0.675, nar_loss=1.02] 33%|███▎ | 3300/10000 [01:54<03:27, 32.22it/s, loss=0.662, nar_loss=0.978] 34%|███▍ | 3400/10000 [01:57<03:24, 32.28it/s, loss=0.664, nar_loss=0.994] 35%|███▌ | 3500/10000 [02:01<03:20, 32.40it/s, loss=0.637, nar_loss=0.985] 36%|███▌ | 3600/10000 [02:05<03:19, 32.11it/s, loss=0.658, nar_loss=1] 37%|███▋ | 3700/10000 [02:08<03:15, 32.30it/s, loss=0.658, nar_loss=0.999] 38%|███▊ | 3799/10000 [02:12<03:10, 32.48it/s, loss=0.641, nar_loss=0.968] 39%|███▉ | 3899/10000 [02:15<03:08, 32.39it/s, loss=0.63, nar_loss=0.997] 40%|███▉ | 3999/10000 [02:19<03:05, 32.30it/s, loss=0.641, nar_loss=0.989] 41%|████ | 4099/10000 [02:23<02:59, 32.82it/s, loss=0.641, nar_loss=0.991] 42%|████▏ | 4199/10000 [02:26<02:59, 32.30it/s, loss=0.615, nar_loss=1.01] 43%|████▎ | 4299/10000 [02:30<02:56, 32.35it/s, loss=0.619, nar_loss=0.984] 44%|████▍ | 4399/10000 [02:33<02:53, 32.36it/s, loss=0.616, nar_loss=0.995] 45%|████▍ | 4499/10000 [02:37<02:49, 32.43it/s, loss=0.63, nar_loss=1.01] 46%|████▌ | 4599/10000 [02:41<02:46, 32.41it/s, loss=0.603, nar_loss=0.962] 47%|████▋ | 4699/10000 [02:44<02:43, 32.50it/s, loss=0.627, nar_loss=0.962] 48%|████▊ | 4799/10000 [02:48<02:42, 32.10it/s, loss=0.634, nar_loss=0.993] 49%|████▉ | 4899/10000 [02:52<02:37, 32.29it/s, loss=0.611, nar_loss=0.987] 50%|████▉ | 4999/10000 [02:55<02:34, 32.35it/s, loss=0.602, nar_loss=0.955] 51%|█████ | 5099/10000 [02:59<02:31, 32.33it/s, loss=0.631, nar_loss=0.99] 52%|█████▏ | 5199/10000 [03:03<02:26, 32.70it/s, loss=0.607, nar_loss=0.979] 53%|█████▎ | 5299/10000 [03:06<02:25, 32.37it/s, loss=0.611, nar_loss=0.968] 54%|█████▍ | 5399/10000 [03:10<02:20, 32.70it/s, loss=0.615, nar_loss=0.998] 55%|█████▍ | 5499/10000 [03:13<02:18, 32.48it/s, loss=0.595, nar_loss=0.955] 56%|█████▌ | 5599/10000 [03:17<02:15, 32.47it/s, loss=0.617, nar_loss=0.994] 57%|█████▋ | 5699/10000 [03:21<02:12, 32.49it/s, loss=0.59, nar_loss=0.954] 58%|█████▊ | 5799/10000 [03:24<02:09, 32.38it/s, loss=0.591, nar_loss=0.951] 59%|█████▉ | 5899/10000 [03:28<02:06, 32.46it/s, loss=0.631, nar_loss=0.999] 60%|█████▉ | 5999/10000 [03:32<02:03, 32.43it/s, loss=0.629, nar_loss=1] 61%|██████ | 6099/10000 [03:35<02:00, 32.47it/s, loss=0.597, nar_loss=0.975] 62%|██████▏ | 6199/10000 [03:39<01:57, 32.41it/s, loss=0.598, nar_loss=0.988] 63%|██████▎ | 6299/10000 [03:43<01:54, 32.39it/s, loss=0.593, nar_loss=0.967] 64%|██████▍ | 6399/10000 [03:47<01:50, 32.49it/s, loss=0.632, nar_loss=1] 65%|██████▍ | 6499/10000 [03:50<01:47, 32.51it/s, loss=0.613, nar_loss=0.981] 66%|██████▌ | 6599/10000 [03:54<01:44, 32.46it/s, loss=0.614, nar_loss=0.981] 67%|██████▋ | 6699/10000 [03:58<01:41, 32.65it/s, loss=0.611, nar_loss=0.983] 68%|██████▊ | 6799/10000 [04:01<01:38, 32.34it/s, loss=0.569, nar_loss=0.939] 69%|██████▉ | 6899/10000 [04:05<01:35, 32.42it/s, loss=0.605, nar_loss=0.998] 70%|██████▉ | 6999/10000 [04:09<01:33, 32.07it/s, loss=0.582, nar_loss=0.969] 71%|███████ | 7099/10000 [04:13<01:29, 32.28it/s, loss=0.61, nar_loss=0.978] 72%|███████▏ | 7199/10000 [04:16<01:25, 32.80it/s, loss=0.588, nar_loss=0.969] 73%|███████▎ | 7299/10000 [04:20<01:23, 32.51it/s, loss=0.6, nar_loss=0.985] 74%|███████▍ | 7399/10000 [04:24<01:18, 32.93it/s, loss=0.604, nar_loss=0.969] 75%|███████▍ | 7499/10000 [04:28<01:17, 32.43it/s, loss=0.618, nar_loss=0.975] 76%|███████▌ | 7599/10000 [04:31<01:12, 32.96it/s, loss=0.598, nar_loss=0.976] 77%|███████▋ | 7699/10000 [04:35<01:09, 32.98it/s, loss=0.589, nar_loss=0.998] 78%|███████▊ | 7799/10000 [04:39<01:07, 32.70it/s, loss=0.565, nar_loss=0.97] 79%|███████▉ | 7899/10000 [04:43<01:04, 32.58it/s, loss=0.592, nar_loss=0.991] 80%|███████▉ | 7999/10000 [04:46<01:01, 32.60it/s, loss=0.576, nar_loss=0.964] 81%|████████ | 8099/10000 [04:50<00:58, 32.75it/s, loss=0.583, nar_loss=0.97] 82%|████████▏ | 8199/10000 [04:54<00:55, 32.47it/s, loss=0.587, nar_loss=0.993] 83%|████████▎ | 8299/10000 [04:58<00:52, 32.49it/s, loss=0.574, nar_loss=0.966] 84%|████████▍ | 8399/10000 [05:02<00:49, 32.24it/s, loss=0.591, nar_loss=0.972] 85%|████████▍ | 8499/10000 [05:05<00:46, 32.23it/s, loss=0.564, nar_loss=0.952] 86%|████████▌ | 8599/10000 [05:09<00:43, 32.36it/s, loss=0.584, nar_loss=0.972] 87%|████████▋ | 8699/10000 [05:13<00:40, 32.22it/s, loss=0.579, nar_loss=0.948] 88%|████████▊ | 8799/10000 [05:17<00:36, 32.83it/s, loss=0.584, nar_loss=0.979] 89%|████████▉ | 8899/10000 [05:21<00:33, 32.51it/s, loss=0.606, nar_loss=0.981] 90%|████████▉ | 8999/10000 [05:25<00:30, 32.46it/s, loss=0.574, nar_loss=0.962] 91%|█████████ | 9099/10000 [05:29<00:27, 32.38it/s, loss=0.567, nar_loss=0.966] 92%|█████████▏| 9199/10000 [05:32<00:24, 32.41it/s, loss=0.569, nar_loss=0.977] 93%|█████████▎| 9299/10000 [05:36<00:21, 32.48it/s, loss=0.527, nar_loss=0.941] 94%|█████████▍| 9399/10000 [05:40<00:18, 32.39it/s, loss=0.566, nar_loss=0.973] 95%|█████████▍| 9499/10000 [05:44<00:15, 32.46it/s, loss=0.577, nar_loss=0.979] 96%|█████████▌| 9599/10000 [05:48<00:12, 32.46it/s, loss=0.555, nar_loss=0.967] 97%|█████████▋| 9699/10000 [05:52<00:09, 32.53it/s, loss=0.572, nar_loss=0.973] 98%|█████████▊| 9799/10000 [05:56<00:06, 32.59it/s, loss=0.557, nar_loss=0.987] 99%|█████████▉| 9899/10000 [06:00<00:03, 31.94it/s, loss=0.586, nar_loss=0.943]100%|██████████| 10000/10000 [06:04<00:00, 27.47it/s, loss=0.586, nar_loss=0.964]