第一次完整测例跑完

This commit is contained in:
2026-01-18 00:30:10 +08:00
parent ca15cc593b
commit 25c6fc04db
180 changed files with 29305 additions and 0 deletions

133
.gitignore vendored Normal file
View File

@@ -0,0 +1,133 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.venv
venv/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
*.pdf
.pdf
plot_test/
plot/
performance/
localTest/
fig/
figure/
*.mp4
*.json
Data/ControlVAE.yml
Data/Misc
Data/Pretrained
Data/utils.py
Experiment/checkpoint
Experiment/log
*.ckpt
*.STL
*.gif

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "external/dlimp"]
path = external/dlimp
url = https://github.com/kvablack/dlimp

439
LICENSE Normal file
View File

@@ -0,0 +1,439 @@
Attribution-NonCommercial-ShareAlike 4.0 International
Copyright (c) 2016-2025 HangZhou YuShu TECHNOLOGY CO.,LTD. ("Unitree Robotics")
=======================================================================
Creative Commons Corporation ("Creative Commons") is not a law firm and
does not provide legal services or legal advice. Distribution of
Creative Commons public licenses does not create a lawyer-client or
other relationship. Creative Commons makes its licenses and related
information available on an "as-is" basis. Creative Commons gives no
warranties regarding its licenses, any material licensed under their
terms and conditions, or any related information. Creative Commons
disclaims all liability for damages resulting from their use to the
fullest extent possible.
Using Creative Commons Public Licenses
Creative Commons public licenses provide a standard set of terms and
conditions that creators and other rights holders may use to share
original works of authorship and other material subject to copyright
and certain other rights specified in the public license below. The
following considerations are for informational purposes only, are not
exhaustive, and do not form part of our licenses.
Considerations for licensors: Our public licenses are
intended for use by those authorized to give the public
permission to use material in ways otherwise restricted by
copyright and certain other rights. Our licenses are
irrevocable. Licensors should read and understand the terms
and conditions of the license they choose before applying it.
Licensors should also secure all rights necessary before
applying our licenses so that the public can reuse the
material as expected. Licensors should clearly mark any
material not subject to the license. This includes other CC-
licensed material, or material used under an exception or
limitation to copyright. More considerations for licensors:
wiki.creativecommons.org/Considerations_for_licensors
Considerations for the public: By using one of our public
licenses, a licensor grants the public permission to use the
licensed material under specified terms and conditions. If
the licensor's permission is not necessary for any reason--for
example, because of any applicable exception or limitation to
copyright--then that use is not regulated by the license. Our
licenses grant only permissions under copyright and certain
other rights that a licensor has authority to grant. Use of
the licensed material may still be restricted for other
reasons, including because others have copyright or other
rights in the material. A licensor may make special requests,
such as asking that all changes be marked or described.
Although not required by our licenses, you are encouraged to
respect those requests where reasonable. More considerations
for the public:
wiki.creativecommons.org/Considerations_for_licensees
=======================================================================
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
Public License
By exercising the Licensed Rights (defined below), You accept and agree
to be bound by the terms and conditions of this Creative Commons
Attribution-NonCommercial-ShareAlike 4.0 International Public License
("Public License"). To the extent this Public License may be
interpreted as a contract, You are granted the Licensed Rights in
consideration of Your acceptance of these terms and conditions, and the
Licensor grants You such rights in consideration of benefits the
Licensor receives from making the Licensed Material available under
these terms and conditions.
Section 1 -- Definitions.
a. Adapted Material means material subject to Copyright and Similar
Rights that is derived from or based upon the Licensed Material
and in which the Licensed Material is translated, altered,
arranged, transformed, or otherwise modified in a manner requiring
permission under the Copyright and Similar Rights held by the
Licensor. For purposes of this Public License, where the Licensed
Material is a musical work, performance, or sound recording,
Adapted Material is always produced where the Licensed Material is
synched in timed relation with a moving image.
b. Adapter's License means the license You apply to Your Copyright
and Similar Rights in Your contributions to Adapted Material in
accordance with the terms and conditions of this Public License.
c. BY-NC-SA Compatible License means a license listed at
creativecommons.org/compatiblelicenses, approved by Creative
Commons as essentially the equivalent of this Public License.
d. Copyright and Similar Rights means copyright and/or similar rights
closely related to copyright including, without limitation,
performance, broadcast, sound recording, and Sui Generis Database
Rights, without regard to how the rights are labeled or
categorized. For purposes of this Public License, the rights
specified in Section 2(b)(1)-(2) are not Copyright and Similar
Rights.
e. Effective Technological Measures means those measures that, in the
absence of proper authority, may not be circumvented under laws
fulfilling obligations under Article 11 of the WIPO Copyright
Treaty adopted on December 20, 1996, and/or similar international
agreements.
f. Exceptions and Limitations means fair use, fair dealing, and/or
any other exception or limitation to Copyright and Similar Rights
that applies to Your use of the Licensed Material.
g. License Elements means the license attributes listed in the name
of a Creative Commons Public License. The License Elements of this
Public License are Attribution, NonCommercial, and ShareAlike.
h. Licensed Material means the artistic or literary work, database,
or other material to which the Licensor applied this Public
License.
i. Licensed Rights means the rights granted to You subject to the
terms and conditions of this Public License, which are limited to
all Copyright and Similar Rights that apply to Your use of the
Licensed Material and that the Licensor has authority to license.
j. Licensor means the individual(s) or entity(ies) granting rights
under this Public License.
k. NonCommercial means not primarily intended for or directed towards
commercial advantage or monetary compensation. For purposes of
this Public License, the exchange of the Licensed Material for
other material subject to Copyright and Similar Rights by digital
file-sharing or similar means is NonCommercial provided there is
no payment of monetary compensation in connection with the
exchange.
l. Share means to provide material to the public by any means or
process that requires permission under the Licensed Rights, such
as reproduction, public display, public performance, distribution,
dissemination, communication, or importation, and to make material
available to the public including in ways that members of the
public may access the material from a place and at a time
individually chosen by them.
m. Sui Generis Database Rights means rights other than copyright
resulting from Directive 96/9/EC of the European Parliament and of
the Council of 11 March 1996 on the legal protection of databases,
as amended and/or succeeded, as well as other essentially
equivalent rights anywhere in the world.
n. You means the individual or entity exercising the Licensed Rights
under this Public License. Your has a corresponding meaning.
Section 2 -- Scope.
a. License grant.
1. Subject to the terms and conditions of this Public License,
the Licensor hereby grants You a worldwide, royalty-free,
non-sublicensable, non-exclusive, irrevocable license to
exercise the Licensed Rights in the Licensed Material to:
a. reproduce and Share the Licensed Material, in whole or
in part, for NonCommercial purposes only; and
b. produce, reproduce, and Share Adapted Material for
NonCommercial purposes only.
2. Exceptions and Limitations. For the avoidance of doubt, where
Exceptions and Limitations apply to Your use, this Public
License does not apply, and You do not need to comply with
its terms and conditions.
3. Term. The term of this Public License is specified in Section
6(a).
4. Media and formats; technical modifications allowed. The
Licensor authorizes You to exercise the Licensed Rights in
all media and formats whether now known or hereafter created,
and to make technical modifications necessary to do so. The
Licensor waives and/or agrees not to assert any right or
authority to forbid You from making technical modifications
necessary to exercise the Licensed Rights, including
technical modifications necessary to circumvent Effective
Technological Measures. For purposes of this Public License,
simply making modifications authorized by this Section 2(a)
(4) never produces Adapted Material.
5. Downstream recipients.
a. Offer from the Licensor -- Licensed Material. Every
recipient of the Licensed Material automatically
receives an offer from the Licensor to exercise the
Licensed Rights under the terms and conditions of this
Public License.
b. Additional offer from the Licensor -- Adapted Material.
Every recipient of Adapted Material from You
automatically receives an offer from the Licensor to
exercise the Licensed Rights in the Adapted Material
under the conditions of the Adapter's License You apply.
c. No downstream restrictions. You may not offer or impose
any additional or different terms or conditions on, or
apply any Effective Technological Measures to, the
Licensed Material if doing so restricts exercise of the
Licensed Rights by any recipient of the Licensed
Material.
6. No endorsement. Nothing in this Public License constitutes or
may be construed as permission to assert or imply that You
are, or that Your use of the Licensed Material is, connected
with, or sponsored, endorsed, or granted official status by,
the Licensor or others designated to receive attribution as
provided in Section 3(a)(1)(A)(i).
b. Other rights.
1. Moral rights, such as the right of integrity, are not
licensed under this Public License, nor are publicity,
privacy, and/or other similar personality rights; however, to
the extent possible, the Licensor waives and/or agrees not to
assert any such rights held by the Licensor to the limited
extent necessary to allow You to exercise the Licensed
Rights, but not otherwise.
2. Patent and trademark rights are not licensed under this
Public License.
3. To the extent possible, the Licensor waives any right to
collect royalties from You for the exercise of the Licensed
Rights, whether directly or through a collecting society
under any voluntary or waivable statutory or compulsory
licensing scheme. In all other cases the Licensor expressly
reserves any right to collect such royalties, including when
the Licensed Material is used other than for NonCommercial
purposes.
Section 3 -- License Conditions.
Your exercise of the Licensed Rights is expressly made subject to the
following conditions.
a. Attribution.
1. If You Share the Licensed Material (including in modified
form), You must:
a. retain the following if it is supplied by the Licensor
with the Licensed Material:
i. identification of the creator(s) of the Licensed
Material and any others designated to receive
attribution, in any reasonable manner requested by
the Licensor (including by pseudonym if
designated);
ii. a copyright notice;
iii. a notice that refers to this Public License;
iv. a notice that refers to the disclaimer of
warranties;
v. a URI or hyperlink to the Licensed Material to the
extent reasonably practicable;
b. indicate if You modified the Licensed Material and
retain an indication of any previous modifications; and
c. indicate the Licensed Material is licensed under this
Public License, and include the text of, or the URI or
hyperlink to, this Public License.
2. You may satisfy the conditions in Section 3(a)(1) in any
reasonable manner based on the medium, means, and context in
which You Share the Licensed Material. For example, it may be
reasonable to satisfy the conditions by providing a URI or
hyperlink to a resource that includes the required
information.
3. If requested by the Licensor, You must remove any of the
information required by Section 3(a)(1)(A) to the extent
reasonably practicable.
b. ShareAlike.
In addition to the conditions in Section 3(a), if You Share
Adapted Material You produce, the following conditions also apply.
1. The Adapter's License You apply must be a Creative Commons
license with the same License Elements, this version or
later, or a BY-NC-SA Compatible License.
2. You must include the text of, or the URI or hyperlink to, the
Adapter's License You apply. You may satisfy this condition
in any reasonable manner based on the medium, means, and
context in which You Share Adapted Material.
3. You may not offer or impose any additional or different terms
or conditions on, or apply any Effective Technological
Measures to, Adapted Material that restrict exercise of the
rights granted under the Adapter's License You apply.
Section 4 -- Sui Generis Database Rights.
Where the Licensed Rights include Sui Generis Database Rights that
apply to Your use of the Licensed Material:
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
to extract, reuse, reproduce, and Share all or a substantial
portion of the contents of the database for NonCommercial purposes
only;
b. if You include all or a substantial portion of the database
contents in a database in which You have Sui Generis Database
Rights, then the database in which You have Sui Generis Database
Rights (but not its individual contents) is Adapted Material,
including for purposes of Section 3(b); and
c. You must comply with the conditions in Section 3(a) if You Share
all or a substantial portion of the contents of the database.
For the avoidance of doubt, this Section 4 supplements and does not
replace Your obligations under this Public License where the Licensed
Rights include other Copyright and Similar Rights.
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
c. The disclaimer of warranties and limitation of liability provided
above shall be interpreted in a manner that, to the extent
possible, most closely approximates an absolute disclaimer and
waiver of all liability.
Section 6 -- Term and Termination.
a. This Public License applies for the term of the Copyright and
Similar Rights licensed here. However, if You fail to comply with
this Public License, then Your rights under this Public License
terminate automatically.
b. Where Your right to use the Licensed Material has terminated under
Section 6(a), it reinstates:
1. automatically as of the date the violation is cured, provided
it is cured within 30 days of Your discovery of the
violation; or
2. upon express reinstatement by the Licensor.
For the avoidance of doubt, this Section 6(b) does not affect any
right the Licensor may have to seek remedies for Your violations
of this Public License.
c. For the avoidance of doubt, the Licensor may also offer the
Licensed Material under separate terms or conditions or stop
distributing the Licensed Material at any time; however, doing so
will not terminate this Public License.
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
License.
Section 7 -- Other Terms and Conditions.
a. The Licensor shall not be bound by any additional or different
terms or conditions communicated by You unless expressly agreed.
b. Any arrangements, understandings, or agreements regarding the
Licensed Material not stated herein are separate from and
independent of the terms and conditions of this Public License.
Section 8 -- Interpretation.
a. For the avoidance of doubt, this Public License does not, and
shall not be interpreted to, reduce, limit, restrict, or impose
conditions on any use of the Licensed Material that could lawfully
be made without permission under this Public License.
b. To the extent possible, if any provision of this Public License is
deemed unenforceable, it shall be automatically reformed to the
minimum extent necessary to make it enforceable. If the provision
cannot be reformed, it shall be severed from this Public License
without affecting the enforceability of the remaining terms and
conditions.
c. No term or condition of this Public License will be waived and no
failure to comply consented to unless expressly agreed to by the
Licensor.
d. Nothing in this Public License constitutes or may be interpreted
as a limitation upon, or waiver of, any privileges and immunities
that apply to the Licensor or You, including from the legal
processes of any jurisdiction or authority.
=======================================================================
Creative Commons is not a party to its public
licenses. Notwithstanding, Creative Commons may elect to apply one of
its public licenses to material it publishes and in those instances
will be considered the “Licensor.” The text of the Creative Commons
public licenses is dedicated to the public domain under the CC0 Public
Domain Dedication. Except for the limited purpose of indicating that
material is shared under a Creative Commons public license or as
otherwise permitted by the Creative Commons policies published at
creativecommons.org/policies, Creative Commons does not authorize the
use of the trademark "Creative Commons" or any other trademark or logo
of Creative Commons without its prior written consent including,
without limitation, in connection with any unauthorized modifications
to any of its public licenses or any other arrangements,
understandings, or agreements concerning use of licensed material. For
the avoidance of doubt, this paragraph does not form part of the
public licenses.
Creative Commons may be contacted at creativecommons.org.

228
README.md Normal file
View File

