Files
unifolm-world-model-action/src/unifolm_wma/utils/data.py
2026-01-18 00:30:10 +08:00

243 lines
10 KiB
Python

import os, sys
import numpy as np
import torch
import pytorch_lightning as pl
from functools import partial
from torch.utils.data import (DataLoader, Dataset, ConcatDataset,
WeightedRandomSampler)
from unifolm_wma.data.base import Txt2ImgIterableBaseDataset
from unifolm_wma.utils.utils import instantiate_from_config
def worker_init_fn(_):
worker_info = torch.utils.data.get_worker_info()
dataset = worker_info.dataset
worker_id = worker_info.id
if isinstance(dataset, Txt2ImgIterableBaseDataset):
split_size = dataset.num_records // worker_info.num_workers
# Reset num_records to the true number to retain reliable length information
dataset.sample_ids = dataset.valid_ids[worker_id *
split_size:(worker_id + 1) *
split_size]
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
else:
return np.random.seed(np.random.get_state()[1][0] + worker_id)
class WrappedDataset(Dataset):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
def __init__(self, dataset):
self.data = dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(self,
batch_size,
train=None,
validation=None,
test=None,
predict=None,
wrap=False,
num_workers=None,
shuffle_test_loader=False,
use_worker_init_fn=False,
shuffle_val_dataloader=True,
train_img=None,
dataset_and_weights=None):
super().__init__()
self.batch_size = batch_size
self.dataset_configs = dict()
self.num_workers = num_workers if num_workers is not None else batch_size * 2
self.use_worker_init_fn = use_worker_init_fn
if train is not None:
self.dataset_configs["train"] = train
self.train_dataloader = self._train_dataloader
if validation is not None:
self.dataset_configs["validation"] = validation
self.val_dataloader = partial(self._val_dataloader,
shuffle=shuffle_val_dataloader)
if test is not None:
self.dataset_configs["test"] = test
self.test_dataloader = partial(self._test_dataloader,
shuffle=shuffle_test_loader)
if predict is not None:
self.dataset_configs["predict"] = predict
self.predict_dataloader = self._predict_dataloader
self.img_loader = None
self.wrap = wrap
self.collate_fn = None
self.dataset_weights = dataset_and_weights
assert round(sum(self.dataset_weights.values()),
2) == 1.0, "The sum of dataset weights != 1.0"
def prepare_data(self):
pass
def setup(self, stage=None):
if 'train' in self.dataset_configs:
self.train_datasets = dict()
for dataname in self.dataset_weights:
data_dir = self.dataset_configs['train']['params']['data_dir']
transition_dir = '/'.join([data_dir, 'transitions'])
csv_file = f'{dataname}.csv'
meta_path = '/'.join([data_dir, csv_file])
self.dataset_configs['train']['params'][
'meta_path'] = meta_path
self.dataset_configs['train']['params'][
'transition_dir'] = transition_dir
self.dataset_configs['train']['params'][
'dataset_name'] = dataname
self.train_datasets[dataname] = instantiate_from_config(
self.dataset_configs['train'])
# Setup validation dataset
if 'validation' in self.dataset_configs:
self.val_datasets = dict()
for dataname in self.dataset_weights:
data_dir = self.dataset_configs['validation']['params'][
'data_dir']
transition_dir = '/'.join([data_dir, 'transitions'])
csv_file = f'{dataname}.csv'
meta_path = '/'.join([data_dir, csv_file])
self.dataset_configs['validation']['params'][
'meta_path'] = meta_path
self.dataset_configs['validation']['params'][
'transition_dir'] = transition_dir
self.dataset_configs['validation']['params'][
'dataset_name'] = dataname
self.val_datasets[dataname] = instantiate_from_config(
self.dataset_configs['validation'])
# Setup test dataset
if 'test' in self.dataset_configs:
self.test_datasets = dict()
for dataname in self.dataset_weights:
data_dir = self.dataset_configs['test']['params']['data_dir']
transition_dir = '/'.join([data_dir, 'transitions'])
csv_file = f'{dataname}.csv'
meta_path = '/'.join([data_dir, csv_file])
self.dataset_configs['test']['params']['meta_path'] = meta_path
self.dataset_configs['test']['params'][
'transition_dir'] = transition_dir
self.dataset_configs['test']['params'][
'dataset_name'] = dataname
self.test_datasets[dataname] = instantiate_from_config(
self.dataset_configs['test'])
if self.wrap:
for k in self.datasets:
self.datasets[k] = WrappedDataset(self.datasets[k])
def _train_dataloader(self):
is_iterable_dataset = False # NOTE Hand Code
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
combined_dataset = []
sample_weights = []
for dataname, dataset in self.train_datasets.items():
combined_dataset.append(dataset)
sample_weights.append(
torch.full((len(dataset), ),
self.dataset_weights[dataname] / len(dataset)))
combined_dataset = ConcatDataset(combined_dataset)
sample_weights = torch.cat(sample_weights)
sampler = WeightedRandomSampler(sample_weights,
num_samples=len(combined_dataset),
replacement=True)
loader = DataLoader(combined_dataset,
sampler=sampler,
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
collate_fn=self.collate_fn,
drop_last=True
)
return loader
def _val_dataloader(self, shuffle=False):
is_iterable_dataset = False # NOTE Hand Code
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
combined_dataset = []
sample_weights = []
for dataname, dataset in self.val_datasets.items():
combined_dataset.append(dataset)
sample_weights.append(
torch.full((len(dataset), ),
self.dataset_weights[dataname] / len(dataset)))
combined_dataset = ConcatDataset(combined_dataset)
sample_weights = torch.cat(sample_weights)
sampler = WeightedRandomSampler(sample_weights,
num_samples=len(combined_dataset),
replacement=True)
loader = DataLoader(combined_dataset,
sampler=sampler,
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
collate_fn=self.collate_fn)
return loader
def _test_dataloader(self, shuffle=False):
is_iterable_dataset = False # NOTE Hand Code
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
combined_dataset = []
sample_weights = []
for dataname, dataset in self.test_datasets.items():
combined_dataset.append(dataset)
sample_weights.append(
torch.full((len(dataset), ),
self.dataset_weights[dataname] / len(dataset)))
combined_dataset = ConcatDataset(combined_dataset)
sample_weights = torch.cat(sample_weights)
sampler = WeightedRandomSampler(sample_weights,
num_samples=len(combined_dataset),
replacement=True)
loader = DataLoader(combined_dataset,
sampler=sampler,
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
collate_fn=self.collate_fn)
return loader
def _predict_dataloader(self, shuffle=False):
if isinstance(self.datasets['predict'],
Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets["predict"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
collate_fn=self.collate_fn,
)
def __len__(self):
count = 0
for _, values in self.train_datasets.items():
count += len(values)
return count