diff --git a/tabgpt/model.py b/tabgpt/model.py index 4f22141..c3605ee 100644 --- a/tabgpt/model.py +++ b/tabgpt/model.py @@ -29,9 +29,7 @@ class NewGELU(nn.Module): class CausalSelfAttention(nn.Module): """ - A vanilla multi-head masked self-attention layer with a projection at the end. - It is possible to use torch.nn.MultiheadAttention here but I am including an - explicit implementation here to show that there is nothing too scary here. + A vanilla multi-head self-attention layer with a projection at the end. """ def __init__(self, config): @@ -44,9 +42,6 @@ class CausalSelfAttention(nn.Module): # regularization self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) - # causal mask to ensure that attention is only applied to the left in the input sequence - self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) - .view(1, 1, config.block_size, config.block_size)) self.n_head = config.n_head self.n_embd = config.n_embd @@ -59,9 +54,8 @@ class CausalSelfAttention(nn.Module): q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + # self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) @@ -243,10 +237,12 @@ class tabGPT(nn.Module): # create the pytorch optimizer object optim_groups = [ - {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": 0.0}, {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, ] optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) + # optimizer = torch.optim.RMSprop(optim_groups, lr=train_config.learning_rate) + # optimizer = torch.optim.SGD(optim_groups, lr=train_config.learning_rate) return optimizer def forward(self, x, targets=None):