@@ -0,0 +1,228 @@
# UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family
<p style="font-size: 1.2em;">
<a href="https://unigen-x.github.io/unifolm-world-model-action.github.io"><strong>Project Page</strong></a> |
<a href="https://huggingface.co/collections/unitreerobotics/unifolm-wma-0-68ca23027310c0ca0f34959c"><strong>Models</strong></a> |
<a href="https://huggingface.co/unitreerobotics/datasets"><strong>Dataset</strong></a>
</p>
<div align="center">
<p align="right">
<span> 🌎English </span> | <a href="README_cn.md"> 🇨🇳中文 </a>
</p>
</div>
<div align="justify">
<b>UnifoLM-WMA-0</b> is Unitrees open-source world-modelaction architecture spanning multiple types of robotic embodiments, designed specifically for general-purpose robot learning. Its core component is a world-model capable of understanding the physical interactions between robots and the environments. This world-model provides two key functions: (a) <b>Simulation Engine</b> operates as an interactive simulator to generate synthetic data for robot learning; (b) <b>Policy Enhancement</b> connects with an action head and, by predicting future interaction processes with the world-model, further optimizes decision-making performance.
</div>
## 🦾 Real-Robot Demonstrations
| <img src="assets/gifs/real_z1_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/gifs/real_dual_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
|:---:|:---:|
| <img src="assets/gifs/real_cleanup_pencils.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/gifs/real_g1_pack_camera.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
**Note: the top-right window shows the world models pretion of future action videos.**
## 🔥 News
* Sep 22, 2025: 🚀 We released the deployment code for assisting experiments with [Unitree](https://www.unitree.com/) robots.
* Sep 15, 2025: 🚀 We released the training and inference code along with the model weights of [**UnifoLM-WMA-0**](https://huggingface.co/collections/unitreerobotics/unifolm-wma-0-68ca23027310c0ca0f34959c).
## 📑 Opensource Plan
- [x] Training
- [x] Inference
- [x] Checkpoints
- [x] Deployment
## ⚙️ Installation
```
conda create -n unifolm-wma python==3.10.18
conda activate unifolm-wma
conda install pinocchio=3.2.0 -c conda-forge -y
conda install ffmpeg=7.1.1 -c conda-forge
git clone --recurse-submodules https://github.com/unitreerobotics/unifolm-world-model-action.git
# If you already downloaded the repo:
cd unifolm-world-model-action
git submodule update --init --recursive
pip install -e .
cd external/dlimp
pip install -e .
```
## 🧰 Model Checkpoints
| Model | Description | Link|
|---------|-------|------|
|$\text{UnifoLM-WMA-0}_{Base}$| Fine-tuned on [Open-X](https://robotics-transformer-x.github.io/) dataset. | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Base)|
|$\text{UnifoLM-WMA-0}_{Dual}$| Fine-tuned on five [Unitree opensource dataset](https://huggingface.co/collections/unitreerobotics/g1-dex1-datasets-68bae98bf0a26d617f9983ab) in both decision-making and simulation modes. | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Dual)|
## 🛢️ Dataset
In our experiments, we consider the following three opensource dataset:
| Dataset | Robot | Link |
|---------|-------|------|
|Z1_StackBox| [Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_StackBox_Dataset/tree/v2.1)|
|Z1_DualArm_StackBox|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset/tree/v2.1)|
|Z1_DualArm_StackBox_V2|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset_V2/tree/v2.1)|
|Z1_DualArm_Cleanup_Pencils|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_CleanupPencils_Dataset/tree/v2.1)|
|G1_Pack_Camera|[Unitree G1](https://www.unitree.com/g1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/G1_Dex1_MountCameraRedGripper_Dataset/tree/v2.1)|
To train on your own dataset, first to have the data following the [Huggingface LeRobot V2.1](https://github.com/huggingface/lerobot) dataset format. Assume the datasets source directory structure is as follows:
```
source_dir/
├── dataset1_name
├── dataset2_name
├── dataset3_name
└── ...
```
Then, convert a dataset to the required format using the command below:
```python
cd prepare_data
python prepare_training_data.py \
--source_dir /path/to/your/source_dir \
--target_dir /path/to/save/the/converted/data \
--dataset_name "dataset1_name" \
--robot_name "a tag of the robot in the dataset" # e.g, Unitree Z1 Robot Arm or Unitree G1 Robot with Gripper.
```
The resulting data structure (Note: model training only supports input from the main-view camera. If the dataset includes multiple views, remove the corresponding values from the ```data_dir``` column in the CSV file.
```
target_dir/
├── videos
│ ├──dataset1_name
│ │ ├──camera_view_dir
│ │ ├── 0.mp4
│ │ ├── 1.mp4
│ │ └── ...
│ └── ...
├── transitions
│ ├── dataset1_name
│ ├── meta_data
│ ├── 0.h5
│ ├── 1.h5
│ └── ...
└── dataset1_name.csv
```
## 🚴‍♂️ Training
A. Our training strategy is outlined as follows:
- **Step 1**: Fine-tune a video generation model as the world model using the [Open-X](https://robotics-transformer-x.github.io/) dataset;
- **Step 2**: Post-train $\text{UnifoLM-WMA}$ in decision-making mode on the downstream task dataset;
<div align="left">
<img src="assets/pngs/dm_mode.png" width="600">
</div>
- **Step 3**: Post-train $\text{UnifoLM-WMA}$ in simulation mode on the downstream task dataset.
<div align="left">
<img src="assets/pngs/sim_mode.png" width="600">
</div>
**Note**: If you only require $\text{UnifoLM-WMA}$ to operate in a single mode, you may skip the corresponding step.
B. To conduct training on a single or multiple datasets, please follow the steps below:
- **Step 1**: The maximum DoF is assumed to be 16, if you have more than 16 DoF, update ```agent_state_dim``` and ```agent_action_dim``` in [configs/train/config.yaml](https://github.com/unitreerobotics/unifolm-wma/blob/working/configs/train/config.yaml) ;
- **Step 2**: Set up the input shapes for each modality in [configs/train/meta.json](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/meta.json);
- **Step 3**: Configure the training parameters in [configs/train/config.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/config.yaml). For the ```pretrained_checkpoint```, we recommend using the checkpoint " $\text{UnifoLM-WMA-0}_{Base}$ " fine-tuned on the [Open-X](https://robotics-transformer-x.github.io/) dataset;
```yaml
model:
pretrained_checkpoint: /path/to/pretrained/checkpoint;
...
decision_making_only: True # Train the world model only in decision-making mode. If False, jointly train it in both decision-making and simulation modes.
...
data:
...
train:
...
data_dir: /path/to/training/dataset/directory
dataset_and_weights: # list the name of each dataset below and make sure the summation of weights is 1.0
dataset1_name: 0.2
dataset2_name: 0.2
dataset3_name: 0.2
dataset4_name: 0.2
dataset5_name: 0.2
```
- **Step 4**: Setup ```experiment_name```, ```save_root``` variables in [scripts/train.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/train.sh);
- **Step 5**: Launch the training with the command:
```
bash scripts/train.sh
```
## 🌏 Inference under Interactive Simulation Mode
To run the world model in an interactive simulation mode, follow these steps:
- **Step 1**: (Skip this step if you just would like to test using the examples we provided) Prepare your own prompt following the format used in the [examples/world_model_interaction_prompts](https://github.com/unitreerobotics/unitree-world-model/tree/main/examples/world_model_interaction_prompts):
```
world_model_interaction_prompts/
├── images
│ ├── dataset1_name
│ │ ├── 0.png # Image prompt
│ │ └── ...
│ └── ...
├── transitions
│ ├── dataset1_name
│ │ ├── meta_data # Used for normalization
│ │ ├── 0.h # Robot state and action data; in interaction mode,
│ │ │ # only used to retrieve the robot state corresponding
│ │ │ # to the image prompt
│ │ └── ...
│ └── ...
├── dataset1_name.csv # File for loading image prompts, text instruction and corresponding robot states
└── ...
```
- **Step 2**: Specify the correct paths for ```pretrained_checkpoint```(e.g, $\text{UnifoLM-WMA-0}_{Dual}$) and ```data_dir``` in [configs/inference/world_model_interaction.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/inference/world_model_interaction.yaml)
- **Step 3**: Set the paths for ```checkpoint```, ```res_dir``` and ```prompt_dir``` in [scripts/run_world_model_interaction.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/run_world_model_interaction.sh), and specify all the dataset's name in ```datasets=(...)```. Then, launch the inference with the command:
```
bash scripts/run_world_model_interaction.sh
```
## 🧠 Inference and Deployment under Decision-Making Mode
In this setup, inference is performed on a server, while a robot client gathers observations from the real-robot and sends them to the server to query actions. The process unfolds through the following steps:
### Server Setup:
- **Step-1**: Specify ```ckpt```, ```res_dir```, ```datasets``` in [scripts/run_real_eval_server.sh](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/scripts/run_real_eval_server.sh);
- **Step-2**: Configure ```data_dir``` and ```dataset_and_weights``` in [config/inference/world_model_decision_making.yaml](https://github.com/unitreerobotics/unifolm-world-model-action/blob/f12b4782652ca00452941d851b17446e4ee7124a/configs/inference/world_model_decision_making.yaml#L225);
- **Step-3**: Launch the server:
```
conda activate unifolm-wma
cd unifolm-world-model-action
bash scripts/run_real_eval_server.sh
```
### Client Setup
- **Step-1**: Follow the instructions in [unitree_deploy/README.md](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/unitree_deploy/README.md) to create the ```unitree_deploy``` conda environment, install the required packages, launch the controllers or services on the real-robot.
- **Step-2**: Open a new terminal and establish a tunnel connection from the client to the server:
```
ssh user_name@remote_server_IP -CNg -L 8000:127.0.0.1:8000
```
- **Step-3**: Run the ```unitree_deploy/robot_client.py``` script to start inference:
```
cd unitree_deploy
python scripts/robot_client.py --robot_type "g1_dex1" --action_horizon 16 --exe_steps 16 --observation_horizon 2 --language_instruction "pack black camera into box" --output_dir ./results --control_freq 15
```
## 📝 Codebase Architecture
Here's a high-level overview of the project's code structure and core components:
```
unitree-world-model/
├── assets # Media assets such as GIFs, images, and demo videos
├── configs # Configuration files for training and inference
│ ├── inference
│ └── train
├── examples # Example inputs and prompts for running inference
├── external # External packages
├── prepare_data # Scripts for dataset preprocessing and format conversion
├── scripts # Main scripts for training, evaluation, and deployment
├── src
│ ├──unitree_worldmodel # Core Python package for the Unitree world model
│ │ ├── data # Dataset loading, transformations, and dataloaders
│ │ ├── models # Model architectures and backbone definitions
│ │ ├── modules # Custom model modules and components
│ │ └── utils # Utility functions and common helpers
└── unitree_deploy # Deployment code
```
## 🙏 Acknowledgement
Lots of code are inherited from [DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter), [Diffusion Policy](https://github.com/real-stanford/diffusion_policy), [ACT](https://github.com/MarkFzp/act-plus-plus) and [HPT](https://github.com/liruiw/HPT).
## 📝 Citation
```
@misc{unifolm-wma-0,
author = {Unitree},
title = {UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family},
year = {2025},
}
```

216
README_cn.md Normal file
View File

@@ -0,0 +1,216 @@
# UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family
<p style="font-size: 1.2em;">
<a href="https://unigen-x.github.io/unifolm-world-model-action.github.io"><strong>项目主页</strong></a> |
<a href="https://huggingface.co/collections/unitreerobotics/unifolm-wma-0-68ca23027310c0ca0f34959c"><strong>开源模型</strong></a> |
<a href="https://huggingface.co/unitreerobotics/datasets"><strong>开源数据</strong></a>
</p>
<div align="center">
<p align="right">
<span> 🌎English </span> | <a href="README_cn.md"> 🇨🇳中文 </a>
</p>
</div>
**UnifoLM-WMA-0** 是宇树科技跨多类机器人本体的开源世界模型-动作架构专为通用机器人学习而设计。其核心成分在于一个可以理解机器人与环境交互物理规律的世界模型。该世界模型具备两大核心功能1**仿真引擎**,作为交互式仿真器运行,为机器人学习提供合成数据;2**策略增强**,可与一个动作头进行对接,通过预测未来与物理世界的交互过程,进一步优化决策性能。模型的真机部署效果如下所示,其中右上角小窗口是世界模型对于未来环境变化的预测,可辅助控制指令生成。
## 🦾 真机效果
| <img src="assets/gifs/real_z1_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/gifs/real_dual_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
|:---:|:---:|
| <img src="assets/gifs/real_cleanup_pencils.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/gifs/real_g1_pack_camera.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
**注:右上角小窗口显示世界模型对未来动作视频的预测。**
## 新闻
* 2025年9月22日: 🚀 我们发布了应用宇树科技机器人进行真机实验的部署代码.
* 2025年9月15日: 🚀 我们发布了 **UnifoLM-WMA-0** 的训练与推理代码,以及对应的模型权重.
## 📑 开源计划
- [x] 训练代码
- [x] 推理代码
- [x] 模型Checkpoints
- [x] 真机部署代码
## ⚙️ 安装
```
conda create -n unifolm-wma python==3.10.18
conda activate unifolm-wma
conda install pinocchio=3.2.0 -c conda-forge -y
conda install ffmpeg=7.1.1 -c conda-forge
git clone --recurse-submodules https://github.com/unitreerobotics/unifolm-world-model-action.git
# If you already downloaded the repo:
cd unifolm-world-model-action
git submodule update --init --recursive
pip install -e .
cd external/dlimp
pip install -e .
```
## 🧰 模型 Checkpoints
| 模型 | 描述 | 链接 |
|---------|-------|------|
|$\text{UnifoLM-WMA-0}_{Base}$| 在 [Open-X](https://robotics-transformer-x.github.io/) 数据集微调后的模型 | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Base)|
|$\text{UnifoLM-WMA-0}_{Dual}$| 在五个[宇树科技开源数据集](https://huggingface.co/collections/unitreerobotics/g1-dex1-datasets-68bae98bf0a26d617f9983ab)上,决策和仿真双模式,联合微调后的模型 | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Dual)|
## 🛢️ 数据集
实验中,我们训练测试了如下五个开源数据集:
| 数据集 | 机器人 | 链接 |
|---------|-------|------|
|Z1_StackBox| [Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_StackBox_Dataset/tree/v2.1)|
|Z1_DualArm_StackBox|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset/tree/v2.1)|
|Z1_DualArm_StackBox_V2|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset_V2/tree/v2.1)|
|Z1_DualArm_Cleanup_Pencils|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_CleanupPencils_Dataset/tree/v2.1)|
|G1_Pack_Camera|[Unitree G1](https://www.unitree.com/g1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/G1_Dex1_MountCameraRedGripper_Dataset/tree/v2.1)|
要在自定义数据集上训练,请首先确保数据符合 [Huggingface LeRobot V2.1](https://github.com/huggingface/lerobot) 数据集格式,假设下载后的数据目录结构如下:
```
source_dir/
├── dataset1_name
├── dataset2_name
├── dataset3_name
└── ...
```
随后执行以下命令进行格式转换:
```python
cd prepare_data
python prepare_training_data.py \
--source_dir /path/to/your/source_dir \
--target_dir /path/to/save/the/converted/data/directory \
--dataset_name "dataset1_name" \
--robot_name "a tag of the robot in the dataset" # 例如: Unitree Z1 Robot Arm 或 Unitree G1 Robot with Gripper。
```
转换后的数据结构如下(注:模型训练只支持主视角相机输入, 如数据存在腕部视角需删除CSV文件中```data_dir```列对应的视频路径):
```
target_dir/
├── videos
│ ├──dataset1_name
│ │ ├──camera_view_dir
│ │ ├── 0.mp4
│ │ ├── 1.mp4
│ │ └── ...
│ └── ...
├── transitions
│ ├── dataset1_name
│ │ ├── meta_data
│ │ ├── 0.h5
│ │ ├── 1.h5
│ │ └── ...
│ └── ...
└── dataset1_name.csv
```
## 🚴 ♂️ 模型训练
一. 我们的训练策略概括如下:
- **步骤 1**:在 [Open-X](https://robotics-transformer-x.github.io/) 数据集上微调视频生成模型使其作为世界模型World Model
- **步骤 2**:在下游任务数据集上,对 $\text{UnifoLM-WMA}$ 进行决策模式decision-making mode后训练
<div align="left">
<img src="assets/pngs/dm_mode.png" width="600">
</div>
- **步骤 3**:在下游任务数据集上,对 $\text{UnifoLM-WMA}$ 进行仿真模式simulation mode后训练。
<div align="left">
<img src="assets/pngs/sim_mode.png" width="600">
</div>
**注意**:如果只需要 $\text{UnifoLM-WMA}$ 在单一模式下运行,可以跳过相应的步骤。
二. 在单个或多个数据集上进行训练,请按照以下步骤操作:
- **步骤1**默认的最高自由度为16DOF若需更多自由度请修改[configs/train/config.yaml](https://github.com/unitreerobotics/unifolm-wma/blob/working/configs/train/config.yaml) 中 ```agent_state_dim``` 及 ```agent_action_dim``` 的数值;
- **步骤2**:在 [configs/train/meta.json](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/meta.json) 中为每种模态设置输入维度;
- **步骤3** 在 [configs/train/config.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/config.yaml) 中配置训练参数及路径。关于预训练的模型,推荐使用 $\text{UnifoLM-WMA-0}_{Base}$ ,其在[Open-X](https://robotics-transformer-x.github.io/) 数据集上微调过;
```yaml
model:
pretrained_checkpoint: /path/to/pretrained/checkpoint
...
dicision_making_only: True # 是否只训练世界模型决策模式?如果否,则决策模式与仿真模式联合训练。
...
data:
...
train:
...
data_dir: /path/to/training/dataset/directory
dataset_and_weights: # 列出所有数据集的名称及权重确保权重和为1.0
dataset1_name: 0.2
dataset2_name: 0.2
dataset3_name: 0.2
dataset4_name: 0.2
dataset5_name: 0.2
```
- **步骤4** 在 [scripts/train.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/train.sh) 中配置```experiment_name```, ```save_root``` 变量;
- **步骤5** 运行如下指令开启训练:
```
bash scripts/train.sh
```
## 🌏 世界模型交互推理
要启用世界模型的交互模式,请按以下步骤操作:
- **步骤1**(若仅用提供的实例进行测试,可跳过此步) 请按照 [examples/world_model_interaction_prompts](https://github.com/unitreerobotics/unitree-world-model/tree/main/examples/world_model_interaction_prompts) 目录中的格式,自定义提示词目录:
```
world_model_interaction_prompts/
├── images
│ ├── dataset1_name
│ │ ├── 0.png # 图像提示词
│ │ └── ...
│ └── ...
├── transitions
│ ├── dataset1_name
│ │ ├── meta_data # 用于归一化
│ │ ├── 0.h # 机器人状态、动作相关数据,在交互模式下仅用于获取与图像提示词对应的机器人状态
│ │ └── ...
│ └── ...
├── dataset1_name.csv # 该文件用于加载对应的:图像提示词、文本指令及机器人状态
└── ...
```
- **步骤2** 在 [configs/inference/world_model_interaction.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/inference/world_model_interaction.yaml) 中指定 ```pretrained_checkpoint``` (例如:$\text{UnifoLM-WMA-0}_{Dual}$) 和 ```data_dir``` 的正确路径;
- **步骤3** 在 [scripts/run_world_model_interaction.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/run_world_model_interaction.sh) 中指定```checkpoint```、```res_dir``` 和 ```prompt_dir```的正确路径,并在```datasets=(...)```中列出测试的数据集名称,然后用下述指令启动推理:
```
bash scripts/run_world_model_interaction.sh
```
## 🧠 世界模型决策推理及部署
在我们的系统中,推理在服务器端执行;机器人客户端从真实机器人收集观测信息并发送至服务器, 进行视频及动作推理。可通过如下步骤实现整个过程:
### 服务器端设置
- **步骤1** 在 [scripts/run_real_eval_server.sh](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/scripts/run_real_eval_server.sh) 中指定 ```ckpt```、```res_dir```、```datasets```;
- **步骤2** 在 [config/inference/world_model_decision_making.yaml](https://github.com/unitreerobotics/unifolm-world-model-action/blob/f12b4782652ca00452941d851b17446e4ee7124a/configs/inference/world_model_decision_making.yaml#L225) 中配置 ```data_dir``` 和 ```dataset_and_weights```;
- **步骤3** 启动服务器:
```
conda activate unifolm-wma
cd unifolm-world-model-action
bash scripts/run_real_eval_server.sh
```
### 客户端设置
- **步骤1** 参考 [unitree_deploy/README.md](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/unitree_deploy/README.md),创建 ```unitree_deploy``` conda 环境,安装所需依赖包,并在真实机器人端启动控制器或服务;
- **步骤2**: 打开一个新的终端,从客户端到服务器建立隧道连接:
```
ssh user_name@remote_server_IP -CNg -L 8000:127.0.0.1:8000
```
- **步骤3** 运行 ```unitree_deploy/robot_client.py``` 脚本以启动推理:
```
cd unitree_deploy
python scripts/robot_client.py --robot_type "g1_dex1" --action_horizon 16 --exe_steps 16 --observation_horizon 2 --language_instruction "pack black camera into box" --output_dir ./results --control_freq 15
```
## 📝 代码架构
以下是本项目代码结构设计及核心组件说明::
```
unitree-world-model/
├── assets # GIF动图、静态图片和演示视频等媒体素材
├── configs # 配置文件
│ ├── inference
│ └── train
├── examples # 示例数据
├── external # 外部代码库
├── prepare_data # 数据处理
├── scripts # 主程序脚本
├── src
│ ├──unitree_worldmodel # 核心库
│ │ ├── data # 数据加载
│ │ ├── models # 模型架构
│ │ ├── modules # 自定义模块
| │ └── utils # 工具函数
```
## 🙏 致谢声明
本项目代码基于以下优秀开源项目构建,特此致谢:[DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter), [Diffusion Policy](https://github.com/real-stanford/diffusion_policy), [ACT](https://github.com/MarkFzp/act-plus-plus) 和 [HPT](https://github.com/liruiw/HPT).

BIN
assets/pngs/dm_mode.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

BIN
assets/pngs/sim_mode.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

54
ckpts/.gitattributes vendored Normal file
View File

@@ -0,0 +1,54 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bin.* filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zstandard filter=lfs diff=lfs merge=lfs -text
*.tfevents* filter=lfs diff=lfs merge=lfs -text
*.db* filter=lfs diff=lfs merge=lfs -text
*.ark* filter=lfs diff=lfs merge=lfs -text
**/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
**/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
**/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
*.gguf* filter=lfs diff=lfs merge=lfs -text
*.ggml filter=lfs diff=lfs merge=lfs -text
*.llamafile* filter=lfs diff=lfs merge=lfs -text
*.pt2 filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
assets/real_cleanup_pencils.gif filter=lfs diff=lfs merge=lfs -text
assets/world_model_interaction.gif filter=lfs diff=lfs merge=lfs -text
assets/real_dual_stackbox.gif filter=lfs diff=lfs merge=lfs -text
assets/real_g1_pack_camera.gif filter=lfs diff=lfs merge=lfs -text
assets/real_z1_stackbox.gif filter=lfs diff=lfs merge=lfs -text
unifolm_wma_dual.ckpt filter=lfs diff=lfs merge=lfs -text

439
ckpts/LICENSE Normal file
View File

@@ -0,0 +1,439 @@
Attribution-NonCommercial-ShareAlike 4.0 International
Copyright (c) 2016-2025 HangZhou YuShu TECHNOLOGY CO.,LTD. ("Unitree Robotics")
=======================================================================
Creative Commons Corporation ("Creative Commons") is not a law firm and
does not provide legal services or legal advice. Distribution of
Creative Commons public licenses does not create a lawyer-client or
other relationship. Creative Commons makes its licenses and related
information available on an "as-is" basis. Creative Commons gives no
warranties regarding its licenses, any material licensed under their
terms and conditions, or any related information. Creative Commons
disclaims all liability for damages resulting from their use to the
fullest extent possible.
Using Creative Commons Public Licenses
Creative Commons public licenses provide a standard set of terms and
conditions that creators and other rights holders may use to share
original works of authorship and other material subject to copyright
and certain other rights specified in the public license below. The
following considerations are for informational purposes only, are not
exhaustive, and do not form part of our licenses.
Considerations for licensors: Our public licenses are
intended for use by those authorized to give the public
permission to use material in ways otherwise restricted by
copyright and certain other rights. Our licenses are
irrevocable. Licensors should read and understand the terms
and conditions of the license they choose before applying it.
Licensors should also secure all rights necessary before
applying our licenses so that the public can reuse the
material as expected. Licensors should clearly mark any
material not subject to the license. This includes other CC-
licensed material, or material used under an exception or
limitation to copyright. More considerations for licensors:
wiki.creativecommons.org/Considerations_for_licensors
Considerations for the public: By using one of our public
licenses, a licensor grants the public permission to use the
licensed material under specified terms and conditions. If
the licensor's permission is not necessary for any reason--for
example, because of any applicable exception or limitation to
copyright--then that use is not regulated by the license. Our
licenses grant only permissions under copyright and certain
other rights that a licensor has authority to grant. Use of
the licensed material may still be restricted for other
reasons, including because others have copyright or other
rights in the material. A licensor may make special requests,
such as asking that all changes be marked or described.
Although not required by our licenses, you are encouraged to
respect those requests where reasonable. More considerations
for the public:
wiki.creativecommons.org/Considerations_for_licensees
=======================================================================
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
Public License
By exercising the Licensed Rights (defined below), You accept and agree
to be bound by the terms and conditions of this Creative Commons
Attribution-NonCommercial-ShareAlike 4.0 International Public License
("Public License"). To the extent this Public License may be
interpreted as a contract, You are granted the Licensed Rights in
consideration of Your acceptance of these terms and conditions, and the
Licensor grants You such rights in consideration of benefits the
Licensor receives from making the Licensed Material available under
these terms and conditions.
Section 1 -- Definitions.
a. Adapted Material means material subject to Copyright and Similar
Rights that is derived from or based upon the Licensed Material
and in which the Licensed Material is translated, altered,
arranged, transformed, or otherwise modified in a manner requiring
permission under the Copyright and Similar Rights held by the
Licensor. For purposes of this Public License, where the Licensed
Material is a musical work, performance, or sound recording,
Adapted Material is always produced where the Licensed Material is
synched in timed relation with a moving image.
b. Adapter's License means the license You apply to Your Copyright
and Similar Rights in Your contributions to Adapted Material in
accordance with the terms and conditions of this Public License.
c. BY-NC-SA Compatible License means a license listed at
creativecommons.org/compatiblelicenses, approved by Creative
Commons as essentially the equivalent of this Public License.
d. Copyright and Similar Rights means copyright and/or similar rights
closely related to copyright including, without limitation,
performance, broadcast, sound recording, and Sui Generis Database
Rights, without regard to how the rights are labeled or
categorized. For purposes of this Public License, the rights
specified in Section 2(b)(1)-(2) are not Copyright and Similar
Rights.
e. Effective Technological Measures means those measures that, in the
absence of proper authority, may not be circumvented under laws
fulfilling obligations under Article 11 of the WIPO Copyright
Treaty adopted on December 20, 1996, and/or similar international
agreements.
f. Exceptions and Limitations means fair use, fair dealing, and/or
any other exception or limitation to Copyright and Similar Rights
that applies to Your use of the Licensed Material.
g. License Elements means the license attributes listed in the name
of a Creative Commons Public License. The License Elements of this
Public License are Attribution, NonCommercial, and ShareAlike.
h. Licensed Material means the artistic or literary work, database,
or other material to which the Licensor applied this Public
License.
i. Licensed Rights means the rights granted to You subject to the
terms and conditions of this Public License, which are limited to
all Copyright and Similar Rights that apply to Your use of the
Licensed Material and that the Licensor has authority to license.
j. Licensor means the individual(s) or entity(ies) granting rights
under this Public License.
k. NonCommercial means not primarily intended for or directed towards
commercial advantage or monetary compensation. For purposes of
this Public License, the exchange of the Licensed Material for
other material subject to Copyright and Similar Rights by digital
file-sharing or similar means is NonCommercial provided there is
no payment of monetary compensation in connection with the
exchange.
l. Share means to provide material to the public by any means or
process that requires permission under the Licensed Rights, such
as reproduction, public display, public performance, distribution,
dissemination, communication, or importation, and to make material
available to the public including in ways that members of the
public may access the material from a place and at a time
individually chosen by them.
m. Sui Generis Database Rights means rights other than copyright
resulting from Directive 96/9/EC of the European Parliament and of
the Council of 11 March 1996 on the legal protection of databases,
as amended and/or succeeded, as well as other essentially
equivalent rights anywhere in the world.
n. You means the individual or entity exercising the Licensed Rights
under this Public License. Your has a corresponding meaning.
Section 2 -- Scope.
a. License grant.
1. Subject to the terms and conditions of this Public License,
the Licensor hereby grants You a worldwide, royalty-free,
non-sublicensable, non-exclusive, irrevocable license to
exercise the Licensed Rights in the Licensed Material to:
a. reproduce and Share the Licensed Material, in whole or
in part, for NonCommercial purposes only; and
b. produce, reproduce, and Share Adapted Material for
NonCommercial purposes only.
2. Exceptions and Limitations. For the avoidance of doubt, where
Exceptions and Limitations apply to Your use, this Public
License does not apply, and You do not need to comply with
its terms and conditions.
3. Term. The term of this Public License is specified in Section
6(a).
4. Media and formats; technical modifications allowed. The
Licensor authorizes You to exercise the Licensed Rights in
all media and formats whether now known or hereafter created,
and to make technical modifications necessary to do so. The
Licensor waives and/or agrees not to assert any right or
authority to forbid You from making technical modifications
necessary to exercise the Licensed Rights, including
technical modifications necessary to circumvent Effective
Technological Measures. For purposes of this Public License,
simply making modifications authorized by this Section 2(a)
(4) never produces Adapted Material.
5. Downstream recipients.
a. Offer from the Licensor -- Licensed Material. Every
recipient of the Licensed Material automatically
receives an offer from the Licensor to exercise the
Licensed Rights under the terms and conditions of this
Public License.
b. Additional offer from the Licensor -- Adapted Material.
Every recipient of Adapted Material from You
automatically receives an offer from the Licensor to
exercise the Licensed Rights in the Adapted Material
under the conditions of the Adapter's License You apply.
c. No downstream restrictions. You may not offer or impose
any additional or different terms or conditions on, or
apply any Effective Technological Measures to, the
Licensed Material if doing so restricts exercise of the
Licensed Rights by any recipient of the Licensed
Material.
6. No endorsement. Nothing in this Public License constitutes or
may be construed as permission to assert or imply that You
are, or that Your use of the Licensed Material is, connected
with, or sponsored, endorsed, or granted official status by,
the Licensor or others designated to receive attribution as
provided in Section 3(a)(1)(A)(i).
b. Other rights.
1. Moral rights, such as the right of integrity, are not
licensed under this Public License, nor are publicity,
privacy, and/or other similar personality rights; however, to
the extent possible, the Licensor waives and/or agrees not to
assert any such rights held by the Licensor to the limited
extent necessary to allow You to exercise the Licensed
Rights, but not otherwise.
2. Patent and trademark rights are not licensed under this
Public License.
3. To the extent possible, the Licensor waives any right to
collect royalties from You for the exercise of the Licensed
Rights, whether directly or through a collecting society
under any voluntary or waivable statutory or compulsory
licensing scheme. In all other cases the Licensor expressly
reserves any right to collect such royalties, including when
the Licensed Material is used other than for NonCommercial
purposes.
Section 3 -- License Conditions.
Your exercise of the Licensed Rights is expressly made subject to the
following conditions.
a. Attribution.
1. If You Share the Licensed Material (including in modified
form), You must:
a. retain the following if it is supplied by the Licensor
with the Licensed Material:
i. identification of the creator(s) of the Licensed
Material and any others designated to receive
attribution, in any reasonable manner requested by
the Licensor (including by pseudonym if
designated);
ii. a copyright notice;
iii. a notice that refers to this Public License;
iv. a notice that refers to the disclaimer of
warranties;
v. a URI or hyperlink to the Licensed Material to the
extent reasonably practicable;
b. indicate if You modified the Licensed Material and
retain an indication of any previous modifications; and
c. indicate the Licensed Material is licensed under this
Public License, and include the text of, or the URI or
hyperlink to, this Public License.
2. You may satisfy the conditions in Section 3(a)(1) in any
reasonable manner based on the medium, means, and context in
which You Share the Licensed Material. For example, it may be
reasonable to satisfy the conditions by providing a URI or
hyperlink to a resource that includes the required
information.
3. If requested by the Licensor, You must remove any of the
information required by Section 3(a)(1)(A) to the extent
reasonably practicable.
b. ShareAlike.
In addition to the conditions in Section 3(a), if You Share
Adapted Material You produce, the following conditions also apply.
1. The Adapter's License You apply must be a Creative Commons
license with the same License Elements, this version or
later, or a BY-NC-SA Compatible License.
2. You must include the text of, or the URI or hyperlink to, the
Adapter's License You apply. You may satisfy this condition
in any reasonable manner based on the medium, means, and
context in which You Share Adapted Material.
3. You may not offer or impose any additional or different terms
or conditions on, or apply any Effective Technological
Measures to, Adapted Material that restrict exercise of the
rights granted under the Adapter's License You apply.
Section 4 -- Sui Generis Database Rights.
Where the Licensed Rights include Sui Generis Database Rights that
apply to Your use of the Licensed Material:
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
to extract, reuse, reproduce, and Share all or a substantial
portion of the contents of the database for NonCommercial purposes
only;
b. if You include all or a substantial portion of the database
contents in a database in which You have Sui Generis Database
Rights, then the database in which You have Sui Generis Database
Rights (but not its individual contents) is Adapted Material,
including for purposes of Section 3(b); and
c. You must comply with the conditions in Section 3(a) if You Share
all or a substantial portion of the contents of the database.
For the avoidance of doubt, this Section 4 supplements and does not
replace Your obligations under this Public License where the Licensed
Rights include other Copyright and Similar Rights.
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
c. The disclaimer of warranties and limitation of liability provided
above shall be interpreted in a manner that, to the extent
possible, most closely approximates an absolute disclaimer and
waiver of all liability.
Section 6 -- Term and Termination.
a. This Public License applies for the term of the Copyright and
Similar Rights licensed here. However, if You fail to comply with
this Public License, then Your rights under this Public License
terminate automatically.
b. Where Your right to use the Licensed Material has terminated under
Section 6(a), it reinstates:
1. automatically as of the date the violation is cured, provided
it is cured within 30 days of Your discovery of the
violation; or
2. upon express reinstatement by the Licensor.
For the avoidance of doubt, this Section 6(b) does not affect any
right the Licensor may have to seek remedies for Your violations
of this Public License.
c. For the avoidance of doubt, the Licensor may also offer the
Licensed Material under separate terms or conditions or stop
distributing the Licensed Material at any time; however, doing so
will not terminate this Public License.
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
License.
Section 7 -- Other Terms and Conditions.
a. The Licensor shall not be bound by any additional or different
terms or conditions communicated by You unless expressly agreed.
b. Any arrangements, understandings, or agreements regarding the
Licensed Material not stated herein are separate from and
independent of the terms and conditions of this Public License.
Section 8 -- Interpretation.
a. For the avoidance of doubt, this Public License does not, and
shall not be interpreted to, reduce, limit, restrict, or impose
conditions on any use of the Licensed Material that could lawfully
be made without permission under this Public License.
b. To the extent possible, if any provision of this Public License is
deemed unenforceable, it shall be automatically reformed to the
minimum extent necessary to make it enforceable. If the provision
cannot be reformed, it shall be severed from this Public License
without affecting the enforceability of the remaining terms and
conditions.
c. No term or condition of this Public License will be waived and no
failure to comply consented to unless expressly agreed to by the
Licensor.
d. Nothing in this Public License constitutes or may be interpreted
as a limitation upon, or waiver of, any privileges and immunities
that apply to the Licensor or You, including from the legal
processes of any jurisdiction or authority.
=======================================================================
Creative Commons is not a party to its public
licenses. Notwithstanding, Creative Commons may elect to apply one of
its public licenses to material it publishes and in those instances
will be considered the “Licensor.” The text of the Creative Commons
public licenses is dedicated to the public domain under the CC0 Public
Domain Dedication. Except for the limited purpose of indicating that
material is shared under a Creative Commons public license or as
otherwise permitted by the Creative Commons policies published at
creativecommons.org/policies, Creative Commons does not authorize the
use of the trademark "Creative Commons" or any other trademark or logo
of Creative Commons without its prior written consent including,
without limitation, in connection with any unauthorized modifications
to any of its public licenses or any other arrangements,
understandings, or agreements concerning use of licensed material. For
the avoidance of doubt, this paragraph does not form part of the
public licenses.
Creative Commons may be contacted at creativecommons.org.

38
ckpts/README.md Normal file
View File

@@ -0,0 +1,38 @@
---
tags:
- robotics
---
# UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family
<p style="font-size: 1.2em;">
<a href="https://unigen-x.github.io/unifolm-world-model-action.github.io"><strong>Project Page</strong></a> |
<a href="https://github.com/unitreerobotics/unifolm-world-model-action"><strong>Code</strong></a> |
<a href="https://huggingface.co/unitreerobotics/datasets"><strong>Dataset</strong></a>
</p>
<div align="center">
<div align="justify">
<b>UnifoLM-WMA-0</b> is Unitrees first open-source world-modelaction architecture spanning multiple types of robotic embodiments, designed specifically for general-purpose robot learning. Its core component is a world-model capable of understanding the physical interactions between robots and the environments. This world-model provides two key functions: (a) <b>Simulation Engine</b> operates as an interactive simulator to generate synthetic data for robot learning; (b) <b>Policy Enhancement</b> connects with an action head and, by predicting future interaction processes with the world-model, further optimizes decision-making performance.
</div>
</div>
## 🦾 Real Robot Deployment
| <img src="assets/real_z1_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/real_dual_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
|:---:|:---:|
| <img src="assets/real_cleanup_pencils.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/real_g1_pack_camera.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
**Note: the top-right window shows the world models prediction of future environmental changes.**
## License
The model is released under the CC BY-NC-SA 4.0 license as found in the [LICENSE](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0/blob/main/LICENSE). You are responsible for ensuring that your use of Unitree AI Models complies with all applicable laws.
## Model Architecture
![Demo](assets/world_model_interaction.gif)
## Citation
```
@misc{unifolm-wma-0,
author = {Unitree},
title = {UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family},
year = {2025},
}
```

View File

@@ -0,0 +1,213 @@
model:
target: unifolm_wma.models.ddpms.LatentVisualDiffusion
params:
rescale_betas_zero_snr: True
parameterization: "v"
linear_start: 0.00085
linear_end: 0.012
num_timesteps_cond: 1
timesteps: 1000
first_stage_key: video
cond_stage_key: instruction
cond_stage_trainable: False
conditioning_key: hybrid
image_size: [40, 64]
channels: 4
scale_by_std: False
scale_factor: 0.18215
use_ema: False
uncond_type: 'empty_seq'
use_dynamic_rescale: true
base_scale: 0.7
fps_condition_type: 'fps'
perframe_ae: True
freeze_embedder: True
n_obs_steps_imagen: 1
n_obs_steps_acting: 1
agent_state_dim: 16
agent_action_dim: 16
###################### DP Related
input_pertub: 0.1
lr_scheduler: cosine
lr_warmup_steps: 2000
num_epochs: 30000
gradient_accumulate_every: 1
use_scheduler: True
dp_use_ema: True
dp_ema_config:
target: unifolm_wma.models.diffusion_head.ema_model.EMAModel
params:
update_after_step: 0
inv_gamma: 1.0
power: 0.75
min_value: 0.0
max_value: 0.9999
noise_scheduler_config:
target: diffusers.DDIMScheduler
params:
num_train_timesteps: 1000
beta_start: 0.0001
beta_end: 0.02
beta_schedule: squaredcos_cap_v2
clip_sample: True
set_alpha_to_one: True
steps_offset: 0
prediction_type: epsilon
dp_optimizer_config:
target: torch.optim.AdamW
params:
lr: 1.0e-4
betas: [0.95, 0.999]
eps: 1.0e-8
weight_decay: 1.0e-6
wma_config:
target: unifolm_wma.modules.networks.wma_model.WMAModel
params:
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions:
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
- 2
- 4
- 4
dropout: 0.1
num_head_channels: 64
transformer_depth: 1
context_dim: 1024
use_linear: true
use_checkpoint: True
temporal_conv: True
temporal_attention: True
temporal_selfatt_only: True
use_relative_position: False
use_causal_attention: False
temporal_length: 16
addition_attention: True
image_cross_attention: True
default_fs: 10
fs_condition: True
cross_attention_scale_learnable: False
n_obs_steps: ${model.params.n_obs_steps_imagen}
num_stem_token: 16
base_model_gen_only: True
unet_head_config:
target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D
params:
input_dim: ${model.params.agent_action_dim}
n_obs_steps: ${model.params.n_obs_steps_acting}
diffusion_step_embed_dim: 128
down_dims: [256, 512, 1024, 2048]
kernel_size: 5
n_groups: 8
cond_predict_scale: True
num_head_channels: ${model.params.wma_config.params.num_head_channels}
horizon: ${model.params.wma_config.params.temporal_length}
use_linear_attn: ${model.params.wma_config.params.use_linear}
use_linear_act_proj: True
act_proj_dim: 32
cond_cross_attention: False
context_dims: []
image_size: ${model.params.image_size}
imagen_cond_gradient: True
last_frame_only: False
use_imagen_mid_only: False
use_z_only: False
obs_encoder_config:
target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder
params:
rgb_model_config:
target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet
params:
name: resnet18
weights: null
resize_shape: null
crop_shape: null
random_crop: False
use_group_norm: True
share_rgb_model: False
imagenet_norm: True
use_spatial_softmax: True
spatial_softmax_kp: 128
###################### Action Tokenization
stem_process_config:
target: unifolm_wma.modules.encoders.condition.SATokenProjector
params:
dim: 1024
depth: 1
dim_head: 64
heads: 16
num_queries: ${model.params.wma_config.params.num_stem_token}
output_dim: 1024
ff_mult: 4
chunk_size: ${model.params.wma_config.params.temporal_length}
first_stage_config:
target: unifolm_wma.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
img_cond_stage_config:
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
params:
freeze: true
image_proj_stage_config:
target: unifolm_wma.modules.encoders.resampler.Resampler
params:
dim: 1024
depth: 4
dim_head: 64
heads: 12
num_queries: 16
embedding_dim: 1280
output_dim: 1024
ff_mult: 4
video_length: ${model.params.wma_config.params.temporal_length}
normalization_config:
input_shapes:
observation.state: ${model.params.wma_config.params.action_unet_config.params.input_dim}
input_normalization_modes:
observation.state: 'min_max'
output_shapes:
action: ${model.params.wma_config.params.action_unet_config.params.input_dim}
output_normalization_modes:
action: 'min_max'

View File

@@ -0,0 +1,240 @@
model:
target: unifolm_wma.models.ddpms.LatentVisualDiffusion
params:
rescale_betas_zero_snr: True
parameterization: "v"
linear_start: 0.00085
linear_end: 0.012
num_timesteps_cond: 1
timesteps: 1000
first_stage_key: video
cond_stage_key: instruction
cond_stage_trainable: False
conditioning_key: hybrid
image_size: [40, 64]
channels: 4
scale_by_std: False
scale_factor: 0.18215
use_ema: False
uncond_type: 'empty_seq'
use_dynamic_rescale: true
base_scale: 0.7
fps_condition_type: 'fps'
perframe_ae: True
freeze_embedder: True
n_obs_steps_imagen: 2
n_obs_steps_acting: 2
agent_state_dim: 16
agent_action_dim: 16
decision_making_only: True
###################### DP Related
input_pertub: 0.1
lr_scheduler: cosine
lr_warmup_steps: 2000
num_epochs: 30000
gradient_accumulate_every: 1
use_scheduler: True
dp_use_ema: True
dp_ema_config:
target: unifolm_wma.models.diffusion_head.ema_model.EMAModel
params:
update_after_step: 0
inv_gamma: 1.0
power: 0.75
min_value: 0.0
max_value: 0.9999
noise_scheduler_config:
target: diffusers.DDIMScheduler
params:
num_train_timesteps: 1000
beta_start: 0.0001
beta_end: 0.02
beta_schedule: squaredcos_cap_v2
clip_sample: True
set_alpha_to_one: True
steps_offset: 0
prediction_type: epsilon
dp_optimizer_config:
target: torch.optim.AdamW
params:
lr: 1.0e-4
betas: [0.95, 0.999]
eps: 1.0e-8
weight_decay: 1.0e-6
wma_config:
target: unifolm_wma.modules.networks.wma_model.WMAModel
params:
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions:
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
- 2
- 4
- 4
dropout: 0.1
num_head_channels: 64
transformer_depth: 1
context_dim: 1024
use_linear: true
use_checkpoint: True
temporal_conv: True
temporal_attention: True
temporal_selfatt_only: True
use_relative_position: False
use_causal_attention: False
temporal_length: 16
addition_attention: True
image_cross_attention: True
default_fs: 10
fs_condition: True
cross_attention_scale_learnable: False
n_obs_steps: ${model.params.n_obs_steps_imagen}
num_stem_token: 16
base_model_gen_only: False
unet_head_config:
target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D
params:
input_dim: ${model.params.agent_action_dim}
n_obs_steps: ${model.params.n_obs_steps_acting}
diffusion_step_embed_dim: 128
down_dims: [256, 512, 1024, 2048]
kernel_size: 5
n_groups: 8
cond_predict_scale: True
num_head_channels: ${model.params.wma_config.params.num_head_channels}
horizon: ${model.params.wma_config.params.temporal_length}
use_linear_attn: ${model.params.wma_config.params.use_linear}
use_linear_act_proj: True
act_proj_dim: 32
cond_cross_attention: False
context_dims: []
image_size: ${model.params.image_size}
imagen_cond_gradient: True
last_frame_only: False
use_imagen_mid_only: False
use_z_only: False
obs_encoder_config:
target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder
params:
rgb_model_config:
target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet
params:
name: resnet18
weights: null
resize_shape: null
crop_shape: null
random_crop: False
use_group_norm: True
share_rgb_model: False
imagenet_norm: True
use_spatial_softmax: True
spatial_softmax_kp: 128
###################### Action Tokenization
stem_process_config:
target: unifolm_wma.modules.encoders.condition.SATokenProjector
params:
dim: 1024
depth: 1
dim_head: 64
heads: 16
num_queries: ${model.params.wma_config.params.num_stem_token}
output_dim: 1024
ff_mult: 4
chunk_size: ${model.params.wma_config.params.temporal_length}
first_stage_config:
target: unifolm_wma.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
img_cond_stage_config:
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
params:
freeze: true
image_proj_stage_config:
target: unifolm_wma.modules.encoders.resampler.Resampler
params:
dim: 1024
depth: 4
dim_head: 64
heads: 12
num_queries: 16
embedding_dim: 1280
output_dim: 1024
ff_mult: 4
video_length: ${model.params.wma_config.params.temporal_length}
normalization_config:
input_shapes:
observation.state: ${model.params.wma_config.params.action_unet_config.params.input_dim}
input_normalization_modes:
observation.state: 'min_max'
output_shapes:
action: ${model.params.wma_config.params.action_unet_config.params.input_dim}
output_normalization_modes:
action: 'min_max'
data:
target: unifolm_wma.utils.data.DataModuleFromConfig
params:
batch_size: 6
num_workers: 12
wrap: False
test:
target: unifolm_wma.data.wma_data.WMAData
params:
data_dir: '/path/to/the/dataset/directory/that/contains/the/meta/folder/of/the/testing/case/under/a/transitions/folder' # e.g., /path/to/unifolm-world-model-action/examples/world_model_interaction_prompts
video_length: ${model.params.wma_config.params.temporal_length}
frame_stride: 2
load_raw_resolution: True
resolution: [320, 512]
spatial_transform: resize_center_crop
crop_resolution: [320, 512]
random_fs: False
cond_robot_label_prob: 0.0
normalization_mode: 'min_max'
individual_normalization: True
n_obs_steps: ${model.params.n_obs_steps_imagen}
max_action_dim: ${model.params.agent_action_dim}
max_state_dim: ${model.params.agent_state_dim}
dataset_and_weights:
unitree_g1_pack_camera: 1.0

View File

@@ -0,0 +1,244 @@
model:
target: unifolm_wma.models.ddpms.LatentVisualDiffusion
params:
rescale_betas_zero_snr: True
parameterization: "v"
linear_start: 0.00085
linear_end: 0.012
num_timesteps_cond: 1
timesteps: 1000
first_stage_key: video
cond_stage_key: instruction
cond_stage_trainable: False
conditioning_key: hybrid
image_size: [40, 64]
channels: 4
scale_by_std: False
scale_factor: 0.18215
use_ema: False
uncond_type: 'empty_seq'
use_dynamic_rescale: true
base_scale: 0.7
fps_condition_type: 'fps'
perframe_ae: True
freeze_embedder: True
n_obs_steps_imagen: 2
n_obs_steps_acting: 2
agent_state_dim: 16
agent_action_dim: 16
decision_making_only: False
###################### DP Related
input_pertub: 0.1
lr_scheduler: cosine
lr_warmup_steps: 2000
num_epochs: 30000
gradient_accumulate_every: 1
use_scheduler: True
dp_use_ema: True
dp_ema_config:
target: unifolm_wma.models.diffusion_head.ema_model.EMAModel
params:
update_after_step: 0
inv_gamma: 1.0
power: 0.75
min_value: 0.0
max_value: 0.9999
noise_scheduler_config:
target: diffusers.DDIMScheduler
params:
num_train_timesteps: 1000
beta_start: 0.0001
beta_end: 0.02
beta_schedule: squaredcos_cap_v2
clip_sample: True
set_alpha_to_one: True
steps_offset: 0
prediction_type: epsilon
dp_optimizer_config:
target: torch.optim.AdamW
params:
lr: 1.0e-4
betas: [0.95, 0.999]
eps: 1.0e-8
weight_decay: 1.0e-6
wma_config:
target: unifolm_wma.modules.networks.wma_model.WMAModel
params:
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions:
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
- 2
- 4
- 4
dropout: 0.1
num_head_channels: 64
transformer_depth: 1
context_dim: 1024
use_linear: true
use_checkpoint: True
temporal_conv: True
temporal_attention: True
temporal_selfatt_only: True
use_relative_position: False
use_causal_attention: False
temporal_length: 16
addition_attention: True
image_cross_attention: True
default_fs: 10
fs_condition: True
cross_attention_scale_learnable: False
n_obs_steps: ${model.params.n_obs_steps_imagen}
num_stem_token: 16
base_model_gen_only: False
unet_head_config:
target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D
params:
input_dim: ${model.params.agent_action_dim}
n_obs_steps: ${model.params.n_obs_steps_acting}
diffusion_step_embed_dim: 128
down_dims: [256, 512, 1024, 2048]
kernel_size: 5
n_groups: 8
cond_predict_scale: True
num_head_channels: ${model.params.wma_config.params.num_head_channels}
horizon: ${model.params.wma_config.params.temporal_length}
use_linear_attn: ${model.params.wma_config.params.use_linear}
use_linear_act_proj: True
act_proj_dim: 32
cond_cross_attention: False
context_dims: []
image_size: ${model.params.image_size}
imagen_cond_gradient: True
last_frame_only: False
use_imagen_mid_only: False
use_z_only: False
obs_encoder_config:
target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder
params:
rgb_model_config:
target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet
params:
name: resnet18
weights: null
resize_shape: null
crop_shape: null
random_crop: False
use_group_norm: True
share_rgb_model: False
imagenet_norm: True
use_spatial_softmax: True
spatial_softmax_kp: 128
###################### Action Tokenization
stem_process_config:
target: unifolm_wma.modules.encoders.condition.SATokenProjector
params:
dim: 1024
depth: 1
dim_head: 64
heads: 16
num_queries: ${model.params.wma_config.params.num_stem_token}
output_dim: 1024
ff_mult: 4
chunk_size: ${model.params.wma_config.params.temporal_length}
first_stage_config:
target: unifolm_wma.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
img_cond_stage_config:
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
params:
freeze: true
image_proj_stage_config:
target: unifolm_wma.modules.encoders.resampler.Resampler
params:
dim: 1024
depth: 4
dim_head: 64
heads: 12
num_queries: 16
embedding_dim: 1280
output_dim: 1024
ff_mult: 4
video_length: ${model.params.wma_config.params.temporal_length}
normalization_config:
input_shapes:
observation.state: ${model.params.wma_config.params.action_unet_config.params.input_dim}
input_normalization_modes:
observation.state: 'min_max'
output_shapes:
action: ${model.params.wma_config.params.action_unet_config.params.input_dim}
output_normalization_modes:
action: 'min_max'
data:
target: unifolm_wma.utils.data.DataModuleFromConfig
params:
batch_size: 6
num_workers: 12
wrap: False
test:
target: unifolm_wma.data.wma_data.WMAData
params:
data_dir: '/home/dyz/unifolm-world-model-action/examples/world_model_interaction_prompts'
video_length: ${model.params.wma_config.params.temporal_length}
frame_stride: 2
load_raw_resolution: True
resolution: [320, 512]
spatial_transform: resize_center_crop
crop_resolution: [320, 512]
random_fs: False
cond_robot_label_prob: 0.0
normalization_mode: 'min_max'
individual_normalization: True
n_obs_steps: ${model.params.n_obs_steps_imagen}
max_action_dim: ${model.params.agent_action_dim}
max_state_dim: ${model.params.agent_state_dim}
dataset_and_weights:
unitree_z1_stackbox: 0.2
unitree_z1_dual_arm_stackbox: 0.2
unitree_z1_dual_arm_stackbox_v2: 0.2
unitree_z1_dual_arm_cleanup_pencils: 0.2
unitree_g1_pack_camera: 0.2

287
configs/train/config.yaml Normal file
View File

@@ -0,0 +1,287 @@
model:
pretrained_checkpoint: /path/to/pretrained/checkpoint
base_learning_rate: 1.0e-05
scale_lr: False
target: unifolm_wma.models.ddpms.LatentVisualDiffusion
params:
rescale_betas_zero_snr: True
parameterization: "v"
linear_start: 0.00085
linear_end: 0.012
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: video
cond_stage_key: instruction
cond_stage_trainable: False
image_proj_model_trainable: True
conditioning_key: hybrid
image_size: [40, 64]
channels: 4
scale_by_std: False
scale_factor: 0.18215
use_ema: False
uncond_prob: 0.05
uncond_type: 'empty_seq'
rand_cond_frame: false
use_dynamic_rescale: true
base_scale: 0.7
fps_condition_type: 'fps'
perframe_ae: True
freeze_embedder: True
n_obs_steps_imagen: 2
n_obs_steps_acting: 2
agent_state_dim: 16
agent_action_dim: 16
decision_making_only: True
###################### DP Related
input_pertub: 0.1
lr_scheduler: cosine
lr_warmup_steps: 2000
num_epochs: 60000
gradient_accumulate_every: 1
use_scheduler: True
dp_use_ema: True
dp_ema_config:
target: unifolm_wma.models.diffusion_head.ema_model.EMAModel
params:
update_after_step: 0
inv_gamma: 1.0
power: 0.75
min_value: 0.0
max_value: 0.9999
noise_scheduler_config:
target: diffusers.DDIMScheduler
params:
num_train_timesteps: 1000
beta_start: 0.0001
beta_end: 0.02
beta_schedule: squaredcos_cap_v2
clip_sample: True
set_alpha_to_one: True
steps_offset: 0
prediction_type: epsilon
dp_optimizer_config:
target: torch.optim.AdamW
params:
lr: 1.0e-4
betas: [0.95, 0.999]
eps: 1.0e-8
weight_decay: 1.0e-6
wma_config:
target: unifolm_wma.modules.networks.wma_model.WMAModel
params:
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions:
- 4
- 2
- 1
num_res_blocks: 2
channel_mult:
- 1
- 2
- 4
- 4
dropout: 0.1
num_head_channels: 64
transformer_depth: 1
context_dim: 1024
use_linear: true
use_checkpoint: True
temporal_conv: True
temporal_attention: True
temporal_selfatt_only: True
use_relative_position: False
use_causal_attention: False
temporal_length: 16
addition_attention: True
image_cross_attention: True
default_fs: 10
fs_condition: True
cross_attention_scale_learnable: False
n_obs_steps: ${model.params.n_obs_steps_imagen}
num_stem_token: 16
base_model_gen_only: False
unet_head_config:
target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D
params:
input_dim: ${model.params.agent_action_dim}
n_obs_steps: ${model.params.n_obs_steps_acting}
diffusion_step_embed_dim: 128
down_dims: [256, 512, 1024, 2048]
kernel_size: 5
n_groups: 8
cond_predict_scale: True
num_head_channels: ${model.params.wma_config.params.num_head_channels}
horizon: ${model.params.wma_config.params.temporal_length}
use_linear_attn: ${model.params.wma_config.params.use_linear}
use_linear_act_proj: True
act_proj_dim: 32
cond_cross_attention: False
context_dims: []
image_size: ${model.params.image_size}
imagen_cond_gradient: True
last_frame_only: False
use_imagen_mid_only: False
use_z_only: False
obs_encoder_config:
target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder
params:
rgb_model_config:
target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet
params:
name: resnet18
weights: null
resize_shape: null
crop_shape: null
random_crop: False
use_group_norm: True
share_rgb_model: False
imagenet_norm: True
use_spatial_softmax: True
spatial_softmax_kp: 128
###################### Action Tokenization
stem_process_config:
target: unifolm_wma.modules.encoders.condition.SATokenProjector
params:
dim: 1024
depth: 1
dim_head: 64
heads: 16
num_queries: ${model.params.wma_config.params.num_stem_token}
output_dim: 1024
ff_mult: 4
chunk_size: ${model.params.wma_config.params.temporal_length}
first_stage_config:
target: unifolm_wma.models.autoencoder.AutoencoderKL
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
cond_stage_config:
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"
img_cond_stage_config:
target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
params:
freeze: true
image_proj_stage_config:
target: unifolm_wma.modules.encoders.resampler.Resampler
params:
dim: 1024
depth: 4
dim_head: 64
heads: 12
num_queries: 16
embedding_dim: 1280
output_dim: 1024
ff_mult: 4
video_length: ${model.params.wma_config.params.temporal_length}
normalization_config:
input_shapes:
observation.state: ${model.params.wma_config.params.unet_head_config.params.input_dim}
input_normalization_modes:
observation.state: 'min_max'
output_shapes:
action: ${model.params.wma_config.params.unet_head_config.params.input_dim}
output_normalization_modes:
action: 'min_max'
data:
target: unifolm_wma.utils.data.DataModuleFromConfig
params:
batch_size: 8
num_workers: 12
wrap: False
train:
target: unifolm_wma.data.wma_data.WMAData
params:
data_dir: '/path/to/training/dataset/directory'
video_length: ${model.params.wma_config.params.temporal_length}
frame_stride: 2
load_raw_resolution: True
resolution: [320, 512]
spatial_transform: resize_center_crop
crop_resolution: [320, 512]
random_fs: False
cond_robot_label_prob: 0.0
normalization_mode: 'min_max'
individual_normalization: True
n_obs_steps: ${model.params.n_obs_steps_imagen}
max_action_dim: ${model.params.agent_action_dim}
max_state_dim: ${model.params.agent_state_dim}
dataset_and_weights:
unitree_z1_stackbox: 0.2
unitree_z1_dual_arm_stackbox: 0.2
unitree_z1_dual_arm_stackbox_v2: 0.2
unitree_z1_dual_arm_cleanup_pencils: 0.2
unitree_g1_pack_camera: 0.2
lightning:
precision: 16
trainer:
benchmark: True
accumulate_grad_batches: 2
max_steps: 300000
log_every_n_steps: 50
val_check_interval: 1.0
gradient_clip_algorithm: 'norm'
gradient_clip_val: 0.5
enable_model_summary: False
callbacks:
model_checkpoint:
target: pytorch_lightning.callbacks.ModelCheckpoint
params:
every_n_train_steps: 1000
filename: "{epoch}-{step}"
save_weights_only: True
metrics_over_trainsteps_checkpoint:
target: pytorch_lightning.callbacks.ModelCheckpoint
params:
filename: '{epoch}-{step}'
save_weights_only: True
every_n_train_steps: 10000
batch_logger:
target: unifolm_wma.utils.callbacks.ImageLogger
params:
batch_frequency: 20000
to_local: False
max_images: 8
log_images_kwargs:
ddim_steps: 16
unconditional_guidance_scale: 1.0
timestep_spacing: uniform_trailing
guidance_rescale: 0.7

Binary file not shown.

After

Width:  |  Height:  |  Size: 162 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 258 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 256 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 85 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 85 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 82 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 257 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 80 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 41 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 84 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

View File

@@ -0,0 +1,17 @@
videoid,instruction,fps,start_idx,fs,num_gen
0,wash the pan,16.0,152.0,2.0,6.0
1,pick up the blue cup and put it into the brown cup. ,5.0,0.0,2.0,4.0
2,close top drawer,3.0,0.0,2.0,1.0
3,Close the laptop.,10.0,0.0,2.0,4.0
4,destack cube,5.0,0.0,2.0,3.0
5,arrange plate and fork,20.0,40.0,2.0,4.0
6,Place the lid on the teapot,15.0,30.0,1.0,4.0
7,Pick up the green object and insert it.,10.0,0.0,2.0,4.0
8,place the burger meat in the oven,10.0,0.0,2.0,2.0
9,make a cup of coffee with the keurig machine,10.0,0.0,2.0,4.0
10,assemble one_leg,10.0,0.0,2.0,7.0
11,get the cloth and wipe up the spill under the wine glass,8.0,669.0,2.0,3.0
12,palce dishes in the dish rack,10.0,0.0,2.0,4.0
13,move redbull can near green can,3.0,3.0,2.0,1.0
14,open the drawer,5.0,5.0,1.0,2.0
15,sweep the green cloth to the left side of the table,5.0,0.0,2.0,3.0
1 videoid instruction fps start_idx fs num_gen
2 0 wash the pan 16.0 152.0 2.0 6.0
3 1 pick up the blue cup and put it into the brown cup. 5.0 0.0 2.0 4.0
4 2 close top drawer 3.0 0.0 2.0 1.0
5 3 Close the laptop. 10.0 0.0 2.0 4.0
6 4 destack cube 5.0 0.0 2.0 3.0
7 5 arrange plate and fork 20.0 40.0 2.0 4.0
8 6 Place the lid on the teapot 15.0 30.0 1.0 4.0
9 7 Pick up the green object and insert it. 10.0 0.0 2.0 4.0
10 8 place the burger meat in the oven 10.0 0.0 2.0 2.0
11 9 make a cup of coffee with the keurig machine 10.0 0.0 2.0 4.0
12 10 assemble one_leg 10.0 0.0 2.0 7.0
13 11 get the cloth and wipe up the spill under the wine glass 8.0 669.0 2.0 3.0
14 12 palce dishes in the dish rack 10.0 0.0 2.0 4.0
15 13 move redbull can near green can 3.0 3.0 2.0 1.0
16 14 open the drawer 5.0 5.0 1.0 2.0
17 15 sweep the green cloth to the left side of the table 5.0 0.0 2.0 3.0

Binary file not shown.

After

Width:  |  Height:  |  Size: 153 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 134 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 286 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 161 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

View File

@@ -0,0 +1,2 @@
videoid,contentUrl,duration,data_dir,instruction,dynamic_confidence,dynamic_wording,dynamic_source_category,embodiment,fps
0,x,x,unitree_g1_pack_camera,Pack black camera into box.,x,x,x,Unitree G1 Robot with Gripper,30
1 videoid contentUrl duration data_dir instruction dynamic_confidence dynamic_wording dynamic_source_category embodiment fps
2 0 x x unitree_g1_pack_camera Pack black camera into box. x x x Unitree G1 Robot with Gripper 30

View File

@@ -0,0 +1,2 @@
videoid,contentUrl,duration,data_dir,instruction,dynamic_confidence,dynamic_wording,dynamic_source_category,embodiment,fps
0,x,x,unitree_z1_dual_arm_cleanup_pencils,clean up eraser and pencils,x,x,x,Unitree Z1 Robot Dual-Arm,30
1 videoid contentUrl duration data_dir instruction dynamic_confidence dynamic_wording dynamic_source_category embodiment fps
2 0 x x unitree_z1_dual_arm_cleanup_pencils clean up eraser and pencils x x x Unitree Z1 Robot Dual-Arm 30

View File

@@ -0,0 +1,2 @@
videoid,contentUrl,duration,data_dir,instruction,dynamic_confidence,dynamic_wording,dynamic_source_category,embodiment,fps
0,x,x,unitree_z1_dual_arm_stackbox,"Stack the blocks in the rectangular block: red at the bottom, yellow in the middle, green on top.",x,x,x,Unitree Z1 Robot Dual-Arm,30
1 videoid contentUrl duration data_dir instruction dynamic_confidence dynamic_wording dynamic_source_category embodiment fps
2 0 x x unitree_z1_dual_arm_stackbox Stack the blocks in the rectangular block: red at the bottom, yellow in the middle, green on top. x x x Unitree Z1 Robot Dual-Arm 30

View File

@@ -0,0 +1,2 @@
videoid,contentUrl,duration,data_dir,instruction,dynamic_confidence,dynamic_wording,dynamic_source_category,embodiment,fps
0,x,x,unitree_z1_dual_arm_stackbox_v2,"Stack the blocks in the rectangular block: red at the bottom, yellow in the middle, green on top",x,x,x,Unitree Z1 Robot Dual-Arm,30
1 videoid contentUrl duration data_dir instruction dynamic_confidence dynamic_wording dynamic_source_category embodiment fps
2 0 x x unitree_z1_dual_arm_stackbox_v2 Stack the blocks in the rectangular block: red at the bottom, yellow in the middle, green on top x x x Unitree Z1 Robot Dual-Arm 30

View File

@@ -0,0 +1,2 @@
videoid,contentUrl,duration,data_dir,instruction,dynamic_confidence,dynamic_wording,dynamic_source_category,embodiment,fps
0,x,x,unitree_z1_stackbox,"Stack the blocks in the rectangular block: red at the bottom, yellow in the middle, green on top.",x,x,x,Unitree Z1 Robot Arm,30
1 videoid contentUrl duration data_dir instruction dynamic_confidence dynamic_wording dynamic_source_category embodiment fps
2 0 x x unitree_z1_stackbox Stack the blocks in the rectangular block: red at the bottom, yellow in the middle, green on top. x x x Unitree Z1 Robot Arm 30

1
external/dlimp vendored Submodule

Submodule external/dlimp added at 5edaa46915

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,199 @@
import json
import os
import shutil
import h5py
import argparse
import pandas as pd
import torch
import subprocess
from pathlib import Path
from safetensors.torch import save_file
from tqdm import tqdm
def flatten_dict(d, parent_key="", sep="/"):
"""Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
For example:
```
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
>>> print(flatten_dict(dct))
{"a/b": 1, "a/c/d": 2, "e": 3}
"""
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
def is_av1(file_path):
try:
result = subprocess.run([
"ffprobe", "-v", "error", "-select_streams", "v:0",
"-show_entries", "stream=codec_name", "-of", "csv=p=0",
str(file_path)
],
capture_output=True,
text=True,
check=True)
return result.stdout.strip() == "av1"
except subprocess.CalledProcessError:
return False
def convert_to_h264(input_path, output_path):
subprocess.run([
"ffmpeg", "-i",
str(input_path), "-c:v", "libx264", "-preset", "slow", "-crf", "23",
"-c:a", "copy",
str(output_path)
],
check=True)
def main(args):
source_dir = Path(args.source_dir)
source_data_dir = source_dir / args.dataset_name / "data" / "chunk-000"
source_meta_dir = source_dir / args.dataset_name / "meta"
source_videos_dir = source_dir / args.dataset_name / "videos" / "chunk-000"
target_dir = Path(args.target_dir)
target_videos_dir = target_dir / "videos" / args.dataset_name
target_transitions_dir = target_dir / "transitions" / args.dataset_name
target_meta_dir = target_dir / "transitions" / args.dataset_name / "meta_data"
target_dir.mkdir(parents=True, exist_ok=True)
target_videos_dir.mkdir(parents=True, exist_ok=True)
target_transitions_dir.mkdir(parents=True, exist_ok=True)
target_meta_dir.mkdir(parents=True, exist_ok=True)
csv_file = target_dir / f"{args.dataset_name}.csv"
COLUMNS = [
'videoid', 'contentUrl', 'duration', 'data_dir', 'instruction',
'dynamic_confidence', 'dynamic_wording', 'dynamic_source_category',
'embodiment'
]
df = pd.DataFrame(columns=COLUMNS)
# Load info.json from source dir
info_json_path = source_meta_dir / "info.json"
with open(str(info_json_path), "r") as f:
info = json.load(f)
total_episodes = info['total_episodes']
# Load task.jsonl to get lanugage ins
tasks_jsonl_path = source_meta_dir / "tasks.jsonl"
with open(str(tasks_jsonl_path), "r") as f:
tasks = [json.loads(line) for line in f]
instruction = tasks[0]['task']
source_video_views = [d for d in source_videos_dir.iterdir()]
for v_idx, source_view_dir in enumerate(source_video_views):
view_name = source_view_dir.name
target_videos_view_dir = target_videos_dir / view_name
target_videos_view_dir.mkdir(parents=True, exist_ok=True)
if v_idx == 0:
all_actions = []
all_states = []
for idx in tqdm(range(total_episodes)):
# Copy source video to target vidoe dir
source_video = source_view_dir / f"episode_{idx:06d}.mp4"
if is_av1(source_video):
output_video = str(target_videos_view_dir / f"{idx}.mp4")
print(f"Converting episode_{idx:06d}.mp4 to H.264...")
convert_to_h264(source_video, output_video)
else:
print(f"Skipping episode_{idx:06d}.mp4: not AV1 encoded.")
# Load parquet file
episode_parquet_file = source_data_dir / f"episode_{idx:06d}.parquet"
episode_data = pd.read_parquet(episode_parquet_file)
actions = torch.tensor(episode_data['action'].tolist())
states = torch.tensor(episode_data['observation.state'].tolist())
# Save action and state into a h5 file
if v_idx == 0:
target_h5_file = target_transitions_dir / f"{idx}.h5"
with h5py.File(str(target_h5_file), 'w') as h5f:
h5f.create_dataset('observation.state', data=states)
h5f.create_dataset('action', data=actions)
h5f.attrs['action_type'] = 'joint position'
h5f.attrs['state_type'] = 'joint position'
h5f.attrs['robot_type'] = args.robot_name
# Updata df
df = pd.concat([
df,
pd.DataFrame([{
'videoid': idx,
'contentUrl': 'x',
'duration': 'x',
'data_dir': args.dataset_name + f"/{view_name}",
'instruction': instruction,
'dynamic_confidence': 'x',
'dynamic_wording': 'x',
'dynamic_source_category': 'x',
'embodiment': args.robot_name
}])
],
ignore_index=True)
# Collect action and state
if v_idx == 0:
all_actions.append(actions)
all_states.append(states)
# Create satas.safetensors
actions = torch.cat(all_actions, dim=0)
states = torch.cat(all_states, dim=0)
stats = {'action': {}, 'observation.state': {}}
stats['action']['max'] = actions.max(dim=0).values
stats['action']['min'] = actions.min(dim=0).values
stats['action']['mean'] = actions.mean(dim=0)
stats['action']['std'] = actions.std(dim=0)
stats['observation.state']['max'] = states.max(dim=0).values
stats['observation.state']['min'] = states.min(dim=0).values
stats['observation.state']['mean'] = states.mean(dim=0)
stats['observation.state']['std'] = states.std(dim=0)
flattened_stats = flatten_dict(stats)
target_stats_file = target_meta_dir / "stats.safetensors"
save_file(flattened_stats, target_stats_file)
df.to_csv(csv_file, index=False)
print(f">>> Finished create {args.dataset_name} dataset ...")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--source_dir',
action='store',
type=str,
help='The dataset dir under lerobot 2.0 data format.',
required=True)
parser.add_argument('--target_dir',
action='store',
type=str,
default='./data',
help='The target dir to save new formatted dataset.')
parser.add_argument('--dataset_name',
action='store',
type=str,
help='dataset name',
required=True)
parser.add_argument('--robot_name',
action='store',
type=str,
help='robot name',
required=True)
main(parser.parse_args())

53
pyproject.toml Executable file
View File

@@ -0,0 +1,53 @@
[project]
name = "unifolm_wma"
version = "0.0.1"
description = "UnifoLM-WMA-0"
license = { text = "BSD-3-Clause" }
authors = [
{name="Unitree Embodied AI R&D Team", email="rd_xyc@unitree.com" }
]
requires-python = "==3.10.18"
dependencies = [
"decord==0.6.0",
"einops==0.8.0",
"imageio==2.35.1",
"numpy==1.24.2",
"omegaconf==2.3.0",
"opencv-python==4.10.0.84",
"pandas==2.0.0",
"pillow==9.5.0",
"pytorch-lightning==1.9.3",
"pyyaml==6.0",
"setuptools==65.6.3",
"torch==2.3.1",
"torchvision==0.18.1",
"tqdm==4.66.5",
"transformers==4.40.1",
"moviepy==1.0.3",
"av==12.3.0",
"xformers==0.0.27",
"gradio==4.39.0",
"timm==0.9.10",
"scikit-learn==1.5.1",
"open-clip-torch==2.22.0",
"kornia==0.7.3",
"diffusers==0.30.2",
"termcolor==2.4.0",
"draccus==0.11.5",
"accelerate==1.7.0",
"tensorflow-metadata==1.16.1",
"protobuf==3.20.3",
"datasets==3.6.0",
"tensorflow-graphics==2021.12.3",
"fairscale==0.4.13"
]
[build-system]
requires = ["setuptools>=65.6.3", "wheel"]
build-backend = "setuptools.build_meta"
[tool.setuptools]
package-dir = { "" = "src" }
[tool.setuptools.packages.find]
where = ["src"]

114
run_all_cases.sh Executable file
View File

@@ -0,0 +1,114 @@
#!/bin/bash
# 自动执行所有场景的所有case
# 总共5个场景每个场景4个case共20个case
# 设置环境变量(离线模式)
export HF_HUB_OFFLINE=1
export TRANSFORMERS_OFFLINE=1
# 颜色定义
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
# 定义所有场景
SCENARIOS=(
"unitree_g1_pack_camera"
"unitree_z1_dual_arm_cleanup_pencils"
"unitree_z1_dual_arm_stackbox"
"unitree_z1_dual_arm_stackbox_v2"
"unitree_z1_stackbox"
)
# 定义case数量
CASES=(1 2 3 4)
# 记录开始时间
START_TIME=$(date +%s)
LOG_FILE="run_all_cases_$(date +%Y%m%d_%H%M%S).log"
echo -e "${BLUE}========================================${NC}"
echo -e "${BLUE}开始执行所有场景的case${NC}"
echo -e "${BLUE}总共: ${#SCENARIOS[@]} 个场景 x ${#CASES[@]} 个case = $((${#SCENARIOS[@]} * ${#CASES[@]})) 个任务${NC}"
echo -e "${BLUE}日志文件: ${LOG_FILE}${NC}"
echo -e "${BLUE}========================================${NC}"
echo ""
# 初始化计数器
TOTAL_CASES=$((${#SCENARIOS[@]} * ${#CASES[@]}))
CURRENT_CASE=0
SUCCESS_COUNT=0
FAIL_COUNT=0
# 记录失败的case
declare -a FAILED_CASES
# 遍历所有场景
for scenario in "${SCENARIOS[@]}"; do
echo -e "${YELLOW}>>> 场景: ${scenario}${NC}"
# 遍历所有case
for case_num in "${CASES[@]}"; do
CURRENT_CASE=$((CURRENT_CASE + 1))
case_dir="${scenario}/case${case_num}"
script_path="${case_dir}/run_world_model_interaction.sh"
echo -e "${BLUE}[${CURRENT_CASE}/${TOTAL_CASES}] 执行: ${case_dir}${NC}"
# 检查脚本是否存在
if [ ! -f "${script_path}" ]; then
echo -e "${RED}错误: 脚本不存在 ${script_path}${NC}"
FAIL_COUNT=$((FAIL_COUNT + 1))
FAILED_CASES+=("${case_dir} (脚本不存在)")
continue
fi
# 执行脚本
echo "开始时间: $(date '+%Y-%m-%d %H:%M:%S')"
if bash "${script_path}" >> "${LOG_FILE}" 2>&1; then
echo -e "${GREEN}✓ 成功: ${case_dir}${NC}"
SUCCESS_COUNT=$((SUCCESS_COUNT + 1))
else
echo -e "${RED}✗ 失败: ${case_dir}${NC}"
FAIL_COUNT=$((FAIL_COUNT + 1))
FAILED_CASES+=("${case_dir}")
fi
echo "结束时间: $(date '+%Y-%m-%d %H:%M:%S')"
echo ""
done
echo ""
done
# 计算总耗时
END_TIME=$(date +%s)
DURATION=$((END_TIME - START_TIME))
HOURS=$((DURATION / 3600))
MINUTES=$(((DURATION % 3600) / 60))
SECONDS=$((DURATION % 60))
# 输出总结
echo -e "${BLUE}========================================${NC}"
echo -e "${BLUE}执行完成!${NC}"
echo -e "${BLUE}========================================${NC}"
echo -e "总任务数: ${TOTAL_CASES}"
echo -e "${GREEN}成功: ${SUCCESS_COUNT}${NC}"
echo -e "${RED}失败: ${FAIL_COUNT}${NC}"
echo -e "总耗时: ${HOURS}小时 ${MINUTES}分钟 ${SECONDS}"
echo -e "详细日志: ${LOG_FILE}"
echo ""
# 如果有失败的case列出来
if [ ${FAIL_COUNT} -gt 0 ]; then
echo -e "${RED}失败的case列表:${NC}"
for failed_case in "${FAILED_CASES[@]}"; do
echo -e "${RED} - ${failed_case}${NC}"
done
echo ""
fi
echo -e "${BLUE}========================================${NC}"

View File

@@ -0,0 +1,541 @@
import argparse, os, glob
import datetime, time
import pandas as pd
import torch
import torchvision
import torchvision.transforms as transforms
import random
from pytorch_lightning import seed_everything
from PIL import Image
from omegaconf import OmegaConf
from tqdm import tqdm
from einops import rearrange, repeat
from collections import OrderedDict
from unifolm_wma.models.samplers.ddim import DDIMSampler
from unifolm_wma.utils.utils import instantiate_from_config
def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
"""
Get list of files in `data_dir` with extensions in `postfixes`.
Args:
data_dir (str): Directory path.
postfixes (list[str]): List of file extensions (e.g., ['csv', 'jpg']).
Returns:
list[str]: Sorted list of matched file paths.
"""
patterns = [
os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes
]
file_list = []
for pattern in patterns:
file_list.extend(glob.glob(pattern))
file_list.sort()
return file_list
def load_model_checkpoint(model: torch.nn.Module,
ckpt: str) -> torch.nn.Module:
"""
Load model weights from checkpoint file.
Args:
model (torch.nn.Module): The model to load weights into.
ckpt (str): Path to the checkpoint file.
Returns:
torch.nn.Module: Model with weights loaded.
"""
state_dict = torch.load(ckpt, map_location="cpu")
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
try:
loaded = model.load_state_dict(state_dict, strict=False)
print("Missing keys:")
for k in loaded.missing_keys:
print(f" {k}")
print("Unexpected keys:")
for k in loaded.unexpected_keys:
print(f" {k}")
except:
# Rename the keys for 256x256 model
new_pl_sd = OrderedDict()
for k, v in state_dict.items():
new_pl_sd[k] = v
for k in list(new_pl_sd.keys()):
if "framestride_embed" in k:
new_key = k.replace("framestride_embed", "fps_embedding")
new_pl_sd[new_key] = new_pl_sd[k]
del new_pl_sd[k]
model.load_state_dict(new_pl_sd, strict=False)
else:
new_pl_sd = OrderedDict()
for key in state_dict['module'].keys():
new_pl_sd[key[16:]] = state_dict['module'][key]
model.load_state_dict(new_pl_sd)
print('>>> model checkpoint loaded.')
return model
def load_prompts(prompt_file: str) -> list[str]:
"""
Load prompts from a text file, one per line.
Args:
prompt_file (str): Path to the prompt file.
Returns:
list[str]: List of prompt strings.
"""
f = open(prompt_file, 'r')
prompt_list = []
for idx, line in enumerate(f.readlines()):
l = line.strip()
if len(l) != 0:
prompt_list.append(l)
f.close()
return prompt_list
def load_data_prompts(
data_dir: str,
savedir: str,
video_size: tuple[int, int] = (256, 256),
video_frames: int = 16
) -> tuple[list[str], list[torch.Tensor], list[str], list[float], list[float],
list[int]]:
"""
Load image prompts, process them into video format, and retrieve metadata.
Args:
data_dir (str): Directory containing images and CSV file.
savedir (str): Output directory to check if inference was already done.
video_size (tuple[int, int], optional): Target size of video frames.
video_frames (int, optional): Number of frames in each video.
Returns:
tuple: (filenames, video tensors, prompts, fps values, fs values, num_generations)
"""
transform = transforms.Compose([
transforms.Resize(min(video_size)),
transforms.CenterCrop(video_size),
transforms.ToTensor(),
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
# Load prompt csv
prompt_file = get_filelist(data_dir, ['csv'])
assert len(prompt_file) > 0, "Error: found NO image prompt file!"
# Load image prompts
file_list = get_filelist(data_dir, ['jpg', 'png', 'jpeg', 'JPEG', 'PNG'])
data_list = []
filename_list = []
prompt_list = []
fps_list = []
fs_list = []
num_gen_list = []
prompt_csv = pd.read_csv(prompt_file[0])
n_samples = len(file_list)
for idx in range(n_samples):
image = Image.open(file_list[idx]).convert('RGB')
image_tensor = transform(image).unsqueeze(1)
frame_tensor = repeat(image_tensor,
'c t h w -> c (repeat t) h w',
repeat=video_frames)
_, filename = os.path.split(file_list[idx])
if not is_inferenced(savedir, filename):
video_id = filename[:-4]
prompt_csv['videoid'] = prompt_csv['videoid'].map(str)
if not (prompt_csv['videoid'] == video_id).any():
continue
data_list.append(frame_tensor)
filename_list.append(filename)
ins = prompt_csv[prompt_csv['videoid'] ==
video_id]['instruction'].values[0]
prompt_list.append(ins)
fps = prompt_csv[prompt_csv['videoid'] ==
video_id]['fps'].values[0]
fps_list.append(fps)
fs = prompt_csv[prompt_csv['videoid'] == video_id]['fs'].values[0]
fs_list.append(fs)
num_gen = prompt_csv[prompt_csv['videoid'] ==
video_id]['num_gen'].values[0]
num_gen_list.append(int(num_gen))
return filename_list, data_list, prompt_list, fps_list, fs_list, num_gen_list
def is_inferenced(save_dir: str, filename: str) -> bool:
"""
Check if a result video already exists.
Args:
save_dir (str): Directory where results are saved.
filename (str): Base filename to check.
Returns:
bool: True if file exists, else False.
"""
video_file = os.path.join(save_dir, f"{filename[:-4]}.mp4")
return os.path.exists(video_file)
def save_results_seperate(prompt: str | list[str],
samples: torch.Tensor,
filename: str,
fakedir: str,
fps: int = 8) -> None:
"""
Save generated video samples as .mp4 files.
Args:
prompt (str | list[str]): The prompt text.
samples (torch.Tensor): Generated video tensor of shape [B, C, T, H, W].
filename (str): Output filename.
fakedir (str): Directory to save output videos.
fps (int, optional): Frames per second.
Returns:
None
"""
prompt = prompt[0] if isinstance(prompt, list) else prompt
# Save video
videos = [samples]
savedirs = [fakedir]
for idx, video in enumerate(videos):
if video is None:
continue
video = video.detach().cpu()
video = torch.clamp(video.float(), -1., 1.)
n = video.shape[0]
for i in range(n):
grid = video[i, ...]
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(1, 2, 3, 0)
path = os.path.join(savedirs[idx], f'{filename.split(".")[0]}.mp4')
torchvision.io.write_video(path,
grid,
fps=fps,
video_codec='h264',
options={'crf': '0'})
def get_latent_z(model: torch.nn.Module, videos: torch.Tensor) -> torch.Tensor:
"""
Encode videos to latent space.
Args:
model (torch.nn.Module): Model with encode_first_stage function.
videos (torch.Tensor): Video tensor of shape [B, C, T, H, W].
Returns:
torch.Tensor: Latent representation of shape [B, C, T, H, W].
"""
b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w')
z = model.encode_first_stage(x)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
return z
def image_guided_synthesis(model: torch.nn.Module,
prompts: list[str],
videos: torch.Tensor,
noise_shape: list[int],
ddim_steps: int = 50,
ddim_eta: float = 1.0,
unconditional_guidance_scale: float = 1.0,
fs: int | None = None,
text_input: bool = False,
timestep_spacing: str = 'uniform',
guidance_rescale: float = 0.0,
**kwargs) -> torch.Tensor:
"""
Run DDIM-based image-to-video synthesis with hybrid/text+image guidance.
Args:
model (torch.nn.Module): Diffusion model.
prompts (list[str]): Text prompts.
videos (torch.Tensor): Input images/videos of shape [B, C, T, H, W].
noise_shape (list[int]): Latent noise shape [B, C, T, H, W].
ddim_steps (int, optional): Number of DDIM steps.
ddim_eta (float, optional): Eta value for DDIM.
unconditional_guidance_scale (float, optional): Guidance scale.
fs (int | None, optional): FPS input for sampler.
text_input (bool, optional): If True, use text guidance.
timestep_spacing (str, optional): Timestep schedule spacing.
guidance_rescale (float, optional): Rescale guidance effect.
**kwargs: Additional sampler args.
Returns:
torch.Tensor: Synthesized videos of shape [B, 1, C, T, H, W].
"""
ddim_sampler = DDIMSampler(model)
batch_size = noise_shape[0]
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
if not text_input:
prompts = [""] * batch_size
b, c, t, h, w = videos.shape
img = videos[:, :, 0]
img_emb = model.embedder(img)
img_emb = model.image_proj_model(img_emb)
img_emb = rearrange(img_emb, 'b (t l) c -> (b t) l c', t=t)
cond_emb = model.get_learned_conditioning(prompts)
cond_emb = cond_emb.repeat_interleave(repeats=t, dim=0)
cond = {"c_crossattn": [torch.cat([cond_emb, img_emb], dim=1)]}
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, videos)
img_cat_cond = z[:, :, :1, :, :]
img_cat_cond = repeat(img_cat_cond,
'b c t h w -> b c (repeat t) h w',
repeat=z.shape[2])
cond["c_concat"] = [img_cat_cond]
uc = None
cond_mask = None
kwargs.update({"unconditional_conditioning_img_nonetext": None})
batch_variants = []
if ddim_sampler is not None:
samples, _, _, _ = ddim_sampler.sample(
S=ddim_steps,
batch_size=batch_size,
shape=noise_shape[1:],
conditioning=cond,
eta=ddim_eta,
mask=cond_mask,
x0=None,
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
fs=fs,
timestep_spacing=timestep_spacing,
guidance_rescale=guidance_rescale,
**kwargs)
# Reconstruct from latent to pixel space
batch_images = model.decode_first_stage(samples)
batch_variants.append(batch_images)
batch_variants = torch.stack(batch_variants)
return batch_variants.permute(1, 0, 2, 3, 4, 5)
def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
"""
Run inference pipeline on prompts and image inputs.
Args:
args (argparse.Namespace): Parsed command-line arguments.
gpu_num (int): Number of GPUs.
gpu_no (int): Index of the current GPU.
Returns:
None
"""
# Load config
config = OmegaConf.load(args.config)
# Set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set"
config['model']['params']['wma_config']['params'][
'use_checkpoint'] = False
model = instantiate_from_config(config.model)
model = model.cuda(gpu_no)
model.perframe_ae = args.perframe_ae
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, args.ckpt_path)
model.eval()
# Run over data
assert (args.height % 16 == 0) and (
args.width % 16
== 0), "Error: image size [h,w] should be multiples of 16!"
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
# Get latent noise shape
h, w = args.height // 8, args.width // 8
channels = model.model.diffusion_model.out_channels
n_frames = args.video_length
print(f'>>> Generate {n_frames} frames under each generation ...')
noise_shape = [args.bs, channels, n_frames, h, w]
fakedir = os.path.join(args.savedir, "samples")
os.makedirs(fakedir, exist_ok=True)
# Prompt file setting
assert os.path.exists(args.prompt_dir), "Error: prompt file Not Found!"
filename_list, data_list, prompt_list, fps_list, fs_list, num_gen_list = load_data_prompts(
args.prompt_dir,
args.savedir,
video_size=(args.height, args.width),
video_frames=n_frames)
num_samples = len(prompt_list)
samples_split = num_samples // gpu_num
print('>>> Prompts testing [rank:%d] %d/%d samples loaded.' %
(gpu_no, samples_split, num_samples))
indices = list(range(samples_split * gpu_no, samples_split * (gpu_no + 1)))
fps_list_rank = [fps_list[i] for i in indices]
fs_list_rank = [fs_list[i] for i in indices]
prompt_list_rank = [prompt_list[i] for i in indices]
data_list_rank = [data_list[i] for i in indices]
filename_list_rank = [filename_list[i] for i in indices]
with torch.no_grad(), torch.cuda.amp.autocast():
# Create a new result csv
for idx, indice in enumerate(
tqdm(range(0, len(prompt_list_rank), args.bs),
desc=f'Sample batch')):
fps = fps_list_rank[indice:indice + args.bs]
fs = fs_list_rank[indice:indice + args.bs]
prompts = prompt_list_rank[indice:indice + args.bs]
num_gen = num_gen_list[indice:indice + args.bs]
videos = data_list_rank[indice:indice + args.bs]
filenames = filename_list_rank[indice:indice + args.bs]
if isinstance(videos, list):
videos = torch.stack(videos, dim=0).to("cuda")
else:
videos = videos.unsqueeze(0).to("cuda")
results = []
print(
f">>> {prompts[0]}, frame_stride:{fs[0]}, and {num_gen[0]} generation ..."
)
for _ in range(num_gen[0]):
batch_samples = image_guided_synthesis(
model, prompts, videos, noise_shape, args.ddim_steps,
args.ddim_eta, args.unconditional_guidance_scale,
fps[0] // fs[0], args.text_input, args.timestep_spacing,
args.guidance_rescale)
results.extend(batch_samples)
videos = repeat(batch_samples[0][:, :, -1, :, :].unsqueeze(2),
'b c t h w -> b c (repeat t) h w',
repeat=batch_samples[0].shape[2])
batch_samples = [torch.concat(results, axis=2)]
# Save each example individually
for nn, samples in enumerate(batch_samples):
prompt = prompts[nn]
filename = filenames[nn]
save_results_seperate(prompt,
samples,
filename,
fakedir,
fps=8)
def get_parser() -> argparse.ArgumentParser:
"""
Create and return the argument parser.
Returns:
argparse.ArgumentParser: Parser for command-line arguments.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--savedir",
type=str,
default=None,
help="Path to save the results.")
parser.add_argument("--ckpt_path",
type=str,
default=None,
help="Path to the model checkpoint.")
parser.add_argument("--config",
type=str,
help="Path to the YAML configuration file.")
parser.add_argument(
"--prompt_dir",
type=str,
default=None,
help="Directory containing videos and corresponding prompts.")
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="Number of DDIM steps. If non-positive, DDPM is used instead.")
parser.add_argument(
"--ddim_eta",
type=float,
default=1.0,
help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
parser.add_argument("--bs",
type=int,
default=1,
help="Batch size for inference. Must be 1.")
parser.add_argument("--height",
type=int,
default=320,
help="Height of the generated images in pixels.")
parser.add_argument("--width",
type=int,
default=512,
help="Width of the generated images in pixels.")
parser.add_argument(
"--unconditional_guidance_scale",
type=float,
default=1.0,
help="Scale for classifier-free guidance during sampling.")
parser.add_argument("--seed",
type=int,
default=123,
help="Random seed for reproducibility.")
parser.add_argument("--video_length",
type=int,
default=16,
help="Number of frames in the generated video.")
parser.add_argument(
"--text_input",
action='store_true',
default=False,
help=
"Whether to provide a text prompt as input to the image-to-video model."
)
parser.add_argument(
"--timestep_spacing",
type=str,
default="uniform",
help=
"Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--guidance_rescale",
type=float,
default=0.0,
help=
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--perframe_ae",
action='store_true',
default=False,
help=
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
)
return parser
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
seed = args.seed
if seed < 0:
seed = random.randint(0, 2**31)
seed_everything(seed)
rank, gpu_num = 0, 1
run_inference(args, gpu_num, rank)

View File

@@ -0,0 +1,77 @@
import torch
import warnings
import torchvision
import sys
import pyarrow as pa
import logging
from dataclasses import dataclass, field
from typing import Dict, Any, ClassVar, Deque, Mapping, Union
from datasets.features.features import register_feature
from torch.utils.tensorboard.writer import SummaryWriter
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
@dataclass
class VideoFrame:
"""
Provides a type for a dataset containing video frames.
Example:
```python
data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
features = {"image": VideoFrame()}
Dataset.from_dict(data_dict, features=Features(features))
```
"""
pa_type: ClassVar[Any] = pa.struct({
"path": pa.string(),
"timestamp": pa.float32()
})
_type: str = field(default="VideoFrame", init=False, repr=False)
def __call__(self):
return self.pa_type
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
"'register_feature' is experimental and might be subject to breaking changes in the future.",
category=UserWarning,
)
register_feature(VideoFrame, "VideoFrame")
def populate_queues(
queues: Dict[str, Deque[Any]],
batch: Mapping[str, Any]) -> Dict[str, Deque[Any]]:
for key in batch:
if key not in queues:
continue
if len(queues[key]) != queues[key].maxlen:
while len(queues[key]) != queues[key].maxlen:
queues[key].append(batch[key])
else:
queues[key].append(batch[key])
return queues
def log_to_tensorboard(
writer: SummaryWriter,
data: Union[torch.Tensor, Any],
tag: str,
fps: int = 10) -> None:
if isinstance(data, torch.Tensor) and data.dim() == 5:
video = data
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = grid.unsqueeze(dim=0)
writer.add_video(tag, grid, fps=fps)

View File

@@ -0,0 +1,463 @@
import argparse, os, sys
import torch
import torchvision
import warnings
import imageio
import logging
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import traceback
import uvicorn
from omegaconf import OmegaConf
from einops import rearrange, repeat
from collections import OrderedDict
from pytorch_lightning import seed_everything
from torch import nn
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from typing import Any, Dict, Optional, Tuple, List
from datetime import datetime
from unifolm_wma.utils.utils import instantiate_from_config
from unifolm_wma.models.samplers.ddim import DDIMSampler
def get_device_from_parameters(module: nn.Module) -> torch.device:
"""Get a module's device by checking one of its parameters.
Args:
module (nn.Module): PyTorch module.
Returns:
torch.device: The device where the module's parameters are stored.
"""
return next(iter(module.parameters())).device
def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
"""Load model weights from checkpoint file.
Args:
model (nn.Module): Model to load weights into.
ckpt (str): Path to checkpoint file.
Returns:
nn.Module: Model with loaded weights.
"""
state_dict = torch.load(ckpt, map_location="cpu")
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
try:
model.load_state_dict(state_dict, strict=False)
except:
new_pl_sd = OrderedDict()
for k, v in state_dict.items():
new_pl_sd[k] = v
for k in list(new_pl_sd.keys()):
if "framestride_embed" in k:
new_key = k.replace("framestride_embed", "fps_embedding")
new_pl_sd[new_key] = new_pl_sd[k]
del new_pl_sd[k]
model.load_state_dict(new_pl_sd, strict=False)
else:
new_pl_sd = OrderedDict()
for key in state_dict['module'].keys():
new_pl_sd[key[16:]] = state_dict['module'][key]
model.load_state_dict(new_pl_sd)
print('>>> model checkpoint loaded.')
return model
def write_video(video_path: str, stacked_frames: List[Any], fps: int) -> None:
"""Write a video to disk using imageio.
Args:
video_path (str): Path to save the video.
stacked_frames (List[Any]): Frames to write.
fps (int): Frames per second.
"""
with warnings.catch_warnings():
warnings.filterwarnings("ignore",
"pkg_resources is deprecated as an API",
category=DeprecationWarning)
imageio.mimsave(video_path, stacked_frames, fps=fps)
def save_results(video: torch.Tensor, filename: str, fps: int = 8) -> None:
"""Save a video tensor as an MP4 file.
Args:
video (torch.Tensor): Video tensor of shape (B, C, T, H, W).
filename (str): Path to save video.
fps (int, optional): Frame rate. Defaults to 8.
"""
video = video.detach().cpu()
video = torch.clamp(video.float(), -1., 1.)
n = video.shape[0]
video = video.permute(2, 0, 1, 3, 4)
frame_grids = [
torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
for framesheet in video
]
grid = torch.stack(frame_grids, dim=0)
grid = (grid + 1.0) / 2.0
grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
torchvision.io.write_video(filename,
grid,
fps=fps,
video_codec='h264',
options={'crf': '10'})
def get_latent_z(model: nn.Module, videos: torch.Tensor) -> torch.Tensor:
"""Encode videos into latent space.
Args:
model (nn.Module): Model with `encode_first_stage` method.
videos (torch.Tensor): Input videos (B, C, T, H, W).
Returns:
torch.Tensor: Latent representation (B, C, T, H, W).
"""
b, c, t, h, w = videos.shape
x = rearrange(videos, 'b c t h w -> (b t) c h w')
z = model.encode_first_stage(x)
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
return z
def image_guided_synthesis(
model: torch.nn.Module,
prompts: list[str],
observation: Dict[str, torch.Tensor],
noise_shape: tuple[int, int, int, int, int],
ddim_steps: int = 50,
ddim_eta: float = 1.0,
unconditional_guidance_scale: float = 1.0,
fs: int | None = None,
timestep_spacing: str = 'uniform',
guidance_rescale: float = 0.0,
**kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Run inference with DDIM sampling.
Args:
model (nn.Module): Diffusion model.
prompts (Any): Conditioning text prompts.
observation (Dict[str, torch.Tensor]): Observation dictionary.
noise_shape (List[int]): Shape of noise tensor.
ddim_steps (int, optional): Number of DDIM steps. Defaults to 50.
ddim_eta (float, optional): Sampling eta. Defaults to 1.0.
unconditional_guidance_scale (float, optional): Guidance scale. Defaults to 1.0.
fs (Optional[int], optional): Frame stride or FPS. Defaults to None.
timestep_spacing (str, optional): Spacing strategy. Defaults to "uniform".
guidance_rescale (float, optional): Guidance rescale. Defaults to 0.0.
**kwargs (Any): Additional arguments.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
b, _, t, _, _ = noise_shape
ddim_sampler = DDIMSampler(model)
batch_size = noise_shape[0]
fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
img = observation['observation.images.top']
cond_img = img[:, -1, ...]
cond_img_emb = model.embedder(cond_img)
cond_img_emb = model.image_proj_model(cond_img_emb)
if model.model.conditioning_key == 'hybrid':
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
img_cat_cond = z[:, :, -1:, :, :]
img_cat_cond = repeat(img_cat_cond,
'b c t h w -> b c (repeat t) h w',
repeat=noise_shape[2])
cond = {"c_concat": [img_cat_cond]}
cond_ins_emb = model.get_learned_conditioning(prompts)
cond_state = model.state_projector(observation['observation.state'])
cond_state_emb = model.agent_state_pos_emb + cond_state
cond_action = model.action_projector(observation['action'])
cond_action_emb = model.agent_action_pos_emb + cond_action
cond_action_emb = torch.zeros_like(cond_action_emb)
cond["c_crossattn"] = [
torch.cat([cond_state_emb, cond_ins_emb, cond_img_emb], dim=1)
]
cond["c_crossattn_action"] = [
observation['observation.images.top'].permute(
0, 2, 1, 3, 4)[:, :, -model.n_obs_steps_acting:],
observation['observation.state'][:, -model.n_obs_steps_acting:]
]
uc = None
kwargs.update({"unconditional_conditioning_img_nonetext": None})
cond_mask = None
cond_z0 = None
if ddim_sampler is not None:
samples, actions, states, intermedia = ddim_sampler.sample(
S=ddim_steps,
conditioning=cond,
batch_size=batch_size,
shape=noise_shape[1:],
verbose=False,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc,
eta=ddim_eta,
cfg_img=None,
mask=cond_mask,
x0=cond_z0,
fs=fs,
timestep_spacing=timestep_spacing,
guidance_rescale=guidance_rescale,
**kwargs)
# Reconstruct from latent to pixel space
batch_images = model.decode_first_stage(samples)
batch_variants = batch_images
return batch_variants, actions, states
def run_inference(args: argparse.Namespace, gpu_num: int,
gpu_no: int) -> Tuple[nn.Module, List[int], Any]:
"""
Run inference pipeline on prompts and image inputs.
Args:
args (argparse.Namespace): Parsed command-line arguments.
gpu_num (int): Number of GPUs.
gpu_no (int): Index of the current GPU.
Returns:
None
"""
# Load config
config = OmegaConf.load(args.config)
# Set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set"
config['model']['params']['wma_config']['params']['use_checkpoint'] = False
model = instantiate_from_config(config.model)
model.perframe_ae = args.perframe_ae
assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
model = load_model_checkpoint(model, args.ckpt_path)
model = model.cuda(gpu_no)
model.eval()
print(">>> Model is successfully loaded ...")
# Build unnomalizer
logging.info("***** Configing Data *****")
data = instantiate_from_config(config.data)
data.setup()
print(">>> Dataset is successfully loaded ...")
## Run over data
assert (args.height % 16 == 0) and (
args.width % 16
== 0), "Error: image size [h,w] should be multiples of 16!"
assert args.bs == 1, "Current implementation only support [batch size = 1]!"
## Get latent noise shape
h, w = args.height // 8, args.width // 8
channels = model.model.diffusion_model.out_channels
n_frames = args.video_length
print(f'>>> Generate {n_frames} frames under each generation ...')
noise_shape = [args.bs, channels, n_frames, h, w]
return model, noise_shape, data
def get_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("--savedir",
type=str,
default=None,
help="Path to save the results.")
parser.add_argument("--ckpt_path",
type=str,
default=None,
help="Path to the model checkpoint.")
parser.add_argument("--config", type=str, help="Path to the config file.")
parser.add_argument(
"--ddim_steps",
type=int,
default=50,
help="Number of DDIM steps. If non-positive, DDPM is used instead.")
parser.add_argument(
"--ddim_eta",
type=float,
default=1.0,
help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
parser.add_argument("--bs",
type=int,
default=1,
help="Batch size for inference. Must be 1.")
parser.add_argument("--height",
type=int,
default=320,
help="Height of the generated images in pixels.")
parser.add_argument("--width",
type=int,
default=512,
help="Width of the generated images in pixels.")
parser.add_argument(
"--frame_stride",
type=int,
default=3,
help=
"frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)"
)
parser.add_argument(
"--unconditional_guidance_scale",
type=float,
default=1.0,
help="Scale for classifier-free guidance during sampling.")
parser.add_argument("--seed",
type=int,
default=123,
help="Random seed for reproducibility.")
parser.add_argument("--video_length",
type=int,
default=16,
help="Number of frames in the generated video.")
parser.add_argument(
"--timestep_spacing",
type=str,
default="uniform",
help=
"Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--guidance_rescale",
type=float,
default=0.0,
help=
"Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
)
parser.add_argument(
"--perframe_ae",
action='store_true',
default=False,
help=
"Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
)
return parser
class Server:
def __init__(self, args: argparse.Namespace) -> None:
self.model_, self.noise_shape_, self.data_ = run_inference(args, 1, 0)
self.args_ = args
self.dataset_name = self.data_.dataset_configs['test']['params'][
'dataset_name']
self.device_ = get_device_from_parameters(self.model_)
def normalize_image(self, image: torch.Tensor) -> torch.Tensor:
return (image / 255 - 0.5) * 2
def predict_action(self, payload: Dict[str, Any]) -> Any:
try:
images = payload['observation.images.top']
states = payload['observation.state']
actions = payload['action'] # Should be all zeros
language_instruction = payload['language_instruction']
images = torch.tensor(images).cuda()
images = self.data_.test_datasets[
self.dataset_name].spatial_transform(images).unsqueeze(0)
images = self.normalize_image(images)
print(f"images shape: {images.shape} ...")
states = torch.tensor(states)
states = self.data_.test_datasets[self.dataset_name].normalizer(
{'observation.state': states})['observation.state']
states, _ = self.data_.test_datasets[
self.dataset_name]._map_to_uni_state(states, "joint position")
print(f"states shape: {states.shape} ...")
actions = torch.tensor(actions)
actions, action_mask = self.data_.test_datasets[
self.dataset_name]._map_to_uni_action(actions,
"joint position")
print(f"actions shape: {actions.shape} ...")
print("=" * 20)
states = states.unsqueeze(0).cuda()
actions = actions.unsqueeze(0).cuda()
observation = {
'observation.images.top': images,
'observation.state': states,
'action': actions
}
observation = {
key: observation[key].to(self.device_, non_blocking=True)
for key in observation
}
args = self.args_
pred_videos, pred_action, _ = image_guided_synthesis(
self.model_,
language_instruction,
observation,
self.noise_shape_,
ddim_steps=args.ddim_steps,
ddim_ets=args.ddim_eta,
unconditional_guidance_scale=args.unconditional_guidance_scale,
fs=30 / args.frame_stride,
timestep_spacing=args.timestep_spacing,
guidance_rescale=args.guidance_rescale)
pred_action = pred_action[..., action_mask[0] == 1.0][0].cpu()
pred_action = self.data_.test_datasets[
self.dataset_name].unnormalizer({'action':
pred_action})['action']
os.makedirs(args.savedir, exist_ok=True)
current_time = datetime.now().strftime("%H:%M:%S")
video_file = f'{args.savedir}/{current_time}.mp4'
save_results(pred_videos.cpu(), video_file)
response = {
'result': 'ok',
'action': pred_action.tolist(),
'desc': 'success'
}
return JSONResponse(response)
except:
logging.error(traceback.format_exc())
logging.warning(
"Your request threw an error; make sure your request complies with the expected format:\n"
"{'image': np.ndarray, 'instruction': str}\n"
"You can optionally an `unnorm_key: str` to specific the dataset statistics you want to use for "
"de-normalizing the output actions.")
return {'result': 'error', 'desc': traceback.format_exc()}
def run(self, host: str = "127.0.0.1", port: int = 8000) -> None:
self.app = FastAPI()
self.app.post("/predict_action")(self.predict_action)
print(">>> Inference server is ready ... ")
uvicorn.run(self.app, host=host, port=port)
print(">>> Inference server stops ... ")
return
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
seed = args.seed
seed_everything(seed)
rank, gpu_num = 0, 1
print(">>> Launch inference server ... ")
server = Server(args)
server.run()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,23 @@
#!/bin/bash
model_name=base_model
ckpt=/path/to/base/model
config=configs/inference/base_model_inference.yaml
res_dir="/path/to/result/directory"
seed=123
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/base_model_inference.py \
--seed ${seed} \
--ckpt_path $ckpt \
--config $config \
--savedir "${res_dir}/videos" \
--bs 1 --height 320 --width 512 \
--unconditional_guidance_scale 1.0 \
--ddim_steps 16 \
--ddim_eta 1.0 \
--prompt_dir "/path/to/examples/base_model_prompts" \
--text_input \
--video_length 16 \
--timestep_spacing 'uniform_trailing' \
--guidance_rescale 0.7 \
--perframe_ae

View File

@@ -0,0 +1,26 @@
model_name=testing
ckpt=/path/to/model/checkpoint
config=configs/inference/world_model_decision_making.yaml
seed=123
res_dir="path/to/results/directory"
datasets=(
"unitree_g1_pack_camera"
)
for dataset in "${datasets[@]}"; do
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/real_eval_server.py \
--seed ${seed} \
--ckpt_path $ckpt \
--config $config \
--savedir "${res_dir}/${dataset}/${model_name}/videos" \
--bs 1 --height 320 --width 512 \
--unconditional_guidance_scale 1.0 \
--ddim_steps 16 \
--ddim_eta 1.0 \
--video_length 16 \
--frame_stride 2 \
--timestep_spacing 'uniform_trailing' \
--guidance_rescale 0.7 \
--perframe_ae
done

View File

@@ -0,0 +1,42 @@
model_name=testing
ckpt=/path/to/model/checkpoint
config=configs/inference/world_model_interaction.yaml
seed=123
res_dir="/path/to/result/directory"
datasets=(
"unitree_z1_stackbox"
"unitree_z1_dual_arm_stackbox"
"unitree_z1_dual_arm_stackbox_v2"
"unitree_z1_dual_arm_cleanup_pencils"
"unitree_g1_pack_camera"
)
n_iters=(12 7 11 8 11)
fses=(4 4 4 4 6)
for i in "${!datasets[@]}"; do
dataset=${datasets[$i]}
n_iter=${n_iters[$i]}
fs=${fses[$i]}
CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
--seed ${seed} \
--ckpt_path $ckpt \
--config $config \
--savedir "${res_dir}/${model_name}/${dataset}" \
--bs 1 --height 320 --width 512 \
--unconditional_guidance_scale 1.0 \
--ddim_steps 50 \
--ddim_eta 1.0 \
--prompt_dir "/path/to/unifolm-world-model-action/examples/world_model_interaction_prompts" \
--dataset ${dataset} \
--video_length 16 \
--frame_stride ${fs} \
--n_action_steps 16 \
--exe_steps 16 \
--n_iter ${n_iter} \
--timestep_spacing 'uniform_trailing' \
--guidance_rescale 0.7 \
--perframe_ae
done

32
scripts/train.sh Normal file
View File

@@ -0,0 +1,32 @@
# NCCL configuration
# export NCCL_DEBUG=debug
# export NCCL_IB_DISABLE=0
# export NCCL_IB_GID_INDEX=3
# export NCCL_NET_GDR_LEVEL=3
# export CUDA_LAUNCH_BLOCKING=1
# export NCCL_TOPO_FILE=/tmp/topo.txt
# export MASTER_ADDR="master.ip."
# export MASTER_PROT=12366
# args
name="experiment_name"
config_file=configs/train/config.yaml
# save root dir for logs, checkpoints, tensorboard record, etc.
save_root="/path/to/savedir"
mkdir -p $save_root/$name
## run
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \
--nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=12366 --node_rank=0 \
./scripts/trainer.py \
--base $config_file \
--train \
--name $name \
--logdir $save_root \
--devices 8 \
--total_gpus=8 \
lightning.trainer.num_nodes=1

214
scripts/trainer.py Normal file
View File

@@ -0,0 +1,214 @@
import argparse, os, datetime
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf
from transformers import logging as transf_logging
from pytorch_lightning import seed_everything
from pytorch_lightning.trainer import Trainer
from unifolm_wma.utils.utils import instantiate_from_config
from unifolm_wma.utils.train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy
from unifolm_wma.utils.train import set_logger, init_workspace, load_checkpoints, get_num_parameters
def get_parser(**parser_kwargs):
parser = argparse.ArgumentParser(**parser_kwargs)
parser.add_argument("--seed",
"-s",
type=int,
default=20250912,
help="seed for seed_everything")
parser.add_argument("--name",
"-n",
type=str,
default="",
help="experiment name, as saving folder")
parser.add_argument(
"--base",
"-b",
nargs="*",
metavar="base_config.yaml",
help="paths to base configs. Loaded from left-to-right.",
default=list())
parser.add_argument("--train",
"-t",
action='store_true',
default=False,
help='train')
parser.add_argument("--val",
"-v",
action='store_true',
default=False,
help='val')
parser.add_argument("--test",
action='store_true',
default=False,
help='test')
parser.add_argument("--logdir",
"-l",
type=str,
default="logs",
help="directory for logging dat shit")
parser.add_argument("--auto_resume",
action='store_true',
default=False,
help="resume from full-info checkpoint")
parser.add_argument("--auto_resume_weight_only",
action='store_true',
default=False,
help="resume from weight-only checkpoint")
parser.add_argument("--debug",
"-d",
action='store_true',
default=False,
help="enable post-mortem debugging")
return parser
def get_nondefault_trainer_args(args):
parser = argparse.ArgumentParser()
parser = Trainer.add_argparse_args(parser)
default_trainer_args = parser.parse_args([])
return sorted(k for k in vars(default_trainer_args)
if getattr(args, k) != getattr(default_trainer_args, k))
if __name__ == "__main__":
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
local_rank = int(os.environ.get('LOCAL_RANK'))
global_rank = int(os.environ.get('RANK'))
num_rank = int(os.environ.get('WORLD_SIZE'))
parser = get_parser()
# Extends existing argparse by default Trainer attributes
parser = Trainer.add_argparse_args(parser)
args, unknown = parser.parse_known_args()
transf_logging.set_verbosity_error()
seed_everything(args.seed)
configs = [OmegaConf.load(cfg) for cfg in args.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
lightning_config = config.pop("lightning", OmegaConf.create())
trainer_config = lightning_config.get("trainer", OmegaConf.create())
# Setup workspace directories
workdir, ckptdir, cfgdir, loginfo = init_workspace(args.name, args.logdir,
config,
lightning_config,
global_rank)
logger = set_logger(
logfile=os.path.join(loginfo, 'log_%d:%s.txt' % (global_rank, now)))
logger.info("@lightning version: %s [>=1.8 required]" % (pl.__version__))
logger.info("***** Configing Model *****")
config.model.params.logdir = workdir
model = instantiate_from_config(config.model)
# Load checkpoints
model = load_checkpoints(model, config.model)
# Register_schedule again to make ZTSNR work
if model.rescale_betas_zero_snr:
model.register_schedule(given_betas=model.given_betas,
beta_schedule=model.beta_schedule,
timesteps=model.timesteps,
linear_start=model.linear_start,
linear_end=model.linear_end,
cosine_s=model.cosine_s)
# Update trainer config
for k in get_nondefault_trainer_args(args):
trainer_config[k] = getattr(args, k)
num_nodes = trainer_config.num_nodes
ngpu_per_node = trainer_config.devices
logger.info(f"Running on {num_rank}={num_nodes}x{ngpu_per_node} GPUs")
# Setup learning rate
base_lr = config.model.base_learning_rate
bs = config.data.params.batch_size
if getattr(config.model, 'scale_lr', True):
model.learning_rate = num_rank * bs * base_lr
else:
model.learning_rate = base_lr
logger.info("***** Configing Data *****")
data = instantiate_from_config(config.data)
data.setup()
for k in data.train_datasets:
logger.info(
f"{k}, {data.train_datasets[k].__class__.__name__}, {len(data.train_datasets[k])}"
)
if hasattr(data, 'val_datasets'):
for k in data.val_datasets:
logger.info(
f"{k}, {data.val_datasets[k].__class__.__name__}, {len(data.val_datasets[k])}"
)
for item in unknown:
if item.startswith('--total_gpus'):
num_gpus = int(item.split('=')[-1])
break
model.datasets_len = len(data)
logger.info("***** Configing Trainer *****")
if "accelerator" not in trainer_config:
trainer_config["accelerator"] = "gpu"
# Setup trainer args: pl-logger and callbacks
trainer_kwargs = dict()
trainer_kwargs["num_sanity_val_steps"] = 0
logger_cfg = get_trainer_logger(lightning_config, workdir, args.debug)
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
# Setup callbacks
callbacks_cfg = get_trainer_callbacks(lightning_config, config, workdir,
ckptdir, logger)
trainer_kwargs["callbacks"] = [
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
]
strategy_cfg = get_trainer_strategy(lightning_config)
trainer_kwargs["strategy"] = strategy_cfg if type(
strategy_cfg) == str else instantiate_from_config(strategy_cfg)
trainer_kwargs['precision'] = lightning_config.get('precision', 32)
trainer_kwargs["sync_batchnorm"] = False
# Trainer config: others
trainer_args = argparse.Namespace(**trainer_config)
trainer = Trainer.from_argparse_args(trainer_args, **trainer_kwargs)
# Allow checkpointing via USR1
def melk(*args, **kwargs):
if trainer.global_rank == 0:
print("Summoning checkpoint.")
ckpt_path = os.path.join(ckptdir, "last_summoning.ckpt")
trainer.save_checkpoint(ckpt_path)
def divein(*args, **kwargs):
if trainer.global_rank == 0:
import pudb
pudb.set_trace()
import signal
signal.signal(signal.SIGUSR1, melk)
signal.signal(signal.SIGUSR2, divein)
# List the key model sizes
total_params = get_num_parameters(model)
logger.info("***** Running the Loop *****")
if args.train:
try:
if "strategy" in lightning_config and lightning_config[
'strategy'].startswith('deepspeed'):
logger.info("<Training in DeepSpeed Mode>")
if trainer_kwargs['precision'] == 16:
with torch.cuda.amp.autocast():
trainer.fit(model, data)
else:
trainer.fit(model, data)
else:
logger.info("<Training in DDPSharded Mode>")
trainer.fit(model, data)
except Exception:
raise

View File

View File

@@ -0,0 +1,26 @@
from abc import abstractmethod
from torch.utils.data import IterableDataset
class Txt2ImgIterableBaseDataset(IterableDataset):
'''
Define an interface to make the IterableDatasets for text2img data chainable
'''
def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__()
self.num_records = num_records
self.valid_ids = valid_ids
self.sample_ids = valid_ids
self.size = size
print(
f'{self.__class__.__name__} dataset contains {self.__len__()} examples.'
)
def __len__(self):
return self.num_records
@abstractmethod
def __iter__(self):
pass

View File

@@ -0,0 +1,230 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor, nn
from typing import Dict, List
def create_stats_buffers(
shapes: Dict[str, List[int]],
modes: Dict[str, str],
stats: Dict[str, Dict[str, Tensor]] = None,
) -> Dict[str, Dict[str, nn.ParameterDict]]:
"""
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
statistics.
Args: (see Normalize and Unnormalize)
Returns:
Dict: A Dictionary where keys are modalities and values are `nn.ParameterDict` containing
`nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
"""
stats_buffers = {}
for key, mode in modes.items():
assert mode in ["mean_std", "min_max"]
shape = tuple(shapes[key])
if "image" in key:
# sanity checks
assert len(
shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
c, h, w = shape
assert c < h and c < w, f"{key} is not channel first ({shape=})"
# override image shape to be invariant to height and width
shape = (c, 1, 1)
# Note: we initialize mean, std, min, max to infinity. They should be overwritten
# downstream by `stats` or `policy.load_state_Dict`, as expected. During forward,
# we assert they are not infinity anymore.
if "action" in key:
target_key = "action"
elif "state" in key:
target_key = 'observation.state'
else:
target_key = key
buffer = {}
if mode == "mean_std":
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
std = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict({
"mean":
nn.Parameter(mean, requires_grad=False),
"std":
nn.Parameter(std, requires_grad=False),
})
elif mode == "min_max":
min = torch.ones(shape, dtype=torch.float32) * torch.inf
max = torch.ones(shape, dtype=torch.float32) * torch.inf
buffer = nn.ParameterDict({
"min":
nn.Parameter(min, requires_grad=False),
"max":
nn.Parameter(max, requires_grad=False),
})
if stats is not None:
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
if mode == "mean_std":
buffer["mean"].data = stats[target_key]["mean"].clone()
buffer["std"].data = stats[target_key]["std"].clone()
elif mode == "min_max":
buffer["min"].data = stats[target_key]["min"].clone()
buffer["max"].data = stats[target_key]["max"].clone()
stats_buffers[key] = buffer
return stats_buffers
def _no_stats_error_str(name: str) -> str:
return (
f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
"pretrained model.")
class Normalize(nn.Module):
"""Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
def __init__(
self,
shapes: Dict[str, List[int]],
modes: Dict[str, str],
stats: Dict[str, Dict[str, Tensor]] = None,
):
"""
Args:
shapes (Dict): A Dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (Dict): A Dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (Dict, optional): A Dictionary where keys are output modalities (e.g. "observation.image")
and values are Dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_Dict(state_Dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_Dict.
"""
super().__init__()
self.shapes = shapes
self.modes = modes
self.stats = stats
stats_buffers = create_stats_buffers(shapes, modes, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@torch.no_grad()
def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
for key, mode in self.modes.items():
if key not in batch:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std":
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = (batch[key] - mean) / (std + 1e-8)
elif mode == "min_max":
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
# normalize to [0,1]
batch[key] = (batch[key] - min) / (max - min + 1e-8)
# normalize to [-1, 1]
batch[key] = batch[key] * 2 - 1
else:
raise ValueError(mode)
return batch
class Unnormalize(nn.Module):
"""
Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
original range used by the environment.
"""
def __init__(
self,
shapes: Dict[str, List[int]],
modes: Dict[str, str],
stats: Dict[str, Dict[str, Tensor]] = None,
):
"""
Args:
shapes (Dict): A Dictionary where keys are input modalities (e.g. "observation.image") and values
are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
modes (Dict): A Dictionary where keys are output modalities (e.g. "observation.image") and values
are their normalization modes among:
- "mean_std": subtract the mean and divide by standard deviation.
- "min_max": map to [-1, 1] range.
stats (Dict, optional): A Dictionary where keys are output modalities (e.g. "observation.image")
and values are Dictionaries of statistic types and their values (e.g.
`{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
training the model for the first time, these statistics will overwrite the default buffers. If
not provided, as expected for finetuning or evaluation, the default buffers should to be
overwritten by a call to `policy.load_state_Dict(state_Dict)`. That way, initializing the
dataset is not needed to get the stats, since they are already in the policy state_Dict.
"""
super().__init__()
self.shapes = shapes
self.modes = modes
self.stats = stats
stats_buffers = create_stats_buffers(shapes, modes, stats)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@torch.no_grad()
def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
for key, mode in self.modes.items():
if key not in batch:
continue
buffer = getattr(self, "buffer_" + key.replace(".", "_"))
if mode == "mean_std":
mean = buffer["mean"]
std = buffer["std"]
assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
assert not torch.isinf(std).any(), _no_stats_error_str("std")
batch[key] = batch[key] * std + mean
elif mode == "min_max":
min = buffer["min"]
max = buffer["max"]
assert not torch.isinf(min).any(), _no_stats_error_str("min")
assert not torch.isinf(max).any(), _no_stats_error_str("max")
batch[key] = (batch[key] + 1) / 2
batch[key] = batch[key] * (max - min) + min
else:
raise ValueError(mode)
return batch

View File

@@ -0,0 +1,60 @@
import torch
from huggingface_hub import hf_hub_download, snapshot_download
from typing import Dict, List, Union
from pathlib import Path
from safetensors.torch import load_file
def unflatten_dict(d, sep="/"):
outdict = {}
for key, value in d.items():
parts = key.split(sep)
d = outdict
for part in parts[:-1]:
if part not in d:
d[part] = {}
d = d[part]
d[parts[-1]] = value
return outdict
def load_episode_data_index(repo_id, version, root) -> Dict[str, torch.Tensor]:
"""episode_data_index contains the range of indices for each episode
Example:
```python
from_id = episode_data_index["from"][episode_id].item()
to_id = episode_data_index["to"][episode_id].item()
episode_frames = [dataset[i] for i in range(from_id, to_id)]
```
"""
if root is not None:
path = Path(
root) / repo_id / "meta_data" / "episode_data_index.safetensors"
else:
path = hf_hub_download(repo_id,
"meta_data/episode_data_index.safetensors",
repo_type="dataset",
revision=version)
return load_file(path)
def load_stats(repo_id, version, root) -> Dict[str, Dict[str, torch.Tensor]]:
"""stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
Example:
```python
normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
```
"""
if root is not None:
path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
else:
path = hf_hub_download(repo_id,
"meta_data/stats.safetensors",
repo_type="dataset",
revision=version)
stats = load_file(path)
return unflatten_dict(stats)

View File

@@ -0,0 +1,408 @@
import torch
import os
import random
import pandas as pd
import h5py
from decord import VideoReader, cpu
from torch.utils.data import Dataset
from torchvision import transforms
from pathlib import Path
from unifolm_wma.data.utils import load_stats
from unifolm_wma.data.normolize import Normalize, Unnormalize
class WMAData(Dataset):
"""
Assuming the following dataset structure:
dataset_dir/
├── videos
│ ├──dataset_name
│ │ ├──camera_view_dir
│ │ ├── 0.mp4
│ │ ├── 1.mp4
│ │ └── ...
│ └── ...
├── transitions
│ ├── dataset_name
│ ├── meta_data
│ ├── 0.h5
│ ├── 1.h5
│ └── ...
└── dataset_name.csv
"""
def __init__(
self,
meta_path,
data_dir,
subsample=None,
video_length=16,
resolution=[256, 512],
frame_stride=1,
frame_stride_min=1,
spatial_transform=None,
crop_resolution=None,
fps_max=None,
load_raw_resolution=False,
fixed_fps=None,
random_fs=False,
cond_robot_label_prob=0.0,
transition_dir=None,
dataset_name=None,
normalization_mode='min_max',
individual_normalization=False,
n_obs_steps=1,
max_action_dim=7,
max_state_dim=7,
):
self.meta_path = meta_path
self.data_dir = data_dir
self.subsample = subsample
self.video_length = video_length
self.resolution = [resolution, resolution] if isinstance(
resolution, int) else resolution
self.fps_max = fps_max
self.frame_stride = frame_stride
self.frame_stride_min = frame_stride_min
self.fixed_fps = fixed_fps
self.load_raw_resolution = load_raw_resolution
self.random_fs = random_fs
self.cond_robot_label_prob = cond_robot_label_prob
self.transition_dir = transition_dir
self.dataset_name = dataset_name
self.max_action_dim = max_action_dim
self.max_state_dim = max_state_dim
self._load_metadata()
if spatial_transform is not None:
if spatial_transform == "random_crop":
self.spatial_transform = transforms.RandomCrop(crop_resolution)
elif spatial_transform == "center_crop":
self.spatial_transform = transforms.Compose([
transforms.CenterCrop(resolution),
])
elif spatial_transform == "resize_center_crop":
self.spatial_transform = transforms.Compose([
transforms.Resize(min(self.resolution)),
transforms.CenterCrop(self.resolution),
])
elif spatial_transform == "resize":
self.spatial_transform = transforms.Resize(self.resolution)
else:
raise NotImplementedError
else:
self.spatial_transform = None
self.normalization_mode = normalization_mode
self.individual_normalization = individual_normalization
self.n_obs_steps = n_obs_steps
self._load_stats()
if individual_normalization:
self._init_normalizers()
def _load_metadata(self):
metadata = pd.read_csv(self.meta_path, dtype=str)
if self.subsample is not None:
metadata = metadata.sample(self.subsample, random_state=0)
self.metadata = metadata
# drop the rows contain NaN values
self.metadata.dropna(inplace=True)
print(
f">>> {metadata['data_dir'].iloc[0]}: {len(metadata)} data samples loaded."
)
def _load_stats(self):
self.stats = load_stats(self.dataset_name, None, self.transition_dir)
print(f">>> {self.metadata['data_dir'].iloc[0]}: data stats loaded.")
def _init_normalizers(self):
shape_dict = {
'pre_action': [self.stats['action']['max'].shape[-1]],
'action': [self.stats['action']['max'].shape[-1]],
'observation.state':
[self.stats['observation.state']['max'].shape[-1]],
'next.state': [self.stats['observation.state']['max'].shape[-1]]
}
normalization_mode_dict = {
'pre_action': self.normalization_mode,
'action': self.normalization_mode,
'observation.state': self.normalization_mode,
'next.state': self.normalization_mode
}
self.normalizer = Normalize(shape_dict, normalization_mode_dict,
self.stats)
self.unnormalizer = Unnormalize(shape_dict, normalization_mode_dict,
self.stats)
print(
f">>> {self.metadata['data_dir'].iloc[0]}: normalizer initiated.")
def _get_video_path(self, sample):
rel_video_fp = os.path.join(sample['data_dir'],
str(sample['videoid']) + '.mp4')
full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp)
return full_video_fp
def _get_transition_path(self, sample):
data_dir = Path(sample['data_dir'])
if self.dataset_name == data_dir.name:
rel_transition_fp = os.path.join(str(data_dir),
str(sample['videoid']) + '.h5')
else:
rel_transition_fp = os.path.join(str(data_dir.parent),
str(sample['videoid']) + '.h5')
full_transition_fp = os.path.join(self.data_dir, 'transitions',
rel_transition_fp)
return full_transition_fp
def get_uni_vec(self, action_state_dict, action_type, state_type):
if 'pre_action' in action_state_dict:
action_state_dict['pre_action'], _ = self._map_to_uni_action(
action_state_dict['pre_action'], action_type)
if 'action' in action_state_dict:
action_state_dict['action'], action_state_dict[
'action_mask'] = self._map_to_uni_action(
action_state_dict['action'], action_type)
if 'observation.state' in action_state_dict:
action_state_dict['observation.state'], _ = self._map_to_uni_state(
action_state_dict['observation.state'], state_type)
if 'next.state' in action_state_dict:
action_state_dict['next.state'], action_state_dict[
'state_mask'] = self._map_to_uni_state(
action_state_dict['next.state'], state_type)
return action_state_dict
def _map_to_uni_action(self, action, action_type):
action_dim = action.shape[-1]
uni_action = torch.nn.functional.pad(
action, (0, self.max_action_dim - action_dim),
mode='constant',
value=0)
uni_action_mask = torch.zeros_like(uni_action)
uni_action_mask[:, :action_dim] = 1
return uni_action, uni_action_mask
def _map_to_uni_state(self, state, state_type):
state_dim = state.shape[-1]
uni_state = torch.nn.functional.pad(
state, (0, self.max_state_dim - state_dim),
mode='constant',
value=0)
uni_state_mask = torch.zeros_like(uni_state)
uni_state_mask[:, :state_dim] = 1
return uni_state, uni_state_mask
def __getitem__(self, index):
if self.random_fs:
frame_stride = random.randint(self.frame_stride_min,
self.frame_stride)
else:
frame_stride = self.frame_stride
# Get frames until success
while True:
index = index % len(self.metadata)
sample = self.metadata.iloc[index]
video_path = self._get_video_path(sample)
instruction = sample['instruction']
if self.cond_robot_label_prob > 0.0 and random.random(
) < self.cond_robot_label_prob:
if sample['embodiment'] != 'x':
instruction = sample['embodiment'] + ' [SEP] ' + sample[
'instruction']
try:
if self.load_raw_resolution:
video_reader = VideoReader(video_path, ctx=cpu(0))
else:
video_reader = VideoReader(video_path,
ctx=cpu(0),
width=530,
height=300)
if len(video_reader) < self.video_length:
print(
f">>> Video length ({len(video_reader)}) is smaller than target length({self.video_length})"
)
index += 1
continue
else:
pass
except:
index += 1
print(f">>> Error: load video failed! path = {video_path}")
continue
fps_ori = video_reader.get_avg_fps()
if self.fixed_fps is not None:
frame_stride = int(frame_stride *
(1.0 * fps_ori / self.fixed_fps))
# To avoid extreme cases when fixed_fps is used
frame_stride = max(frame_stride, 1)
# Get valid range (adapting case by case)
required_frame_num = frame_stride * (self.video_length - 1) + 1
frame_num = len(video_reader)
if frame_num < required_frame_num:
# Drop extra samples if fixed fps is required
if self.fixed_fps is not None and frame_num < required_frame_num * 0.5:
index += 1
continue
else:
frame_stride = frame_num // self.video_length
required_frame_num = frame_stride * (self.video_length -
1) + 1
# Select a random clip
random_range = frame_num - required_frame_num
start_idx = random.randint(
0, random_range -
frame_stride) if random_range - frame_stride > 0 else 0
# Calculate frame indices
frame_indices = [
start_idx + frame_stride * i for i in range(self.video_length)
]
try:
next_frame_indices = [
idx + frame_stride for idx in frame_indices
]
frames = video_reader.get_batch(next_frame_indices)
break
except:
print(
f">>> Error: Get frames failed! path = {video_path}; [max_ind vs frame_total:{max(frame_indices)} / {frame_num}]"
)
index += 1
continue
# Load transition data
transition_path = self._get_transition_path(sample)
with h5py.File(transition_path, 'r') as h5f:
transition_dict = {}
for key in h5f.keys():
transition_dict[key] = torch.tensor(h5f[key][()])
for key in h5f.attrs.keys():
transition_dict[key] = h5f.attrs[key]
# Load observable states
if start_idx < self.n_obs_steps - 1:
state_indices = list(range(0, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
num_padding = self.n_obs_steps - 1 - start_idx
first_slice = states[0:1, :] # (t, d)
padding = first_slice.repeat(num_padding, 1)
states = torch.cat((padding, states), dim=0)
else:
state_indices = list(
range(start_idx - self.n_obs_steps + 1, start_idx + 1))
states = transition_dict['observation.state'][state_indices, :]
assert states.shape[
0] == self.n_obs_steps, '>>> Do not have enough previous states as observation.'
# Load observable actions
if start_idx < self.n_obs_steps:
pre_action_indices = list(range(0, start_idx))
pre_actions = transition_dict['action'][pre_action_indices, :]
num_padding = self.n_obs_steps - start_idx
first_slice = torch.zeros_like(transition_dict['action'][:1, :])
padding = first_slice.repeat(num_padding, 1)
pre_actions = torch.cat((padding, pre_actions), dim=0)
else:
pre_action_indices = list(
range(start_idx - self.n_obs_steps, start_idx))
pre_actions = transition_dict['action'][pre_action_indices, :]
assert pre_actions.shape[
0] == self.n_obs_steps, ">>> Do not have enough previous actions as observation"
# Load future actions
actions = transition_dict['action'][frame_indices, :]
# Load future states
next_state_indices = [idx + frame_stride for idx in frame_indices]
next_states = transition_dict['observation.state'][
next_state_indices, :]
frames_action_state_dict = {
'pre_action': pre_actions,
'action': actions,
'observation.state': states,
'next.state': next_states
}
if self.individual_normalization:
frames_action_state_dict = self.normalizer(
frames_action_state_dict)
# Update action and states to unified vector
frames_action_state_dict = self.get_uni_vec(
frames_action_state_dict,
transition_dict['action_type'],
transition_dict['state_type'],
)
# Load observable images
if start_idx < self.n_obs_steps - 1:
action_net_frame_indices = list(range(0, start_idx + 1))
action_net_frames = video_reader.get_batch(
action_net_frame_indices)
action_net_frames = torch.tensor(
action_net_frames.asnumpy()).permute(0, 3, 1, 2).float()
first_slice = action_net_frames[0:1, :]
num_padding = self.n_obs_steps - 1 - start_idx
padding = first_slice.repeat(num_padding, 1, 1, 1)
action_net_frames = torch.cat((padding, action_net_frames), dim=0)
assert (
action_net_frames.shape[0] == self.n_obs_steps
), f'{len(action_net_frames)}, self.n_obs_steps={self.n_obs_steps}'
action_net_frames = action_net_frames.permute(1, 0, 2, 3)
else:
action_net_frame_indices = list(
range(start_idx - self.n_obs_steps + 1, start_idx + 1))
action_net_frames = video_reader.get_batch(
action_net_frame_indices)
assert (
action_net_frames.shape[0] == self.n_obs_steps
), f'{len(action_net_frames)}, self.n_obs_steps={self.n_obs_steps}'
action_net_frames = torch.tensor(
action_net_frames.asnumpy()).permute(3, 0, 1, 2).float()
assert (frames.shape[0] == self.video_length
), f'{len(frames)}, self.video_length={self.video_length}'
frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
if self.spatial_transform is not None:
frames = self.spatial_transform(frames)
action_net_frames = self.spatial_transform(action_net_frames)
if self.resolution is not None:
assert (frames.shape[2], frames.shape[3]) == (
self.resolution[0], self.resolution[1]
), f'frames={frames.shape}, self.resolution={self.resolution}'
assert (
action_net_frames.shape[2], action_net_frames.shape[3]
) == (
self.resolution[0], self.resolution[1]
), f'action_net_frames={action_net_frames.shape}, self.resolution={self.resolution}'
# Normalize frames tensors to [-1,1]
frames = (frames / 255 - 0.5) * 2
action_net_frames = (action_net_frames / 255 - 0.5) * 2
fps_clip = fps_ori // frame_stride
if self.fps_max is not None and fps_clip > self.fps_max:
fps_clip = self.fps_max
data = {
'video': frames,
'instruction': instruction,
'path': video_path,
'fps': fps_clip,
'frame_stride': frame_stride,
'observation.image': action_net_frames,
}
data.update(frames_action_state_dict)
return data
def __len__(self):
return len(self.metadata)

View File

View File

@@ -0,0 +1,267 @@
import os
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from einops import rearrange
from unifolm_wma.modules.networks.ae_modules import Encoder, Decoder
from unifolm_wma.utils.distributions import DiagonalGaussianDistribution
from unifolm_wma.utils.utils import instantiate_from_config
class AutoencoderKL(pl.LightningModule):
def __init__(
self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
test=False,
logdir=None,
input_dim=4,
test_args=None,
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"],
2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
self.input_dim = input_dim
self.test = test
self.test_args = test_args
self.logdir = logdir
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer("colorize",
torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
if self.test:
self.init_test()
def init_test(self, ):
self.test = True
save_dir = os.path.join(self.logdir, "test")
if 'ckpt' in self.test_args:
ckpt_name = os.path.basename(self.test_args.ckpt).split(
'.ckpt')[0] + f'_epoch{self._cur_epoch}'
self.root = os.path.join(save_dir, ckpt_name)
else:
self.root = save_dir
if 'test_subdir' in self.test_args:
self.root = os.path.join(save_dir, self.test_args.test_subdir)
self.root_zs = os.path.join(self.root, "zs")
self.root_dec = os.path.join(self.root, "reconstructions")
self.root_inputs = os.path.join(self.root, "inputs")
os.makedirs(self.root, exist_ok=True)
if self.test_args.save_z:
os.makedirs(self.root_zs, exist_ok=True)
if self.test_args.save_reconstruction:
os.makedirs(self.root_dec, exist_ok=True)
if self.test_args.save_input:
os.makedirs(self.root_inputs, exist_ok=True)
assert (self.test_args is not None)
self.test_maximum = getattr(self.test_args, 'test_maximum', None)
self.count = 0
self.eval_metrics = {}
self.decodes = []
self.save_decode_samples = 2048
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")
try:
self._cur_epoch = sd['epoch']
sd = sd["state_dict"]
except:
self._cur_epoch = 'null'
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
def encode(self, x, **kwargs):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z, **kwargs):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if x.dim() == 5 and self.input_dim == 4:
b, c, t, h, w = x.shape
self.b = b
self.t = t
x = rearrange(x, 'b c t h w -> (b t) c h w')
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train")
self.log("aeloss",
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True)
self.log_dict(log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split="train")
self.log("discloss",
discloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True)
self.log_dict(log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs,
reconstructions,
posterior,
0,
self.global_step,
last_layer=self.get_last_layer(),
split="val")
discloss, log_dict_disc = self.loss(inputs,
reconstructions,
posterior,
1,
self.global_step,
last_layer=self.get_last_layer(),
split="val")
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(list(self.encoder.parameters()) +
list(self.decoder.parameters()) +
list(self.quant_conv.parameters()) +
list(self.post_quant_conv.parameters()),
lr=lr,
betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr,
betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
log["inputs"] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize",
torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,217 @@
"""
Contains torch Modules that correspond to basic network building blocks, like
MLP, RNN, and CNN backbones.
"""
import abc
import numpy as np
import torch
import torch.nn.functional as F
class Module(torch.nn.Module):
"""
Base class for networks. The only difference from torch.nn.Module is that it
requires implementing @output_shape.
"""
@abc.abstractmethod
def output_shape(self, input_shape=None):
"""
Function to compute output shape from inputs to this module.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
raise NotImplementedError
"""
================================================
Visual Backbone Networks
================================================
"""
class ConvBase(Module):
"""
Base class for ConvNets.
"""
def __init__(self):
super(ConvBase, self).__init__()
# dirty hack - re-implement to pass the buck onto subclasses from ABC parent
def output_shape(self, input_shape):
"""
Function to compute output shape from inputs to this module.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
raise NotImplementedError
def forward(self, inputs):
x = self.nets(inputs)
if list(self.output_shape(list(inputs.shape)[1:])) != list(
x.shape)[1:]:
raise ValueError('Size mismatch: expect size %s, but got size %s' %
(str(self.output_shape(list(
inputs.shape)[1:])), str(list(x.shape)[1:])))
return x
class SpatialSoftmax(ConvBase):
"""
Spatial Softmax Layer.
Based on Deep Spatial Autoencoders for Visuomotor Learning by Finn et al.
https://rll.berkeley.edu/dsae/dsae.pdf
"""
def __init__(
self,
input_shape,
num_kp=32,
temperature=1.,
learnable_temperature=False,
output_variance=False,
noise_std=0.0,
):
"""
Args:
input_shape (list): shape of the input feature (C, H, W)
num_kp (int): number of keypoints (None for not using spatialsoftmax)
temperature (float): temperature term for the softmax.
learnable_temperature (bool): whether to learn the temperature
output_variance (bool): treat attention as a distribution, and compute second-order statistics to return
noise_std (float): add random spatial noise to the predicted keypoints
"""
super(SpatialSoftmax, self).__init__()
assert len(input_shape) == 3
self._in_c, self._in_h, self._in_w = input_shape # (C, H, W)
if num_kp is not None:
self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
self._num_kp = num_kp
else:
self.nets = None
self._num_kp = self._in_c
self.learnable_temperature = learnable_temperature
self.output_variance = output_variance
self.noise_std = noise_std
if self.learnable_temperature:
# temperature will be learned
temperature = torch.nn.Parameter(torch.ones(1) * temperature,
requires_grad=True)
self.register_parameter('temperature', temperature)
else:
# temperature held constant after initialization
temperature = torch.nn.Parameter(torch.ones(1) * temperature,
requires_grad=False)
self.register_buffer('temperature', temperature)
pos_x, pos_y = np.meshgrid(np.linspace(-1., 1., self._in_w),
np.linspace(-1., 1., self._in_h))
pos_x = torch.from_numpy(pos_x.reshape(1, self._in_h *
self._in_w)).float()
pos_y = torch.from_numpy(pos_y.reshape(1, self._in_h *
self._in_w)).float()
self.register_buffer('pos_x', pos_x)
self.register_buffer('pos_y', pos_y)
self.kps = None
def __repr__(self):
"""Pretty print network."""
header = format(str(self.__class__.__name__))
return header + '(num_kp={}, temperature={}, noise={})'.format(
self._num_kp, self.temperature.item(), self.noise_std)
def output_shape(self, input_shape):
"""
Function to compute output shape from inputs to this module.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
assert (len(input_shape) == 3)
assert (input_shape[0] == self._in_c)
return [self._num_kp, 2]
def forward(self, feature):
"""
Forward pass through spatial softmax layer. For each keypoint, a 2D spatial
probability distribution is created using a softmax, where the support is the
pixel locations. This distribution is used to compute the expected value of
the pixel location, which becomes a keypoint of dimension 2. K such keypoints
are created.
Returns:
out (torch.Tensor or tuple): mean keypoints of shape [B, K, 2], and possibly
keypoint variance of shape [B, K, 2, 2] corresponding to the covariance
under the 2D spatial softmax distribution
"""
assert (feature.shape[1] == self._in_c)
assert (feature.shape[2] == self._in_h)
assert (feature.shape[3] == self._in_w)
if self.nets is not None:
feature = self.nets(feature)
# [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
feature = feature.reshape(-1, self._in_h * self._in_w)
# 2d softmax normalization
attention = F.softmax(feature / self.temperature, dim=-1)
# [1, H * W] x [B * K, H * W] -> [B * K, 1] for spatial coordinate mean in x and y dimensions
expected_x = torch.sum(self.pos_x * attention, dim=1, keepdim=True)
expected_y = torch.sum(self.pos_y * attention, dim=1, keepdim=True)
# stack to [B * K, 2]
expected_xy = torch.cat([expected_x, expected_y], 1)
# reshape to [B, K, 2]
feature_keypoints = expected_xy.view(-1, self._num_kp, 2)
if self.training:
noise = torch.randn_like(feature_keypoints) * self.noise_std
feature_keypoints += noise
if self.output_variance:
# treat attention as a distribution, and compute second-order statistics to return
expected_xx = torch.sum(self.pos_x * self.pos_x * attention,
dim=1,
keepdim=True)
expected_yy = torch.sum(self.pos_y * self.pos_y * attention,
dim=1,
keepdim=True)
expected_xy = torch.sum(self.pos_x * self.pos_y * attention,
dim=1,
keepdim=True)
var_x = expected_xx - expected_x * expected_x
var_y = expected_yy - expected_y * expected_y
var_xy = expected_xy - expected_x * expected_y
# stack to [B * K, 4] and then reshape to [B, K, 2, 2] where last 2 dims are covariance matrix
feature_covar = torch.cat([var_x, var_xy, var_xy, var_y],
1).reshape(-1, self._num_kp, 2, 2)
feature_keypoints = (feature_keypoints, feature_covar)
if isinstance(feature_keypoints, tuple):
self.kps = (feature_keypoints[0].detach(),
feature_keypoints[1].detach())
else:
self.kps = feature_keypoints.detach()
return feature_keypoints

View File

@@ -0,0 +1,83 @@
from diffusers.optimization import (Union, SchedulerType, Optional, Optimizer,
TYPE_TO_SCHEDULER_FUNCTION)
def get_scheduler(name: Union[str, SchedulerType],
optimizer: Optimizer,
num_warmup_steps: Optional[int] = None,
num_training_steps: Optional[int] = None,
**kwargs):
"""
Added kwargs vs diffuser's original implementation
Unified API to get any scheduler from its name.
Args:
name (`str` or `SchedulerType`):
The name of the scheduler to use.
optimizer (`torch.optim.Optimizer`):
The optimizer that will be used during training.
num_warmup_steps (`int`, *optional*):
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
num_training_steps (`int``, *optional*):
The number of training steps to do. This is not required by all schedulers (hence the argument being
optional), the function will raise an error if it's unset and the scheduler type requires it.
"""
name = SchedulerType(name)
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
if name == SchedulerType.CONSTANT:
return schedule_func(optimizer, **kwargs)
# All other schedulers require `num_warmup_steps`
if num_warmup_steps is None:
raise ValueError(
f"{name} requires `num_warmup_steps`, please provide that argument."
)
if name == SchedulerType.CONSTANT_WITH_WARMUP:
return schedule_func(optimizer,
num_warmup_steps=num_warmup_steps,
**kwargs)
# All other schedulers require `num_training_steps`
if num_training_steps is None:
raise ValueError(
f"{name} requires `num_training_steps`, please provide that argument."
)
return schedule_func(optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
**kwargs)
import torch
from torch.optim.lr_scheduler import _LRScheduler
import pytorch_lightning as pl
from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, SchedulerType
class SelectiveLRScheduler(_LRScheduler):
def __init__(self,
optimizer,
base_scheduler,
group_indices,
default_lr=[1e-5, 1e-4],
last_epoch=-1):
self.base_scheduler = base_scheduler
self.group_indices = group_indices # Indices of parameter groups to update
self.default_lr = default_lr
super().__init__(optimizer, last_epoch)
def step(self, epoch=None):
self.base_scheduler.step()
base_lrs = self.base_scheduler.get_last_lr()
for idx, group in enumerate(self.optimizer.param_groups):
if idx in self.group_indices:
group['lr'] = base_lrs[idx]
else:
# Reset the learning rate to its initial value
group['lr'] = self.default_lr[idx]

View File

@@ -0,0 +1,16 @@
import torch.nn as nn
class ModuleAttrMixin(nn.Module):
def __init__(self):
super().__init__()
self._dummy_variable = nn.Parameter()
@property
def device(self):
return next(iter(self.parameters())).device
@property
def dtype(self):
return next(iter(self.parameters())).dtype

View File

@@ -0,0 +1,91 @@
import collections
import torch
import torch.nn as nn
from typing import Dict, Callable, List
def dict_apply(
x: Dict[str, torch.Tensor],
func: Callable[[torch.Tensor],
torch.Tensor]) -> Dict[str, torch.Tensor]:
result = dict()
for key, value in x.items():
if isinstance(value, dict):
result[key] = dict_apply(value, func)
else:
result[key] = func(value)
return result
def pad_remaining_dims(x, target):
assert x.shape == target.shape[:len(x.shape)]
return x.reshape(x.shape + (1, ) * (len(target.shape) - len(x.shape)))
def dict_apply_split(
x: Dict[str, torch.Tensor], split_func: Callable[[torch.Tensor],
Dict[str, torch.Tensor]]
) -> Dict[str, torch.Tensor]:
results = collections.defaultdict(dict)
for key, value in x.items():
result = split_func(value)
for k, v in result.items():
results[k][key] = v
return results
def dict_apply_reduce(
x: List[Dict[str,
torch.Tensor]], reduce_func: Callable[[List[torch.Tensor]],
torch.Tensor]
) -> Dict[str, torch.Tensor]:
result = dict()
for key in x[0].keys():
result[key] = reduce_func([x_[key] for x_ in x])
return result
def replace_submodules(root_module: nn.Module, predicate: Callable[[nn.Module],
bool],
func: Callable[[nn.Module], nn.Module]) -> nn.Module:
"""
predicate: Return true if the module is to be replaced.
func: Return new module to use.
"""
if predicate(root_module):
return func(root_module)
bn_list = [
k.split('.')
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
for *parent, k in bn_list:
parent_module = root_module
if len(parent) > 0:
parent_module = root_module.get_submodule('.'.join(parent))
if isinstance(parent_module, nn.Sequential):
src_module = parent_module[int(k)]
else:
src_module = getattr(parent_module, k)
tgt_module = func(src_module)
if isinstance(parent_module, nn.Sequential):
parent_module[int(k)] = tgt_module
else:
setattr(parent_module, k, tgt_module)
# verify that all BN are replaced
bn_list = [
k.split('.')
for k, m in root_module.named_modules(remove_duplicate=True)
if predicate(m)
]
assert len(bn_list) == 0
return root_module
def optimizer_to(optimizer, device):
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(device=device)
return optimizer

View File

@@ -0,0 +1,960 @@
"""
A collection of utilities for working with nested tensor structures consisting
of numpy arrays and torch tensors.
"""
import collections
import numpy as np
import torch
def recursive_dict_list_tuple_apply(x, type_func_dict):
"""
Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of
{data_type: function_to_apply}.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
type_func_dict (dict): a mapping from data types to the functions to be
applied for each data type.
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
assert (list not in type_func_dict)
assert (tuple not in type_func_dict)
assert (dict not in type_func_dict)
if isinstance(x, (dict, collections.OrderedDict)):
new_x = collections.OrderedDict() if isinstance(
x, collections.OrderedDict) else dict()
for k, v in x.items():
new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict)
return new_x
elif isinstance(x, (list, tuple)):
ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x]
if isinstance(x, tuple):
ret = tuple(ret)
return ret
else:
for t, f in type_func_dict.items():
if isinstance(x, t):
return f(x)
else:
raise NotImplementedError('Cannot handle data type %s' %
str(type(x)))
def map_tensor(x, func):
"""
Apply function @func to torch.Tensor objects in a nested dictionary or
list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
func (function): function to apply to each tensor
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(x, {
torch.Tensor: func,
type(None): lambda x: x,
})
def map_ndarray(x, func):
"""
Apply function @func to np.ndarray objects in a nested dictionary or
list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
func (function): function to apply to each array
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(x, {
np.ndarray: func,
type(None): lambda x: x,
})
def map_tensor_ndarray(x, tensor_func, ndarray_func):
"""
Apply function @tensor_func to torch.Tensor objects and @ndarray_func to
np.ndarray objects in a nested dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
tensor_func (function): function to apply to each tensor
ndarray_Func (function): function to apply to each array
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: tensor_func,
np.ndarray: ndarray_func,
type(None): lambda x: x,
})
def clone(x):
"""
Clones all torch tensors and numpy arrays in nested dictionary or list
or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x.clone(),
np.ndarray: lambda x: x.copy(),
type(None): lambda x: x,
})
def detach(x):
"""
Detaches all torch tensors in nested dictionary or list
or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(x, {
torch.Tensor: lambda x: x.detach(),
})
def to_batch(x):
"""
Introduces a leading batch dimension of 1 for all torch tensors and numpy
arrays in nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x[None, ...],
np.ndarray: lambda x: x[None, ...],
type(None): lambda x: x,
})
def to_sequence(x):
"""
Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy
arrays in nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x[:, None, ...],
np.ndarray: lambda x: x[:, None, ...],
type(None): lambda x: x,
})
def index_at_time(x, ind):
"""
Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in
nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
ind (int): index
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x[:, ind, ...],
np.ndarray: lambda x: x[:, ind, ...],
type(None): lambda x: x,
})
def unsqueeze(x, dim):
"""
Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays
in nested dictionary or list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
dim (int): dimension
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x.unsqueeze(dim=dim),
np.ndarray: lambda x: np.expand_dims(x, axis=dim),
type(None): lambda x: x,
})
def contiguous(x):
"""
Makes all torch tensors and numpy arrays contiguous in nested dictionary or
list or tuple and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x.contiguous(),
np.ndarray: lambda x: np.ascontiguousarray(x),
type(None): lambda x: x,
})
def to_device(x, device):
"""
Sends all torch tensors in nested dictionary or list or tuple to device
@device, and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
device (torch.Device): device to send tensors to
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x, d=device: x.to(d),
type(None): lambda x: x,
})
def to_tensor(x):
"""
Converts all numpy arrays in nested dictionary or list or tuple to
torch tensors (and leaves existing torch Tensors as-is), and returns
a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x,
np.ndarray: lambda x: torch.from_numpy(x),
type(None): lambda x: x,
})
def to_numpy(x):
"""
Converts all torch tensors in nested dictionary or list or tuple to
numpy (and leaves existing numpy arrays as-is), and returns
a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
def f(tensor):
if tensor.is_cuda:
return tensor.detach().cpu().numpy()
else:
return tensor.detach().numpy()
return recursive_dict_list_tuple_apply(x, {
torch.Tensor: f,
np.ndarray: lambda x: x,
type(None): lambda x: x,
})
def to_list(x):
"""
Converts all torch tensors and numpy arrays in nested dictionary or list
or tuple to a list, and returns a new nested structure. Useful for
json encoding.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
def f(tensor):
if tensor.is_cuda:
return tensor.detach().cpu().numpy().tolist()
else:
return tensor.detach().numpy().tolist()
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: f,
np.ndarray: lambda x: x.tolist(),
type(None): lambda x: x,
})
def to_float(x):
"""
Converts all torch tensors and numpy arrays in nested dictionary or list
or tuple to float type entries, and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x.float(),
np.ndarray: lambda x: x.astype(np.float32),
type(None): lambda x: x,
})
def to_uint8(x):
"""
Converts all torch tensors and numpy arrays in nested dictionary or list
or tuple to uint8 type entries, and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x.byte(),
np.ndarray: lambda x: x.astype(np.uint8),
type(None): lambda x: x,
})
def to_torch(x, device):
"""
Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to
torch tensors on device @device and returns a new nested structure.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
device (torch.Device): device to send tensors to
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return to_device(to_float(to_tensor(x)), device)
def to_one_hot_single(tensor, num_class):
"""
Convert tensor to one-hot representation, assuming a certain number of total class labels.
Args:
tensor (torch.Tensor): tensor containing integer labels
num_class (int): number of classes
Returns:
x (torch.Tensor): tensor containing one-hot representation of labels
"""
x = torch.zeros(tensor.size() + (num_class, )).to(tensor.device)
x.scatter_(-1, tensor.unsqueeze(-1), 1)
return x
def to_one_hot(tensor, num_class):
"""
Convert all tensors in nested dictionary or list or tuple to one-hot representation,
assuming a certain number of total class labels.
Args:
tensor (dict or list or tuple): a possibly nested dictionary or list or tuple
num_class (int): number of classes
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(tensor,
func=lambda x, nc=num_class: to_one_hot_single(x, nc))
def flatten_single(x, begin_axis=1):
"""
Flatten a tensor in all dimensions from @begin_axis onwards.
Args:
x (torch.Tensor): tensor to flatten
begin_axis (int): which axis to flatten from
Returns:
y (torch.Tensor): flattened tensor
"""
fixed_size = x.size()[:begin_axis]
_s = list(fixed_size) + [-1]
return x.reshape(*_s)
def flatten(x, begin_axis=1):
"""
Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
begin_axis (int): which axis to flatten from
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(x, {
torch.Tensor:
lambda x, b=begin_axis: flatten_single(x, begin_axis=b),
})
def reshape_dimensions_single(x, begin_axis, end_axis, target_dims):
"""
Reshape selected dimensions in a tensor to a target dimension.
Args:
x (torch.Tensor): tensor to reshape
begin_axis (int): begin dimension
end_axis (int): end dimension
target_dims (tuple or list): target shape for the range of dimensions
(@begin_axis, @end_axis)
Returns:
y (torch.Tensor): reshaped tensor
"""
assert (begin_axis <= end_axis)
assert (begin_axis >= 0)
assert (end_axis < len(x.shape))
assert (isinstance(target_dims, (tuple, list)))
s = x.shape
final_s = []
for i in range(len(s)):
if i == begin_axis:
final_s.extend(target_dims)
elif i < begin_axis or i > end_axis:
final_s.append(s[i])
return x.reshape(*final_s)
def reshape_dimensions(x, begin_axis, end_axis, target_dims):
"""
Reshape selected dimensions for all tensors in nested dictionary or list or tuple
to a target dimension.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
begin_axis (int): begin dimension
end_axis (int): end dimension
target_dims (tuple or list): target shape for the range of dimensions
(@begin_axis, @end_axis)
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor:
lambda x, b=begin_axis, e=end_axis, t=target_dims:
reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=t),
np.ndarray:
lambda x, b=begin_axis, e=end_axis, t=target_dims:
reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=t),
type(None):
lambda x: x,
})
def join_dimensions(x, begin_axis, end_axis):
"""
Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for
all tensors in nested dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
begin_axis (int): begin dimension
end_axis (int): end dimension
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor:
lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=[-1]),
np.ndarray:
lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
x, begin_axis=b, end_axis=e, target_dims=[-1]),
type(None):
lambda x: x,
})
def expand_at_single(x, size, dim):
"""
Expand a tensor at a single dimension @dim by @size
Args:
x (torch.Tensor): input tensor
size (int): size to expand
dim (int): dimension to expand
Returns:
y (torch.Tensor): expanded tensor
"""
assert dim < x.ndimension()
assert x.shape[dim] == 1
expand_dims = [-1] * x.ndimension()
expand_dims[dim] = size
return x.expand(*expand_dims)
def expand_at(x, size, dim):
"""
Expand all tensors in nested dictionary or list or tuple at a single
dimension @dim by @size.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
size (int): size to expand
dim (int): dimension to expand
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d))
def unsqueeze_expand_at(x, size, dim):
"""
Unsqueeze and expand a tensor at a dimension @dim by @size.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
size (int): size to expand
dim (int): dimension to unsqueeze and expand
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
x = unsqueeze(x, dim)
return expand_at(x, size, dim)
def repeat_by_expand_at(x, repeats, dim):
"""
Repeat a dimension by combining expand and reshape operations.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
repeats (int): number of times to repeat the target dimension
dim (int): dimension to repeat on
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
x = unsqueeze_expand_at(x, repeats, dim + 1)
return join_dimensions(x, dim, dim + 1)
def named_reduce_single(x, reduction, dim):
"""
Reduce tensor at a dimension by named reduction functions.
Args:
x (torch.Tensor): tensor to be reduced
reduction (str): one of ["sum", "max", "mean", "flatten"]
dim (int): dimension to be reduced (or begin axis for flatten)
Returns:
y (torch.Tensor): reduced tensor
"""
assert x.ndimension() > dim
assert reduction in ["sum", "max", "mean", "flatten"]
if reduction == "flatten":
x = flatten(x, begin_axis=dim)
elif reduction == "max":
x = torch.max(x, dim=dim)[0] # [B, D]
elif reduction == "sum":
x = torch.sum(x, dim=dim)
else:
x = torch.mean(x, dim=dim)
return x
def named_reduce(x, reduction, dim):
"""
Reduces all tensors in nested dictionary or list or tuple at a dimension
using a named reduction function.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
reduction (str): one of ["sum", "max", "mean", "flatten"]
dim (int): dimension to be reduced (or begin axis for flatten)
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(
x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d))
def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices):
"""
This function indexes out a target dimension of a tensor in a structured way,
by allowing a different value to be selected for each member of a flat index
tensor (@indices) corresponding to a source dimension. This can be interpreted
as moving along the source dimension, using the corresponding index value
in @indices to select values for all other dimensions outside of the
source and target dimensions. A common use case is to gather values
in target dimension 1 for each batch member (target dimension 0).
Args:
x (torch.Tensor): tensor to gather values for
target_dim (int): dimension to gather values along
source_dim (int): dimension to hold constant and use for gathering values
from the other dimensions
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
@source_dim
Returns:
y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out
"""
assert len(indices.shape) == 1
assert x.shape[source_dim] == indices.shape[0]
# unsqueeze in all dimensions except the source dimension
new_shape = [1] * x.ndimension()
new_shape[source_dim] = -1
indices = indices.reshape(*new_shape)
# repeat in all dimensions - but preserve shape of source dimension,
# and make sure target_dimension has singleton dimension
expand_shape = list(x.shape)
expand_shape[source_dim] = -1
expand_shape[target_dim] = 1
indices = indices.expand(*expand_shape)
out = x.gather(dim=target_dim, index=indices)
return out.squeeze(target_dim)
def gather_along_dim_with_dim(x, target_dim, source_dim, indices):
"""
Apply @gather_along_dim_with_dim_single to all tensors in a nested
dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
target_dim (int): dimension to gather values along
source_dim (int): dimension to hold constant and use for gathering values
from the other dimensions
indices (torch.Tensor): flat index tensor with same shape as tensor @x along
@source_dim
Returns:
y (dict or list or tuple): new nested dict-list-tuple
"""
return map_tensor(x,
lambda y, t=target_dim, s=source_dim, i=indices:
gather_along_dim_with_dim_single(y, t, s, i))
def gather_sequence_single(seq, indices):
"""
Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in
the batch given an index for each sequence.
Args:
seq (torch.Tensor): tensor with leading dimensions [B, T, ...]
indices (torch.Tensor): tensor indices of shape [B]
Return:
y (torch.Tensor): indexed tensor of shape [B, ....]
"""
return gather_along_dim_with_dim_single(seq,
target_dim=1,
source_dim=0,
indices=indices)
def gather_sequence(seq, indices):
"""
Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch
for tensors with leading dimensions [B, T, ...].
Args:
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
of leading dimensions [B, T, ...]
indices (torch.Tensor): tensor indices of shape [B]
Returns:
y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...]
"""
return gather_along_dim_with_dim(seq,
target_dim=1,
source_dim=0,
indices=indices)
def pad_sequence_single(seq,
padding,
batched=False,
pad_same=True,
pad_values=None):
"""
Pad input tensor or array @seq in the time dimension (dimension 1).
Args:
seq (np.ndarray or torch.Tensor): sequence to be padded
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
batched (bool): if sequence has the batch dimension
pad_same (bool): if pad by duplicating
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
Returns:
padded sequence (np.ndarray or torch.Tensor)
"""
assert isinstance(seq, (np.ndarray, torch.Tensor))
assert pad_same or pad_values is not None
if pad_values is not None:
assert isinstance(pad_values, float)
repeat_func = np.repeat if isinstance(
seq, np.ndarray) else torch.repeat_interleave
concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat
ones_like_func = np.ones_like if isinstance(
seq, np.ndarray) else torch.ones_like
seq_dim = 1 if batched else 0
begin_pad = []
end_pad = []
if padding[0] > 0:
pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values
begin_pad.append(repeat_func(pad, padding[0], seq_dim))
if padding[1] > 0:
pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values
end_pad.append(repeat_func(pad, padding[1], seq_dim))
return concat_func(begin_pad + [seq] + end_pad, seq_dim)
def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None):
"""
Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1).
Args:
seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
of leading dimensions [B, T, ...]
padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
batched (bool): if sequence has the batch dimension
pad_same (bool): if pad by duplicating
pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
Returns:
padded sequence (dict or list or tuple)
"""
return recursive_dict_list_tuple_apply(
seq, {
torch.Tensor:
lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values:
pad_sequence_single(x, p, b, ps, pv),
np.ndarray:
lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values:
pad_sequence_single(x, p, b, ps, pv),
type(None):
lambda x: x,
})
def assert_size_at_dim_single(x, size, dim, msg):
"""
Ensure that array or tensor @x has size @size in dim @dim.
Args:
x (np.ndarray or torch.Tensor): input array or tensor
size (int): size that tensors should have at @dim
dim (int): dimension to check
msg (str): text to display if assertion fails
"""
assert x.shape[dim] == size, msg
def assert_size_at_dim(x, size, dim, msg):
"""
Ensure that arrays and tensors in nested dictionary or list or tuple have
size @size in dim @dim.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
size (int): size that tensors should have at @dim
dim (int): dimension to check
"""
map_tensor(
x,
lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m))
def get_shape(x):
"""
Get all shapes of arrays and tensors in nested dictionary or list or tuple.
Args:
x (dict or list or tuple): a possibly nested dictionary or list or tuple
Returns:
y (dict or list or tuple): new nested dict-list-tuple that contains each array or
tensor's shape
"""
return recursive_dict_list_tuple_apply(
x, {
torch.Tensor: lambda x: x.shape,
np.ndarray: lambda x: x.shape,
type(None): lambda x: x,
})
def list_of_flat_dict_to_dict_of_list(list_of_dict):
"""
Helper function to go from a list of flat dictionaries to a dictionary of lists.
By "flat" we mean that none of the values are dictionaries, but are numpy arrays,
floats, etc.
Args:
list_of_dict (list): list of flat dictionaries
Returns:
dict_of_list (dict): dictionary of lists
"""
assert isinstance(list_of_dict, list)
dic = collections.OrderedDict()
for i in range(len(list_of_dict)):
for k in list_of_dict[i]:
if k not in dic:
dic[k] = []
dic[k].append(list_of_dict[i][k])
return dic
def flatten_nested_dict_list(d, parent_key='', sep='_', item_key=''):
"""
Flatten a nested dict or list to a list.
For example, given a dict
{
a: 1
b: {
c: 2
}
c: 3
}
the function would return [(a, 1), (b_c, 2), (c, 3)]
Args:
d (dict, list): a nested dict or list to be flattened
parent_key (str): recursion helper
sep (str): separator for nesting keys
item_key (str): recursion helper
Returns:
list: a list of (key, value) tuples
"""
items = []
if isinstance(d, (tuple, list)):
new_key = parent_key + sep + item_key if len(
parent_key) > 0 else item_key
for i, v in enumerate(d):
items.extend(
flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i)))
return items
elif isinstance(d, dict):
new_key = parent_key + sep + item_key if len(
parent_key) > 0 else item_key
for k, v in d.items():
assert isinstance(k, str)
items.extend(
flatten_nested_dict_list(v, new_key, sep=sep, item_key=k))
return items
else:
new_key = parent_key + sep + item_key if len(
parent_key) > 0 else item_key
return [(new_key, d)]
def time_distributed(inputs,
op,
activation=None,
inputs_as_kwargs=False,
inputs_as_args=False,
**kwargs):
"""
Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the
batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...].
Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping
outputs to [B, T, ...].
Args:
inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors
of leading dimensions [B, T, ...]
op: a layer op that accepts inputs
activation: activation to apply at the output
inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op
inputs_as_args (bool) whether to feed input as a args list to the op
kwargs (dict): other kwargs to supply to the op
Returns:
outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T].
"""
batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2]
inputs = join_dimensions(inputs, 0, 1)
if inputs_as_kwargs:
outputs = op(**inputs, **kwargs)
elif inputs_as_args:
outputs = op(*inputs, **kwargs)
else:
outputs = op(inputs, **kwargs)
if activation is not None:
outputs = map_tensor(outputs, activation)
outputs = reshape_dimensions(outputs,
begin_axis=0,
end_axis=0,
target_dims=(batch_size, seq_len))
return outputs

