第一次完整测例跑完
This commit is contained in:
26
src/unifolm_wma/data/base.py
Normal file
26
src/unifolm_wma/data/base.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from abc import abstractmethod
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
|
||||
class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||
'''
|
||||
Define an interface to make the IterableDatasets for text2img data chainable
|
||||
'''
|
||||
|
||||
def __init__(self, num_records=0, valid_ids=None, size=256):
|
||||
super().__init__()
|
||||
self.num_records = num_records
|
||||
self.valid_ids = valid_ids
|
||||
self.sample_ids = valid_ids
|
||||
self.size = size
|
||||
|
||||
print(
|
||||
f'{self.__class__.__name__} dataset contains {self.__len__()} examples.'
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_records
|
||||
|
||||
@abstractmethod
|
||||
def __iter__(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user