打印推理权重精度信息
This commit is contained in:
@@ -334,6 +334,15 @@ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
|
||||
|
||||
@autocast
|
||||
def forward(self, image, no_dropout=False):
|
||||
if not hasattr(self, "_printed_autocast_info"):
|
||||
print(
|
||||
">>> 图像编码 autocast:",
|
||||
torch.is_autocast_enabled(),
|
||||
torch.get_autocast_gpu_dtype(),
|
||||
"输入dtype:",
|
||||
image.dtype,
|
||||
)
|
||||
self._printed_autocast_info = True
|
||||
z = self.encode_with_vision_transformer(image)
|
||||
if self.ucg_rate > 0. and not no_dropout:
|
||||
z = torch.bernoulli(
|
||||
@@ -407,6 +416,15 @@ class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
|
||||
|
||||
def forward(self, image, no_dropout=False):
|
||||
## image: b c h w
|
||||
if not hasattr(self, "_printed_autocast_info"):
|
||||
print(
|
||||
">>> 图像编码V2 autocast:",
|
||||
torch.is_autocast_enabled(),
|
||||
torch.get_autocast_gpu_dtype(),
|
||||
"输入dtype:",
|
||||
image.dtype,
|
||||
)
|
||||
self._printed_autocast_info = True
|
||||
z = self.encode_with_vision_transformer(image)
|
||||
return z
|
||||
|
||||
|
||||
Reference in New Issue
Block a user