View File

@@ -0,0 +1,701 @@
import logging
import torch
import torch.nn as nn
import einops
from einops import rearrange, repeat
from typing import Union
from unifolm_wma.models.diffusion_head.conv1d_components import (
Downsample1d, Upsample1d, Conv1dBlock)
from unifolm_wma.models.diffusion_head.positional_embedding import SinusoidalPosEmb
from unifolm_wma.models.diffusion_head.base_nets import SpatialSoftmax
from unifolm_wma.utils.basics import zero_module
from unifolm_wma.utils.common import (
checkpoint,
exists,
default,
)
from unifolm_wma.utils.utils import instantiate_from_config
logger = logging.getLogger(__name__)
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
class CrossAttention(nn.Module):
def __init__(self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.,
relative_position=False):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout))
def efficient_forward(self, x, context=None):
spatial_self_attn = (context is None)
k_ip, v_ip, out_ip = None, None, None
q = self.to_q(x)
if spatial_self_attn:
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(q,
k,
v,
attn_bias=None,
op=None)
out = (out.unsqueeze(0).reshape(
b, self.heads, out.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out.shape[1],
self.heads * self.dim_head))
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(self,
dim,
n_heads,
d_head,
dropout=0.,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
attention_cls=None):
super().__init__()
attn_cls = CrossAttention if attention_cls is None else attention_cls
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None, **kwargs):
## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
input_tuple = (
x,
) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments
if context is not None:
input_tuple = (x, context)
return checkpoint(self._forward, input_tuple, self.parameters(),
self.checkpoint)
def _forward(self, x, context=None, mask=None):
x = self.attn1(self.norm1(x),
context=context if self.disable_self_attn else None,
mask=mask) + x
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
x = self.ff(self.norm3(x)) + x
return x
class ActionLatentImageCrossAttention(nn.Module):
def __init__(self,
in_channels,
in_dim,
n_heads,
d_head,
depth=1,
dropout=0.,
context_dim=None,
use_checkpoint=True,
disable_self_attn=False,
use_linear=True):
super().__init__()
"""
in_channels: action input dim
"""
self.in_channels = in_channels
self.in_dim = in_dim
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=8,
num_channels=in_channels,
eps=1e-6,
affine=True)
self.proj_in_action = nn.Linear(in_dim, inner_dim)
self.proj_in_cond = nn.Linear(context_dim, inner_dim)
self.proj_out = zero_module(nn.Linear(inner_dim, in_dim))
self.use_linear = use_linear
attention_cls = None
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
disable_self_attn=disable_self_attn,
checkpoint=use_checkpoint,
attention_cls=attention_cls)
for d in range(depth)
])
def forward(self, x, context=None, **kwargs):
ba, ca, da = x.shape
b, t, c, h, w = context.shape
context = rearrange(context, 'b t c h w -> b (t h w) c').contiguous()
x_in = x
x = self.norm(x) # ba x ja x d_in
if self.use_linear:
x = self.proj_in_action(x)
context = self.proj_in_cond(context)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context, **kwargs)
if self.use_linear:
x = self.proj_out(x)
return x + x_in
class ConditionalResidualBlock1D(nn.Module):
def __init__(self,
in_channels,
out_channels,
cond_dim,
kernel_size=3,
n_groups=8,
cond_predict_scale=True,
use_linear_act_proj=False):
super().__init__()
self.blocks = nn.ModuleList([
Conv1dBlock(in_channels,
out_channels,
kernel_size,
n_groups=n_groups),
Conv1dBlock(out_channels,
out_channels,
kernel_size,
n_groups=n_groups),
])
self.cond_predict_scale = cond_predict_scale
self.use_linear_act_proj = use_linear_act_proj
self.out_channels = out_channels
# FiLM modulation https://arxiv.org/abs/1709.07871
# predicts per-channel scale and bias
cond_channels = out_channels
if cond_predict_scale and use_linear_act_proj:
cond_channels = out_channels * 2
self.cond_encoder = nn.Sequential(
nn.Mish(),
nn.Linear(cond_dim, cond_channels),
)
# make sure dimensions compatible
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
if in_channels != out_channels else nn.Identity()
def forward(self, x, cond=None):
'''
x : [ batch_size x in_channels x horizon ]
cond : [ batch_size x cond_dim]
returns:
out : [ batch_size x out_channels x horizon ]
'''
B, T, _ = cond.shape
out = self.blocks[0](x)
if self.cond_predict_scale:
embed = self.cond_encoder(cond)
if self.use_linear_act_proj:
embed = embed.reshape(B * T, -1)
embed = embed.reshape(-1, 2, self.out_channels, 1)
else:
embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
scale = embed[:, 0, ...]
bias = embed[:, 1, ...]
out = scale * out + bias
# else:
# out = out + embed
out = self.blocks[1](out)
out = out + self.residual_conv(x)
return out
class ConditionalUnet1D(nn.Module):
def __init__(self,
input_dim,
n_obs_steps=1,
local_cond_dim=None,
global_cond_dim=None,
diffusion_step_embed_dim=256,
down_dims=[256, 512, 1024],
kernel_size=3,
n_groups=8,
cond_predict_scale=False,
horizon=16,
num_head_channels=64,
use_linear_attn=True,
use_linear_act_proj=True,
act_proj_dim=32,
cond_cross_attention=False,
context_dims=None,
image_size=None,
imagen_cond_gradient=False,
last_frame_only=False,
use_imagen_mid_only=False,
use_z_only=False,
spatial_num_kp=32,
obs_encoder_config=None):
super().__init__()
self.n_obs_steps = n_obs_steps
self.obs_encoder = instantiate_from_config(obs_encoder_config)
all_dims = [input_dim] + list(down_dims)
start_dim = down_dims[0]
dsed = diffusion_step_embed_dim
diffusion_step_encoder = nn.Sequential(
SinusoidalPosEmb(dsed),
nn.Linear(dsed, dsed * 4),
nn.Mish(),
nn.Linear(dsed * 4, dsed),
)
cond_dim = dsed + self.obs_encoder.output_shape()[-1] * self.n_obs_steps
in_out = list(zip(all_dims[:-1], all_dims[1:]))
local_cond_encoder = None
down_modules = nn.ModuleList([])
dim_a_list = []
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (len(in_out) - 1)
if ind == 0:
dim_a = horizon
else:
dim_a = horizon // 2 * ind
dim_a_list.append(dim_a)
# for attention
num_heads = dim_out // num_head_channels
dim_head = num_head_channels
if use_linear_act_proj:
if use_imagen_mid_only:
cur_cond_dim = cond_dim + 2 * context_dims[-1]
elif use_z_only:
cur_cond_dim = cond_dim + 2 * spatial_num_kp
else:
cur_cond_dim = cond_dim + 2 * context_dims[ind]
else:
cur_cond_dim = cond_dim + horizon * context_dims[ind]
down_modules.append(
nn.ModuleList([
ConditionalResidualBlock1D(
dim_in,
dim_out,
cond_dim=cur_cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
use_linear_act_proj=use_linear_act_proj),
ConditionalResidualBlock1D(
dim_out,
dim_out,
cond_dim=cur_cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
use_linear_act_proj=use_linear_act_proj),
ActionLatentImageCrossAttention(
dim_out,
dim_a,
num_heads,
dim_head,
context_dim=context_dims[ind],
use_linear=use_linear_attn)
if cond_cross_attention else nn.Identity(),
Downsample1d(dim_out) if not is_last else nn.Identity()
]))
mid_dim = all_dims[-1]
self.mid_modules = nn.ModuleList([
ConditionalResidualBlock1D(
mid_dim,
mid_dim,
cond_dim=cur_cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
use_linear_act_proj=use_linear_act_proj),
ConditionalResidualBlock1D(
mid_dim,
mid_dim,
cond_dim=cur_cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
use_linear_act_proj=use_linear_act_proj),
ActionLatentImageCrossAttention(mid_dim,
dim_a_list[-1],
num_heads,
dim_head,
context_dim=context_dims[-1],
use_linear=use_linear_attn)
if cond_cross_attention else nn.Identity(),
])
up_modules = nn.ModuleList([])
context_dims = context_dims[::-1]
for ind, (dim_in, dim_out) in enumerate(
reversed(in_out[1:] + [(down_dims[-1], down_dims[-1])])):
is_last = ind >= (len(in_out) - 1)
if use_linear_act_proj:
if use_imagen_mid_only:
cur_cond_dim = cond_dim + 2 * context_dims[0]
elif use_z_only:
cur_cond_dim = cond_dim + 2 * spatial_num_kp
else:
cur_cond_dim = cond_dim + 2 * context_dims[ind]
else:
cur_cond_dim = cond_dim + horizon * context_dims[ind]
up_modules.append(
nn.ModuleList([
ConditionalResidualBlock1D(
dim_out + dim_in,
dim_in,
cond_dim=cur_cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
use_linear_act_proj=use_linear_act_proj),
ConditionalResidualBlock1D(
dim_in,
dim_in,
cond_dim=cur_cond_dim,
kernel_size=kernel_size,
n_groups=n_groups,
cond_predict_scale=cond_predict_scale,
use_linear_act_proj=use_linear_act_proj),
ActionLatentImageCrossAttention(
dim_in,
dim_a_list.pop(),
num_heads,
dim_head,
context_dim=context_dims[ind],
use_linear=use_linear_attn)
if cond_cross_attention else nn.Identity(),
Upsample1d(dim_in) if not is_last else nn.Identity()
]))
final_conv = nn.Sequential(
Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
nn.Conv1d(start_dim, input_dim, 1),
)
if use_z_only:
h, w = image_size
self.spatial_softmax_blocks = nn.ModuleList(
[SpatialSoftmax((4, h, w), spatial_num_kp)])
else:
self.spatial_softmax_blocks = nn.ModuleList([])
context_dims = context_dims[::-1]
for ind, context_dim in enumerate(context_dims):
h, w = image_size
if ind != 0:
h //= 2**ind
w //= 2**ind
net = SpatialSoftmax((context_dim, h, w), context_dim)
self.spatial_softmax_blocks.append(net)
self.spatial_softmax_blocks.append(net)
self.spatial_softmax_blocks += self.spatial_softmax_blocks[
0:4][::-1]
self.diffusion_step_encoder = diffusion_step_encoder
self.local_cond_encoder = local_cond_encoder
self.up_modules = up_modules
self.down_modules = down_modules
self.final_conv = final_conv
self.cond_cross_attention = cond_cross_attention
self.use_linear_act_proj = use_linear_act_proj
self.proj_in_action = nn.Sequential(nn.Linear(1, act_proj_dim),
nn.LayerNorm(act_proj_dim))
self.proj_in_horizon = nn.Sequential(nn.Linear(horizon, act_proj_dim),
nn.LayerNorm(act_proj_dim))
self.proj_out_action = nn.Sequential(nn.LayerNorm(act_proj_dim),
nn.Linear(act_proj_dim, 1))
self.proj_out_horizon = nn.Sequential(nn.LayerNorm(act_proj_dim),
nn.Linear(act_proj_dim, horizon))
logger.info("number of parameters: %e",
sum(p.numel() for p in self.parameters()))
self.imagen_cond_gradient = imagen_cond_gradient
self.use_imagen_mid_only = use_imagen_mid_only
self.use_z_only = use_z_only
self.spatial_num_kp = spatial_num_kp
self.last_frame_only = last_frame_only
self.horizon = horizon
def forward(self,
sample: torch.Tensor,
timestep: Union[torch.Tensor, float, int],
imagen_cond=None,
cond=None,
**kwargs):
"""
sample: (B,T,input_dim)
timestep: (B,) or int, diffusion step
imagen_cond: a list of hidden info from video gen unet
cond: dict:
image: (B, 3, To, h, w)
agent_pos: (B, Ta, d)
output: (B,T,input_dim)
"""
if not self.imagen_cond_gradient:
imagen_cond = [c.detach() for c in imagen_cond]
cond = {'image': cond[0], 'agent_pos': cond[1]}
cond['image'] = cond['image'].permute(0, 2, 1, 3,
4)
cond['image'] = rearrange(cond['image'], 'b t c h w -> (b t) c h w')
cond['agent_pos'] = rearrange(cond['agent_pos'], 'b t d -> (b t) d')
B, T, D = sample.shape
if self.use_linear_act_proj:
sample = self.proj_in_action(sample.unsqueeze(-1))
global_cond = self.obs_encoder(cond)
global_cond = rearrange(global_cond,
'(b t) d -> b 1 (t d)',
b=B,
t=self.n_obs_steps)
global_cond = repeat(global_cond,
'b c d -> b (repeat c) d',
repeat=T)
else:
sample = einops.rearrange(sample, 'b h t -> b t h')
sample = self.proj_in_horizon(sample)
robo_state_cond = rearrange(robo_state_cond, 'b t d -> b 1 (t d)')
robo_state_cond = repeat(robo_state_cond,
'b c d -> b (repeat c) d',
repeat=2)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps],
dtype=torch.long,
device=sample.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# Broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
global_feature = self.diffusion_step_encoder(timesteps)
(imagen_cond_down, imagen_cond_mid, imagen_cond_up
) = imagen_cond[0:4], imagen_cond[4], imagen_cond[5:] #NOTE HAND CODE
x = sample if not self.use_linear_act_proj else sample.reshape(
B * T, D, -1)
h = []
for idx, modules in enumerate(self.down_modules):
if self.cond_cross_attention:
(resnet, resnet2, crossatten, downsample) = modules
else:
(resnet, resnet2, _, downsample) = modules
# Access the cond from the unet embeds from video unet
if self.use_imagen_mid_only:
imagen_cond = imagen_cond_mid
elif self.use_z_only:
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
else:
imagen_cond = imagen_cond_down[idx]
if self.last_frame_only:
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
imagen_cond = repeat(imagen_cond,
'b t c h w -> b (repeat t) c h w',
repeat=self.horizon)
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
if self.use_imagen_mid_only:
imagen_cond = self.spatial_softmax_blocks[len(
self.spatial_softmax_blocks) // 2](imagen_cond)
elif self.use_z_only:
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
else:
imagen_cond = self.spatial_softmax_blocks[idx](imagen_cond)
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
if self.use_linear_act_proj:
imagen_cond = imagen_cond.reshape(B, T, -1)
cur_global_feature = global_feature.unsqueeze(
1).repeat_interleave(repeats=T, dim=1)
else:
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
imagen_cond = imagen_cond.reshape(B, 2, -1)
cur_global_feature = global_feature.unsqueeze(
1).repeat_interleave(repeats=2, dim=1)
cur_global_feature = torch.cat(
[cur_global_feature, global_cond, imagen_cond], axis=-1)
x = resnet(x, cur_global_feature)
x = resnet2(x, cur_global_feature)
h.append(x)
x = downsample(x)
#>>> mide blocks
resnet, resnet2, _ = self.mid_modules
# Access the cond from the unet embeds from video unet
if self.use_z_only:
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
else:
imagen_cond = imagen_cond_mid
if self.last_frame_only:
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
imagen_cond = repeat(imagen_cond,
'b t c h w -> b (repeat t) c h w',
repeat=self.horizon)
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
idx += 1
if self.use_z_only:
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
else:
imagen_cond = self.spatial_softmax_blocks[idx](imagen_cond)
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
if self.use_linear_act_proj:
imagen_cond = imagen_cond.reshape(B, T, -1)
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
repeats=T, dim=1)
else:
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
imagen_cond = imagen_cond.reshape(B, 2, -1)
cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
repeats=2, dim=1)
cur_global_feature = torch.cat(
[cur_global_feature, global_cond, imagen_cond], axis=-1)
x = resnet(x, cur_global_feature)
x = resnet2(x, cur_global_feature)
#>>> up blocks
idx += 1
for jdx, modules in enumerate(self.up_modules):
if self.cond_cross_attention:
(resnet, resnet2, crossatten, upsample) = modules
else:
(resnet, resnet2, _, upsample) = modules
# Access the cond from the unet embeds from video unet
if self.use_imagen_mid_only:
imagen_cond = imagen_cond_mid
elif self.use_z_only:
imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
else:
imagen_cond = imagen_cond_up[jdx]
if self.last_frame_only:
imagen_cond = imagen_cond[:, -1].unsqueeze(1)
imagen_cond = repeat(imagen_cond,
'b t c h w -> b (repeat t) c h w',
repeat=self.horizon)
imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
if self.use_imagen_mid_only:
imagen_cond = self.spatial_softmax_blocks[len(
self.spatial_softmax_blocks) // 2](imagen_cond)
elif self.use_z_only:
imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
else:
imagen_cond = self.spatial_softmax_blocks[jdx +
idx](imagen_cond)
imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
if self.use_linear_act_proj:
imagen_cond = imagen_cond.reshape(B, T, -1)
cur_global_feature = global_feature.unsqueeze(
1).repeat_interleave(repeats=T, dim=1)
else:
imagen_cond = imagen_cond.permute(0, 3, 1, 2)
imagen_cond = imagen_cond.reshape(B, 2, -1)
cur_global_feature = global_feature.unsqueeze(
1).repeat_interleave(repeats=2, dim=1)
cur_global_feature = torch.cat(
[cur_global_feature, global_cond, imagen_cond], axis=-1)
x = torch.cat((x, h.pop()), dim=1)
x = resnet(x, cur_global_feature)
x = resnet2(x, cur_global_feature)
x = upsample(x)
x = self.final_conv(x)
if self.use_linear_act_proj:
x = x.reshape(B, T, D, -1)
x = self.proj_out_action(x)
x = x.reshape(B, T, D)
else:
x = self.proj_out_horizon(x)
x = einops.rearrange(x, 'b t h -> b h t')
return x

View File

@@ -0,0 +1,52 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class Downsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Upsample1d(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Conv1dBlock(nn.Module):
'''
Conv1d --> GroupNorm --> Mish
'''
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(inp_channels,
out_channels,
kernel_size,
padding=kernel_size // 2),
# Rearrange('batch channels horizon -> batch channels 1 horizon'),
nn.GroupNorm(n_groups, out_channels),
# Rearrange('batch channels 1 horizon -> batch channels horizon'),
nn.Mish(),
)
def forward(self, x):
return self.block(x)
def test():
cb = Conv1dBlock(256, 128, kernel_size=3)
x = torch.zeros((1, 256, 16))
o = cb(x)

View File

@@ -0,0 +1,80 @@
import copy
import torch
from torch.nn.modules.batchnorm import _BatchNorm
class EMAModel:
"""
Exponential Moving Average of models weights
"""
def __init__(self,
model,
update_after_step=0,
inv_gamma=1.0,
power=2 / 3,
min_value=0.0,
max_value=0.9999):
"""
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
at 215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 2/3.
min_value (float): The minimum EMA decay rate. Default: 0.
"""
self.averaged_model = model
self.averaged_model.eval()
self.averaged_model.requires_grad_(False)
self.update_after_step = update_after_step
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
self.max_value = max_value
self.decay = 0.0
self.optimization_step = 0
def get_decay(self, optimization_step):
"""
Compute the decay factor for the exponential moving average.
"""
step = max(0, optimization_step - self.update_after_step - 1)
value = 1 - (1 + step / self.inv_gamma)**-self.power
if step <= 0:
return 0.0
return max(self.min_value, min(value, self.max_value))
@torch.no_grad()
def step(self, new_model):
self.decay = self.get_decay(self.optimization_step)
all_dataptrs = set()
for module, ema_module in zip(new_model.modules(),
self.averaged_model.modules()):
for param, ema_param in zip(module.parameters(recurse=False),
ema_module.parameters(recurse=False)):
# iterative over immediate parameters only.
if isinstance(param, dict):
raise RuntimeError('Dict parameter not supported')
if isinstance(module, _BatchNorm):
# skip batchnorms
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
elif not param.requires_grad:
ema_param.copy_(param.to(dtype=ema_param.dtype).data)
else:
ema_param.mul_(self.decay)
ema_param.add_(param.data.to(dtype=ema_param.dtype),
alpha=1 - self.decay)
# verify that iterating over module and then parameters is identical to parameters recursively.
# assert old_all_dataptrs == all_dataptrs
self.optimization_step += 1

View File

@@ -0,0 +1,19 @@
import math
import torch
import torch.nn as nn
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = x[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb

View File

@@ -0,0 +1,322 @@
import torch
import torch.nn as nn
import torchvision.transforms.functional as ttf
import unifolm_wma.models.diffusion_head.common.tensor_util as tu
class CropRandomizer(nn.Module):
"""
Randomly sample crops at input, and then average across crop features at output.
"""
def __init__(
self,
input_shape,
crop_height,
crop_width,
num_crops=1,
pos_enc=False,
):
"""
Args:
input_shape (tuple, list): shape of input (not including batch dimension)
crop_height (int): crop height
crop_width (int): crop width
num_crops (int): number of random crops to take
pos_enc (bool): if True, add 2 channels to the output to encode the spatial
location of the cropped pixels in the source image
"""
super().__init__()
assert len(input_shape) == 3 # (C, H, W)
assert crop_height < input_shape[1]
assert crop_width < input_shape[2]
self.input_shape = input_shape
self.crop_height = crop_height
self.crop_width = crop_width
self.num_crops = num_crops
self.pos_enc = pos_enc
def output_shape_in(self, input_shape=None):
"""
Function to compute output shape from inputs to this module. Corresponds to
the @forward_in operation, where raw inputs (usually observation modalities)
are passed in.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
# outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because
# the number of crops are reshaped into the batch dimension, increasing the batch
# size from B to B * N
out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0]
return [out_c, self.crop_height, self.crop_width]
def output_shape_out(self, input_shape=None):
"""
Function to compute output shape from inputs to this module. Corresponds to
the @forward_out operation, where processed inputs (usually encoded observation
modalities) are passed in.
Args:
input_shape (iterable of int): shape of input. Does not include batch dimension.
Some modules may not need this argument, if their output does not depend
on the size of the input, or if they assume fixed size input.
Returns:
out_shape ([int]): list of integers corresponding to output shape
"""
# since the forward_out operation splits [B * N, ...] -> [B, N, ...]
# and then pools to result in [B, ...], only the batch dimension changes,
# and so the other dimensions retain their shape.
return list(input_shape)
def forward_in(self, inputs):
"""
Samples N random crops for each input in the batch, and then reshapes
inputs to [B * N, ...].
"""
assert len(
inputs.shape) >= 3 # must have at least (C, H, W) dimensions
if self.training:
# generate random crops
out, _ = sample_random_image_crops(
images=inputs,
crop_height=self.crop_height,
crop_width=self.crop_width,
num_crops=self.num_crops,
pos_enc=self.pos_enc,
)
# [B, N, ...] -> [B * N, ...]
return tu.join_dimensions(out, 0, 1)
else:
# take center crop during eval
out = ttf.center_crop(img=inputs,
output_size=(self.crop_height,
self.crop_width))
if self.num_crops > 1:
B, C, H, W = out.shape
out = out.unsqueeze(1).expand(B, self.num_crops, C, H,
W).reshape(-1, C, H, W)
# [B * N, ...]
return out
def forward_out(self, inputs):
"""
Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N
to result in shape [B, ...] to make sure the network output is consistent with
what would have happened if there were no randomization.
"""
if self.num_crops <= 1:
return inputs
else:
batch_size = (inputs.shape[0] // self.num_crops)
out = tu.reshape_dimensions(inputs,
begin_axis=0,
end_axis=0,
target_dims=(batch_size,
self.num_crops))
return out.mean(dim=1)
def forward(self, inputs):
return self.forward_in(inputs)
def __repr__(self):
"""Pretty print network."""
header = '{}'.format(str(self.__class__.__name__))
msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format(
self.input_shape, self.crop_height, self.crop_width,
self.num_crops)
return msg
def crop_image_from_indices(images, crop_indices, crop_height, crop_width):
"""
Crops images at the locations specified by @crop_indices. Crops will be
taken across all channels.
Args:
images (torch.Tensor): batch of images of shape [..., C, H, W]
crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where
N is the number of crops to take per image and each entry corresponds
to the pixel height and width of where to take the crop. Note that
the indices can also be of shape [..., 2] if only 1 crop should
be taken per image. Leading dimensions must be consistent with
@images argument. Each index specifies the top left of the crop.
Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where
H and W are the height and width of @images and CH and CW are
@crop_height and @crop_width.
crop_height (int): height of crop to take
crop_width (int): width of crop to take
Returns:
crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width]
"""
# make sure length of input shapes is consistent
assert crop_indices.shape[-1] == 2
ndim_im_shape = len(images.shape)
ndim_indices_shape = len(crop_indices.shape)
assert (ndim_im_shape == ndim_indices_shape +
1) or (ndim_im_shape == ndim_indices_shape + 2)
# maybe pad so that @crop_indices is shape [..., N, 2]
is_padded = False
if ndim_im_shape == ndim_indices_shape + 2:
crop_indices = crop_indices.unsqueeze(-2)
is_padded = True
# make sure leading dimensions between images and indices are consistent
assert images.shape[:-3] == crop_indices.shape[:-2]
device = images.device
image_c, image_h, image_w = images.shape[-3:]
num_crops = crop_indices.shape[-2]
# make sure @crop_indices are in valid range
assert (crop_indices[..., 0] >= 0).all().item()
assert (crop_indices[..., 0] < (image_h - crop_height)).all().item()
assert (crop_indices[..., 1] >= 0).all().item()
assert (crop_indices[..., 1] < (image_w - crop_width)).all().item()
# convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window.
# 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW]
crop_ind_grid_h = torch.arange(crop_height).to(device)
crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h,
size=crop_width,
dim=-1)
# 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW]
crop_ind_grid_w = torch.arange(crop_width).to(device)
crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w,
size=crop_height,
dim=0)
# combine into shape [CH, CW, 2]
crop_in_grid = torch.cat(
(crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1)
# Add above grid with the offset index of each sampled crop to get 2d indices for each crop.
# After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2]
# shape array that tells us which pixels from the corresponding source image to grab.
grid_reshape = [1] * len(crop_indices.shape[:-1]) + [
crop_height, crop_width, 2
]
all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(
-2) + crop_in_grid.reshape(grid_reshape)
# For using @torch.gather, convert to flat indices from 2D indices, and also
# repeat across the channel dimension. To get flat index of each pixel to grab for
# each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind
all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[
..., 1] # shape [..., N, CH, CW]
all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c,
dim=-3) # shape [..., N, C, CH, CW]
all_crop_inds = tu.flatten(all_crop_inds,
begin_axis=-2) # shape [..., N, C, CH * CW]
# Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds
images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4)
images_to_crop = tu.flatten(images_to_crop, begin_axis=-2)
crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds)
# [..., N, C, CH * CW] -> [..., N, C, CH, CW]
reshape_axis = len(crops.shape) - 1
crops = tu.reshape_dimensions(crops,
begin_axis=reshape_axis,
end_axis=reshape_axis,
target_dims=(crop_height, crop_width))
if is_padded:
# undo padding -> [..., C, CH, CW]
crops = crops.squeeze(-4)
return crops
def sample_random_image_crops(images,
crop_height,
crop_width,
num_crops,
pos_enc=False):
"""
For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from
@images.
Args:
images (torch.Tensor): batch of images of shape [..., C, H, W]
crop_height (int): height of crop to take
crop_width (int): width of crop to take
num_crops (n): number of crops to sample
pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial
encoding of the original source pixel locations. This means that the
output crops will contain information about where in the source image
it was sampled from.
Returns:
crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width)
if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width)
crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2)
"""
device = images.device
# maybe add 2 channels of spatial encoding to the source image
source_im = images
if pos_enc:
# spatial encoding [y, x] in [0, 1]
h, w = source_im.shape[-2:]
pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w))
pos_y = pos_y.float().to(device) / float(h)
pos_x = pos_x.float().to(device) / float(w)
position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W]
# unsqueeze and expand to match leading dimensions -> shape [..., C, H, W]
leading_shape = source_im.shape[:-3]
position_enc = position_enc[(None, ) * len(leading_shape)]
position_enc = position_enc.expand(*leading_shape, -1, -1, -1)
# concat across channel dimension with input
source_im = torch.cat((source_im, position_enc), dim=-3)
# make sure sample boundaries ensure crops are fully within the images
image_c, image_h, image_w = source_im.shape[-3:]
max_sample_h = image_h - crop_height
max_sample_w = image_w - crop_width
# Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W].
# Each gets @num_crops samples - typically this will just be the batch dimension (B), so
# we will sample [B, N] indices, but this supports having more than one leading dimension,
# or possibly no leading dimension.
#
# Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints
crop_inds_h = (
max_sample_h *
torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
crop_inds_w = (
max_sample_w *
torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
crop_inds = torch.cat(
(crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)),
dim=-1) # shape [..., N, 2]
crops = crop_image_from_indices(
images=source_im,
crop_indices=crop_inds,
crop_height=crop_height,
crop_width=crop_width,
)
return crops, crop_inds

View File

@@ -0,0 +1,30 @@
import torch
import torchvision
def get_resnet(name, weights=None, **kwargs):
"""
name: resnet18, resnet34, resnet50
weights: "IMAGENET1K_V1", "r3m"
"""
# load r3m weights
if (weights == "r3m") or (weights == "R3M"):
return get_r3m(name=name, **kwargs)
func = getattr(torchvision.models, name)
resnet = func(weights=weights, **kwargs)
resnet.fc = torch.nn.Identity()
return resnet
def get_r3m(name, **kwargs):
"""
name: resnet18, resnet34, resnet50
"""
import r3m
r3m.device = 'cpu'
model = r3m.load_r3m(name)
r3m_model = model.module
resnet_model = r3m_model.convnet
resnet_model = resnet_model.to('cpu')
return resnet_model

View File

@@ -0,0 +1,247 @@
import copy
import torch
import torch.nn as nn
import torchvision
import json
import os
from unifolm_wma.models.diffusion_head.vision.crop_randomizer import CropRandomizer
from unifolm_wma.models.diffusion_head.base_nets import SpatialSoftmax
from unifolm_wma.models.diffusion_head.common.module_attr_mixin import ModuleAttrMixin
from unifolm_wma.models.diffusion_head.common.pytorch_util import dict_apply, replace_submodules
from unifolm_wma.utils.utils import instantiate_from_config
from einops import rearrange, repeat
from typing import Dict, Tuple, Union
from pathlib import Path
class MultiImageObsEncoder(ModuleAttrMixin):
def __init__(
self,
rgb_model_config: Dict,
shape_meta_path: str | None = None,
resize_shape: Union[Tuple[int, int], Dict[str, tuple],
None] = None,
crop_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
random_crop: bool = True,
# replace BatchNorm with GroupNorm
use_group_norm: bool = False,
# use single rgb model for all rgb inputs
share_rgb_model: bool = False,
# renormalize rgb input with imagenet normalization
# assuming input in [0,1]
imagenet_norm: bool = False,
use_spatial_softmax=False,
spatial_softmax_kp=32,
use_dinoSiglip=False):
"""
Assumes rgb input: B,C,H,W
Assumes low_dim input: B,D
"""
super().__init__()
if not shape_meta_path:
shape_meta_path = str(Path(os.getcwd()) / "configs/train/meta.json")
with open(shape_meta_path, 'r') as file:
shape_meta = json.load(file)
rgb_model = instantiate_from_config(rgb_model_config)
rgb_keys = list()
low_dim_keys = list()
key_model_map = nn.ModuleDict()
key_transform_map = nn.ModuleDict()
key_shape_map = dict()
# handle sharing vision backbone
if share_rgb_model:
assert isinstance(rgb_model, nn.Module)
key_model_map['rgb'] = rgb_model
obs_shape_meta = shape_meta['obs']
for key, attr in obs_shape_meta.items():
shape = tuple(attr['shape'])
type = attr.get('type', 'low_dim')
key_shape_map[key] = shape
if type == 'rgb':
rgb_keys.append(key)
if not use_dinoSiglip:
# configure model for this key
this_model = None
if not share_rgb_model:
if isinstance(rgb_model, dict):
# have provided model for each key
this_model = rgb_model[key]
else:
assert isinstance(rgb_model, nn.Module)
# have a copy of the rgb model
this_model = copy.deepcopy(rgb_model)
if this_model is not None:
if use_group_norm:
this_model = replace_submodules(
root_module=this_model,
predicate=lambda x: isinstance(
x, nn.BatchNorm2d),
func=lambda x: nn.GroupNorm(
num_groups=x.num_features // 16,
num_channels=x.num_features))
key_model_map[key] = this_model
# configure resize
input_shape = shape
this_resizer = nn.Identity()
if resize_shape is not None:
if isinstance(resize_shape, dict):
h, w = resize_shape[key]
else:
h, w = resize_shape
this_resizer = torchvision.transforms.Resize(size=(h,
w))
input_shape = (shape[0], h, w)
# configure randomizer
this_randomizer = nn.Identity()
if crop_shape is not None:
if isinstance(crop_shape, dict):
h, w = crop_shape[key]
else:
h, w = crop_shape
if random_crop:
this_randomizer = CropRandomizer(
input_shape=input_shape,
crop_height=h,
crop_width=w,
num_crops=1,
pos_enc=False)
else:
this_normalizer = torchvision.transforms.CenterCrop(
size=(h, w))
# configure normalizer
this_normalizer = nn.Identity()
if imagenet_norm:
this_normalizer = torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
this_transform = nn.Sequential(this_resizer,
this_randomizer,
this_normalizer)
key_transform_map[key] = this_transform
else:
key_model_map[key] = rgb_model
elif type == 'low_dim':
low_dim_keys.append(key)
else:
raise RuntimeError(f"Unsupported obs type: {type}")
rgb_keys = sorted(rgb_keys)
low_dim_keys = sorted(low_dim_keys)
self.shape_meta = shape_meta
self.key_model_map = key_model_map
self.key_transform_map = key_transform_map
self.share_rgb_model = share_rgb_model
self.rgb_keys = rgb_keys
self.low_dim_keys = low_dim_keys
self.key_shape_map = key_shape_map
self.use_dinoSiglip = use_dinoSiglip
##NOTE add spatial softmax
self.use_spatial_softmax = use_spatial_softmax
if use_spatial_softmax and not use_dinoSiglip:
model = nn.Sequential(
key_model_map['image'].conv1,
key_model_map['image'].bn1,
key_model_map['image'].relu,
key_model_map['image'].maxpool,
key_model_map['image'].layer1,
key_model_map['image'].layer2,
key_model_map['image'].layer3,
key_model_map['image'].layer4,
)
key_model_map['image'] = model
input_shape = self.output_shape(resnet_output_shape=True)
self.spatial_softmax = SpatialSoftmax(input_shape,
num_kp=spatial_softmax_kp)
def forward(self, obs_dict, resnet_output_shape=False):
batch_size = None
features = list()
# process rgb input
if self.share_rgb_model:
# pass all rgb obs to rgb model
imgs = list()
for key in self.rgb_keys:
img = obs_dict[key]
if batch_size is None:
batch_size = img.shape[0]
else:
assert batch_size == img.shape[0]
assert img.shape[1:] == self.key_shape_map[key]
img = self.key_transform_map[key](img)
imgs.append(img)
# (N*B,C,H,W)
imgs = torch.cat(imgs, dim=0)
# (N*B,D)
feature = self.key_model_map['rgb'](imgs)
# (N,B,D)
feature = feature.reshape(-1, batch_size, *feature.shape[1:])
# (B,N,D)
feature = torch.moveaxis(feature, 0, 1)
# (B,N*D)
feature = feature.reshape(batch_size, -1)
features.append(feature)
else:
# run each rgb obs to independent models
for key in self.rgb_keys:
img = obs_dict[key]
if batch_size is None:
batch_size = img.shape[0]
else:
assert batch_size == img.shape[0]
assert img.shape[1:] == self.key_shape_map[key]
if not self.use_dinoSiglip:
img = self.key_transform_map[key](img)
feature = self.key_model_map[key](img)
else:
feature = self.key_model_map[key](img)[:, :1, :]
if resnet_output_shape:
return feature
if not self.use_dinoSiglip and self.use_spatial_softmax:
feature = self.spatial_softmax(feature)
feature = feature.reshape(batch_size, -1)
features.append(feature)
# process lowdim input
for key in self.low_dim_keys:
data = obs_dict[key]
if batch_size is None:
batch_size = data.shape[0]
else:
assert batch_size == data.shape[0]
assert data.shape[1:] == self.key_shape_map[key]
features.append(data)
# concatenate all features
result = torch.cat(features, dim=-1)
return result
@torch.no_grad()
def output_shape(self, resnet_output_shape=False):
example_obs_dict = dict()
obs_shape_meta = self.shape_meta['obs']
batch_size = 1
for key, attr in obs_shape_meta.items():
shape = tuple(attr['shape'])
this_obs = torch.zeros((batch_size, ) + shape,
dtype=self.dtype,
device=self.device)
example_obs_dict[key] = this_obs
example_output = self.forward(example_obs_dict,
resnet_output_shape=resnet_output_shape)
output_shape = example_output.shape[1:]
return output_shape

View File

@@ -0,0 +1,473 @@
import numpy as np
import torch
import copy
from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg
from unifolm_wma.utils.common import noise_like
from unifolm_wma.utils.common import extract_into_tensor
from tqdm import tqdm
class DDIMSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.counter = 0
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
setattr(self, name, attr)
def make_schedule(self,
ddim_num_steps,
ddim_discretize="uniform",
ddim_eta=0.,
verbose=True):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model
.device)
if self.model.use_dynamic_rescale:
self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps]
self.ddim_scale_arr_prev = torch.cat(
[self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]])
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev',
to_torch(self.model.alphas_cumprod_prev))
# Calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod',
to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod',
to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# DDIM sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas',
np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
schedule_verbose=False,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
precision=None,
fs=None,
timestep_spacing='uniform', #uniform_trailing for starting from last timestep
guidance_rescale=0.0,
**kwargs):
# Check condition bs
if conditioning is not None:
if isinstance(conditioning, dict):
try:
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
except:
cbs = conditioning[list(
conditioning.keys())[0]][0].shape[0]
if cbs != batch_size:
print(
f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
)
else:
if conditioning.shape[0] != batch_size:
print(
f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
)
self.make_schedule(ddim_num_steps=S,
ddim_discretize=timestep_spacing,
ddim_eta=eta,
verbose=schedule_verbose)
# Make shape
if len(shape) == 3:
C, H, W = shape
size = (batch_size, C, H, W)
elif len(shape) == 4:
C, T, H, W = shape
size = (batch_size, C, T, H, W)
samples, actions, states, intermediates = self.ddim_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
verbose=verbose,
precision=precision,
fs=fs,
guidance_rescale=guidance_rescale,
**kwargs)
return samples, actions, states, intermediates
@torch.no_grad()
def ddim_sampling(self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
verbose=True,
precision=None,
fs=None,
guidance_rescale=0.0,
**kwargs):
device = self.model.betas.device
dp_ddim_scheduler_action = self.model.dp_noise_scheduler_action
dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
action = torch.randn((b, 16, self.model.agent_action_dim),
device=device)
state = torch.randn((b, 16, self.model.agent_state_dim),
device=device)
else:
img = x_T
action = torch.randn((b, 16, self.model.agent_action_dim),
device=device)
state = torch.randn((b, 16, self.model.agent_state_dim),
device=device)
if precision is not None:
if precision == 16:
img = img.to(dtype=torch.float16)
action = action.to(dtype=torch.float16)
state = state.to(dtype=torch.float16)
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(
min(timesteps / self.ddim_timesteps.shape[0], 1) *
self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {
'x_inter': [img],
'pred_x0': [img],
'x_inter_action': [action],
'pred_x0_action': [action],
'x_inter_state': [state],
'pred_x0_state': [state],
}
time_range = reversed(range(
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
0]
if verbose:
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
else:
iterator = time_range
clean_cond = kwargs.pop("clean_cond", False)
dp_ddim_scheduler_action.set_timesteps(len(timesteps))
dp_ddim_scheduler_state.set_timesteps(len(timesteps))
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b, ), step, device=device, dtype=torch.long)
# Use mask to blend noised original latent (img_orig) & new sampled latent (img)
if mask is not None:
assert x0 is not None
if clean_cond:
img_orig = x0
else:
img_orig = self.model.q_sample(x0, ts)
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_ddim(
img,
action,
state,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
mask=mask,
x0=x0,
fs=fs,
guidance_rescale=guidance_rescale,
**kwargs)
img, pred_x0, model_output_action, model_output_state = outs
action = dp_ddim_scheduler_action.step(
model_output_action,
step,
action,
generator=None,
).prev_sample
state = dp_ddim_scheduler_state.step(
model_output_state,
step,
state,
generator=None,
).prev_sample
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
intermediates['x_inter_action'].append(action)
intermediates['x_inter_state'].append(state)
return img, action, state, intermediates
@torch.no_grad()
def p_sample_ddim(self,
x,
x_action,
x_state,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
uc_type=None,
conditional_guidance_scale_temporal=None,
mask=None,
x0=None,
guidance_rescale=0.0,
**kwargs):
b, *_, device = *x.shape, x.device
if x.dim() == 5:
is_video = True
else:
is_video = False
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output, model_output_action, model_output_state = self.model.apply_model(
x, x_action, x_state, t, c, **kwargs) # unet denoiser
else:
# do_classifier_free_guidance
if isinstance(c, torch.Tensor) or isinstance(c, dict):
e_t_cond, e_t_cond_action, e_t_cond_state = self.model.apply_model(
x, x_action, x_state, t, c, **kwargs)
e_t_uncond, e_t_uncond_action, e_t_uncond_state = self.model.apply_model(
x, x_action, x_state, t, unconditional_conditioning,
**kwargs)
else:
raise NotImplementedError
model_output = e_t_uncond + unconditional_guidance_scale * (
e_t_cond - e_t_uncond)
model_output_action = e_t_uncond_action + unconditional_guidance_scale * (
e_t_cond_action - e_t_uncond_action)
model_output_state = e_t_uncond_state + unconditional_guidance_scale * (
e_t_cond_state - e_t_uncond_state)
if guidance_rescale > 0.0:
model_output = rescale_noise_cfg(
model_output, e_t_cond, guidance_rescale=guidance_rescale)
model_output_action = rescale_noise_cfg(
model_output_action,
e_t_cond_action,
guidance_rescale=guidance_rescale)
model_output_state = rescale_noise_cfg(
model_output_state,
e_t_cond_state,
guidance_rescale=guidance_rescale)
if self.model.parameterization == "v":
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
if score_corrector is not None:
assert self.model.parameterization == "eps", 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
**corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
if is_video:
size = (b, 1, 1, 1, 1)
else:
size = (b, 1, 1, 1)
a_t = torch.full(size, alphas[index], device=device)
a_prev = torch.full(size, alphas_prev[index], device=device)
sigma_t = torch.full(size, sigmas[index], device=device)
sqrt_one_minus_at = torch.full(size,
sqrt_one_minus_alphas[index],
device=device)
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if self.model.use_dynamic_rescale:
scale_t = torch.full(size,
self.ddim_scale_arr[index],
device=device)
prev_scale_t = torch.full(size,
self.ddim_scale_arr_prev[index],
device=device)
rescale = (prev_scale_t / scale_t)
pred_x0 *= rescale
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device,
repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0, model_output_action, model_output_state
@torch.no_grad()
def decode(self,
x_latent,
cond,
t_start,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
callback=None):
timesteps = np.arange(self.ddpm_num_timesteps
) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0], ),
step,
device=x_latent.device,
dtype=torch.long)
x_dec, _ = self.p_sample_ddim(
x_dec,
cond,
ts,
index=index,
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback: callback(i)
return x_dec
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) *
noise)

View File

View File

@@ -0,0 +1,806 @@
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from functools import partial
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except:
XFORMERS_IS_AVAILBLE = False
from unifolm_wma.utils.common import (
checkpoint,
exists,
default,
)
from unifolm_wma.utils.basics import zero_module
class RelativePosition(nn.Module):
""" https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """
def __init__(self, num_units, max_relative_position):
super().__init__()
self.num_units = num_units
self.max_relative_position = max_relative_position
self.embeddings_table = nn.Parameter(
torch.Tensor(max_relative_position * 2 + 1, num_units))
nn.init.xavier_uniform_(self.embeddings_table)
def forward(self, length_q, length_k):
device = self.embeddings_table.device
range_vec_q = torch.arange(length_q, device=device)
range_vec_k = torch.arange(length_k, device=device)
distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
distance_mat_clipped = torch.clamp(distance_mat,
-self.max_relative_position,
self.max_relative_position)
final_mat = distance_mat_clipped + self.max_relative_position
final_mat = final_mat.long()
embeddings = self.embeddings_table[final_mat]
return embeddings
class CrossAttention(nn.Module):
def __init__(self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.,
relative_position=False,
temporal_length=None,
video_length=None,
agent_state_context_len=2,
agent_action_context_len=16,
image_cross_attention=False,
image_cross_attention_scale=1.0,
agent_state_cross_attention_scale=1.0,
agent_action_cross_attention_scale=1.0,
cross_attention_scale_learnable=False,
text_context_len=77):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout))
self.relative_position = relative_position
if self.relative_position:
assert (temporal_length is not None)
self.relative_position_k = RelativePosition(
num_units=dim_head, max_relative_position=temporal_length)
self.relative_position_v = RelativePosition(
num_units=dim_head, max_relative_position=temporal_length)
else:
## only used for spatial attention, while NOT for temporal attention
if XFORMERS_IS_AVAILBLE and temporal_length is None:
self.forward = self.efficient_forward
self.video_length = video_length
self.image_cross_attention = image_cross_attention
self.image_cross_attention_scale = image_cross_attention_scale
self.agent_state_cross_attention_scale = agent_state_cross_attention_scale
self.agent_action_cross_attention_scale = agent_action_cross_attention_scale
self.text_context_len = text_context_len
self.agent_state_context_len = agent_state_context_len
self.agent_action_context_len = agent_action_context_len
self.cross_attention_scale_learnable = cross_attention_scale_learnable
if self.image_cross_attention:
self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
self.to_k_as = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v_as = nn.Linear(context_dim, inner_dim, bias=False)
self.to_k_aa = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v_aa = nn.Linear(context_dim, inner_dim, bias=False)
if cross_attention_scale_learnable:
self.register_parameter('alpha_ctx',
nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_cas',
nn.Parameter(torch.tensor(0.)))
self.register_parameter('alpha_caa',
nn.Parameter(torch.tensor(0.)))
def forward(self, x, context=None, mask=None):
spatial_self_attn = (context is None)
k_ip, v_ip, out_ip = None, None, None
k_as, v_as, out_as = None, None, None
k_aa, v_aa, out_aa = None, None, None
h = self.heads
q = self.to_q(x)
context = default(context, x)
if self.image_cross_attention and not spatial_self_attn:
assert 1 > 2, ">>> ERROR: should setup xformers and use efficient_forward ..."
context_agent_state = context[:, :self.agent_state_context_len, :]
context_agent_action = context[:,
self.agent_state_context_len:self.
agent_state_context_len +
self.agent_action_context_len, :]
context_ins = context[:, self.agent_state_context_len +
self.agent_action_context_len:self.
agent_state_context_len +
self.agent_action_context_len +
self.text_context_len, :]
context_image = context[:, self.agent_state_context_len +
self.agent_action_context_len +
self.text_context_len:, :]
k = self.to_k(context_ins)
v = self.to_v(context_ins)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
k_as = self.to_k_as(context_agent_state)
v_as = self.to_v_as(context_agent_state)
k_aa = self.to_k_aa(context_agent_action)
v_aa = self.to_v_aa(context_agent_action)
else:
if not spatial_self_attn:
context = context[:, :self.text_context_len, :]
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(q, k, v))
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
if self.relative_position:
len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
k2 = self.relative_position_k(len_q, len_k)
sim2 = einsum('b t d, t s d -> b t s', q,
k2) * self.scale # TODO check
sim += sim2
del k
if exists(mask):
## feasible for causal attention mask only
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b i j -> (b h) i j', h=h)
sim.masked_fill_(~(mask > 0.5), max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = torch.einsum('b i j, b j d -> b i d', sim, v)
if self.relative_position:
v2 = self.relative_position_v(len_q, len_v)
out2 = einsum('b t s, t s d -> b t d', sim, v2) # TODO check
out += out2
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
if k_ip is not None and k_as is not None and k_aa is not None:
## for image cross-attention
k_ip, v_ip = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_ip, v_ip))
sim_ip = torch.einsum('b i d, b j d -> b i j', q,
k_ip) * self.scale
del k_ip
sim_ip = sim_ip.softmax(dim=-1)
out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip)
out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
## for agent state cross-attention
k_as, v_as = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_as, v_as))
sim_as = torch.einsum('b i d, b j d -> b i j', q,
k_as) * self.scale
del k_as
sim_as = sim_as.softmax(dim=-1)
out_as = torch.einsum('b i j, b j d -> b i d', sim_as, v_as)
out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
## for agent action cross-attention
k_aa, v_aa = map(
lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(k_aa, v_aa))
sim_aa = torch.einsum('b i d, b j d -> b i j', q,
k_aa) * self.scale
del k_aa
sim_aa = sim_aa.softmax(dim=-1)
out_aa = torch.einsum('b i j, b j d -> b i d', sim_aa, v_aa)
out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
if out_ip is not None and out_as is not None and out_aa is not None:
if self.cross_attention_scale_learnable:
out = out + \
self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
else:
out = out + \
self.image_cross_attention_scale * out_ip + \
self.agent_state_cross_attention_scale * out_as + \
self.agent_action_cross_attention_scale * out_aa
return self.to_out(out)
def efficient_forward(self, x, context=None, mask=None):
spatial_self_attn = (context is None)
k, v, out = None, None, None
k_ip, v_ip, out_ip = None, None, None
k_as, v_as, out_as = None, None, None
k_aa, v_aa, out_aa = None, None, None
q = self.to_q(x)
context = default(context, x)
if self.image_cross_attention and not spatial_self_attn:
if context.shape[1] == self.text_context_len + self.video_length:
context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :]
k = self.to_k(context)
v = self.to_v(context)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
elif context.shape[1] == self.agent_state_context_len + self.text_context_len + self.video_length:
context_agent_state = context[:, :self.agent_state_context_len, :]
context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :]
context_image = context[:, self.agent_state_context_len+self.text_context_len:, :]
k = self.to_k(context_ins)
v = self.to_v(context_ins)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
k_as = self.to_k_as(context_agent_state)
v_as = self.to_v_as(context_agent_state)
else:
context_agent_state = context[:, :self.agent_state_context_len, :]
context_agent_action = context[:, self.agent_state_context_len:self.agent_state_context_len+self.agent_action_context_len, :]
context_ins = context[:, self.agent_state_context_len+self.agent_action_context_len:self.agent_state_context_len+self.agent_action_context_len+self.text_context_len, :]
context_image = context[:, self.agent_state_context_len+self.agent_action_context_len+self.text_context_len:, :]
k = self.to_k(context_ins)
v = self.to_v(context_ins)
k_ip = self.to_k_ip(context_image)
v_ip = self.to_v_ip(context_image)
k_as = self.to_k_as(context_agent_state)
v_as = self.to_v_as(context_agent_state)
k_aa = self.to_k_aa(context_agent_action)
v_aa = self.to_v_aa(context_agent_action)
attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
q.shape[1],
k_aa.shape[1],
block_size=16).to(k_aa.device)
else:
if not spatial_self_attn:
assert 1 > 2, ">>> ERROR: you should never go into here ..."
context = context[:, :self.text_context_len, :]
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q = q.unsqueeze(3).reshape(b, q.shape[1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(b * self.heads, q.shape[1], self.dim_head).contiguous()
if k is not None:
k, v = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(),
(k, v),
)
out = xformers.ops.memory_efficient_attention(q,
k,
v,
attn_bias=None,
op=None)
out = (out.unsqueeze(0).reshape(
b, self.heads, out.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out.shape[1],
self.heads * self.dim_head))
if k_ip is not None:
# For image cross-attention
k_ip, v_ip = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(
),
(k_ip, v_ip),
)
out_ip = xformers.ops.memory_efficient_attention(q,
k_ip,
v_ip,
attn_bias=None,
op=None)
out_ip = (out_ip.unsqueeze(0).reshape(
b, self.heads, out_ip.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out_ip.shape[1],
self.heads * self.dim_head))
if k_as is not None:
# For agent state cross-attention
k_as, v_as = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(
),
(k_as, v_as),
)
out_as = xformers.ops.memory_efficient_attention(q,
k_as,
v_as,
attn_bias=None,
op=None)
out_as = (out_as.unsqueeze(0).reshape(
b, self.heads, out_as.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out_as.shape[1],
self.heads * self.dim_head))
if k_aa is not None:
# For agent action cross-attention
k_aa, v_aa = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(
),
(k_aa, v_aa),
)
attn_mask_aa = attn_mask_aa.unsqueeze(1).repeat(1,self.heads,1,1).reshape(
b * self.heads, attn_mask_aa.shape[1], attn_mask_aa.shape[2])
attn_mask_aa = attn_mask_aa.to(q.dtype)
out_aa = xformers.ops.memory_efficient_attention(
q, k_aa, v_aa, attn_bias=attn_mask_aa, op=None)
out_aa = (out_aa.unsqueeze(0).reshape(
b, self.heads, out_aa.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out_aa.shape[1],
self.heads * self.dim_head))
if exists(mask):
raise NotImplementedError
out = 0.0 if out is None else out
out_ip = 0.0 if out_ip is None else out_ip
out_as = 0.0 if out_as is None else out_as
out_aa = 0.0 if out_aa is None else out_aa
if self.cross_attention_scale_learnable:
out = out + \
self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
else:
out = out + \
self.image_cross_attention_scale * out_ip + \
self.agent_state_cross_attention_scale * out_as + \
self.agent_action_cross_attention_scale * out_aa
return self.to_out(out)
def _get_attn_mask_aa(self, b, l1, l2, block_size=16):
num_token = l2 // block_size
start_positions = ((torch.arange(b) % block_size) + 1) * num_token
col_indices = torch.arange(l2)
mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
attn_mask = torch.zeros_like(mask, dtype=torch.float)
attn_mask[mask] = float('-inf')
return attn_mask
class BasicTransformerBlock(nn.Module):
def __init__(self,
dim,
n_heads,
d_head,
dropout=0.,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False,
attention_cls=None,
video_length=None,
agent_state_context_len=2,
agent_action_context_len=16,
image_cross_attention=False,
image_cross_attention_scale=1.0,
cross_attention_scale_learnable=False,
text_context_len=77):
super().__init__()
attn_cls = CrossAttention if attention_cls is None else attention_cls
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None)
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
video_length=video_length,
agent_state_context_len=agent_state_context_len,
agent_action_context_len=agent_action_context_len,
image_cross_attention=image_cross_attention,
image_cross_attention_scale=image_cross_attention_scale,
cross_attention_scale_learnable=cross_attention_scale_learnable,
text_context_len=text_context_len)
self.image_cross_attention = image_cross_attention
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None, mask=None, **kwargs):
# implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
input_tuple = (
x,
) # should not be (x), otherwise *input_tuple will decouple x into multiple arguments
if context is not None:
input_tuple = (x, context)
if mask is not None:
forward_mask = partial(self._forward, mask=mask)
return checkpoint(forward_mask, (x, ), self.parameters(),
self.checkpoint)
return checkpoint(self._forward, input_tuple, self.parameters(),
self.checkpoint)
def _forward(self, x, context=None, mask=None):
x = self.attn1(self.norm1(x),
context=context if self.disable_self_attn else None,
mask=mask) + x
x = self.attn2(self.norm2(x), context=context, mask=mask) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data in spatial axis.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.,
context_dim=None,
use_checkpoint=True,
disable_self_attn=False,
use_linear=False,
video_length=None,
agent_state_context_len=2,
agent_action_context_len=16,
image_cross_attention=False,
cross_attention_scale_learnable=False):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True)
if not use_linear:
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
attention_cls = None
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
disable_self_attn=disable_self_attn,
checkpoint=use_checkpoint,
attention_cls=attention_cls,
video_length=video_length,
agent_state_context_len=agent_state_context_len,
agent_action_context_len=agent_action_context_len,
image_cross_attention=image_cross_attention,
cross_attention_scale_learnable=cross_attention_scale_learnable,
) for d in range(depth)
])
if not use_linear:
self.proj_out = zero_module(
nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
else:
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
self.use_linear = use_linear
def forward(self, x, context=None, **kwargs):
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context, **kwargs)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in
class TemporalTransformer(nn.Module):
"""
Transformer block for image-like data in temporal axis.
First, reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.,
context_dim=None,
use_checkpoint=True,
use_linear=False,
only_self_att=True,
causal_attention=False,
causal_block_size=1,
relative_position=False,
temporal_length=None):
super().__init__()
self.only_self_att = only_self_att
self.relative_position = relative_position
self.causal_attention = causal_attention
self.causal_block_size = causal_block_size
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True)
self.proj_in = nn.Conv1d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
if not use_linear:
self.proj_in = nn.Conv1d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
if relative_position:
assert (temporal_length is not None)
attention_cls = partial(CrossAttention,
relative_position=True,
temporal_length=temporal_length)
else:
attention_cls = partial(CrossAttention,
temporal_length=temporal_length)
if self.causal_attention:
assert (temporal_length is not None)
self.mask = torch.tril(
torch.ones([1, temporal_length, temporal_length]))
if self.only_self_att:
context_dim = None
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim,
attention_cls=attention_cls,
checkpoint=use_checkpoint)
for d in range(depth)
])
if not use_linear:
self.proj_out = zero_module(
nn.Conv1d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
else:
self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
self.use_linear = use_linear
def forward(self, x, context=None):
b, c, t, h, w = x.shape
x_in = x
x = self.norm(x)
x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous()
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'bhw c t -> bhw t c').contiguous()
if self.use_linear:
x = self.proj_in(x)
temp_mask = None
if self.causal_attention:
# Slice the from mask map
temp_mask = self.mask[:, :t, :t].to(x.device)
if temp_mask is not None:
mask = temp_mask.to(x.device)
mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b * h * w)
else:
mask = None
if self.only_self_att:
# NOTE: if no context is given, cross-attention defaults to self-attention
for i, block in enumerate(self.transformer_blocks):
x = block(x, mask=mask)
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
else:
x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
context = rearrange(context, '(b t) l con -> b t l con',
t=t).contiguous()
for i, block in enumerate(self.transformer_blocks):
# Calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
for j in range(b):
context_j = repeat(context[j],
't l con -> (t r) l con',
r=(h * w) // t,
t=t).contiguous()
# Note: causal mask will not applied in cross-attention case
x[j] = block(x[j], context=context_j)
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous()
if not self.use_linear:
x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous()
x = self.proj_out(x)
x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h,
w=w).contiguous()
return x + x_in
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv,
'b (qkv heads c) h w -> qkv b heads c (h w)',
heads=self.heads,
qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out,
'b heads c (h w) -> b (heads c) h w',
heads=self.heads,
h=h,
w=w)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=32,
num_channels=in_channels,
eps=1e-6,
affine=True)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# Compute attention
b, c, h, w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# Attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x + h_

