打印推理权重精度信息

This commit is contained in:
2026-01-18 11:19:10 +08:00
parent c86c2be5ff
commit 7b499284bf
9 changed files with 256 additions and 143 deletions

View File

@@ -334,6 +334,15 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
@autocast
def forward(self, image, no_dropout=False):
if not hasattr(self, "_printed_autocast_info"):
print(
">>> 图像编码 autocast:",
torch.is_autocast_enabled(),
torch.get_autocast_gpu_dtype(),
"输入dtype:",
image.dtype,
)
self._printed_autocast_info = True
z = self.encode_with_vision_transformer(image)
if self.ucg_rate > 0. and not no_dropout:
z = torch.bernoulli(
@@ -407,6 +416,15 @@ class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
def forward(self, image, no_dropout=False):
## image: b c h w
if not hasattr(self, "_printed_autocast_info"):
print(
">>> 图像编码V2 autocast:",
torch.is_autocast_enabled(),
torch.get_autocast_gpu_dtype(),
"输入dtype:",
image.dtype,
)
self._printed_autocast_info = True
z = self.encode_with_vision_transformer(image)
return z