减少循环里的张量形状重排和临时对象
This commit is contained in:
@@ -236,9 +236,13 @@ class MLP(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: (B*T, D)
|
||||
x: (..., D)
|
||||
"""
|
||||
return self.net(x)
|
||||
if x.ndim <= 2:
|
||||
return self.net(x)
|
||||
|
||||
output = self.net(x.reshape(-1, x.size(-1)))
|
||||
return output.reshape(*x.shape[:-1], output.size(-1))
|
||||
|
||||
|
||||
class ARPredictor(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user