1
0
mirror of https://github.com/karpathy/minGPT synced 2024-11-15 19:10:39 +01:00

drop the causal mask to enable full permutation invariance of columns

This commit is contained in:
Felix Wick 2024-09-12 21:22:16 +02:00
parent b189b465d3
commit 9ad506d53f

@ -29,9 +29,7 @@ class NewGELU(nn.Module):
class CausalSelfAttention(nn.Module):
"""
A vanilla multi-head masked self-attention layer with a projection at the end.
It is possible to use torch.nn.MultiheadAttention here but I am including an
explicit implementation here to show that there is nothing too scary here.
A vanilla multi-head self-attention layer with a projection at the end.
"""
def __init__(self, config):
@ -44,9 +42,6 @@ class CausalSelfAttention(nn.Module):
# regularization
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
self.n_head = config.n_head
self.n_embd = config.n_embd
@ -59,9 +54,8 @@ class CausalSelfAttention(nn.Module):
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
# self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
@ -243,10 +237,12 @@ class tabGPT(nn.Module):
# create the pytorch optimizer object
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.0},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
# optimizer = torch.optim.RMSprop(optim_groups, lr=train_config.learning_rate)
# optimizer = torch.optim.SGD(optim_groups, lr=train_config.learning_rate)
return optimizer
def forward(self, x, targets=None):