第一次完整测例跑完

This commit is contained in:
2026-01-18 00:30:10 +08:00
parent ca15cc593b
commit 25c6fc04db
180 changed files with 29305 additions and 0 deletions

View File

@@ -0,0 +1,244 @@
"""
base_vision.py
Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility
functions, and initialization logic.
We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision
Transformer model for feature extraction.
"""
import timm
import torch
import torch.nn as nn
import torchvision.transforms.functional as TVF
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
from PIL.Image import Image
from timm.models.vision_transformer import Block, VisionTransformer
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
from torchvision.transforms import Compose, Resize
# === Utility Functions for Monkey-Patching ===
def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
def wrapper(*args: Any, **kwargs: Any) -> Any:
result = fn(*args, **kwargs)
return result[0] if isinstance(result, tuple) else result
return wrapper
# === Interface for an Image Transform ===
class ImageTransform(Protocol):
def __call__(
self, img: Image,
**kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
...
# === Custom Torchvision Image Transforms ===
@dataclass
class LetterboxPad:
padding_fill_value: Tuple[int, int, int]
def __call__(self, image: Image) -> Image:
"""Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
(w, h), max_wh = image.size, max(image.size)
horizontal_pad, vertical_pad = int((max_wh - w) / 2), int(
(max_wh - h) / 2)
padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
return TVF.pad(image,
padding,
fill=self.padding_fill_value,
padding_mode="constant")
# === Abstract Base Class for arbitrary Vision Backbones ===
class VisionBackbone(nn.Module, ABC):
def __init__(self,
vision_backbone_id: str,
image_resize_strategy: str,
default_image_size: int = 224) -> None:
super().__init__()
self.identifier: str = vision_backbone_id
self.image_resize_strategy: str = image_resize_strategy
self.default_image_size: int = default_image_size
# Instance attributes for a Vision Backbone
self.featurizer: nn.Module = None
self.image_transform: ImageTransform = None
def get_image_transform(self) -> ImageTransform:
return self.image_transform
@abstractmethod
def get_fsdp_wrapping_policy(self) -> Callable:
...
@abstractmethod
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""Run a forward pass through the featurizer given a set of processed images, returning patch/grid features."""
raise NotImplementedError
@property
@abstractmethod
def default_image_resolution(self) -> Tuple[int, int, int]:
...
@property
@abstractmethod
def embed_dim(self) -> int:
...
@property
@abstractmethod
def num_patches(self) -> int:
...
@property
@abstractmethod
def half_precision_dtype(self) -> torch.dtype:
...
# === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones ===
class TimmViTBackbone(VisionBackbone, ABC):
def __init__(
self,
vision_backbone_id: str,
timm_path_or_url: str,
image_resize_strategy: str,
default_image_size: int = 224,
override_act_layer: Optional[str] = None,
) -> None:
super().__init__(vision_backbone_id,
image_resize_strategy,
default_image_size=default_image_size)
self.timm_path_or_url = timm_path_or_url
self.override_act_layer = override_act_layer
self.dtype = torch.bfloat16
# Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary
if self.override_act_layer is None:
self.featurizer: VisionTransformer = timm.create_model(
self.timm_path_or_url,
pretrained=True,
num_classes=0,
img_size=self.default_image_size)
else:
self.featurizer: VisionTransformer = timm.create_model(
self.timm_path_or_url,
pretrained=True,
num_classes=0,
img_size=self.default_image_size,
act_layer=self.override_act_layer,
)
self.featurizer.eval()
# Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
self.featurizer.forward = unpack_tuple(
partial(self.featurizer.get_intermediate_layers,
n={len(self.featurizer.blocks) - 2}))
# Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!)
assert isinstance(self.featurizer, VisionTransformer), (
"Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, "
"file an issue or implement the requisite logic (see `prismatic/models/backbones/vision/base_vision.py`)!"
)
# Get Config =>> Note :: Override default image size to ensure correct image transform
self.data_cfg = timm.data.resolve_model_data_config(self.featurizer)
self.data_cfg["input_size"] = (3, self.default_image_size,
self.default_image_size)
# Initialize Default Image Transform --> Modified by `self.image_resize_strategy`
default_image_transform = timm.data.create_transform(**self.data_cfg,
is_training=False)
# Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)!
if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url:
assert isinstance(default_image_transform,
Compose), "Unexpected `default_image_transform`!"
assert isinstance(default_image_transform.transforms[0], Resize)
default_image_transform = Compose([
Resize(self.default_image_size,
interpolation=default_image_transform.transforms[0].
interpolation),
*default_image_transform.transforms[1:],
])
# Switch on `image_resize_strategy`
if self.image_resize_strategy == "resize-naive":
assert isinstance(default_image_transform,
Compose), "Unexpected `default_image_transform`!"
assert isinstance(default_image_transform.transforms[0], Resize)
target_size = (self.default_image_size, self.default_image_size)
self.image_transform = Compose([
Resize(target_size,
interpolation=default_image_transform.transforms[0].
interpolation),
*default_image_transform.transforms[1:],
])
elif self.image_resize_strategy == "resize-crop":
self.image_transform = default_image_transform
elif self.image_resize_strategy == "letterbox":
assert isinstance(default_image_transform,
Compose), "Unexpected `default_image_transform`!"
assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!"
# Compute Padding Fill Value (rescaled normalization mean if applicable)
fill = tuple([int(x * 255) for x in self.data_cfg["mean"]])
# Build New Transform
self.image_transform = Compose(
[LetterboxPad(fill), *default_image_transform.transforms])
else:
raise ValueError(
f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!"
)
def get_fsdp_wrapping_policy(self) -> Callable:
"""Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer."""
vit_wrap_policy = partial(_module_wrap_policy,
module_classes={VisionTransformer})
transformer_block_policy = partial(transformer_auto_wrap_policy,
transformer_layer_cls={Block})
return partial(_or_policy,
policies=[vit_wrap_policy, transformer_block_policy])
def forward(
self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]]
) -> torch.Tensor:
"""Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features."""
return self.featurizer(pixel_values)
@property
def default_image_resolution(self) -> Tuple[int, int, int]:
return self.data_cfg["input_size"]
@property
def embed_dim(self) -> int:
return self.featurizer.embed_dim
@property
def num_patches(self) -> int:
return self.featurizer.patch_embed.num_patches
@property
def half_precision_dtype(self) -> torch.dtype:
return self.dtype

