第一次完整测例跑完

This commit is contained in:
2026-01-18 00:30:10 +08:00
parent ca15cc593b
commit 25c6fc04db
180 changed files with 29305 additions and 0 deletions

View File

@@ -0,0 +1,26 @@
from abc import abstractmethod
from torch.utils.data import IterableDataset
class Txt2ImgIterableBaseDataset(IterableDataset):
'''
Define an interface to make the IterableDatasets for text2img data chainable
'''
def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__()
self.num_records = num_records
self.valid_ids = valid_ids
self.sample_ids = valid_ids
self.size = size
print(
f'{self.__class__.__name__} dataset contains {self.__len__()} examples.'
)
def __len__(self):
return self.num_records
@abstractmethod
def __iter__(self):
pass

View File

@@ -0,0 +1,230 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor, nn
from typing import Dict, List
def create_stats_buffers(
shapes: Dict[str, List[int]],
modes: Dict[str, str],
stats: Dict[str, Dict[str, Tensor]] = None,
) -> Dict[str, Dict[str, nn.ParameterDict]]:
"""
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
statistics.
Args: (see Normalize and Unnormalize)
Returns:
Dict: A Dictionary where keys are modalities and values are `nn.ParameterDict` containing
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
"""
stats_buffers = {}
for key, mode in modes.items():
assert mode in ["mean_std", "min_max"]
shape = tuple(shapes[key])
if "image" in key:
# sanity checks
assert len(
shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
c, h, w = shape
assert c < h and c < w, f"{key} is not channel first ({shape=})"
# override image shape to be invariant to height and width
shape = (c, 1, 1)
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
# downstream by `stats` or `policy.load_state_Dict`, as expected. During forward,
# we assert they are not infinity anymore.
if "action" in key:
target_key = "action"
elif "state" in key:
target_key = 'observation.state'
else:
target_key = key
buffer = {}
if mode == "mean_std":
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
std = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict({
"mean":
nn.Parameter(mean, requires_grad=False),
"std":
nn.Parameter(std, requires_grad=False),
})
elif mode == "min_max":
min = torch.ones(shape, dtype=torch.float32) * torch.inf
max = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict({
"min":
nn.Parameter(min, requires_grad=False),
"max":
nn.Parameter(max, requires_grad=False),
})
if stats is not None:
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
if mode == "mean_std":
buffer["mean"].data = stats[target_key]["mean"].clone()
buffer["std"].data = stats[target_key]["std"].clone()
elif mode == "min_max":
buffer["min"].data = stats[target_key]["min"].clone()
buffer["max"].data = stats[target_key]["max"].clone()
stats_buffers[key] = buffer
return stats_buffers
def _no_stats_error_str(name: str) -> str:
return (
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
"pretrained model.")
class Normalize(nn.Module):
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
def __init__(
self,
shapes: Dict[str, List[int]],
modes: Dict[str, str],
stats: Dict[str, Dict[str, Tensor]] = None,
):
"""
Args:
shapes (Dict): A Dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (Dict): A Dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (Dict, optional): A Dictionary where keys are output modalities (e.g. "observation.image")
and values are Dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_Dict(state_Dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_Dict.
"""
super().__init__()
self.shapes = shapes
self.modes = modes
self.stats = stats
stats_buffers = create_stats_buffers(shapes, modes, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@torch.no_grad()
def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
for key, mode in self.modes.items():
if key not in batch:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std":
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif mode == "min_max":
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min + 1e-8)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:
raise ValueError(mode)
return batch
class Unnormalize(nn.Module):
"""
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
original range used by the environment.
"""
def __init__(
self,
shapes: Dict[str, List[int]],
modes: Dict[str, str],
stats: Dict[str, Dict[str, Tensor]] = None,
):
"""
Args:
shapes (Dict): A Dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (Dict): A Dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (Dict, optional): A Dictionary where keys are output modalities (e.g. "observation.image")
and values are Dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_Dict(state_Dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_Dict.
"""
super().__init__()
self.shapes = shapes
self.modes = modes
self.stats = stats
stats_buffers = create_stats_buffers(shapes, modes, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@torch.no_grad()
def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
for key, mode in self.modes.items():
if key not in batch:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std":
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean
elif mode == "min_max":
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:
raise ValueError(mode)
return batch

View File

@@ -0,0 +1,60 @@
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from typing import Dict, List, Union
from pathlib import Path
from safetensors.torch import load_file
def unflatten_dict(d, sep="/"):
outdict = {}
for key, value in d.items():
parts = key.split(sep)
d = outdict
for part in parts[:-1]:
if part not in d:
d[part] = {}
d = d[part]
d[parts[-1]] = value
return outdict
def load_episode_data_index(repo_id, version, root) -> Dict[str, torch.Tensor]:
"""episode_data_index contains the range of indices for each episode
Example:
```python
from_id = episode_data_index["from"][episode_id].item()
to_id = episode_data_index["to"][episode_id].item()
episode_frames = [dataset[i] for i in range(from_id, to_id)]
```
"""
if root is not None:
path = Path(
root) / repo_id / "meta_data" / "episode_data_index.safetensors"
else:
path = hf_hub_download(repo_id,
"meta_data/episode_data_index.safetensors",
repo_type="dataset",
revision=version)
return load_file(path)
def load_stats(repo_id, version, root) -> Dict[str, Dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
Example:
```python
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
else:
path = hf_hub_download(repo_id,
"meta_data/stats.safetensors",
repo_type="dataset",
revision=version)
stats = load_file(path)
return unflatten_dict(stats)

View File

@@ -0,0 +1,408 @@
import torch
import os
import random
import pandas as pd
import h5py
from decord import VideoReader, cpu
from torch.utils.data import Dataset
from torchvision import transforms
from pathlib import Path
from unifolm_wma.data.utils import load_stats
from unifolm_wma.data.normolize import Normalize, Unnormalize
class WMAData(Dataset):
"""
Assuming the following dataset structure:
dataset_dir/
├── videos
│ ├──dataset_name
│ │ ├──camera_view_dir
│ │ ├── 0.mp4
│ │ ├── 1.mp4
│ │ └── ...
│ └── ...
├── transitions
│ ├── dataset_name
│ ├── meta_data
│ ├── 0.h5
│ ├── 1.h5
│ └── ...
└── dataset_name.csv
"""
def __init__(
self,
meta_path,
data_dir,
subsample=None,
video_length=16,
resolution=[256, 512],
frame_stride=1,
frame_stride_min=1,
spatial_transform=None,
crop_resolution=None,
fps_max=None,
load_raw_resolution=False,
fixed_fps=None,
random_fs=False,
cond_robot_label_prob=0.0,
transition_dir=None,
dataset_name=None,
normalization_mode='min_max',
individual_normalization=False,
n_obs_steps=1,
max_action_dim=7,
max_state_dim=7,
):
self.meta_path = meta_path
self.data_dir = data_dir
self.subsample = subsample
self.video_length = video_length
self.resolution = [resolution, resolution] if isinstance(
resolution, int) else resolution
self.fps_max = fps_max
self.frame_stride = frame_stride
self.frame_stride_min = frame_stride_min
self.fixed_fps = fixed_fps
self.load_raw_resolution = load_raw_resolution
self.random_fs = random_fs
self.cond_robot_label_prob = cond_robot_label_prob
self.transition_dir = transition_dir
self.dataset_name = dataset_name
self.max_action_dim = max_action_dim
self.max_state_dim = max_state_dim
self._load_metadata()
if spatial_transform is not None:
if spatial_transform == "random_crop":
self.spatial_transform = transforms.RandomCrop(crop_resolution)
elif spatial_transform == "center_crop":
self.spatial_transform = transforms.Compose([
transforms.CenterCrop(resolution),
])
elif spatial_transform == "resize_center_crop":
self.spatial_transform = transforms.Compose([
transforms.Resize(min(self.resolution)),
transforms.CenterCrop(self.resolution),
])
elif spatial_transform == "resize":
self.spatial_transform = transforms.Resize(self.resolution)
else:
raise NotImplementedError
else:
self.spatial_transform = None
self.normalization_mode = normalization_mode
self.individual_normalization = individual_normalization
self.n_obs_steps = n_obs_steps
self._load_stats()
if individual_normalization:
self._init_normalizers()
def _load_metadata(self):
metadata = pd.read_csv(self.meta_path, dtype=str)
if self.subsample is not None:
metadata = metadata.sample(self.subsample, random_state=0)
self.metadata = metadata
# drop the rows contain NaN values
self.metadata.dropna(inplace=True)
print(
f">>> {metadata['data_dir'].iloc[0]}: {len(metadata)} data samples loaded."
)
def _load_stats(self):
self.stats = load_stats(self.dataset_name, None, self.transition_dir)
print(f">>> {self.metadata['data_dir'].iloc[0]}: data stats loaded.")
def _init_normalizers(self):
shape_dict = {
'pre_action': [self.stats['action']['max'].shape[-1]],
'action': [self.stats['action']['max'].shape[-1]],
'observation.state':
[self.stats['observation.state']['max'].shape[-1]],
'next.state': [self.stats['observation.state']['max'].shape[-1]]
}
normalization_mode_dict = {
'pre_action': self.normalization_mode,
'action': self.normalization_mode,
'observation.state': self.normalization_mode,
'next.state': self.normalization_mode
}
self.normalizer = Normalize(shape_dict, normalization_mode_dict,
self.stats)
self.unnormalizer = Unnormalize(shape_dict, normalization_mode_dict,
self.stats)
print(
f">>> {self.metadata['data_dir'].iloc[0]}: normalizer initiated.")
def _get_video_path(self, sample):
rel_video_fp = os.path.join(sample['data_dir'],
str(sample['videoid']) + '.mp4')
full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp)
return full_video_fp
def _get_transition_path(self, sample):
data_dir = Path(sample['data_dir'])
if self.dataset_name == data_dir.name:
rel_transition_fp = os.path.join(str(data_dir),
str(sample['videoid']) + '.h5')
else:
rel_transition_fp = os.path.join(str(data_dir.parent),
str(sample['videoid']) + '.h5')
full_transition_fp = os.path.join(self.data_dir, 'transitions',
rel_transition_fp)
return full_transition_fp
def get_uni_vec(self, action_state_dict, action_type, state_type):
if 'pre_action' in action_state_dict:
action_state_dict['pre_action'], _ = self._map_to_uni_action(
action_state_dict['pre_action'], action_type)
if 'action' in action_state_dict:
action_state_dict['action'], action_state_dict[
'action_mask'] = self._map_to_uni_action(
action_state_dict['action'], action_type)
if 'observation.state' in action_state_dict:
action_state_dict['observation.state'], _ = self._map_to_uni_state(
action_state_dict['observation.state'], state_type)
if 'next.state' in action_state_dict:
action_state_dict['next.state'], action_state_dict[
'state_mask'] = self._map_to_uni_state(
action_state_dict['next.state'], state_type)
return action_state_dict
def _map_to_uni_action(self, action, action_type):
action_dim = action.shape[-1]
uni_action = torch.nn.functional.pad(
action, (0, self.max_action_dim - action_dim),
mode='constant',
value=0)
uni_action_mask = torch.zeros_like(uni_action)
uni_action_mask[:, :action_dim] = 1
return uni_action, uni_action_mask
def _map_to_uni_state(self, state, state_type):
state_dim = state.shape[-1]
uni_state = torch.nn.functional.pad(
state, (0, self.max_state_dim - state_dim),
mode='constant',
value=0)
uni_state_mask = torch.zeros_like(uni_state)
uni_state_mask[:, :state_dim] = 1
return uni_state, uni_state_mask
def __getitem__(self, index):
if self.random_fs:
frame_stride = random.randint(self.frame_stride_min,
self.frame_stride)
else:
frame_stride = self.frame_stride
# Get frames until success
while True:
index = index % len(self.metadata)
sample = self.metadata.iloc[index]
video_path = self._get_video_path(sample)
instruction = sample['instruction']
if self.cond_robot_label_prob > 0.0 and random.random(
) < self.cond_robot_label_prob:
if sample['embodiment'] != 'x':
instruction = sample['embodiment'] + ' [SEP] ' + sample[
'instruction']
try:
if self.load_raw_resolution:
video_reader = VideoReader(video_path, ctx=cpu(0))
else:
video_reader = VideoReader(video_path,
ctx=cpu(0),
width=530,
height=300)
if len(video_reader) < self.video_length:
print(
f">>> Video length ({len(video_reader)}) is smaller than target length({self.video_length})"
)
index += 1
continue
else:
pass
except:
index += 1
print(f">>> Error: load video failed! path = {video_path}")
continue
fps_ori = video_reader.get_avg_fps()
if self.fixed_fps is not None:
frame_stride = int(frame_stride *
(1.0 * fps_ori / self.fixed_fps))
# To avoid extreme cases when fixed_fps is used
frame_stride = max(frame_stride, 1)
# Get valid range (adapting case by case)
required_frame_num = frame_stride * (self.video_length - 1) + 1
frame_num = len(video_reader)
if frame_num < required_frame_num:
# Drop extra samples if fixed fps is required
if self.fixed_fps is not None and frame_num < required_frame_num * 0.5:
index += 1
continue
else:
frame_stride = frame_num // self.video_length
required_frame_num = frame_stride * (self.video_length -
1) + 1
# Select a random clip
random_range = frame_num - required_frame_num
start_idx = random.randint(
0, random_range -
frame_stride) if random_range - frame_stride > 0 else 0
# Calculate frame indices
frame_indices = [
start_idx + frame_stride * i for i in range(self.video_length)
]
try:
next_frame_indices = [
idx + frame_stride for idx in frame_indices
]
frames = video_reader.get_batch(next_frame_indices)
break
except:
print(
f">>> Error: Get frames failed! path = {video_path}; [max_ind vs frame_total:{max(frame_indices)} / {frame_num}]"
)
index += 1
continue
# Load transition data
transition_path = self._get_transition_path(sample)
with h5py.File(transition_path, 'r') as h5f:
transition_dict = {}
for key in h5f.keys():
transition_dict[key] = torch.tensor(h5f[key][()])
for key in h5f.attrs.keys():
transition_dict[key] = h5f.attrs[key]
# Load observable states
if start_idx < self.n_obs_steps - 1:
state_indices = list(range(0, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
num_padding = self.n_obs_steps - 1 - start_idx
first_slice = states[0:1, :] # (t, d)
padding = first_slice.repeat(num_padding, 1)
states = torch.cat((padding, states), dim=0)
else:
state_indices = list(
range(start_idx - self.n_obs_steps + 1, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
assert states.shape[
0] == self.n_obs_steps, '>>> Do not have enough previous states as observation.'
# Load observable actions
if start_idx < self.n_obs_steps:
pre_action_indices = list(range(0, start_idx))
pre_actions = transition_dict['action'][pre_action_indices, :]
num_padding = self.n_obs_steps - start_idx
first_slice = torch.zeros_like(transition_dict['action'][:1, :])
padding = first_slice.repeat(num_padding, 1)
pre_actions = torch.cat((padding, pre_actions), dim=0)
else:
pre_action_indices = list(
range(start_idx - self.n_obs_steps, start_idx))
pre_actions = transition_dict['action'][pre_action_indices, :]
assert pre_actions.shape[
0] == self.n_obs_steps, ">>> Do not have enough previous actions as observation"
# Load future actions
actions = transition_dict['action'][frame_indices, :]
# Load future states
next_state_indices = [idx + frame_stride for idx in frame_indices]
next_states = transition_dict['observation.state'][
next_state_indices, :]
frames_action_state_dict = {
'pre_action': pre_actions,
'action': actions,
'observation.state': states,
'next.state': next_states
}
if self.individual_normalization:
frames_action_state_dict = self.normalizer(
frames_action_state_dict)
# Update action and states to unified vector
frames_action_state_dict = self.get_uni_vec(
frames_action_state_dict,
transition_dict['action_type'],
transition_dict['state_type'],
)
# Load observable images
if start_idx < self.n_obs_steps - 1:
action_net_frame_indices = list(range(0, start_idx + 1))
action_net_frames = video_reader.get_batch(
action_net_frame_indices)
action_net_frames = torch.tensor(
action_net_frames.asnumpy()).permute(0, 3, 1, 2).float()
first_slice = action_net_frames[0:1, :]
num_padding = self.n_obs_steps - 1 - start_idx
padding = first_slice.repeat(num_padding, 1, 1, 1)
action_net_frames = torch.cat((padding, action_net_frames), dim=0)
assert (
action_net_frames.shape[0] == self.n_obs_steps
), f'{len(action_net_frames)}, self.n_obs_steps={self.n_obs_steps}'
action_net_frames = action_net_frames.permute(1, 0, 2, 3)
else:
action_net_frame_indices = list(
range(start_idx - self.n_obs_steps + 1, start_idx + 1))
action_net_frames = video_reader.get_batch(
action_net_frame_indices)
assert (
action_net_frames.shape[0] == self.n_obs_steps
), f'{len(action_net_frames)}, self.n_obs_steps={self.n_obs_steps}'
action_net_frames = torch.tensor(
action_net_frames.asnumpy()).permute(3, 0, 1, 2).float()
assert (frames.shape[0] == self.video_length
), f'{len(frames)}, self.video_length={self.video_length}'
frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
if self.spatial_transform is not None:
frames = self.spatial_transform(frames)
action_net_frames = self.spatial_transform(action_net_frames)
if self.resolution is not None:
assert (frames.shape[2], frames.shape[3]) == (
self.resolution[0], self.resolution[1]
), f'frames={frames.shape}, self.resolution={self.resolution}'
assert (
action_net_frames.shape[2], action_net_frames.shape[3]
) == (
self.resolution[0], self.resolution[1]
), f'action_net_frames={action_net_frames.shape}, self.resolution={self.resolution}'
# Normalize frames tensors to [-1,1]
frames = (frames / 255 - 0.5) * 2
action_net_frames = (action_net_frames / 255 - 0.5) * 2
fps_clip = fps_ori // frame_stride
if self.fps_max is not None and fps_clip > self.fps_max:
fps_clip = self.fps_max
data = {
'video': frames,
'instruction': instruction,
'path': video_path,
'fps': fps_clip,
'frame_stride': frame_stride,
'observation.image': action_net_frames,
}
data.update(frames_action_state_dict)
return data
def __len__(self):
return len(self.metadata)