View File

@@ -0,0 +1,630 @@
import torch
import torch.nn as nn
import kornia
import open_clip
import math
from torch.utils.checkpoint import checkpoint
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
from unifolm_wma.utils.common import autocast
from unifolm_wma.utils.utils import count_params
from unifolm_wma.modules.encoders.resampler import reshape_tensor
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class IdentityEncoder(AbstractEncoder):
def encode(self, x):
return x
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
self.n_classes = n_classes
self.ucg_rate = ucg_rate
def forward(self, batch, key=None, disable_dropout=False):
if key is None:
key = self.key
# this is for use in crossattn
c = batch[key][:, None]
if self.ucg_rate > 0. and not disable_dropout:
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes -
1)
c = c.long()
c = self.embedding(c)
return c
def get_unconditional_conditioning(self, bs, device="cuda"):
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
uc = torch.ones((bs, ), device=device) * uc_class
uc = {self.key: uc}
return uc
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text"""
def __init__(self,
version="google/t5-v1_1-xxl",
device="cuda",
max_length=77,
freeze=True
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = ["last", "pooled", "hidden"]
def __init__(self,
version="openai/clip-vit-large-patch14",
device="cuda",
max_length=77,
freeze=True,
layer="last",
layer_idx=None): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = layer_idx
if layer == "hidden":
assert layer_idx is not None
assert 0 <= abs(layer_idx) <= 12
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens,
output_hidden_states=self.layer == "hidden")
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
return z
def encode(self, text):
return self(text)
class ClipImageEmbedder(nn.Module):
def __init__(self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
antialias=True,
ucg_rate=0.):
super().__init__()
from clip import load as load_clip
self.model, _ = load_clip(name=model, device=device, jit=jit)
self.antialias = antialias
self.register_buffer('mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False)
self.register_buffer('std',
torch.Tensor([0.26862954, 0.26130258,
0.27577711]),
persistent=False)
self.ucg_rate = ucg_rate
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',
align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# re-normalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x, no_dropout=False):
# x is assumed to be in range [-1,1]
out = self.model.encode_image(self.preprocess(x))
out = out.to(x.dtype)
if self.ucg_rate > 0. and not no_dropout:
out = torch.bernoulli(
(1. - self.ucg_rate) *
torch.ones(out.shape[0], device=out.device))[:, None] * out
return out
class FrozenOpenCLIPEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = [
# "pooled",
"last",
"penultimate"
]
def __init__(self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
max_length=77,
freeze=True,
layer="last"):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(
arch, device=torch.device('cpu'), pretrained=version)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "last":
self.layer_idx = 0
elif self.layer == "penultimate":
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
tokens = open_clip.tokenize(
text) ## all clip models use 77 as context length
z = self.encode_with_transformer(tokens.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
):
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def __init__(self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
max_length=77,
freeze=True,
layer="pooled",
antialias=True,
ucg_rate=0.):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(
arch,
device=torch.device('cpu'),
pretrained=version,
)
del model.transformer
self.model = model
# self.mapper = torch.nn.Linear(1280, 1024)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == "penultimate":
raise NotImplementedError()
self.layer_idx = 1
self.antialias = antialias
self.register_buffer('mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False)
self.register_buffer('std',
torch.Tensor([0.26862954, 0.26130258,
0.27577711]),
persistent=False)
self.ucg_rate = ucg_rate
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',
align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def freeze(self):
self.model = self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
@autocast
def forward(self, image, no_dropout=False):
z = self.encode_with_vision_transformer(image)
if self.ucg_rate > 0. and not no_dropout:
z = torch.bernoulli(
(1. - self.ucg_rate) *
torch.ones(z.shape[0], device=z.device))[:, None] * z
return z
def encode_with_vision_transformer(self, img):
img = self.preprocess(img)
x = self.model.visual(img)
return x
def encode(self, text):
return self(text)
class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
"""
Uses the OpenCLIP vision transformer encoder for images
"""
def __init__(self,
arch="ViT-H-14",
version="laion2b_s32b_b79k",
device="cuda",
freeze=True,
layer="pooled",
antialias=True):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(
arch,
device=torch.device('cpu'),
pretrained=version,
)
del model.transformer
self.model = model
self.device = device
if freeze:
self.freeze()
self.layer = layer
if self.layer == "penultimate":
raise NotImplementedError()
self.layer_idx = 1
self.antialias = antialias
self.register_buffer('mean',
torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
persistent=False)
self.register_buffer('std',
torch.Tensor([0.26862954, 0.26130258,
0.27577711]),
persistent=False)
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',
align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def freeze(self):
self.model = self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
def forward(self, image, no_dropout=False):
## image: b c h w
z = self.encode_with_vision_transformer(image)
return z
def encode_with_vision_transformer(self, x):
x = self.preprocess(x)
# to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
if self.model.visual.input_patchnorm:
# einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
x = x.reshape(x.shape[0], x.shape[1],
self.model.visual.grid_size[0],
self.model.visual.patch_size[0],
self.model.visual.grid_size[1],
self.model.visual.patch_size[1])
x = x.permute(0, 2, 4, 1, 3, 5)
x = x.reshape(
x.shape[0], self.model.visual.grid_size[0] *
self.model.visual.grid_size[1], -1)
x = self.model.visual.patchnorm_pre_ln(x)
x = self.model.visual.conv1(x)
else:
x = self.model.visual.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# class embeddings and positional embeddings
x = torch.cat([
self.model.visual.class_embedding.to(x.dtype) + torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
],
dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.model.visual.positional_embedding.to(x.dtype)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
x = self.model.visual.patch_dropout(x)
x = self.model.visual.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.model.visual.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
return x
class FrozenCLIPT5Encoder(AbstractEncoder):
def __init__(self,
clip_version="openai/clip-vit-large-patch14",
t5_version="google/t5-v1_1-xl",
device="cuda",
clip_max_length=77,
t5_max_length=77):
super().__init__()
self.clip_encoder = FrozenCLIPEmbedder(clip_version,
device,
max_length=clip_max_length)
self.t5_encoder = FrozenT5Embedder(t5_version,
device,
max_length=t5_max_length)
print(
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
)
def encode(self, text):
return self(text)
def forward(self, text):
clip_z = self.clip_encoder.encode(text)
t5_z = self.t5_encoder.encode(text)
return [clip_z, t5_z]
class LinearProjector(nn.Module):
def __init__(self, input_dim: int, output_dim: int) -> None:
super().__init__()
self.projector = nn.Linear(input_dim, output_dim, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.projector(x)
class MLPProjector(nn.Module):
def __init__(self,
input_dim: int,
output_dim: int,
mlp_type: str = "gelu-mlp") -> None:
super().__init__()
if mlp_type == "gelu-mlp":
self.projector = nn.Sequential(
nn.Linear(input_dim, output_dim, bias=True),
nn.GELU(approximate='tanh'),
nn.Linear(output_dim, output_dim, bias=True),
)
elif mlp_type == "silu-mlp":
self.projector = nn.Sequential(
nn.Linear(input_dim, output_dim, bias=True),
nn.SiLU(),
nn.Linear(output_dim, output_dim, bias=True),
)
else:
raise ValueError(
f"Projector with `{mlp_type = }` is not supported!")
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.projector(x)
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(
-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
def FeedForward(dim, mult=4, ffd_type="gelu-ffd"):
inner_dim = int(dim * mult)
if ffd_type == "gelu-ffd":
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(approximate='tanh'),
nn.Linear(inner_dim, dim, bias=False),
)
elif ffd_type == "silu-ffd":
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.SiLU(),
nn.Linear(inner_dim, dim, bias=False),
)
else:
raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
class SATokenProjector(nn.Module):
def __init__(self,
dim=1024,
depth=1,
dim_head=64,
heads=16,
num_queries=16,
output_dim=1024,
ff_mult=4,
chunk_size=None):
super().__init__()
self.num_queries = num_queries
self.chunk_size = chunk_size
if chunk_size is not None:
num_queries = num_queries * chunk_size
self.latents = nn.Parameter(
torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList([
PerceiverAttention(dim=dim, dim_head=dim_head,
heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]))
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
latents = self.norm_out(latents)
return latents

View File

@@ -0,0 +1,153 @@
# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
# and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
import math
import torch
import torch.nn as nn
class ImageProjModel(nn.Module):
"""Projection Model"""
def __init__(self,
cross_attention_dim=1024,
clip_embeddings_dim=1024,
clip_extra_context_tokens=4):
super().__init__()
self.cross_attention_dim = cross_attention_dim
self.clip_extra_context_tokens = clip_extra_context_tokens
self.proj = nn.Linear(
clip_embeddings_dim,
self.clip_extra_context_tokens * cross_attention_dim)
self.norm = nn.LayerNorm(cross_attention_dim)
def forward(self, image_embeds):
#embeds = image_embeds
embeds = image_embeds.type(list(self.proj.parameters())[0].dtype)
clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim)
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
return clip_extra_context_tokens
# FFN
def FeedForward(dim, mult=4):
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False),
nn.GELU(),
nn.Linear(inner_dim, dim, bias=False),
)
def reshape_tensor(x, heads):
bs, length, width = x.shape
#(bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8):
super().__init__()
self.scale = dim_head**-0.5
self.dim_head = dim_head
self.heads = heads
inner_dim = dim_head * heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
"""
x = self.norm1(x)
latents = self.norm2(latents)
b, l, _ = latents.shape
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = reshape_tensor(q, self.heads)
k = reshape_tensor(k, self.heads)
v = reshape_tensor(v, self.heads)
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(
-2, -1) # More stable with f16 than dividing afterwards
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
return self.to_out(out)
class Resampler(nn.Module):
def __init__(
self,
dim=1024,
depth=8,
dim_head=64,
heads=16,
num_queries=8,
embedding_dim=768,
output_dim=1024,
ff_mult=4,
video_length=None, # using frame-wise version or not
):
super().__init__()
## queries for a single frame / image
self.num_queries = num_queries
self.video_length = video_length
## <num_queries> queries for each frame
if video_length is not None:
num_queries = num_queries * video_length
self.latents = nn.Parameter(
torch.randn(1, num_queries, dim) / dim**0.5)
self.proj_in = nn.Linear(embedding_dim, dim)
self.proj_out = nn.Linear(dim, output_dim)
self.norm_out = nn.LayerNorm(output_dim)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
nn.ModuleList([
PerceiverAttention(dim=dim, dim_head=dim_head,
heads=heads),
FeedForward(dim=dim, mult=ff_mult),
]))
def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x)
for attn, ff in self.layers:
latents = attn(x, latents) + latents
latents = ff(latents) + latents
latents = self.proj_out(latents)
latents = self.norm_out(latents)
return latents

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,848 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from functools import partial
from abc import abstractmethod
from einops import rearrange
from omegaconf import OmegaConf
from typing import Optional, Sequence, Any, Tuple, Union, List, Dict
from collections.abc import Mapping, Iterable, Callable
from unifolm_wma.utils.diffusion import timestep_embedding
from unifolm_wma.utils.common import checkpoint
from unifolm_wma.utils.basics import (zero_module, conv_nd, linear,
avg_pool_nd, normalization)
from unifolm_wma.modules.attention import SpatialTransformer, TemporalTransformer
from unifolm_wma.utils.utils import instantiate_from_config
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb, context=None, batch_size=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb, batch_size=batch_size)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context)
elif isinstance(layer, TemporalTransformer):
x = rearrange(x, '(b f) c h w -> b c f h w', b=batch_size)
x = layer(x, context)
x = rearrange(x, 'b c f h w -> (b f) c h w')
else:
x = layer(x)
return x
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self,
channels,
use_conv,
dims=2,
out_channels=None,
padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(dims,
self.channels,
self.out_channels,
3,
stride=stride,
padding=padding)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self,
channels,
use_conv,
dims=2,
out_channels=None,
padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims,
self.channels,
self.out_channels,
3,
padding=padding)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
mode='nearest')
else:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.use_conv:
x = self.conv(x)
return x
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
:param use_temporal_conv: if True, use the temporal convolution.
:param use_image_dataset: if True, the temporal parameters will not be optimized.
"""
def __init__(self,
channels,
emb_channels,
dropout,
out_channels=None,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
use_conv=False,
up=False,
down=False,
use_temporal_conv=False,
tempspatial_aware=False):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.use_temporal_conv = use_temporal_conv
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(
emb_channels,
2 * self.out_channels
if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(dims,
channels,
self.out_channels,
3,
padding=1)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels,
1)
if self.use_temporal_conv:
self.temopral_conv = TemporalConvBlock(
self.out_channels,
self.out_channels,
dropout=0.1,
spatial_aware=tempspatial_aware)
def forward(self, x, emb, batch_size=None):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
input_tuple = (x, emb)
if batch_size:
forward_batchsize = partial(self._forward, batch_size=batch_size)
return checkpoint(forward_batchsize, input_tuple,
self.parameters(), self.use_checkpoint)
return checkpoint(self._forward, input_tuple, self.parameters(),
self.use_checkpoint)
def _forward(self, x, emb, batch_size=None):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
h = self.skip_connection(x) + h
if self.use_temporal_conv and batch_size:
h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size)
h = self.temopral_conv(h)
h = rearrange(h, 'b c t h w -> (b t) c h w')
return h
class TemporalConvBlock(nn.Module):
"""
Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py
"""
def __init__(self,
in_channels,
out_channels=None,
dropout=0.0,
spatial_aware=False):
super(TemporalConvBlock, self).__init__()
if out_channels is None:
out_channels = in_channels
self.in_channels = in_channels
self.out_channels = out_channels
th_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 1)
th_padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 0)
tw_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 1, 3)
tw_padding_shape = (1, 0, 0) if not spatial_aware else (1, 0, 1)
# conv layers
self.conv1 = nn.Sequential(
nn.GroupNorm(32, in_channels), nn.SiLU(),
nn.Conv3d(in_channels,
out_channels,
th_kernel_shape,
padding=th_padding_shape))
self.conv2 = nn.Sequential(
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
nn.Conv3d(out_channels,
in_channels,
tw_kernel_shape,
padding=tw_padding_shape))
self.conv3 = nn.Sequential(
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
nn.Conv3d(out_channels,
in_channels,
th_kernel_shape,
padding=th_padding_shape))
self.conv4 = nn.Sequential(
nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
nn.Conv3d(out_channels,
in_channels,
tw_kernel_shape,
padding=tw_padding_shape))
# Zero out the last layer params,so the conv block is identity
nn.init.zeros_(self.conv4[-1].weight)
nn.init.zeros_(self.conv4[-1].bias)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
return identity + x
class WMAModel(nn.Module):
"""
The full World-Model-Action model.
"""
def __init__(self,
in_channels: int,
model_channels: int,
out_channels: int,
num_res_blocks: int,
attention_resolutions: Sequence[int],
dropout: float = 0.0,
channel_mult: Sequence[int] = (1, 2, 4, 8),
conv_resample: bool = True,
dims: int = 2,
context_dim: int | None = None,
use_scale_shift_norm: bool = False,
resblock_updown: bool = False,
num_heads: int = -1,
num_head_channels: int = -1,
transformer_depth: int = 1,
use_linear: bool = False,
use_checkpoint: bool = False,
temporal_conv: bool = False,
tempspatial_aware: bool = False,
temporal_attention: bool = True,
use_relative_position: bool = True,
use_causal_attention: bool = False,
temporal_length: int | None = None,
use_fp16: bool = False,
addition_attention: bool = False,
temporal_selfatt_only: bool = True,
image_cross_attention: bool = False,
cross_attention_scale_learnable: bool = False,
default_fs: int = 4,
fs_condition: bool = False,
n_obs_steps: int = 1,
num_stem_token: int = 1,
unet_head_config: OmegaConf | None = None,
stem_process_config: OmegaConf | None = None,
base_model_gen_only: bool = False):
"""
Initialize the World-Model-Action network.
Args:
in_channels: Number of input channels to the backbone.
model_channels: Base channel width for the UNet/backbone.
out_channels: Number of output channels.
num_res_blocks: Number of residual blocks per resolution stage.
attention_resolutions: Resolutions at which to enable attention.
dropout: Dropout probability used inside residual/attention blocks.
channel_mult: Multipliers for channels at each resolution level.
conv_resample: If True, use convolutional resampling for up/down sampling.
dims: Spatial dimensionality of the backbone (1/2/3).
context_dim: Optional context embedding dimension (for cross-attention).
use_scale_shift_norm: Enable scale-shift (FiLM-style) normalization in blocks.
resblock_updown: Use residual blocks for up/down sampling (instead of plain conv).
num_heads: Number of attention heads (if >= 0). If -1, derive from num_head_channels.
num_head_channels: Channels per attention head (if >= 0). If -1, derive from num_heads.
transformer_depth: Number of transformer/attention blocks per stage.
use_linear: Use linear attention variants where applicable.
use_checkpoint: Enable gradient checkpointing in blocks to save memory.
temporal_conv: Include temporal convolution along the time dimension.
tempspatial_aware: If True, use timespace aware blocks.
temporal_attention: Enable temporal self-attention.
use_relative_position: Use relative position encodings in attention.
use_causal_attention: Use causal (uni-directional) attention along time.
temporal_length: Optional maximum temporal length expected by the model.
use_fp16: Enable half-precision layers/normalization where supported.
addition_attention: Add auxiliary attention modules.
temporal_selfatt_only: Restrict attention to temporal-only (no spatial) if True.
image_cross_attention: Enable cross-attention with image embeddings.
cross_attention_scale_learnable: Make cross-attention scaling a learnable parameter.
default_fs: Default frame-stride / fps.
fs_condition: If True, condition on frame-stride/fps features.
n_obs_steps: Number of observed steps used in conditioning heads.
num_stem_token: Number of stem tokens for action tokenization.
unet_head_config: OmegaConf for UNet heads (e.g., action/state heads).
stem_process_config: OmegaConf for stem/preprocessor module.
base_model_gen_only: Perform the generation using the base model with out action and state outputs.
"""
super(WMAModel, self).__init__()
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.temporal_attention = temporal_attention
time_embed_dim = model_channels * 4
self.use_checkpoint = use_checkpoint
self.dtype = torch.float16 if use_fp16 else torch.float32
temporal_self_att_only = True
self.addition_attention = addition_attention
self.temporal_length = temporal_length
self.image_cross_attention = image_cross_attention
self.cross_attention_scale_learnable = cross_attention_scale_learnable
self.default_fs = default_fs
self.fs_condition = fs_condition
self.n_obs_steps = n_obs_steps
self.num_stem_token = num_stem_token
self.base_model_gen_only = base_model_gen_only
# Time embedding blocks
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if fs_condition:
self.fps_embedding = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
nn.init.zeros_(self.fps_embedding[-1].weight)
nn.init.zeros_(self.fps_embedding[-1].bias)
# Input Block
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1))
])
if self.addition_attention:
self.init_attn = TimestepEmbedSequential(
TemporalTransformer(model_channels,
n_heads=8,
d_head=num_head_channels,
depth=transformer_depth,
context_dim=context_dim,
use_checkpoint=use_checkpoint,
only_self_att=temporal_selfatt_only,
causal_attention=False,
relative_position=use_relative_position,
temporal_length=temporal_length))
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
tempspatial_aware=tempspatial_aware,
use_temporal_conv=temporal_conv)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
layers.append(
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
disable_self_attn=False,
video_length=temporal_length,
agent_state_context_len=self.n_obs_steps,
agent_action_context_len=self.temporal_length *
num_stem_token,
image_cross_attention=self.image_cross_attention,
cross_attention_scale_learnable=self.
cross_attention_scale_learnable,
))
if self.temporal_attention:
layers.append(
TemporalTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
only_self_att=temporal_self_att_only,
causal_attention=use_causal_attention,
relative_position=use_relative_position,
temporal_length=temporal_length))
self.input_blocks.append(TimestepEmbedSequential(*layers))
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True)
if resblock_updown else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch))
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
layers = [
ResBlock(ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
tempspatial_aware=tempspatial_aware,
use_temporal_conv=temporal_conv),
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
disable_self_attn=False,
video_length=temporal_length,
agent_state_context_len=self.n_obs_steps,
agent_action_context_len=self.temporal_length * num_stem_token,
image_cross_attention=self.image_cross_attention,
cross_attention_scale_learnable=self.
cross_attention_scale_learnable)
]
if self.temporal_attention:
layers.append(
TemporalTransformer(ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
only_self_att=temporal_self_att_only,
causal_attention=use_causal_attention,
relative_position=use_relative_position,
temporal_length=temporal_length))
layers.append(
ResBlock(ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
tempspatial_aware=tempspatial_aware,
use_temporal_conv=temporal_conv))
# Middle Block
self.middle_block = TimestepEmbedSequential(*layers)
# Output Block
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(ch + ich,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
tempspatial_aware=tempspatial_aware,
use_temporal_conv=temporal_conv)
]
ch = model_channels * mult
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
layers.append(
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
disable_self_attn=False,
video_length=temporal_length,
agent_state_context_len=self.n_obs_steps,
image_cross_attention=self.image_cross_attention,
cross_attention_scale_learnable=self.
cross_attention_scale_learnable))
if self.temporal_attention:
layers.append(
TemporalTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
use_linear=use_linear,
use_checkpoint=use_checkpoint,
only_self_att=temporal_self_att_only,
causal_attention=use_causal_attention,
relative_position=use_relative_position,
temporal_length=temporal_length))
if level and i == num_res_blocks:
out_ch = ch
layers.append(
ResBlock(ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True)
if resblock_updown else Upsample(
ch, conv_resample, dims=dims, out_channels=out_ch))
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(
conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
# Action and state prediction unet
unet_head_config['params']['context_dims'] = [
mult * model_channels for mult in channel_mult
]
self.action_unet = instantiate_from_config(unet_head_config)
self.state_unet = instantiate_from_config(unet_head_config)
# Initialize action token_projector
self.action_token_projector = instantiate_from_config(
stem_process_config)
def forward(self,
x: Tensor,
x_action: Tensor,
x_state: Tensor,
timesteps: Tensor,
context: Tensor | None = None,
context_action: Tensor | None = None,
features_adapter: Any = None,
fs: Tensor | None = None,
**kwargs) -> Tensor | tuple[Tensor, ...]:
"""
Forward pass of the World-Model-Action backbone.
Args:
x: Input tensor (latent video), shape (B, C,...).
x_action: action stream input.
x_state: state stream input.
timesteps: Diffusion timesteps, shape (B,) or scalar Tensor.
context: conditioning context for cross-attention.
context_action: conditioning context specific to action/state (implementation-specific).
features_adapter: module or dict to adapt intermediate features.
fs: frame-stride / fps conditioning.
Returns:
Tuple of Tensors for predictions:
"""
b, _, t, _, _ = x.shape
t_emb = timestep_embedding(timesteps,
self.model_channels,
repeat_only=False).type(x.dtype)
emb = self.time_embed(t_emb)
bt, l_context, _ = context.shape
if self.base_model_gen_only:
assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
else:
if l_context == self.n_obs_steps + 77 + t * 16:
context_agent_state = context[:, :self.n_obs_steps]
context_text = context[:, self.n_obs_steps:self.n_obs_steps +
77, :]
context_img = context[:, self.n_obs_steps + 77:, :]
context_agent_state = context_agent_state.repeat_interleave(
repeats=t, dim=0)
context_text = context_text.repeat_interleave(repeats=t, dim=0)
context_img = rearrange(context_img,
'b (t l) c -> (b t) l c',
t=t)
context = torch.cat(
[context_agent_state, context_text, context_img], dim=1)
elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
context_agent_state = context[:, :self.n_obs_steps]
context_agent_action = context[:, self.
n_obs_steps:self.n_obs_steps +
16, :]
context_agent_action = rearrange(
context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
context_agent_action = self.action_token_projector(
context_agent_action)
context_agent_action = rearrange(context_agent_action,
'(b o) l d -> b o l d',
o=t)
context_agent_action = rearrange(context_agent_action,
'b o (t l) d -> b o t l d',
t=t)
context_agent_action = context_agent_action.permute(
0, 2, 1, 3, 4)
context_agent_action = rearrange(context_agent_action,
'b t o l d -> (b t) (o l) d')
context_text = context[:, self.n_obs_steps +
16:self.n_obs_steps + 16 + 77, :]
context_text = context_text.repeat_interleave(repeats=t, dim=0)
context_img = context[:, self.n_obs_steps + 16 + 77:, :]
context_img = rearrange(context_img,
'b (t l) c -> (b t) l c',
t=t)
context_agent_state = context_agent_state.repeat_interleave(
repeats=t, dim=0)
context = torch.cat([
context_agent_state, context_agent_action, context_text,
context_img
],
dim=1)
emb = emb.repeat_interleave(repeats=t, dim=0)
x = rearrange(x, 'b c t h w -> (b t) c h w')
# Combine emb
if self.fs_condition:
if fs is None:
fs = torch.tensor([self.default_fs] * b,
dtype=torch.long,
device=x.device)
fs_emb = timestep_embedding(fs,
self.model_channels,
repeat_only=False).type(x.dtype)
fs_embed = self.fps_embedding(fs_emb)
fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
emb = emb + fs_embed
h = x.type(self.dtype)
adapter_idx = 0
hs = []
hs_a = []
for id, module in enumerate(self.input_blocks):
h = module(h, emb, context=context, batch_size=b)
if id == 0 and self.addition_attention:
h = self.init_attn(h, emb, context=context, batch_size=b)
# plug-in adapter features
if ((id + 1) % 3 == 0) and features_adapter is not None:
h = h + features_adapter[adapter_idx]
adapter_idx += 1
if id != 0:
if isinstance(module[0], Downsample):
hs_a.append(
rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t))
hs.append(h)
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
if features_adapter is not None:
assert len(
features_adapter) == adapter_idx, 'Wrong features_adapter'
h = self.middle_block(h, emb, context=context, batch_size=b)
hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
hs_out = []
for module in self.output_blocks:
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, context=context, batch_size=b)
if isinstance(module[-1], Upsample):
hs_a.append(
rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
hs_out.append(h)
h = h.type(x.dtype)
hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
y = self.out(h)
y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
if not self.base_model_gen_only:
ba, _, _ = x_action.shape
a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
context_action[:2], **kwargs)
# Predict state
if b > 1:
s_y = self.state_unet(x_state, timesteps[:ba], hs_a,
context_action[:2], **kwargs)
else:
s_y = self.state_unet(x_state, timesteps, hs_a,
context_action[:2], **kwargs)
else:
a_y = torch.zeros_like(x_action)
s_y = torch.zeros_like(x_state)
return y, a_y, s_y

View File

@@ -0,0 +1,244 @@
"""
base_vision.py
Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility
functions, and initialization logic.
We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision
Transformer model for feature extraction.
"""
import timm
import torch
import torch.nn as nn
import torchvision.transforms.functional as TVF
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
from PIL.Image import Image
from timm.models.vision_transformer import Block, VisionTransformer
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
from torchvision.transforms import Compose, Resize
# === Utility Functions for Monkey-Patching ===
def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
def wrapper(*args: Any, **kwargs: Any) -> Any:
result = fn(*args, **kwargs)
return result[0] if isinstance(result, tuple) else result
return wrapper
# === Interface for an Image Transform ===
class ImageTransform(Protocol):
def __call__(
self, img: Image,
**kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
...
# === Custom Torchvision Image Transforms ===
@dataclass
class LetterboxPad:
padding_fill_value: Tuple[int, int, int]
def __call__(self, image: Image) -> Image:
"""Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
(w, h), max_wh = image.size, max(image.size)
horizontal_pad, vertical_pad = int((max_wh - w) / 2), int(
(max_wh - h) / 2)
padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
return TVF.pad(image,
padding,
fill=self.padding_fill_value,
padding_mode="constant")
# === Abstract Base Class for arbitrary Vision Backbones ===
class VisionBackbone(nn.Module, ABC):
def __init__(self,
vision_backbone_id: str,
image_resize_strategy: str,
default_image_size: int = 224) -> None:
super().__init__()
self.identifier: str = vision_backbone_id
self.image_resize_strategy: str = image_resize_strategy
self.default_image_size: int = default_image_size
# Instance attributes for a Vision Backbone
self.featurizer: nn.Module = None
self.image_transform: ImageTransform = None
def get_image_transform(self) -> ImageTransform:
return self.image_transform
@abstractmethod
def get_fsdp_wrapping_policy(self) -> Callable:
...
@abstractmethod
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""Run a forward pass through the featurizer given a set of processed images, returning patch/grid features."""
raise NotImplementedError
@property
@abstractmethod
def default_image_resolution(self) -> Tuple[int, int, int]:
...
@property
@abstractmethod
def embed_dim(self) -> int:
...
@property
@abstractmethod
def num_patches(self) -> int:
...
@property
@abstractmethod
def half_precision_dtype(self) -> torch.dtype:
...
# === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones ===
class TimmViTBackbone(VisionBackbone, ABC):
def __init__(
self,
vision_backbone_id: str,
timm_path_or_url: str,
image_resize_strategy: str,
default_image_size: int = 224,
override_act_layer: Optional[str] = None,
) -> None:
super().__init__(vision_backbone_id,
image_resize_strategy,
default_image_size=default_image_size)
self.timm_path_or_url = timm_path_or_url
self.override_act_layer = override_act_layer
self.dtype = torch.bfloat16
# Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary
if self.override_act_layer is None:
self.featurizer: VisionTransformer = timm.create_model(
self.timm_path_or_url,
pretrained=True,
num_classes=0,
img_size=self.default_image_size)
else:
self.featurizer: VisionTransformer = timm.create_model(
self.timm_path_or_url,
pretrained=True,
num_classes=0,
img_size=self.default_image_size,
act_layer=self.override_act_layer,
)
self.featurizer.eval()
# Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
self.featurizer.forward = unpack_tuple(
partial(self.featurizer.get_intermediate_layers,
n={len(self.featurizer.blocks) - 2}))
# Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!)
assert isinstance(self.featurizer, VisionTransformer), (
"Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, "
"file an issue or implement the requisite logic (see `prismatic/models/backbones/vision/base_vision.py`)!"
)
# Get Config =>> Note :: Override default image size to ensure correct image transform
self.data_cfg = timm.data.resolve_model_data_config(self.featurizer)
self.data_cfg["input_size"] = (3, self.default_image_size,
self.default_image_size)
# Initialize Default Image Transform --> Modified by `self.image_resize_strategy`
default_image_transform = timm.data.create_transform(**self.data_cfg,
is_training=False)
# Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)!
if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url:
assert isinstance(default_image_transform,
Compose), "Unexpected `default_image_transform`!"
assert isinstance(default_image_transform.transforms[0], Resize)
default_image_transform = Compose([
Resize(self.default_image_size,
interpolation=default_image_transform.transforms[0].
interpolation),
*default_image_transform.transforms[1:],
])
# Switch on `image_resize_strategy`
if self.image_resize_strategy == "resize-naive":
assert isinstance(default_image_transform,
Compose), "Unexpected `default_image_transform`!"
assert isinstance(default_image_transform.transforms[0], Resize)
target_size = (self.default_image_size, self.default_image_size)
self.image_transform = Compose([
Resize(target_size,
interpolation=default_image_transform.transforms[0].
interpolation),
*default_image_transform.transforms[1:],
])
elif self.image_resize_strategy == "resize-crop":
self.image_transform = default_image_transform
elif self.image_resize_strategy == "letterbox":
assert isinstance(default_image_transform,
Compose), "Unexpected `default_image_transform`!"
assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!"
# Compute Padding Fill Value (rescaled normalization mean if applicable)
fill = tuple([int(x * 255) for x in self.data_cfg["mean"]])
# Build New Transform
self.image_transform = Compose(
[LetterboxPad(fill), *default_image_transform.transforms])
else:
raise ValueError(
f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!"
)
def get_fsdp_wrapping_policy(self) -> Callable:
"""Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer."""
vit_wrap_policy = partial(_module_wrap_policy,
module_classes={VisionTransformer})
transformer_block_policy = partial(transformer_auto_wrap_policy,
transformer_layer_cls={Block})
return partial(_or_policy,
policies=[vit_wrap_policy, transformer_block_policy])
def forward(
self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]]
) -> torch.Tensor:
"""Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features."""
return self.featurizer(pixel_values)
@property
def default_image_resolution(self) -> Tuple[int, int, int]:
return self.data_cfg["input_size"]
@property
def embed_dim(self) -> int:
return self.featurizer.embed_dim
@property
def num_patches(self) -> int:
return self.featurizer.patch_embed.num_patches
@property
def half_precision_dtype(self) -> torch.dtype:
return self.dtype

View File

@@ -0,0 +1,273 @@
"""
dinosiglip_vit.py
Vision backbone that returns concatenated features from both DINOv2 and SigLIP.
"""
import timm
import torch
import torchvision.transforms as transforms
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, Tuple
from PIL import Image
from timm.models.vision_transformer import Block, VisionTransformer
from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
from torchvision.transforms import Compose, Resize, Normalize
from unifolm_wma.modules.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple
from unifolm_wma.utils.nn_utils import FusedMLPProjector, LinearProjector, MLPProjector
# Registry =>> Supported DinoSigLIP Pairs (as TIMM identifiers)
DINOSigLIP_VISION_BACKBONES = {
"dinosiglip-vit-so-224px": {
"dino": "vit_large_patch14_reg4_dinov2.lvd142m",
"siglip": "vit_so400m_patch14_siglip_224",
},
"dinosiglip-vit-so-384px": {
"dino": "vit_large_patch14_reg4_dinov2.lvd142m",
"siglip": "vit_so400m_patch14_siglip_384",
},
}
@dataclass
class DinoSigLIPImageTransform:
dino_image_transform: ImageTransform
siglip_image_transform: ImageTransform
is_prismatic: bool = True
def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]:
return {
"dino": self.dino_image_transform(img, **kwargs),
"siglip": self.siglip_image_transform(img, **kwargs)
}
class DinoSigLIPViTBackbone(VisionBackbone):
def __init__(self,
vision_backbone_id: str,
image_resize_strategy: str,
arch_specifier: str,
output_dim: int,
pretrained_checkpoint=None,
freeze=True,
default_image_size: int = 224) -> None:
super().__init__(vision_backbone_id,
image_resize_strategy,
default_image_size=default_image_size)
self.dino_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[
vision_backbone_id]["dino"]
self.siglip_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[
vision_backbone_id]["siglip"]
# Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary
self.dino_featurizer: VisionTransformer = timm.create_model(
self.dino_timm_path_or_url,
pretrained=True,
num_classes=0,
img_size=self.default_image_size)
if pretrained_checkpoint:
ckpt = pretrained_checkpoint + '/openvla_dino.pt'
self.dino_featurizer.load_state_dict(
torch.load(ckpt, weights_only=True))
print('>>> load dino weights')
if freeze:
self.dino_featurizer.eval()
for param in self.dino_featurizer.parameters():
param.requires_grad = False
self.siglip_featurizer: VisionTransformer = timm.create_model(
self.siglip_timm_path_or_url,
pretrained=True,
num_classes=0,
img_size=self.default_image_size)
if pretrained_checkpoint:
ckpt = pretrained_checkpoint + '/openvla_siglip.pt'
self.siglip_featurizer.load_state_dict(
torch.load(ckpt, weights_only=True))
print('>>> load siglip weights')
if freeze:
self.siglip_featurizer.eval()
for param in self.siglip_featurizer.parameters():
param.requires_grad = False
# Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility
# => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
# => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
self.dino_featurizer.forward = unpack_tuple(
partial(self.dino_featurizer.get_intermediate_layers,
n={len(self.dino_featurizer.blocks) - 2}))
self.siglip_featurizer.forward = unpack_tuple(
partial(self.siglip_featurizer.get_intermediate_layers,
n={len(self.siglip_featurizer.blocks) - 2}))
# Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models
self.dino_data_cfg = timm.data.resolve_model_data_config(
self.dino_featurizer)
self.dino_data_cfg["input_size"] = (3, self.default_image_size,
self.default_image_size)
self.siglip_data_cfg = timm.data.resolve_model_data_config(
self.siglip_featurizer)
self.siglip_data_cfg["input_size"] = (3, self.default_image_size,
self.default_image_size)
# Initialize *both* Transforms
self.default_dino_transform = timm.data.create_transform(
**self.dino_data_cfg, is_training=False)
self.default_siglip_transform = timm.data.create_transform(
**self.siglip_data_cfg, is_training=False)
# Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!!
assert isinstance(self.default_siglip_transform,
Compose), "Unexpected `default_image_transform`!"
assert isinstance(self.default_siglip_transform.transforms[0], Resize)
self.default_siglip_transform = Compose([
Resize(self.default_image_size,
interpolation=self.default_siglip_transform.transforms[0].
interpolation),
*self.default_siglip_transform.transforms[1:],
])
if self.image_resize_strategy == "resize-naive":
assert isinstance(
self.default_dino_transform,
Compose), "Unexpected `default_dino_image_transform`!"
assert isinstance(
self.default_siglip_transform,
Compose), "Unexpected `default_siglip_image_transform`!"
assert isinstance(self.default_dino_transform.transforms[0],
Resize)
assert isinstance(self.default_siglip_transform.transforms[0],
Resize)
self.target_size = (self.default_image_size,
self.default_image_size)
dino_transform = Compose([
Resize(self.target_size,
interpolation=self.default_dino_transform.transforms[0].
interpolation),
*self.default_dino_transform.transforms[1:],
])
siglip_transform = Compose([
Resize(self.target_size,
interpolation=self.default_siglip_transform.
transforms[0].interpolation),
*self.default_siglip_transform.transforms[1:],
])
self.image_transform = DinoSigLIPImageTransform(
dino_transform, siglip_transform)
elif self.image_resize_strategy == "resize-crop":
self.image_transform = DinoSigLIPImageTransform(
self.default_dino_transform, self.default_siglip_transform)
elif self.image_resize_strategy == "letterbox":
assert isinstance(self.default_dino_transform,
Compose), "Unexpected `default_dino_transform`!"
assert isinstance(
self.default_siglip_transform,
Compose), "Unexpected `default_siglip_transform`!"
assert ("mean" in self.dino_data_cfg
and "mean" in self.siglip_data_cfg
), "DinoSigLIP `data_cfg` missing `mean`!"
# Compute Padding Fill Value(s) (rescaled normalization mean if applicable)
dino_fill = tuple(
[int(x * 255) for x in self.dino_data_cfg["mean"]])
siglip_fill = tuple(
[int(x * 255) for x in self.siglip_data_cfg["mean"]])
# Build New Transform
self.image_transform = DinoSigLIPImageTransform(
Compose([
LetterboxPad(dino_fill),
*self.default_dino_transform.transforms
]),
Compose([
LetterboxPad(siglip_fill),
*self.default_siglip_transform.transforms
]),
)
else:
raise ValueError(
f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!"
)
self.arch_specifier = arch_specifier
if arch_specifier == "linear":
self.projector = LinearProjector(self.embed_dim, output_dim)
elif arch_specifier.endswith("fused-gelu-mlp"):
self.projector = FusedMLPProjector(self.embed_dim, output_dim)
elif arch_specifier.endswith("gelu-mlp"):
self.projector = MLPProjector(self.embed_dim, output_dim)
else:
raise ValueError(
f"PrismaticVLM with `{arch_specifier = }` is not supported!")
self.on_gpu = False
def get_fsdp_wrapping_policy(self) -> Callable:
"""Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers."""
vit_wrap_policy = partial(_module_wrap_policy,
module_classes={VisionTransformer})
transformer_block_policy = partial(transformer_auto_wrap_policy,
transformer_layer_cls={Block})
return partial(_or_policy,
policies=[vit_wrap_policy, transformer_block_policy])
def forward(self, img) -> torch.Tensor:
img = torch.clamp(img.float(), -1., 1.)
img = (img + 1.0) / 2.0
img = img * 255
resize = transforms.Resize(min(self.target_size),
interpolation=self.default_dino_transform.
transforms[0].interpolation,
max_size=None,
antialias=True)
center_crop = transforms.CenterCrop(self.target_size)
img = center_crop(resize(img))
dino_normalizer = Normalize(mean=torch.tensor([0.4850, 0.4560,
0.4060]),
std=torch.tensor([0.2290, 0.2240, 0.2250]))
siglip_normalizer = Normalize(
mean=torch.tensor([0.5000, 0.5000, 0.5000]),
std=torch.tensor([0.5000, 0.5000, 0.5000]))
pixel_values = {
'dino': dino_normalizer(img),
'siglip': siglip_normalizer(img)
}
if self.on_gpu:
pixel_values = {k: v.cuda() for k, v in pixel_values.items()}
elif next(self.dino_featurizer.parameters()).device.type != 'cpu':
self.on_gpu = True
"""Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches."""
dino_patches = self.dino_featurizer(pixel_values["dino"])
siglip_patches = self.siglip_featurizer(pixel_values["siglip"])
return self.projector(torch.cat([dino_patches, siglip_patches], dim=2))
@property
def default_image_resolution(self) -> Tuple[int, int, int]:
return self.dino_data_cfg["input_size"]
@property
def embed_dim(self) -> int:
return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim
@property
def num_patches(self) -> int:
assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches
return self.dino_featurizer.patch_embed.num_patches
@property
def half_precision_dtype(self) -> torch.dtype:
return torch.bfloat16

Some files were not shown because too many files have changed in this diff Show More