import math import torch import torch.nn as nn class SinusoidalPosEmb(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim # Dummy buffer so .to(dtype) propagates to this module self.register_buffer('_dtype_buf', torch.zeros(1), persistent=False) def forward(self, x): device = x.device half_dim = self.dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, device=device) * -emb) emb = x.float()[:, None] * emb[None, :] emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb.to(self._dtype_buf.dtype)