tf32推理
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -129,4 +129,6 @@ Data/utils.py
|
|||||||
Experiment/checkpoint
|
Experiment/checkpoint
|
||||||
Experiment/log
|
Experiment/log
|
||||||
|
|
||||||
*.ckpt
|
*.ckpt
|
||||||
|
|
||||||
|
*.0
|
||||||
@@ -222,7 +222,7 @@ data:
|
|||||||
test:
|
test:
|
||||||
target: unifolm_wma.data.wma_data.WMAData
|
target: unifolm_wma.data.wma_data.WMAData
|
||||||
params:
|
params:
|
||||||
data_dir: '/mnt/ASC1637/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
||||||
video_length: ${model.params.wma_config.params.temporal_length}
|
video_length: ${model.params.wma_config.params.temporal_length}
|
||||||
frame_stride: 2
|
frame_stride: 2
|
||||||
load_raw_resolution: True
|
load_raw_resolution: True
|
||||||
|
|||||||
@@ -16,6 +16,9 @@ from collections import OrderedDict
|
|||||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||||
from unifolm_wma.utils.utils import instantiate_from_config
|
from unifolm_wma.utils.utils import instantiate_from_config
|
||||||
|
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
|
||||||
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
|
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ from fastapi.responses import JSONResponse
|
|||||||
from typing import Any, Dict, Optional, Tuple, List
|
from typing import Any, Dict, Optional, Tuple, List
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
from unifolm_wma.utils.utils import instantiate_from_config
|
from unifolm_wma.utils.utils import instantiate_from_config
|
||||||
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
from unifolm_wma.models.samplers.ddim import DDIMSampler
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,9 @@ from collections import OrderedDict
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from eval_utils import populate_queues, log_to_tensorboard
|
from eval_utils import populate_queues, log_to_tensorboard
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|||||||
@@ -11,6 +11,9 @@ from unifolm_wma.utils.utils import instantiate_from_config
|
|||||||
from unifolm_wma.utils.train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy
|
from unifolm_wma.utils.train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy
|
||||||
from unifolm_wma.utils.train import set_logger, init_workspace, load_checkpoints, get_num_parameters
|
from unifolm_wma.utils.train import set_logger, init_workspace, load_checkpoints, get_num_parameters
|
||||||
|
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
|
||||||
|
|
||||||
def get_parser(**parser_kwargs):
|
def get_parser(**parser_kwargs):
|
||||||
parser = argparse.ArgumentParser(**parser_kwargs)
|
parser = argparse.ArgumentParser(**parser_kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user