View File

@@ -0,0 +1,273 @@
"""
dinosiglip_vit.py
Vision backbone that returns concatenated features from both DINOv2 and SigLIP.
"""
import timm
import torch
import torchvision.transforms as transforms
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, Tuple
from PIL import Image
from timm.models.vision_transformer import Block, VisionTransformer
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
from torchvision.transforms import Compose, Resize, Normalize
from unifolm_wma.modules.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple
from unifolm_wma.utils.nn_utils import FusedMLPProjector, LinearProjector, MLPProjector
# Registry =>> Supported DinoSigLIP Pairs (as TIMM identifiers)
DINOSigLIP_VISION_BACKBONES = {
"dinosiglip-vit-so-224px": {
"dino": "vit_large_patch14_reg4_dinov2.lvd142m",
"siglip": "vit_so400m_patch14_siglip_224",
},
"dinosiglip-vit-so-384px": {
"dino": "vit_large_patch14_reg4_dinov2.lvd142m",
"siglip": "vit_so400m_patch14_siglip_384",
},
}
@dataclass
class DinoSigLIPImageTransform:
dino_image_transform: ImageTransform
siglip_image_transform: ImageTransform
is_prismatic: bool = True
def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]:
return {
"dino": self.dino_image_transform(img, **kwargs),
"siglip": self.siglip_image_transform(img, **kwargs)
}
class DinoSigLIPViTBackbone(VisionBackbone):
def __init__(self,
vision_backbone_id: str,
image_resize_strategy: str,
arch_specifier: str,
output_dim: int,
pretrained_checkpoint=None,
freeze=True,
default_image_size: int = 224) -> None:
super().__init__(vision_backbone_id,
image_resize_strategy,
default_image_size=default_image_size)
self.dino_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[
vision_backbone_id]["dino"]
self.siglip_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[
vision_backbone_id]["siglip"]
# Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary
self.dino_featurizer: VisionTransformer = timm.create_model(
self.dino_timm_path_or_url,
pretrained=True,
num_classes=0,
img_size=self.default_image_size)
if pretrained_checkpoint:
ckpt = pretrained_checkpoint + '/openvla_dino.pt'
self.dino_featurizer.load_state_dict(
torch.load(ckpt, weights_only=True))
print('>>> load dino weights')
if freeze:
self.dino_featurizer.eval()
for param in self.dino_featurizer.parameters():
param.requires_grad = False
self.siglip_featurizer: VisionTransformer = timm.create_model(
self.siglip_timm_path_or_url,
pretrained=True,
num_classes=0,
img_size=self.default_image_size)
if pretrained_checkpoint:
ckpt = pretrained_checkpoint + '/openvla_siglip.pt'
self.siglip_featurizer.load_state_dict(
torch.load(ckpt, weights_only=True))
print('>>> load siglip weights')
if freeze:
self.siglip_featurizer.eval()
for param in self.siglip_featurizer.parameters():
param.requires_grad = False
# Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
self.dino_featurizer.forward = unpack_tuple(
partial(self.dino_featurizer.get_intermediate_layers,
n={len(self.dino_featurizer.blocks) - 2}))
self.siglip_featurizer.forward = unpack_tuple(
partial(self.siglip_featurizer.get_intermediate_layers,
n={len(self.siglip_featurizer.blocks) - 2}))
# Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models
self.dino_data_cfg = timm.data.resolve_model_data_config(
self.dino_featurizer)
self.dino_data_cfg["input_size"] = (3, self.default_image_size,
self.default_image_size)
self.siglip_data_cfg = timm.data.resolve_model_data_config(
self.siglip_featurizer)
self.siglip_data_cfg["input_size"] = (3, self.default_image_size,
self.default_image_size)
# Initialize *both* Transforms
self.default_dino_transform = timm.data.create_transform(
**self.dino_data_cfg, is_training=False)
self.default_siglip_transform = timm.data.create_transform(
**self.siglip_data_cfg, is_training=False)
# Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!!
assert isinstance(self.default_siglip_transform,
Compose), "Unexpected `default_image_transform`!"
assert isinstance(self.default_siglip_transform.transforms[0], Resize)
self.default_siglip_transform = Compose([
Resize(self.default_image_size,
interpolation=self.default_siglip_transform.transforms[0].
interpolation),
*self.default_siglip_transform.transforms[1:],
])
if self.image_resize_strategy == "resize-naive":
assert isinstance(
self.default_dino_transform,
Compose), "Unexpected `default_dino_image_transform`!"
assert isinstance(
self.default_siglip_transform,
Compose), "Unexpected `default_siglip_image_transform`!"
assert isinstance(self.default_dino_transform.transforms[0],
Resize)
assert isinstance(self.default_siglip_transform.transforms[0],
Resize)
self.target_size = (self.default_image_size,
self.default_image_size)
dino_transform = Compose([
Resize(self.target_size,
interpolation=self.default_dino_transform.transforms[0].
interpolation),
*self.default_dino_transform.transforms[1:],
])
siglip_transform = Compose([
Resize(self.target_size,
interpolation=self.default_siglip_transform.
transforms[0].interpolation),
*self.default_siglip_transform.transforms[1:],
])
self.image_transform = DinoSigLIPImageTransform(
dino_transform, siglip_transform)
elif self.image_resize_strategy == "resize-crop":
self.image_transform = DinoSigLIPImageTransform(
self.default_dino_transform, self.default_siglip_transform)
elif self.image_resize_strategy == "letterbox":
assert isinstance(self.default_dino_transform,
Compose), "Unexpected `default_dino_transform`!"
assert isinstance(
self.default_siglip_transform,
Compose), "Unexpected `default_siglip_transform`!"
assert ("mean" in self.dino_data_cfg
and "mean" in self.siglip_data_cfg
), "DinoSigLIP `data_cfg` missing `mean`!"
# Compute Padding Fill Value(s) (rescaled normalization mean if applicable)
dino_fill = tuple(
[int(x * 255) for x in self.dino_data_cfg["mean"]])
siglip_fill = tuple(
[int(x * 255) for x in self.siglip_data_cfg["mean"]])
# Build New Transform
self.image_transform = DinoSigLIPImageTransform(
Compose([
LetterboxPad(dino_fill),
*self.default_dino_transform.transforms
]),
Compose([
LetterboxPad(siglip_fill),
*self.default_siglip_transform.transforms
]),
)
else:
raise ValueError(
f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!"
)
self.arch_specifier = arch_specifier
if arch_specifier == "linear":
self.projector = LinearProjector(self.embed_dim, output_dim)
elif arch_specifier.endswith("fused-gelu-mlp"):
self.projector = FusedMLPProjector(self.embed_dim, output_dim)
elif arch_specifier.endswith("gelu-mlp"):
self.projector = MLPProjector(self.embed_dim, output_dim)
else:
raise ValueError(
f"PrismaticVLM with `{arch_specifier = }` is not supported!")
self.on_gpu = False
def get_fsdp_wrapping_policy(self) -> Callable:
"""Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers."""
vit_wrap_policy = partial(_module_wrap_policy,
module_classes={VisionTransformer})
transformer_block_policy = partial(transformer_auto_wrap_policy,
transformer_layer_cls={Block})
return partial(_or_policy,
policies=[vit_wrap_policy, transformer_block_policy])
def forward(self, img) -> torch.Tensor:
img = torch.clamp(img.float(), -1., 1.)
img = (img + 1.0) / 2.0
img = img * 255
resize = transforms.Resize(min(self.target_size),
interpolation=self.default_dino_transform.
transforms[0].interpolation,
max_size=None,
antialias=True)
center_crop = transforms.CenterCrop(self.target_size)
img = center_crop(resize(img))
dino_normalizer = Normalize(mean=torch.tensor([0.4850, 0.4560,
0.4060]),
std=torch.tensor([0.2290, 0.2240, 0.2250]))
siglip_normalizer = Normalize(
mean=torch.tensor([0.5000, 0.5000, 0.5000]),
std=torch.tensor([0.5000, 0.5000, 0.5000]))
pixel_values = {
'dino': dino_normalizer(img),
'siglip': siglip_normalizer(img)
}
if self.on_gpu:
pixel_values = {k: v.cuda() for k, v in pixel_values.items()}
elif next(self.dino_featurizer.parameters()).device.type != 'cpu':
self.on_gpu = True
"""Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches."""
dino_patches = self.dino_featurizer(pixel_values["dino"])
siglip_patches = self.siglip_featurizer(pixel_values["siglip"])
return self.projector(torch.cat([dino_patches, siglip_patches], dim=2))
@property
def default_image_resolution(self) -> Tuple[int, int, int]:
return self.dino_data_cfg["input_size"]
@property
def embed_dim(self) -> int:
return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim
@property
def num_patches(self) -> int:
assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches
return self.dino_featurizer.patch_embed.num_patches
@property
def half_precision_dtype(self) -> torch.dtype:
return torch.bfloat16