第一次完整测例跑完
This commit is contained in:
26
src/unifolm_wma/data/base.py
Normal file
26
src/unifolm_wma/data/base.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from abc import abstractmethod
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
|
||||
class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
'''
|
||||
Define an interface to make the IterableDatasets for text2img data chainable
|
||||
'''
|
||||
|
||||
def __init__(self, num_records=0, valid_ids=None, size=256):
|
||||
super().__init__()
|
||||
self.num_records = num_records
|
||||
self.valid_ids = valid_ids
|
||||
self.sample_ids = valid_ids
|
||||
self.size = size
|
||||
|
||||
print(
|
||||
f'{self.__class__.__name__} dataset contains {self.__len__()} examples.'
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_records
|
||||
|
||||
@abstractmethod
|
||||
def __iter__(self):
|
||||
pass
|
||||
230
src/unifolm_wma/data/normolize.py
Normal file
230
src/unifolm_wma/data/normolize.py
Normal file
@@ -0,0 +1,230 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def create_stats_buffers(
|
||||
shapes: Dict[str, List[int]],
|
||||
modes: Dict[str, str],
|
||||
stats: Dict[str, Dict[str, Tensor]] = None,
|
||||
) -> Dict[str, Dict[str, nn.ParameterDict]]:
|
||||
"""
|
||||
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
|
||||
statistics.
|
||||
|
||||
Args: (see Normalize and Unnormalize)
|
||||
|
||||
Returns:
|
||||
Dict: A Dictionary where keys are modalities and values are `nn.ParameterDict` containing
|
||||
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
|
||||
"""
|
||||
stats_buffers = {}
|
||||
|
||||
for key, mode in modes.items():
|
||||
assert mode in ["mean_std", "min_max"]
|
||||
|
||||
shape = tuple(shapes[key])
|
||||
|
||||
if "image" in key:
|
||||
# sanity checks
|
||||
assert len(
|
||||
shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
|
||||
c, h, w = shape
|
||||
assert c < h and c < w, f"{key} is not channel first ({shape=})"
|
||||
# override image shape to be invariant to height and width
|
||||
shape = (c, 1, 1)
|
||||
|
||||
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
|
||||
# downstream by `stats` or `policy.load_state_Dict`, as expected. During forward,
|
||||
# we assert they are not infinity anymore.
|
||||
|
||||
if "action" in key:
|
||||
target_key = "action"
|
||||
elif "state" in key:
|
||||
target_key = 'observation.state'
|
||||
else:
|
||||
target_key = key
|
||||
|
||||
buffer = {}
|
||||
if mode == "mean_std":
|
||||
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict({
|
||||
"mean":
|
||||
nn.Parameter(mean, requires_grad=False),
|
||||
"std":
|
||||
nn.Parameter(std, requires_grad=False),
|
||||
})
|
||||
elif mode == "min_max":
|
||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
buffer = nn.ParameterDict({
|
||||
"min":
|
||||
nn.Parameter(min, requires_grad=False),
|
||||
"max":
|
||||
nn.Parameter(max, requires_grad=False),
|
||||
})
|
||||
|
||||
if stats is not None:
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
if mode == "mean_std":
|
||||
buffer["mean"].data = stats[target_key]["mean"].clone()
|
||||
buffer["std"].data = stats[target_key]["std"].clone()
|
||||
elif mode == "min_max":
|
||||
buffer["min"].data = stats[target_key]["min"].clone()
|
||||
buffer["max"].data = stats[target_key]["max"].clone()
|
||||
|
||||
stats_buffers[key] = buffer
|
||||
return stats_buffers
|
||||
|
||||
|
||||
def _no_stats_error_str(name: str) -> str:
|
||||
return (
|
||||
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
|
||||
"pretrained model.")
|
||||
|
||||
|
||||
class Normalize(nn.Module):
|
||||
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shapes: Dict[str, List[int]],
|
||||
modes: Dict[str, str],
|
||||
stats: Dict[str, Dict[str, Tensor]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (Dict): A Dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (Dict): A Dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (Dict, optional): A Dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are Dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_Dict(state_Dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_Dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.stats = stats
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
for key, mode in self.modes.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if mode == "mean_std":
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = (batch[key] - mean) / (std + 1e-8)
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
# normalize to [0,1]
|
||||
batch[key] = (batch[key] - min) / (max - min + 1e-8)
|
||||
# normalize to [-1, 1]
|
||||
batch[key] = batch[key] * 2 - 1
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
|
||||
|
||||
class Unnormalize(nn.Module):
|
||||
"""
|
||||
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
|
||||
original range used by the environment.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shapes: Dict[str, List[int]],
|
||||
modes: Dict[str, str],
|
||||
stats: Dict[str, Dict[str, Tensor]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
shapes (Dict): A Dictionary where keys are input modalities (e.g. "observation.image") and values
|
||||
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
|
||||
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
|
||||
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
|
||||
modes (Dict): A Dictionary where keys are output modalities (e.g. "observation.image") and values
|
||||
are their normalization modes among:
|
||||
- "mean_std": subtract the mean and divide by standard deviation.
|
||||
- "min_max": map to [-1, 1] range.
|
||||
stats (Dict, optional): A Dictionary where keys are output modalities (e.g. "observation.image")
|
||||
and values are Dictionaries of statistic types and their values (e.g.
|
||||
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
|
||||
training the model for the first time, these statistics will overwrite the default buffers. If
|
||||
not provided, as expected for finetuning or evaluation, the default buffers should to be
|
||||
overwritten by a call to `policy.load_state_Dict(state_Dict)`. That way, initializing the
|
||||
dataset is not needed to get the stats, since they are already in the policy state_Dict.
|
||||
"""
|
||||
super().__init__()
|
||||
self.shapes = shapes
|
||||
self.modes = modes
|
||||
self.stats = stats
|
||||
stats_buffers = create_stats_buffers(shapes, modes, stats)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
|
||||
for key, mode in self.modes.items():
|
||||
if key not in batch:
|
||||
continue
|
||||
|
||||
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
|
||||
|
||||
if mode == "mean_std":
|
||||
mean = buffer["mean"]
|
||||
std = buffer["std"]
|
||||
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
|
||||
assert not torch.isinf(std).any(), _no_stats_error_str("std")
|
||||
batch[key] = batch[key] * std + mean
|
||||
elif mode == "min_max":
|
||||
min = buffer["min"]
|
||||
max = buffer["max"]
|
||||
assert not torch.isinf(min).any(), _no_stats_error_str("min")
|
||||
assert not torch.isinf(max).any(), _no_stats_error_str("max")
|
||||
batch[key] = (batch[key] + 1) / 2
|
||||
batch[key] = batch[key] * (max - min) + min
|
||||
else:
|
||||
raise ValueError(mode)
|
||||
return batch
|
||||
60
src/unifolm_wma/data/utils.py
Normal file
60
src/unifolm_wma/data/utils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from typing import Dict, List, Union
|
||||
from pathlib import Path
|
||||
from safetensors.torch import load_file
|
||||
|
||||
def unflatten_dict(d, sep="/"):
|
||||
outdict = {}
|
||||
for key, value in d.items():
|
||||
parts = key.split(sep)
|
||||
d = outdict
|
||||
for part in parts[:-1]:
|
||||
if part not in d:
|
||||
d[part] = {}
|
||||
d = d[part]
|
||||
d[parts[-1]] = value
|
||||
return outdict
|
||||
|
||||
|
||||
def load_episode_data_index(repo_id, version, root) -> Dict[str, torch.Tensor]:
|
||||
"""episode_data_index contains the range of indices for each episode
|
||||
|
||||
Example:
|
||||
```python
|
||||
from_id = episode_data_index["from"][episode_id].item()
|
||||
to_id = episode_data_index["to"][episode_id].item()
|
||||
episode_frames = [dataset[i] for i in range(from_id, to_id)]
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(
|
||||
root) / repo_id / "meta_data" / "episode_data_index.safetensors"
|
||||
else:
|
||||
path = hf_hub_download(repo_id,
|
||||
"meta_data/episode_data_index.safetensors",
|
||||
repo_type="dataset",
|
||||
revision=version)
|
||||
|
||||
return load_file(path)
|
||||
|
||||
|
||||
def load_stats(repo_id, version, root) -> Dict[str, Dict[str, torch.Tensor]]:
|
||||
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
|
||||
|
||||
Example:
|
||||
```python
|
||||
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
|
||||
```
|
||||
"""
|
||||
if root is not None:
|
||||
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
|
||||
else:
|
||||
path = hf_hub_download(repo_id,
|
||||
"meta_data/stats.safetensors",
|
||||
repo_type="dataset",
|
||||
revision=version)
|
||||
|
||||
stats = load_file(path)
|
||||
return unflatten_dict(stats)
|
||||
408
src/unifolm_wma/data/wma_data.py
Normal file
408
src/unifolm_wma/data/wma_data.py
Normal file
@@ -0,0 +1,408 @@
|
||||
import torch
|
||||
import os
|
||||
import random
|
||||
import pandas as pd
|
||||
import h5py
|
||||
|
||||
from decord import VideoReader, cpu
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from pathlib import Path
|
||||
|
||||
from unifolm_wma.data.utils import load_stats
|
||||
from unifolm_wma.data.normolize import Normalize, Unnormalize
|
||||
|
||||
|
||||
class WMAData(Dataset):
|
||||
"""
|
||||
Assuming the following dataset structure:
|
||||
dataset_dir/
|
||||
├── videos
|
||||
│ ├──dataset_name
|
||||
│ │ ├──camera_view_dir
|
||||
│ │ ├── 0.mp4
|
||||
│ │ ├── 1.mp4
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
├── transitions
|
||||
│ ├── dataset_name
|
||||
│ ├── meta_data
|
||||
│ ├── 0.h5
|
||||
│ ├── 1.h5
|
||||
│ └── ...
|
||||
└── dataset_name.csv
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
meta_path,
|
||||
data_dir,
|
||||
subsample=None,
|
||||
video_length=16,
|
||||
resolution=[256, 512],
|
||||
frame_stride=1,
|
||||
frame_stride_min=1,
|
||||
spatial_transform=None,
|
||||
crop_resolution=None,
|
||||
fps_max=None,
|
||||
load_raw_resolution=False,
|
||||
fixed_fps=None,
|
||||
random_fs=False,
|
||||
cond_robot_label_prob=0.0,
|
||||
transition_dir=None,
|
||||
dataset_name=None,
|
||||
normalization_mode='min_max',
|
||||
individual_normalization=False,
|
||||
n_obs_steps=1,
|
||||
max_action_dim=7,
|
||||
max_state_dim=7,
|
||||
):
|
||||
self.meta_path = meta_path
|
||||
self.data_dir = data_dir
|
||||
self.subsample = subsample
|
||||
self.video_length = video_length
|
||||
self.resolution = [resolution, resolution] if isinstance(
|
||||
resolution, int) else resolution
|
||||
self.fps_max = fps_max
|
||||
self.frame_stride = frame_stride
|
||||
self.frame_stride_min = frame_stride_min
|
||||
self.fixed_fps = fixed_fps
|
||||
self.load_raw_resolution = load_raw_resolution
|
||||
self.random_fs = random_fs
|
||||
self.cond_robot_label_prob = cond_robot_label_prob
|
||||
self.transition_dir = transition_dir
|
||||
self.dataset_name = dataset_name
|
||||
self.max_action_dim = max_action_dim
|
||||
self.max_state_dim = max_state_dim
|
||||
|
||||
self._load_metadata()
|
||||
if spatial_transform is not None:
|
||||
if spatial_transform == "random_crop":
|
||||
self.spatial_transform = transforms.RandomCrop(crop_resolution)
|
||||
elif spatial_transform == "center_crop":
|
||||
self.spatial_transform = transforms.Compose([
|
||||
transforms.CenterCrop(resolution),
|
||||
])
|
||||
elif spatial_transform == "resize_center_crop":
|
||||
self.spatial_transform = transforms.Compose([
|
||||
transforms.Resize(min(self.resolution)),
|
||||
transforms.CenterCrop(self.resolution),
|
||||
])
|
||||
elif spatial_transform == "resize":
|
||||
self.spatial_transform = transforms.Resize(self.resolution)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
self.spatial_transform = None
|
||||
|
||||
self.normalization_mode = normalization_mode
|
||||
self.individual_normalization = individual_normalization
|
||||
self.n_obs_steps = n_obs_steps
|
||||
self._load_stats()
|
||||
if individual_normalization:
|
||||
self._init_normalizers()
|
||||
|
||||
def _load_metadata(self):
|
||||
metadata = pd.read_csv(self.meta_path, dtype=str)
|
||||
if self.subsample is not None:
|
||||
metadata = metadata.sample(self.subsample, random_state=0)
|
||||
|
||||
self.metadata = metadata
|
||||
# drop the rows contain NaN values
|
||||
self.metadata.dropna(inplace=True)
|
||||
print(
|
||||
f">>> {metadata['data_dir'].iloc[0]}: {len(metadata)} data samples loaded."
|
||||
)
|
||||
|
||||
def _load_stats(self):
|
||||
self.stats = load_stats(self.dataset_name, None, self.transition_dir)
|
||||
print(f">>> {self.metadata['data_dir'].iloc[0]}: data stats loaded.")
|
||||
|
||||
def _init_normalizers(self):
|
||||
shape_dict = {
|
||||
'pre_action': [self.stats['action']['max'].shape[-1]],
|
||||
'action': [self.stats['action']['max'].shape[-1]],
|
||||
'observation.state':
|
||||
[self.stats['observation.state']['max'].shape[-1]],
|
||||
'next.state': [self.stats['observation.state']['max'].shape[-1]]
|
||||
}
|
||||
normalization_mode_dict = {
|
||||
'pre_action': self.normalization_mode,
|
||||
'action': self.normalization_mode,
|
||||
'observation.state': self.normalization_mode,
|
||||
'next.state': self.normalization_mode
|
||||
}
|
||||
self.normalizer = Normalize(shape_dict, normalization_mode_dict,
|
||||
self.stats)
|
||||
self.unnormalizer = Unnormalize(shape_dict, normalization_mode_dict,
|
||||
self.stats)
|
||||
print(
|
||||
f">>> {self.metadata['data_dir'].iloc[0]}: normalizer initiated.")
|
||||
|
||||
def _get_video_path(self, sample):
|
||||
rel_video_fp = os.path.join(sample['data_dir'],
|
||||
str(sample['videoid']) + '.mp4')
|
||||
full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp)
|
||||
return full_video_fp
|
||||
|
||||
def _get_transition_path(self, sample):
|
||||
data_dir = Path(sample['data_dir'])
|
||||
if self.dataset_name == data_dir.name:
|
||||
rel_transition_fp = os.path.join(str(data_dir),
|
||||
str(sample['videoid']) + '.h5')
|
||||
else:
|
||||
rel_transition_fp = os.path.join(str(data_dir.parent),
|
||||
str(sample['videoid']) + '.h5')
|
||||
full_transition_fp = os.path.join(self.data_dir, 'transitions',
|
||||
rel_transition_fp)
|
||||
return full_transition_fp
|
||||
|
||||
def get_uni_vec(self, action_state_dict, action_type, state_type):
|
||||
if 'pre_action' in action_state_dict:
|
||||
action_state_dict['pre_action'], _ = self._map_to_uni_action(
|
||||
action_state_dict['pre_action'], action_type)
|
||||
if 'action' in action_state_dict:
|
||||
action_state_dict['action'], action_state_dict[
|
||||
'action_mask'] = self._map_to_uni_action(
|
||||
action_state_dict['action'], action_type)
|
||||
if 'observation.state' in action_state_dict:
|
||||
action_state_dict['observation.state'], _ = self._map_to_uni_state(
|
||||
action_state_dict['observation.state'], state_type)
|
||||
if 'next.state' in action_state_dict:
|
||||
action_state_dict['next.state'], action_state_dict[
|
||||
'state_mask'] = self._map_to_uni_state(
|
||||
action_state_dict['next.state'], state_type)
|
||||
return action_state_dict
|
||||
|
||||
def _map_to_uni_action(self, action, action_type):
|
||||
action_dim = action.shape[-1]
|
||||
uni_action = torch.nn.functional.pad(
|
||||
action, (0, self.max_action_dim - action_dim),
|
||||
mode='constant',
|
||||
value=0)
|
||||
uni_action_mask = torch.zeros_like(uni_action)
|
||||
uni_action_mask[:, :action_dim] = 1
|
||||
return uni_action, uni_action_mask
|
||||
|
||||
def _map_to_uni_state(self, state, state_type):
|
||||
state_dim = state.shape[-1]
|
||||
uni_state = torch.nn.functional.pad(
|
||||
state, (0, self.max_state_dim - state_dim),
|
||||
mode='constant',
|
||||
value=0)
|
||||
uni_state_mask = torch.zeros_like(uni_state)
|
||||
uni_state_mask[:, :state_dim] = 1
|
||||
return uni_state, uni_state_mask
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
if self.random_fs:
|
||||
frame_stride = random.randint(self.frame_stride_min,
|
||||
self.frame_stride)
|
||||
else:
|
||||
frame_stride = self.frame_stride
|
||||
|
||||
# Get frames until success
|
||||
while True:
|
||||
index = index % len(self.metadata)
|
||||
sample = self.metadata.iloc[index]
|
||||
video_path = self._get_video_path(sample)
|
||||
|
||||
instruction = sample['instruction']
|
||||
if self.cond_robot_label_prob > 0.0 and random.random(
|
||||
) < self.cond_robot_label_prob:
|
||||
if sample['embodiment'] != 'x':
|
||||
instruction = sample['embodiment'] + ' [SEP] ' + sample[
|
||||
'instruction']
|
||||
try:
|
||||
if self.load_raw_resolution:
|
||||
video_reader = VideoReader(video_path, ctx=cpu(0))
|
||||
else:
|
||||
video_reader = VideoReader(video_path,
|
||||
ctx=cpu(0),
|
||||
width=530,
|
||||
height=300)
|
||||
if len(video_reader) < self.video_length:
|
||||
print(
|
||||
f">>> Video length ({len(video_reader)}) is smaller than target length({self.video_length})"
|
||||
)
|
||||
index += 1
|
||||
continue
|
||||
else:
|
||||
pass
|
||||
except:
|
||||
index += 1
|
||||
print(f">>> Error: load video failed! path = {video_path}")
|
||||
continue
|
||||
|
||||
fps_ori = video_reader.get_avg_fps()
|
||||
if self.fixed_fps is not None:
|
||||
frame_stride = int(frame_stride *
|
||||
(1.0 * fps_ori / self.fixed_fps))
|
||||
|
||||
# To avoid extreme cases when fixed_fps is used
|
||||
frame_stride = max(frame_stride, 1)
|
||||
|
||||
# Get valid range (adapting case by case)
|
||||
required_frame_num = frame_stride * (self.video_length - 1) + 1
|
||||
frame_num = len(video_reader)
|
||||
if frame_num < required_frame_num:
|
||||
# Drop extra samples if fixed fps is required
|
||||
if self.fixed_fps is not None and frame_num < required_frame_num * 0.5:
|
||||
index += 1
|
||||
continue
|
||||
else:
|
||||
frame_stride = frame_num // self.video_length
|
||||
required_frame_num = frame_stride * (self.video_length -
|
||||
1) + 1
|
||||
|
||||
# Select a random clip
|
||||
random_range = frame_num - required_frame_num
|
||||
start_idx = random.randint(
|
||||
0, random_range -
|
||||
frame_stride) if random_range - frame_stride > 0 else 0
|
||||
|
||||
# Calculate frame indices
|
||||
frame_indices = [
|
||||
start_idx + frame_stride * i for i in range(self.video_length)
|
||||
]
|
||||
try:
|
||||
next_frame_indices = [
|
||||
idx + frame_stride for idx in frame_indices
|
||||
]
|
||||
frames = video_reader.get_batch(next_frame_indices)
|
||||
break
|
||||
except:
|
||||
print(
|
||||
f">>> Error: Get frames failed! path = {video_path}; [max_ind vs frame_total:{max(frame_indices)} / {frame_num}]"
|
||||
)
|
||||
index += 1
|
||||
continue
|
||||
|
||||
# Load transition data
|
||||
transition_path = self._get_transition_path(sample)
|
||||
with h5py.File(transition_path, 'r') as h5f:
|
||||
transition_dict = {}
|
||||
for key in h5f.keys():
|
||||
transition_dict[key] = torch.tensor(h5f[key][()])
|
||||
for key in h5f.attrs.keys():
|
||||
transition_dict[key] = h5f.attrs[key]
|
||||
|
||||
# Load observable states
|
||||
if start_idx < self.n_obs_steps - 1:
|
||||
state_indices = list(range(0, start_idx + 1))
|
||||
states = transition_dict['observation.state'][state_indices, :]
|
||||
num_padding = self.n_obs_steps - 1 - start_idx
|
||||
first_slice = states[0:1, :] # (t, d)
|
||||
padding = first_slice.repeat(num_padding, 1)
|
||||
states = torch.cat((padding, states), dim=0)
|
||||
else:
|
||||
state_indices = list(
|
||||
range(start_idx - self.n_obs_steps + 1, start_idx + 1))
|
||||
states = transition_dict['observation.state'][state_indices, :]
|
||||
assert states.shape[
|
||||
0] == self.n_obs_steps, '>>> Do not have enough previous states as observation.'
|
||||
|
||||
# Load observable actions
|
||||
if start_idx < self.n_obs_steps:
|
||||
pre_action_indices = list(range(0, start_idx))
|
||||
pre_actions = transition_dict['action'][pre_action_indices, :]
|
||||
num_padding = self.n_obs_steps - start_idx
|
||||
first_slice = torch.zeros_like(transition_dict['action'][:1, :])
|
||||
padding = first_slice.repeat(num_padding, 1)
|
||||
pre_actions = torch.cat((padding, pre_actions), dim=0)
|
||||
else:
|
||||
pre_action_indices = list(
|
||||
range(start_idx - self.n_obs_steps, start_idx))
|
||||
pre_actions = transition_dict['action'][pre_action_indices, :]
|
||||
assert pre_actions.shape[
|
||||
0] == self.n_obs_steps, ">>> Do not have enough previous actions as observation"
|
||||
|
||||
# Load future actions
|
||||
actions = transition_dict['action'][frame_indices, :]
|
||||
# Load future states
|
||||
next_state_indices = [idx + frame_stride for idx in frame_indices]
|
||||
next_states = transition_dict['observation.state'][
|
||||
next_state_indices, :]
|
||||
frames_action_state_dict = {
|
||||
'pre_action': pre_actions,
|
||||
'action': actions,
|
||||
'observation.state': states,
|
||||
'next.state': next_states
|
||||
}
|
||||
if self.individual_normalization:
|
||||
frames_action_state_dict = self.normalizer(
|
||||
frames_action_state_dict)
|
||||
|
||||
# Update action and states to unified vector
|
||||
frames_action_state_dict = self.get_uni_vec(
|
||||
frames_action_state_dict,
|
||||
transition_dict['action_type'],
|
||||
transition_dict['state_type'],
|
||||
)
|
||||
|
||||
# Load observable images
|
||||
if start_idx < self.n_obs_steps - 1:
|
||||
action_net_frame_indices = list(range(0, start_idx + 1))
|
||||
action_net_frames = video_reader.get_batch(
|
||||
action_net_frame_indices)
|
||||
action_net_frames = torch.tensor(
|
||||
action_net_frames.asnumpy()).permute(0, 3, 1, 2).float()
|
||||
first_slice = action_net_frames[0:1, :]
|
||||
num_padding = self.n_obs_steps - 1 - start_idx
|
||||
padding = first_slice.repeat(num_padding, 1, 1, 1)
|
||||
action_net_frames = torch.cat((padding, action_net_frames), dim=0)
|
||||
assert (
|
||||
action_net_frames.shape[0] == self.n_obs_steps
|
||||
), f'{len(action_net_frames)}, self.n_obs_steps={self.n_obs_steps}'
|
||||
action_net_frames = action_net_frames.permute(1, 0, 2, 3)
|
||||
else:
|
||||
action_net_frame_indices = list(
|
||||
range(start_idx - self.n_obs_steps + 1, start_idx + 1))
|
||||
action_net_frames = video_reader.get_batch(
|
||||
action_net_frame_indices)
|
||||
assert (
|
||||
action_net_frames.shape[0] == self.n_obs_steps
|
||||
), f'{len(action_net_frames)}, self.n_obs_steps={self.n_obs_steps}'
|
||||
action_net_frames = torch.tensor(
|
||||
action_net_frames.asnumpy()).permute(3, 0, 1, 2).float()
|
||||
|
||||
assert (frames.shape[0] == self.video_length
|
||||
), f'{len(frames)}, self.video_length={self.video_length}'
|
||||
frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
|
||||
|
||||
if self.spatial_transform is not None:
|
||||
frames = self.spatial_transform(frames)
|
||||
action_net_frames = self.spatial_transform(action_net_frames)
|
||||
|
||||
if self.resolution is not None:
|
||||
assert (frames.shape[2], frames.shape[3]) == (
|
||||
self.resolution[0], self.resolution[1]
|
||||
), f'frames={frames.shape}, self.resolution={self.resolution}'
|
||||
assert (
|
||||
action_net_frames.shape[2], action_net_frames.shape[3]
|
||||
) == (
|
||||
self.resolution[0], self.resolution[1]
|
||||
), f'action_net_frames={action_net_frames.shape}, self.resolution={self.resolution}'
|
||||
|
||||
# Normalize frames tensors to [-1,1]
|
||||
frames = (frames / 255 - 0.5) * 2
|
||||
action_net_frames = (action_net_frames / 255 - 0.5) * 2
|
||||
fps_clip = fps_ori // frame_stride
|
||||
if self.fps_max is not None and fps_clip > self.fps_max:
|
||||
fps_clip = self.fps_max
|
||||
|
||||
data = {
|
||||
'video': frames,
|
||||
'instruction': instruction,
|
||||
'path': video_path,
|
||||
'fps': fps_clip,
|
||||
'frame_stride': frame_stride,
|
||||
'observation.image': action_net_frames,
|
||||
}
|
||||
data.update(frames_action_state_dict)
|
||||
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.metadata)
|
||||
Reference in New Issue
Block a user