减少循环里的张量形状重排和临时对象

This commit is contained in:
qihuanye
2026-04-09 10:14:58 +00:00
parent 3a94829eac
commit 006102d00c
2 changed files with 14 additions and 12 deletions

View File

@@ -236,9 +236,13 @@ class MLP(nn.Module):
def forward(self, x):
"""
x: (B*T, D)
x: (..., D)
"""
return self.net(x)
if x.ndim <= 2:
return self.net(x)
output = self.net(x.reshape(-1, x.size(-1)))
return output.reshape(*x.shape[:-1], output.size(-1))
class ARPredictor(nn.Module):