mirror of
https://github.com/karpathy/minGPT
synced 2024-11-15 19:10:39 +01:00
drop the causal mask to enable full permutation invariance of columns
This commit is contained in:
parent
b189b465d3
commit
9ad506d53f
@ -29,9 +29,7 @@ class NewGELU(nn.Module):
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
"""
|
||||
A vanilla multi-head masked self-attention layer with a projection at the end.
|
||||
It is possible to use torch.nn.MultiheadAttention here but I am including an
|
||||
explicit implementation here to show that there is nothing too scary here.
|
||||
A vanilla multi-head self-attention layer with a projection at the end.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
@ -44,9 +42,6 @@ class CausalSelfAttention(nn.Module):
|
||||
# regularization
|
||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||
# causal mask to ensure that attention is only applied to the left in the input sequence
|
||||
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
||||
.view(1, 1, config.block_size, config.block_size))
|
||||
self.n_head = config.n_head
|
||||
self.n_embd = config.n_embd
|
||||
|
||||
@ -59,9 +54,8 @@ class CausalSelfAttention(nn.Module):
|
||||
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
|
||||
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||
# self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
|
||||
att = F.softmax(att, dim=-1)
|
||||
att = self.attn_dropout(att)
|
||||
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||
@ -243,10 +237,12 @@ class tabGPT(nn.Module):
|
||||
|
||||
# create the pytorch optimizer object
|
||||
optim_groups = [
|
||||
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
|
||||
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.0},
|
||||
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
|
||||
# optimizer = torch.optim.RMSprop(optim_groups, lr=train_config.learning_rate)
|
||||
# optimizer = torch.optim.SGD(optim_groups, lr=train_config.learning_rate)
|
||||
return optimizer
|
||||
|
||||
def forward(self, x, targets=None):
|
||||
|
Loading…
Reference in New Issue
Block a user