diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..43988b6
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,133 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.venv
+venv/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+*.pdf
+.pdf
+plot_test/
+plot/
+performance/
+localTest/
+fig/
+figure/
+*.mp4
+*.json
+Data/ControlVAE.yml
+Data/Misc
+Data/Pretrained
+Data/utils.py
+Experiment/checkpoint
+Experiment/log
+*.ckpt
+*.STL
+*.gif
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000..040d7d9
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "external/dlimp"]
+ path = external/dlimp
+ url = https://github.com/kvablack/dlimp
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..5522eea
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,439 @@
+Attribution-NonCommercial-ShareAlike 4.0 International
+
+Copyright (c) 2016-2025 HangZhou YuShu TECHNOLOGY CO.,LTD. ("Unitree Robotics")
+
+=======================================================================
+
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.
+
+Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+=======================================================================
+
+Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
+Public License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial-ShareAlike 4.0 International Public License
+("Public License"). To the extent this Public License may be
+interpreted as a contract, You are granted the Licensed Rights in
+consideration of Your acceptance of these terms and conditions, and the
+Licensor grants You such rights in consideration of benefits the
+Licensor receives from making the Licensed Material available under
+these terms and conditions.
+
+
+Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+
+ c. BY-NC-SA Compatible License means a license listed at
+ creativecommons.org/compatiblelicenses, approved by Creative
+ Commons as essentially the equivalent of this Public License.
+
+ d. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+
+ e. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ g. License Elements means the license attributes listed in the name
+ of a Creative Commons Public License. The License Elements of this
+ Public License are Attribution, NonCommercial, and ShareAlike.
+
+ h. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ i. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ j. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ k. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ l. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ m. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ n. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+
+Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce, reproduce, and Share Adapted Material for
+ NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. Additional offer from the Licensor -- Adapted Material.
+ Every recipient of Adapted Material from You
+ automatically receives an offer from the Licensor to
+ exercise the Licensed Rights in the Adapted Material
+ under the conditions of the Adapter's License You apply.
+
+ c. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+ b. ShareAlike.
+
+ In addition to the conditions in Section 3(a), if You Share
+ Adapted Material You produce, the following conditions also apply.
+
+ 1. The Adapter's License You apply must be a Creative Commons
+ license with the same License Elements, this version or
+ later, or a BY-NC-SA Compatible License.
+
+ 2. You must include the text of, or the URI or hyperlink to, the
+ Adapter's License You apply. You may satisfy this condition
+ in any reasonable manner based on the medium, means, and
+ context in which You Share Adapted Material.
+
+ 3. You may not offer or impose any additional or different terms
+ or conditions on, or apply any Effective Technological
+ Measures to, Adapted Material that restrict exercise of the
+ rights granted under the Adapter's License You apply.
+
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material,
+ including for purposes of Section 3(b); and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+
+Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+
+Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+
+Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+=======================================================================
+
+Creative Commons is not a party to its public
+licenses. Notwithstanding, Creative Commons may elect to apply one of
+its public licenses to material it publishes and in those instances
+will be considered the “Licensor.” The text of the Creative Commons
+public licenses is dedicated to the public domain under the CC0 Public
+Domain Dedication. Except for the limited purpose of indicating that
+material is shared under a Creative Commons public license or as
+otherwise permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the
+public licenses.
+
+Creative Commons may be contacted at creativecommons.org.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..15aa4be
--- /dev/null
+++ b/README.md
@@ -0,0 +1,228 @@
+# UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family
+
+ UnifoLM-WMA-0 is Unitree‘s open-source world-model–action architecture spanning multiple types of robotic embodiments, designed specifically for general-purpose robot learning. Its core component is a world-model capable of understanding the physical interactions between robots and the environments. This world-model provides two key functions: (a) Simulation Engine – operates as an interactive simulator to generate synthetic data for robot learning; (b) Policy Enhancement – connects with an action head and, by predicting future interaction processes with the world-model, further optimizes decision-making performance.
+
+
+## 🦾 Real-Robot Demonstrations
+| | |
+|:---:|:---:|
+| | |
+
+**Note: the top-right window shows the world model’s pretion of future action videos.**
+
+## 🔥 News
+
+* Sep 22, 2025: 🚀 We released the deployment code for assisting experiments with [Unitree](https://www.unitree.com/) robots.
+* Sep 15, 2025: 🚀 We released the training and inference code along with the model weights of [**UnifoLM-WMA-0**](https://huggingface.co/collections/unitreerobotics/unifolm-wma-0-68ca23027310c0ca0f34959c).
+
+## 📑 Opensource Plan
+- [x] Training
+- [x] Inference
+- [x] Checkpoints
+- [x] Deployment
+
+## ⚙️ Installation
+```
+conda create -n unifolm-wma python==3.10.18
+conda activate unifolm-wma
+
+conda install pinocchio=3.2.0 -c conda-forge -y
+conda install ffmpeg=7.1.1 -c conda-forge
+
+git clone --recurse-submodules https://github.com/unitreerobotics/unifolm-world-model-action.git
+
+# If you already downloaded the repo:
+cd unifolm-world-model-action
+git submodule update --init --recursive
+
+pip install -e .
+
+cd external/dlimp
+pip install -e .
+```
+## 🧰 Model Checkpoints
+| Model | Description | Link|
+|---------|-------|------|
+|$\text{UnifoLM-WMA-0}_{Base}$| Fine-tuned on [Open-X](https://robotics-transformer-x.github.io/) dataset. | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Base)|
+|$\text{UnifoLM-WMA-0}_{Dual}$| Fine-tuned on five [Unitree opensource dataset](https://huggingface.co/collections/unitreerobotics/g1-dex1-datasets-68bae98bf0a26d617f9983ab) in both decision-making and simulation modes. | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Dual)|
+
+## 🛢️ Dataset
+In our experiments, we consider the following three opensource dataset:
+| Dataset | Robot | Link |
+|---------|-------|------|
+|Z1_StackBox| [Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_StackBox_Dataset/tree/v2.1)|
+|Z1_DualArm_StackBox|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset/tree/v2.1)|
+|Z1_DualArm_StackBox_V2|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset_V2/tree/v2.1)|
+|Z1_DualArm_Cleanup_Pencils|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_CleanupPencils_Dataset/tree/v2.1)|
+|G1_Pack_Camera|[Unitree G1](https://www.unitree.com/g1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/G1_Dex1_MountCameraRedGripper_Dataset/tree/v2.1)|
+
+To train on your own dataset, first to have the data following the [Huggingface LeRobot V2.1](https://github.com/huggingface/lerobot) dataset format. Assume the dataset’s source directory structure is as follows:
+```
+source_dir/
+ ├── dataset1_name
+ ├── dataset2_name
+ ├── dataset3_name
+ └── ...
+```
+Then, convert a dataset to the required format using the command below:
+```python
+cd prepare_data
+python prepare_training_data.py \
+ --source_dir /path/to/your/source_dir \
+ --target_dir /path/to/save/the/converted/data \
+ --dataset_name "dataset1_name" \
+ --robot_name "a tag of the robot in the dataset" # e.g, Unitree Z1 Robot Arm or Unitree G1 Robot with Gripper.
+```
+The resulting data structure (Note: model training only supports input from the main-view camera. If the dataset includes multiple views, remove the corresponding values from the ```data_dir``` column in the CSV file.
+```
+target_dir/
+ ├── videos
+ │ ├──dataset1_name
+ │ │ ├──camera_view_dir
+ │ │ ├── 0.mp4
+ │ │ ├── 1.mp4
+ │ │ └── ...
+ │ └── ...
+ ├── transitions
+ │ ├── dataset1_name
+ │ ├── meta_data
+ │ ├── 0.h5
+ │ ├── 1.h5
+ │ └── ...
+ └── dataset1_name.csv
+```
+## 🚴♂️ Training
+A. Our training strategy is outlined as follows:
+- **Step 1**: Fine-tune a video generation model as the world model using the [Open-X](https://robotics-transformer-x.github.io/) dataset;
+- **Step 2**: Post-train $\text{UnifoLM-WMA}$ in decision-making mode on the downstream task dataset;
+
+
+
+- **Step 3**: Post-train $\text{UnifoLM-WMA}$ in simulation mode on the downstream task dataset.
+
+
+
+**Note**: If you only require $\text{UnifoLM-WMA}$ to operate in a single mode, you may skip the corresponding step.
+
+B. To conduct training on a single or multiple datasets, please follow the steps below:
+- **Step 1**: The maximum DoF is assumed to be 16, if you have more than 16 DoF, update ```agent_state_dim``` and ```agent_action_dim``` in [configs/train/config.yaml](https://github.com/unitreerobotics/unifolm-wma/blob/working/configs/train/config.yaml) ;
+- **Step 2**: Set up the input shapes for each modality in [configs/train/meta.json](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/meta.json);
+- **Step 3**: Configure the training parameters in [configs/train/config.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/config.yaml). For the ```pretrained_checkpoint```, we recommend using the checkpoint " $\text{UnifoLM-WMA-0}_{Base}$ " fine-tuned on the [Open-X](https://robotics-transformer-x.github.io/) dataset;
+ ```yaml
+ model:
+ pretrained_checkpoint: /path/to/pretrained/checkpoint;
+ ...
+ decision_making_only: True # Train the world model only in decision-making mode. If False, jointly train it in both decision-making and simulation modes.
+ ...
+ data:
+ ...
+ train:
+ ...
+ data_dir: /path/to/training/dataset/directory
+ dataset_and_weights: # list the name of each dataset below and make sure the summation of weights is 1.0
+ dataset1_name: 0.2
+ dataset2_name: 0.2
+ dataset3_name: 0.2
+ dataset4_name: 0.2
+ dataset5_name: 0.2
+ ```
+- **Step 4**: Setup ```experiment_name```, ```save_root``` variables in [scripts/train.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/train.sh);
+- **Step 5**: Launch the training with the command:
+```
+bash scripts/train.sh
+```
+## 🌏 Inference under Interactive Simulation Mode
+To run the world model in an interactive simulation mode, follow these steps:
+- **Step 1**: (Skip this step if you just would like to test using the examples we provided) Prepare your own prompt following the format used in the [examples/world_model_interaction_prompts](https://github.com/unitreerobotics/unitree-world-model/tree/main/examples/world_model_interaction_prompts):
+ ```
+ world_model_interaction_prompts/
+ ├── images
+ │ ├── dataset1_name
+ │ │ ├── 0.png # Image prompt
+ │ │ └── ...
+ │ └── ...
+ ├── transitions
+ │ ├── dataset1_name
+ │ │ ├── meta_data # Used for normalization
+ │ │ ├── 0.h # Robot state and action data; in interaction mode,
+ │ │ │ # only used to retrieve the robot state corresponding
+ │ │ │ # to the image prompt
+ │ │ └── ...
+ │ └── ...
+ ├── dataset1_name.csv # File for loading image prompts, text instruction and corresponding robot states
+ └── ...
+ ```
+- **Step 2**: Specify the correct paths for ```pretrained_checkpoint```(e.g, $\text{UnifoLM-WMA-0}_{Dual}$) and ```data_dir``` in [configs/inference/world_model_interaction.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/inference/world_model_interaction.yaml)
+- **Step 3**: Set the paths for ```checkpoint```, ```res_dir``` and ```prompt_dir``` in [scripts/run_world_model_interaction.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/run_world_model_interaction.sh), and specify all the dataset's name in ```datasets=(...)```. Then, launch the inference with the command:
+ ```
+ bash scripts/run_world_model_interaction.sh
+ ```
+
+## 🧠 Inference and Deployment under Decision-Making Mode
+
+In this setup, inference is performed on a server, while a robot client gathers observations from the real-robot and sends them to the server to query actions. The process unfolds through the following steps:
+
+### Server Setup:
+- **Step-1**: Specify ```ckpt```, ```res_dir```, ```datasets``` in [scripts/run_real_eval_server.sh](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/scripts/run_real_eval_server.sh);
+- **Step-2**: Configure ```data_dir``` and ```dataset_and_weights``` in [config/inference/world_model_decision_making.yaml](https://github.com/unitreerobotics/unifolm-world-model-action/blob/f12b4782652ca00452941d851b17446e4ee7124a/configs/inference/world_model_decision_making.yaml#L225);
+- **Step-3**: Launch the server:
+```
+conda activate unifolm-wma
+cd unifolm-world-model-action
+bash scripts/run_real_eval_server.sh
+```
+
+### Client Setup
+- **Step-1**: Follow the instructions in [unitree_deploy/README.md](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/unitree_deploy/README.md) to create the ```unitree_deploy``` conda environment, install the required packages, launch the controllers or services on the real-robot.
+- **Step-2**: Open a new terminal and establish a tunnel connection from the client to the server:
+```
+ssh user_name@remote_server_IP -CNg -L 8000:127.0.0.1:8000
+```
+- **Step-3**: Run the ```unitree_deploy/robot_client.py``` script to start inference:
+```
+cd unitree_deploy
+python scripts/robot_client.py --robot_type "g1_dex1" --action_horizon 16 --exe_steps 16 --observation_horizon 2 --language_instruction "pack black camera into box" --output_dir ./results --control_freq 15
+```
+
+## 📝 Codebase Architecture
+Here's a high-level overview of the project's code structure and core components:
+```
+unitree-world-model/
+ ├── assets # Media assets such as GIFs, images, and demo videos
+ ├── configs # Configuration files for training and inference
+ │ ├── inference
+ │ └── train
+ ├── examples # Example inputs and prompts for running inference
+ ├── external # External packages
+ ├── prepare_data # Scripts for dataset preprocessing and format conversion
+ ├── scripts # Main scripts for training, evaluation, and deployment
+ ├── src
+ │ ├──unitree_worldmodel # Core Python package for the Unitree world model
+ │ │ ├── data # Dataset loading, transformations, and dataloaders
+ │ │ ├── models # Model architectures and backbone definitions
+ │ │ ├── modules # Custom model modules and components
+ │ │ └── utils # Utility functions and common helpers
+ └── unitree_deploy # Deployment code
+```
+
+## 🙏 Acknowledgement
+Lots of code are inherited from [DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter), [Diffusion Policy](https://github.com/real-stanford/diffusion_policy), [ACT](https://github.com/MarkFzp/act-plus-plus) and [HPT](https://github.com/liruiw/HPT).
+
+## 📝 Citation
+```
+@misc{unifolm-wma-0,
+ author = {Unitree},
+ title = {UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family},
+ year = {2025},
+}
+```
diff --git a/README_cn.md b/README_cn.md
new file mode 100644
index 0000000..143f998
--- /dev/null
+++ b/README_cn.md
@@ -0,0 +1,216 @@
+# UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family
+
+**注意**:如果只需要 $\text{UnifoLM-WMA}$ 在单一模式下运行,可以跳过相应的步骤。
+
+二. 在单个或多个数据集上进行训练,请按照以下步骤操作:
+- **步骤1**:默认的最高自由度为16DOF,若需更多自由度,请修改[configs/train/config.yaml](https://github.com/unitreerobotics/unifolm-wma/blob/working/configs/train/config.yaml) 中 ```agent_state_dim``` 及 ```agent_action_dim``` 的数值;
+- **步骤2**:在 [configs/train/meta.json](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/meta.json) 中为每种模态设置输入维度;
+- **步骤3**: 在 [configs/train/config.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/config.yaml) 中配置训练参数及路径。关于预训练的模型,推荐使用 $\text{UnifoLM-WMA-0}_{Base}$ ,其在[Open-X](https://robotics-transformer-x.github.io/) 数据集上微调过;
+ ```yaml
+ model:
+ pretrained_checkpoint: /path/to/pretrained/checkpoint
+ ...
+ dicision_making_only: True # 是否只训练世界模型决策模式?如果否,则决策模式与仿真模式联合训练。
+ ...
+ data:
+ ...
+ train:
+ ...
+ data_dir: /path/to/training/dataset/directory
+ dataset_and_weights: # 列出所有数据集的名称及权重,确保权重和为1.0
+ dataset1_name: 0.2
+ dataset2_name: 0.2
+ dataset3_name: 0.2
+ dataset4_name: 0.2
+ dataset5_name: 0.2
+ ```
+- **步骤4**: 在 [scripts/train.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/train.sh) 中配置```experiment_name```, ```save_root``` 变量;
+- **步骤5**: 运行如下指令开启训练:
+```
+bash scripts/train.sh
+```
+## 🌏 世界模型交互推理
+要启用世界模型的交互模式,请按以下步骤操作:
+- **步骤1**:(若仅用提供的实例进行测试,可跳过此步) 请按照 [examples/world_model_interaction_prompts](https://github.com/unitreerobotics/unitree-world-model/tree/main/examples/world_model_interaction_prompts) 目录中的格式,自定义提示词目录:
+```
+world_model_interaction_prompts/
+ ├── images
+ │ ├── dataset1_name
+ │ │ ├── 0.png # 图像提示词
+ │ │ └── ...
+ │ └── ...
+ ├── transitions
+ │ ├── dataset1_name
+ │ │ ├── meta_data # 用于归一化
+ │ │ ├── 0.h # 机器人状态、动作相关数据,在交互模式下仅用于获取与图像提示词对应的机器人状态
+ │ │ └── ...
+ │ └── ...
+ ├── dataset1_name.csv # 该文件用于加载对应的:图像提示词、文本指令及机器人状态
+ └── ...
+```
+- **步骤2**: 在 [configs/inference/world_model_interaction.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/inference/world_model_interaction.yaml) 中指定 ```pretrained_checkpoint``` (例如:$\text{UnifoLM-WMA-0}_{Dual}$) 和 ```data_dir``` 的正确路径;
+- **步骤3**: 在 [scripts/run_world_model_interaction.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/run_world_model_interaction.sh) 中指定```checkpoint```、```res_dir``` 和 ```prompt_dir```的正确路径,并在```datasets=(...)```中列出测试的数据集名称,然后用下述指令启动推理:
+ ```
+ bash scripts/run_world_model_interaction.sh
+ ```
+
+## 🧠 世界模型决策推理及部署
+在我们的系统中,推理在服务器端执行;机器人客户端从真实机器人收集观测信息并发送至服务器, 进行视频及动作推理。可通过如下步骤实现整个过程:
+
+### 服务器端设置
+- **步骤1**: 在 [scripts/run_real_eval_server.sh](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/scripts/run_real_eval_server.sh) 中指定 ```ckpt```、```res_dir```、```datasets```;
+- **步骤2**: 在 [config/inference/world_model_decision_making.yaml](https://github.com/unitreerobotics/unifolm-world-model-action/blob/f12b4782652ca00452941d851b17446e4ee7124a/configs/inference/world_model_decision_making.yaml#L225) 中配置 ```data_dir``` 和 ```dataset_and_weights```;
+- **步骤3**: 启动服务器:
+```
+conda activate unifolm-wma
+cd unifolm-world-model-action
+bash scripts/run_real_eval_server.sh
+```
+
+### 客户端设置
+- **步骤1**: 参考 [unitree_deploy/README.md](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/unitree_deploy/README.md),创建 ```unitree_deploy``` conda 环境,安装所需依赖包,并在真实机器人端启动控制器或服务;
+- **步骤2**: 打开一个新的终端,从客户端到服务器建立隧道连接:
+```
+ssh user_name@remote_server_IP -CNg -L 8000:127.0.0.1:8000
+```
+- **步骤3**: 运行 ```unitree_deploy/robot_client.py``` 脚本以启动推理:
+```
+cd unitree_deploy
+python scripts/robot_client.py --robot_type "g1_dex1" --action_horizon 16 --exe_steps 16 --observation_horizon 2 --language_instruction "pack black camera into box" --output_dir ./results --control_freq 15
+```
+
+## 📝 代码架构
+以下是本项目代码结构设计及核心组件说明::
+```
+unitree-world-model/
+ ├── assets # GIF动图、静态图片和演示视频等媒体素材
+ ├── configs # 配置文件
+ │ ├── inference
+ │ └── train
+ ├── examples # 示例数据
+ ├── external # 外部代码库
+ ├── prepare_data # 数据处理
+ ├── scripts # 主程序脚本
+ ├── src
+ │ ├──unitree_worldmodel # 核心库
+ │ │ ├── data # 数据加载
+ │ │ ├── models # 模型架构
+ │ │ ├── modules # 自定义模块
+ | │ └── utils # 工具函数
+```
+
+## 🙏 致谢声明
+本项目代码基于以下优秀开源项目构建,特此致谢:[DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter), [Diffusion Policy](https://github.com/real-stanford/diffusion_policy), [ACT](https://github.com/MarkFzp/act-plus-plus) 和 [HPT](https://github.com/liruiw/HPT).
diff --git a/assets/pngs/dm_mode.png b/assets/pngs/dm_mode.png
new file mode 100644
index 0000000..b377a3d
Binary files /dev/null and b/assets/pngs/dm_mode.png differ
diff --git a/assets/pngs/sim_mode.png b/assets/pngs/sim_mode.png
new file mode 100644
index 0000000..6c96b09
Binary files /dev/null and b/assets/pngs/sim_mode.png differ
diff --git a/ckpts/.gitattributes b/ckpts/.gitattributes
new file mode 100644
index 0000000..c65e9b2
--- /dev/null
+++ b/ckpts/.gitattributes
@@ -0,0 +1,54 @@
+*.7z filter=lfs diff=lfs merge=lfs -text
+*.arrow filter=lfs diff=lfs merge=lfs -text
+*.bin filter=lfs diff=lfs merge=lfs -text
+*.bin.* filter=lfs diff=lfs merge=lfs -text
+*.bz2 filter=lfs diff=lfs merge=lfs -text
+*.ftz filter=lfs diff=lfs merge=lfs -text
+*.gz filter=lfs diff=lfs merge=lfs -text
+*.h5 filter=lfs diff=lfs merge=lfs -text
+*.joblib filter=lfs diff=lfs merge=lfs -text
+*.lfs.* filter=lfs diff=lfs merge=lfs -text
+
+*.msgpack filter=lfs diff=lfs merge=lfs -text
+*.onnx filter=lfs diff=lfs merge=lfs -text
+*.ot filter=lfs diff=lfs merge=lfs -text
+*.parquet filter=lfs diff=lfs merge=lfs -text
+*.pb filter=lfs diff=lfs merge=lfs -text
+
+*.pth filter=lfs diff=lfs merge=lfs -text
+*.rar filter=lfs diff=lfs merge=lfs -text
+saved_model/**/* filter=lfs diff=lfs merge=lfs -text
+*.tar.* filter=lfs diff=lfs merge=lfs -text
+*.tflite filter=lfs diff=lfs merge=lfs -text
+*.tgz filter=lfs diff=lfs merge=lfs -text
+*.xz filter=lfs diff=lfs merge=lfs -text
+*.zip filter=lfs diff=lfs merge=lfs -text
+*.zstandard filter=lfs diff=lfs merge=lfs -text
+*.tfevents* filter=lfs diff=lfs merge=lfs -text
+*.db* filter=lfs diff=lfs merge=lfs -text
+*.ark* filter=lfs diff=lfs merge=lfs -text
+**/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
+**/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
+**/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
+*.safetensors filter=lfs diff=lfs merge=lfs -text
+
+*.gguf* filter=lfs diff=lfs merge=lfs -text
+*.ggml filter=lfs diff=lfs merge=lfs -text
+*.llamafile* filter=lfs diff=lfs merge=lfs -text
+*.pt2 filter=lfs diff=lfs merge=lfs -text
+*.mlmodel filter=lfs diff=lfs merge=lfs -text
+*.npy filter=lfs diff=lfs merge=lfs -text
+*.npz filter=lfs diff=lfs merge=lfs -text
+*.pickle filter=lfs diff=lfs merge=lfs -text
+*.pkl filter=lfs diff=lfs merge=lfs -text
+*.tar filter=lfs diff=lfs merge=lfs -text
+*.wasm filter=lfs diff=lfs merge=lfs -text
+*.zst filter=lfs diff=lfs merge=lfs -text
+*tfevents* filter=lfs diff=lfs merge=lfs -text
+
+assets/real_cleanup_pencils.gif filter=lfs diff=lfs merge=lfs -text
+assets/world_model_interaction.gif filter=lfs diff=lfs merge=lfs -text
+assets/real_dual_stackbox.gif filter=lfs diff=lfs merge=lfs -text
+assets/real_g1_pack_camera.gif filter=lfs diff=lfs merge=lfs -text
+assets/real_z1_stackbox.gif filter=lfs diff=lfs merge=lfs -text
+unifolm_wma_dual.ckpt filter=lfs diff=lfs merge=lfs -text
\ No newline at end of file
diff --git a/ckpts/LICENSE b/ckpts/LICENSE
new file mode 100644
index 0000000..5522eea
--- /dev/null
+++ b/ckpts/LICENSE
@@ -0,0 +1,439 @@
+Attribution-NonCommercial-ShareAlike 4.0 International
+
+Copyright (c) 2016-2025 HangZhou YuShu TECHNOLOGY CO.,LTD. ("Unitree Robotics")
+
+=======================================================================
+
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.
+
+Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+=======================================================================
+
+Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
+Public License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial-ShareAlike 4.0 International Public License
+("Public License"). To the extent this Public License may be
+interpreted as a contract, You are granted the Licensed Rights in
+consideration of Your acceptance of these terms and conditions, and the
+Licensor grants You such rights in consideration of benefits the
+Licensor receives from making the Licensed Material available under
+these terms and conditions.
+
+
+Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+
+ c. BY-NC-SA Compatible License means a license listed at
+ creativecommons.org/compatiblelicenses, approved by Creative
+ Commons as essentially the equivalent of this Public License.
+
+ d. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+
+ e. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ g. License Elements means the license attributes listed in the name
+ of a Creative Commons Public License. The License Elements of this
+ Public License are Attribution, NonCommercial, and ShareAlike.
+
+ h. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ i. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ j. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ k. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ l. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ m. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ n. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+
+Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce, reproduce, and Share Adapted Material for
+ NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. Additional offer from the Licensor -- Adapted Material.
+ Every recipient of Adapted Material from You
+ automatically receives an offer from the Licensor to
+ exercise the Licensed Rights in the Adapted Material
+ under the conditions of the Adapter's License You apply.
+
+ c. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+ b. ShareAlike.
+
+ In addition to the conditions in Section 3(a), if You Share
+ Adapted Material You produce, the following conditions also apply.
+
+ 1. The Adapter's License You apply must be a Creative Commons
+ license with the same License Elements, this version or
+ later, or a BY-NC-SA Compatible License.
+
+ 2. You must include the text of, or the URI or hyperlink to, the
+ Adapter's License You apply. You may satisfy this condition
+ in any reasonable manner based on the medium, means, and
+ context in which You Share Adapted Material.
+
+ 3. You may not offer or impose any additional or different terms
+ or conditions on, or apply any Effective Technological
+ Measures to, Adapted Material that restrict exercise of the
+ rights granted under the Adapter's License You apply.
+
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material,
+ including for purposes of Section 3(b); and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+
+Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+
+Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+
+Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+=======================================================================
+
+Creative Commons is not a party to its public
+licenses. Notwithstanding, Creative Commons may elect to apply one of
+its public licenses to material it publishes and in those instances
+will be considered the “Licensor.” The text of the Creative Commons
+public licenses is dedicated to the public domain under the CC0 Public
+Domain Dedication. Except for the limited purpose of indicating that
+material is shared under a Creative Commons public license or as
+otherwise permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the
+public licenses.
+
+Creative Commons may be contacted at creativecommons.org.
diff --git a/ckpts/README.md b/ckpts/README.md
new file mode 100644
index 0000000..038ff9b
--- /dev/null
+++ b/ckpts/README.md
@@ -0,0 +1,38 @@
+---
+tags:
+- robotics
+---
+
+# UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family
+
+ UnifoLM-WMA-0 is Unitree‘s first open-source world-model–action architecture spanning multiple types of robotic embodiments, designed specifically for general-purpose robot learning. Its core component is a world-model capable of understanding the physical interactions between robots and the environments. This world-model provides two key functions: (a) Simulation Engine – operates as an interactive simulator to generate synthetic data for robot learning; (b) Policy Enhancement – connects with an action head and, by predicting future interaction processes with the world-model, further optimizes decision-making performance.
+
+
+
+## 🦾 Real Robot Deployment
+| | |
+|:---:|:---:|
+| | |
+
+**Note: the top-right window shows the world model’s prediction of future environmental changes.**
+
+## License
+The model is released under the CC BY-NC-SA 4.0 license as found in the [LICENSE](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0/blob/main/LICENSE). You are responsible for ensuring that your use of Unitree AI Models complies with all applicable laws.
+
+## Model Architecture
+
+
+## Citation
+```
+@misc{unifolm-wma-0,
+ author = {Unitree},
+ title = {UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family},
+ year = {2025},
+}
+```
\ No newline at end of file
diff --git a/configs/inference/base_model_inference.yaml b/configs/inference/base_model_inference.yaml
new file mode 100644
index 0000000..157000a
--- /dev/null
+++ b/configs/inference/base_model_inference.yaml
@@ -0,0 +1,213 @@
+model:
+ target: unifolm_wma.models.ddpms.LatentVisualDiffusion
+ params:
+ rescale_betas_zero_snr: True
+ parameterization: "v"
+ linear_start: 0.00085
+ linear_end: 0.012
+ num_timesteps_cond: 1
+ timesteps: 1000
+ first_stage_key: video
+ cond_stage_key: instruction
+ cond_stage_trainable: False
+ conditioning_key: hybrid
+ image_size: [40, 64]
+ channels: 4
+ scale_by_std: False
+ scale_factor: 0.18215
+ use_ema: False
+ uncond_type: 'empty_seq'
+ use_dynamic_rescale: true
+ base_scale: 0.7
+ fps_condition_type: 'fps'
+ perframe_ae: True
+ freeze_embedder: True
+ n_obs_steps_imagen: 1
+ n_obs_steps_acting: 1
+ agent_state_dim: 16
+ agent_action_dim: 16
+
+ ###################### DP Related
+ input_pertub: 0.1
+ lr_scheduler: cosine
+ lr_warmup_steps: 2000
+ num_epochs: 30000
+ gradient_accumulate_every: 1
+ use_scheduler: True
+ dp_use_ema: True
+
+ dp_ema_config:
+ target: unifolm_wma.models.diffusion_head.ema_model.EMAModel
+ params:
+ update_after_step: 0
+ inv_gamma: 1.0
+ power: 0.75
+ min_value: 0.0
+ max_value: 0.9999
+
+ noise_scheduler_config:
+ target: diffusers.DDIMScheduler
+ params:
+ num_train_timesteps: 1000
+ beta_start: 0.0001
+ beta_end: 0.02
+ beta_schedule: squaredcos_cap_v2
+ clip_sample: True
+ set_alpha_to_one: True
+ steps_offset: 0
+ prediction_type: epsilon
+
+ dp_optimizer_config:
+ target: torch.optim.AdamW
+ params:
+ lr: 1.0e-4
+ betas: [0.95, 0.999]
+ eps: 1.0e-8
+ weight_decay: 1.0e-6
+
+ wma_config:
+ target: unifolm_wma.modules.networks.wma_model.WMAModel
+ params:
+ in_channels: 8
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions:
+ - 4
+ - 2
+ - 1
+ num_res_blocks: 2
+ channel_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ dropout: 0.1
+ num_head_channels: 64
+ transformer_depth: 1
+ context_dim: 1024
+ use_linear: true
+ use_checkpoint: True
+ temporal_conv: True
+ temporal_attention: True
+ temporal_selfatt_only: True
+ use_relative_position: False
+ use_causal_attention: False
+ temporal_length: 16
+ addition_attention: True
+ image_cross_attention: True
+ default_fs: 10
+ fs_condition: True
+ cross_attention_scale_learnable: False
+ n_obs_steps: ${model.params.n_obs_steps_imagen}
+ num_stem_token: 16
+ base_model_gen_only: True
+
+ unet_head_config:
+ target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D
+ params:
+ input_dim: ${model.params.agent_action_dim}
+ n_obs_steps: ${model.params.n_obs_steps_acting}
+ diffusion_step_embed_dim: 128
+ down_dims: [256, 512, 1024, 2048]
+ kernel_size: 5
+ n_groups: 8
+ cond_predict_scale: True
+ num_head_channels: ${model.params.wma_config.params.num_head_channels}
+ horizon: ${model.params.wma_config.params.temporal_length}
+ use_linear_attn: ${model.params.wma_config.params.use_linear}
+ use_linear_act_proj: True
+ act_proj_dim: 32
+ cond_cross_attention: False
+ context_dims: []
+ image_size: ${model.params.image_size}
+ imagen_cond_gradient: True
+ last_frame_only: False
+ use_imagen_mid_only: False
+ use_z_only: False
+
+ obs_encoder_config:
+ target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder
+ params:
+ rgb_model_config:
+ target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet
+ params:
+ name: resnet18
+ weights: null
+ resize_shape: null
+ crop_shape: null
+ random_crop: False
+ use_group_norm: True
+ share_rgb_model: False
+ imagenet_norm: True
+ use_spatial_softmax: True
+ spatial_softmax_kp: 128
+
+ ###################### Action Tokenization
+ stem_process_config:
+ target: unifolm_wma.modules.encoders.condition.SATokenProjector
+ params:
+ dim: 1024
+ depth: 1
+ dim_head: 64
+ heads: 16
+ num_queries: ${model.params.wma_config.params.num_stem_token}
+ output_dim: 1024
+ ff_mult: 4
+ chunk_size: ${model.params.wma_config.params.temporal_length}
+
+ first_stage_config:
+ target: unifolm_wma.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: True
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder
+ params:
+ freeze: True
+ layer: "penultimate"
+
+ img_cond_stage_config:
+ target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
+ params:
+ freeze: true
+
+ image_proj_stage_config:
+ target: unifolm_wma.modules.encoders.resampler.Resampler
+ params:
+ dim: 1024
+ depth: 4
+ dim_head: 64
+ heads: 12
+ num_queries: 16
+ embedding_dim: 1280
+ output_dim: 1024
+ ff_mult: 4
+ video_length: ${model.params.wma_config.params.temporal_length}
+
+ normalization_config:
+ input_shapes:
+ observation.state: ${model.params.wma_config.params.action_unet_config.params.input_dim}
+ input_normalization_modes:
+ observation.state: 'min_max'
+ output_shapes:
+ action: ${model.params.wma_config.params.action_unet_config.params.input_dim}
+ output_normalization_modes:
+ action: 'min_max'
diff --git a/configs/inference/world_model_decision_making.yaml b/configs/inference/world_model_decision_making.yaml
new file mode 100644
index 0000000..2f51dce
--- /dev/null
+++ b/configs/inference/world_model_decision_making.yaml
@@ -0,0 +1,240 @@
+model:
+ target: unifolm_wma.models.ddpms.LatentVisualDiffusion
+ params:
+ rescale_betas_zero_snr: True
+ parameterization: "v"
+ linear_start: 0.00085
+ linear_end: 0.012
+ num_timesteps_cond: 1
+ timesteps: 1000
+ first_stage_key: video
+ cond_stage_key: instruction
+ cond_stage_trainable: False
+ conditioning_key: hybrid
+ image_size: [40, 64]
+ channels: 4
+ scale_by_std: False
+ scale_factor: 0.18215
+ use_ema: False
+ uncond_type: 'empty_seq'
+ use_dynamic_rescale: true
+ base_scale: 0.7
+ fps_condition_type: 'fps'
+ perframe_ae: True
+ freeze_embedder: True
+ n_obs_steps_imagen: 2
+ n_obs_steps_acting: 2
+ agent_state_dim: 16
+ agent_action_dim: 16
+ decision_making_only: True
+
+ ###################### DP Related
+ input_pertub: 0.1
+ lr_scheduler: cosine
+ lr_warmup_steps: 2000
+ num_epochs: 30000
+ gradient_accumulate_every: 1
+ use_scheduler: True
+ dp_use_ema: True
+
+ dp_ema_config:
+ target: unifolm_wma.models.diffusion_head.ema_model.EMAModel
+ params:
+ update_after_step: 0
+ inv_gamma: 1.0
+ power: 0.75
+ min_value: 0.0
+ max_value: 0.9999
+
+ noise_scheduler_config:
+ target: diffusers.DDIMScheduler
+ params:
+ num_train_timesteps: 1000
+ beta_start: 0.0001
+ beta_end: 0.02
+ beta_schedule: squaredcos_cap_v2
+ clip_sample: True
+ set_alpha_to_one: True
+ steps_offset: 0
+ prediction_type: epsilon
+
+ dp_optimizer_config:
+ target: torch.optim.AdamW
+ params:
+ lr: 1.0e-4
+ betas: [0.95, 0.999]
+ eps: 1.0e-8
+ weight_decay: 1.0e-6
+
+ wma_config:
+ target: unifolm_wma.modules.networks.wma_model.WMAModel
+ params:
+ in_channels: 8
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions:
+ - 4
+ - 2
+ - 1
+ num_res_blocks: 2
+ channel_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ dropout: 0.1
+ num_head_channels: 64
+ transformer_depth: 1
+ context_dim: 1024
+ use_linear: true
+ use_checkpoint: True
+ temporal_conv: True
+ temporal_attention: True
+ temporal_selfatt_only: True
+ use_relative_position: False
+ use_causal_attention: False
+ temporal_length: 16
+ addition_attention: True
+ image_cross_attention: True
+ default_fs: 10
+ fs_condition: True
+ cross_attention_scale_learnable: False
+ n_obs_steps: ${model.params.n_obs_steps_imagen}
+ num_stem_token: 16
+ base_model_gen_only: False
+
+ unet_head_config:
+ target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D
+ params:
+ input_dim: ${model.params.agent_action_dim}
+ n_obs_steps: ${model.params.n_obs_steps_acting}
+ diffusion_step_embed_dim: 128
+ down_dims: [256, 512, 1024, 2048]
+ kernel_size: 5
+ n_groups: 8
+ cond_predict_scale: True
+ num_head_channels: ${model.params.wma_config.params.num_head_channels}
+ horizon: ${model.params.wma_config.params.temporal_length}
+ use_linear_attn: ${model.params.wma_config.params.use_linear}
+ use_linear_act_proj: True
+ act_proj_dim: 32
+ cond_cross_attention: False
+ context_dims: []
+ image_size: ${model.params.image_size}
+ imagen_cond_gradient: True
+ last_frame_only: False
+ use_imagen_mid_only: False
+ use_z_only: False
+
+ obs_encoder_config:
+ target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder
+ params:
+ rgb_model_config:
+ target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet
+ params:
+ name: resnet18
+ weights: null
+ resize_shape: null
+ crop_shape: null
+ random_crop: False
+ use_group_norm: True
+ share_rgb_model: False
+ imagenet_norm: True
+ use_spatial_softmax: True
+ spatial_softmax_kp: 128
+
+ ###################### Action Tokenization
+ stem_process_config:
+ target: unifolm_wma.modules.encoders.condition.SATokenProjector
+ params:
+ dim: 1024
+ depth: 1
+ dim_head: 64
+ heads: 16
+ num_queries: ${model.params.wma_config.params.num_stem_token}
+ output_dim: 1024
+ ff_mult: 4
+ chunk_size: ${model.params.wma_config.params.temporal_length}
+
+ first_stage_config:
+ target: unifolm_wma.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: True
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder
+ params:
+ freeze: True
+ layer: "penultimate"
+
+ img_cond_stage_config:
+ target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
+ params:
+ freeze: true
+
+ image_proj_stage_config:
+ target: unifolm_wma.modules.encoders.resampler.Resampler
+ params:
+ dim: 1024
+ depth: 4
+ dim_head: 64
+ heads: 12
+ num_queries: 16
+ embedding_dim: 1280
+ output_dim: 1024
+ ff_mult: 4
+ video_length: ${model.params.wma_config.params.temporal_length}
+
+ normalization_config:
+ input_shapes:
+ observation.state: ${model.params.wma_config.params.action_unet_config.params.input_dim}
+ input_normalization_modes:
+ observation.state: 'min_max'
+ output_shapes:
+ action: ${model.params.wma_config.params.action_unet_config.params.input_dim}
+ output_normalization_modes:
+ action: 'min_max'
+
+data:
+ target: unifolm_wma.utils.data.DataModuleFromConfig
+ params:
+ batch_size: 6
+ num_workers: 12
+ wrap: False
+ test:
+ target: unifolm_wma.data.wma_data.WMAData
+ params:
+ data_dir: '/path/to/the/dataset/directory/that/contains/the/meta/folder/of/the/testing/case/under/a/transitions/folder' # e.g., /path/to/unifolm-world-model-action/examples/world_model_interaction_prompts
+ video_length: ${model.params.wma_config.params.temporal_length}
+ frame_stride: 2
+ load_raw_resolution: True
+ resolution: [320, 512]
+ spatial_transform: resize_center_crop
+ crop_resolution: [320, 512]
+ random_fs: False
+ cond_robot_label_prob: 0.0
+ normalization_mode: 'min_max'
+ individual_normalization: True
+ n_obs_steps: ${model.params.n_obs_steps_imagen}
+ max_action_dim: ${model.params.agent_action_dim}
+ max_state_dim: ${model.params.agent_state_dim}
+ dataset_and_weights:
+ unitree_g1_pack_camera: 1.0
diff --git a/configs/inference/world_model_interaction.yaml b/configs/inference/world_model_interaction.yaml
new file mode 100644
index 0000000..2d69e09
--- /dev/null
+++ b/configs/inference/world_model_interaction.yaml
@@ -0,0 +1,244 @@
+model:
+ target: unifolm_wma.models.ddpms.LatentVisualDiffusion
+ params:
+ rescale_betas_zero_snr: True
+ parameterization: "v"
+ linear_start: 0.00085
+ linear_end: 0.012
+ num_timesteps_cond: 1
+ timesteps: 1000
+ first_stage_key: video
+ cond_stage_key: instruction
+ cond_stage_trainable: False
+ conditioning_key: hybrid
+ image_size: [40, 64]
+ channels: 4
+ scale_by_std: False
+ scale_factor: 0.18215
+ use_ema: False
+ uncond_type: 'empty_seq'
+ use_dynamic_rescale: true
+ base_scale: 0.7
+ fps_condition_type: 'fps'
+ perframe_ae: True
+ freeze_embedder: True
+ n_obs_steps_imagen: 2
+ n_obs_steps_acting: 2
+ agent_state_dim: 16
+ agent_action_dim: 16
+ decision_making_only: False
+
+ ###################### DP Related
+ input_pertub: 0.1
+ lr_scheduler: cosine
+ lr_warmup_steps: 2000
+ num_epochs: 30000
+ gradient_accumulate_every: 1
+ use_scheduler: True
+ dp_use_ema: True
+
+ dp_ema_config:
+ target: unifolm_wma.models.diffusion_head.ema_model.EMAModel
+ params:
+ update_after_step: 0
+ inv_gamma: 1.0
+ power: 0.75
+ min_value: 0.0
+ max_value: 0.9999
+
+ noise_scheduler_config:
+ target: diffusers.DDIMScheduler
+ params:
+ num_train_timesteps: 1000
+ beta_start: 0.0001
+ beta_end: 0.02
+ beta_schedule: squaredcos_cap_v2
+ clip_sample: True
+ set_alpha_to_one: True
+ steps_offset: 0
+ prediction_type: epsilon
+
+ dp_optimizer_config:
+ target: torch.optim.AdamW
+ params:
+ lr: 1.0e-4
+ betas: [0.95, 0.999]
+ eps: 1.0e-8
+ weight_decay: 1.0e-6
+
+ wma_config:
+ target: unifolm_wma.modules.networks.wma_model.WMAModel
+ params:
+ in_channels: 8
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions:
+ - 4
+ - 2
+ - 1
+ num_res_blocks: 2
+ channel_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ dropout: 0.1
+ num_head_channels: 64
+ transformer_depth: 1
+ context_dim: 1024
+ use_linear: true
+ use_checkpoint: True
+ temporal_conv: True
+ temporal_attention: True
+ temporal_selfatt_only: True
+ use_relative_position: False
+ use_causal_attention: False
+ temporal_length: 16
+ addition_attention: True
+ image_cross_attention: True
+ default_fs: 10
+ fs_condition: True
+ cross_attention_scale_learnable: False
+ n_obs_steps: ${model.params.n_obs_steps_imagen}
+ num_stem_token: 16
+ base_model_gen_only: False
+
+ unet_head_config:
+ target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D
+ params:
+ input_dim: ${model.params.agent_action_dim}
+ n_obs_steps: ${model.params.n_obs_steps_acting}
+ diffusion_step_embed_dim: 128
+ down_dims: [256, 512, 1024, 2048]
+ kernel_size: 5
+ n_groups: 8
+ cond_predict_scale: True
+ num_head_channels: ${model.params.wma_config.params.num_head_channels}
+ horizon: ${model.params.wma_config.params.temporal_length}
+ use_linear_attn: ${model.params.wma_config.params.use_linear}
+ use_linear_act_proj: True
+ act_proj_dim: 32
+ cond_cross_attention: False
+ context_dims: []
+ image_size: ${model.params.image_size}
+ imagen_cond_gradient: True
+ last_frame_only: False
+ use_imagen_mid_only: False
+ use_z_only: False
+
+ obs_encoder_config:
+ target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder
+ params:
+ rgb_model_config:
+ target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet
+ params:
+ name: resnet18
+ weights: null
+ resize_shape: null
+ crop_shape: null
+ random_crop: False
+ use_group_norm: True
+ share_rgb_model: False
+ imagenet_norm: True
+ use_spatial_softmax: True
+ spatial_softmax_kp: 128
+
+ ###################### Action Tokenization
+ stem_process_config:
+ target: unifolm_wma.modules.encoders.condition.SATokenProjector
+ params:
+ dim: 1024
+ depth: 1
+ dim_head: 64
+ heads: 16
+ num_queries: ${model.params.wma_config.params.num_stem_token}
+ output_dim: 1024
+ ff_mult: 4
+ chunk_size: ${model.params.wma_config.params.temporal_length}
+
+ first_stage_config:
+ target: unifolm_wma.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: True
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder
+ params:
+ freeze: True
+ layer: "penultimate"
+
+ img_cond_stage_config:
+ target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
+ params:
+ freeze: true
+
+ image_proj_stage_config:
+ target: unifolm_wma.modules.encoders.resampler.Resampler
+ params:
+ dim: 1024
+ depth: 4
+ dim_head: 64
+ heads: 12
+ num_queries: 16
+ embedding_dim: 1280
+ output_dim: 1024
+ ff_mult: 4
+ video_length: ${model.params.wma_config.params.temporal_length}
+
+ normalization_config:
+ input_shapes:
+ observation.state: ${model.params.wma_config.params.action_unet_config.params.input_dim}
+ input_normalization_modes:
+ observation.state: 'min_max'
+ output_shapes:
+ action: ${model.params.wma_config.params.action_unet_config.params.input_dim}
+ output_normalization_modes:
+ action: 'min_max'
+
+data:
+ target: unifolm_wma.utils.data.DataModuleFromConfig
+ params:
+ batch_size: 6
+ num_workers: 12
+ wrap: False
+ test:
+ target: unifolm_wma.data.wma_data.WMAData
+ params:
+ data_dir: '/home/dyz/unifolm-world-model-action/examples/world_model_interaction_prompts'
+ video_length: ${model.params.wma_config.params.temporal_length}
+ frame_stride: 2
+ load_raw_resolution: True
+ resolution: [320, 512]
+ spatial_transform: resize_center_crop
+ crop_resolution: [320, 512]
+ random_fs: False
+ cond_robot_label_prob: 0.0
+ normalization_mode: 'min_max'
+ individual_normalization: True
+ n_obs_steps: ${model.params.n_obs_steps_imagen}
+ max_action_dim: ${model.params.agent_action_dim}
+ max_state_dim: ${model.params.agent_state_dim}
+ dataset_and_weights:
+ unitree_z1_stackbox: 0.2
+ unitree_z1_dual_arm_stackbox: 0.2
+ unitree_z1_dual_arm_stackbox_v2: 0.2
+ unitree_z1_dual_arm_cleanup_pencils: 0.2
+ unitree_g1_pack_camera: 0.2
diff --git a/configs/train/config.yaml b/configs/train/config.yaml
new file mode 100644
index 0000000..9e7bb06
--- /dev/null
+++ b/configs/train/config.yaml
@@ -0,0 +1,287 @@
+model:
+ pretrained_checkpoint: /path/to/pretrained/checkpoint
+ base_learning_rate: 1.0e-05
+ scale_lr: False
+ target: unifolm_wma.models.ddpms.LatentVisualDiffusion
+ params:
+ rescale_betas_zero_snr: True
+ parameterization: "v"
+ linear_start: 0.00085
+ linear_end: 0.012
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: video
+ cond_stage_key: instruction
+ cond_stage_trainable: False
+ image_proj_model_trainable: True
+ conditioning_key: hybrid
+ image_size: [40, 64]
+ channels: 4
+ scale_by_std: False
+ scale_factor: 0.18215
+ use_ema: False
+ uncond_prob: 0.05
+ uncond_type: 'empty_seq'
+ rand_cond_frame: false
+ use_dynamic_rescale: true
+ base_scale: 0.7
+ fps_condition_type: 'fps'
+ perframe_ae: True
+ freeze_embedder: True
+ n_obs_steps_imagen: 2
+ n_obs_steps_acting: 2
+ agent_state_dim: 16
+ agent_action_dim: 16
+ decision_making_only: True
+
+ ###################### DP Related
+ input_pertub: 0.1
+ lr_scheduler: cosine
+ lr_warmup_steps: 2000
+ num_epochs: 60000
+ gradient_accumulate_every: 1
+ use_scheduler: True
+ dp_use_ema: True
+
+ dp_ema_config:
+ target: unifolm_wma.models.diffusion_head.ema_model.EMAModel
+ params:
+ update_after_step: 0
+ inv_gamma: 1.0
+ power: 0.75
+ min_value: 0.0
+ max_value: 0.9999
+
+ noise_scheduler_config:
+ target: diffusers.DDIMScheduler
+ params:
+ num_train_timesteps: 1000
+ beta_start: 0.0001
+ beta_end: 0.02
+ beta_schedule: squaredcos_cap_v2
+ clip_sample: True
+ set_alpha_to_one: True
+ steps_offset: 0
+ prediction_type: epsilon
+
+ dp_optimizer_config:
+ target: torch.optim.AdamW
+ params:
+ lr: 1.0e-4
+ betas: [0.95, 0.999]
+ eps: 1.0e-8
+ weight_decay: 1.0e-6
+
+ wma_config:
+ target: unifolm_wma.modules.networks.wma_model.WMAModel
+ params:
+ in_channels: 8
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions:
+ - 4
+ - 2
+ - 1
+ num_res_blocks: 2
+ channel_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ dropout: 0.1
+ num_head_channels: 64
+ transformer_depth: 1
+ context_dim: 1024
+ use_linear: true
+ use_checkpoint: True
+ temporal_conv: True
+ temporal_attention: True
+ temporal_selfatt_only: True
+ use_relative_position: False
+ use_causal_attention: False
+ temporal_length: 16
+ addition_attention: True
+ image_cross_attention: True
+ default_fs: 10
+ fs_condition: True
+ cross_attention_scale_learnable: False
+ n_obs_steps: ${model.params.n_obs_steps_imagen}
+ num_stem_token: 16
+ base_model_gen_only: False
+
+ unet_head_config:
+ target: unifolm_wma.models.diffusion_head.conditional_unet1d.ConditionalUnet1D
+ params:
+ input_dim: ${model.params.agent_action_dim}
+ n_obs_steps: ${model.params.n_obs_steps_acting}
+ diffusion_step_embed_dim: 128
+ down_dims: [256, 512, 1024, 2048]
+ kernel_size: 5
+ n_groups: 8
+ cond_predict_scale: True
+ num_head_channels: ${model.params.wma_config.params.num_head_channels}
+ horizon: ${model.params.wma_config.params.temporal_length}
+ use_linear_attn: ${model.params.wma_config.params.use_linear}
+ use_linear_act_proj: True
+ act_proj_dim: 32
+ cond_cross_attention: False
+ context_dims: []
+ image_size: ${model.params.image_size}
+ imagen_cond_gradient: True
+ last_frame_only: False
+ use_imagen_mid_only: False
+ use_z_only: False
+
+ obs_encoder_config:
+ target: unifolm_wma.models.diffusion_head.vision.multi_image_obs_encoder.MultiImageObsEncoder
+ params:
+ rgb_model_config:
+ target: unifolm_wma.models.diffusion_head.vision.model_getter.get_resnet
+ params:
+ name: resnet18
+ weights: null
+ resize_shape: null
+ crop_shape: null
+ random_crop: False
+ use_group_norm: True
+ share_rgb_model: False
+ imagenet_norm: True
+ use_spatial_softmax: True
+ spatial_softmax_kp: 128
+
+ ###################### Action Tokenization
+ stem_process_config:
+ target: unifolm_wma.modules.encoders.condition.SATokenProjector
+ params:
+ dim: 1024
+ depth: 1
+ dim_head: 64
+ heads: 16
+ num_queries: ${model.params.wma_config.params.num_stem_token}
+ output_dim: 1024
+ ff_mult: 4
+ chunk_size: ${model.params.wma_config.params.temporal_length}
+
+ first_stage_config:
+ target: unifolm_wma.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: True
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPEmbedder
+ params:
+ freeze: True
+ layer: "penultimate"
+
+ img_cond_stage_config:
+ target: unifolm_wma.modules.encoders.condition.FrozenOpenCLIPImageEmbedderV2
+ params:
+ freeze: true
+
+ image_proj_stage_config:
+ target: unifolm_wma.modules.encoders.resampler.Resampler
+ params:
+ dim: 1024
+ depth: 4
+ dim_head: 64
+ heads: 12
+ num_queries: 16
+ embedding_dim: 1280
+ output_dim: 1024
+ ff_mult: 4
+ video_length: ${model.params.wma_config.params.temporal_length}
+
+ normalization_config:
+ input_shapes:
+ observation.state: ${model.params.wma_config.params.unet_head_config.params.input_dim}
+ input_normalization_modes:
+ observation.state: 'min_max'
+ output_shapes:
+ action: ${model.params.wma_config.params.unet_head_config.params.input_dim}
+ output_normalization_modes:
+ action: 'min_max'
+
+data:
+ target: unifolm_wma.utils.data.DataModuleFromConfig
+ params:
+ batch_size: 8
+ num_workers: 12
+ wrap: False
+ train:
+ target: unifolm_wma.data.wma_data.WMAData
+ params:
+ data_dir: '/path/to/training/dataset/directory'
+ video_length: ${model.params.wma_config.params.temporal_length}
+ frame_stride: 2
+ load_raw_resolution: True
+ resolution: [320, 512]
+ spatial_transform: resize_center_crop
+ crop_resolution: [320, 512]
+ random_fs: False
+ cond_robot_label_prob: 0.0
+ normalization_mode: 'min_max'
+ individual_normalization: True
+ n_obs_steps: ${model.params.n_obs_steps_imagen}
+ max_action_dim: ${model.params.agent_action_dim}
+ max_state_dim: ${model.params.agent_state_dim}
+ dataset_and_weights:
+ unitree_z1_stackbox: 0.2
+ unitree_z1_dual_arm_stackbox: 0.2
+ unitree_z1_dual_arm_stackbox_v2: 0.2
+ unitree_z1_dual_arm_cleanup_pencils: 0.2
+ unitree_g1_pack_camera: 0.2
+
+lightning:
+ precision: 16
+ trainer:
+ benchmark: True
+ accumulate_grad_batches: 2
+ max_steps: 300000
+ log_every_n_steps: 50
+ val_check_interval: 1.0
+ gradient_clip_algorithm: 'norm'
+ gradient_clip_val: 0.5
+ enable_model_summary: False
+ callbacks:
+ model_checkpoint:
+ target: pytorch_lightning.callbacks.ModelCheckpoint
+ params:
+ every_n_train_steps: 1000
+ filename: "{epoch}-{step}"
+ save_weights_only: True
+ metrics_over_trainsteps_checkpoint:
+ target: pytorch_lightning.callbacks.ModelCheckpoint
+ params:
+ filename: '{epoch}-{step}'
+ save_weights_only: True
+ every_n_train_steps: 10000
+ batch_logger:
+ target: unifolm_wma.utils.callbacks.ImageLogger
+ params:
+ batch_frequency: 20000
+ to_local: False
+ max_images: 8
+ log_images_kwargs:
+ ddim_steps: 16
+ unconditional_guidance_scale: 1.0
+ timestep_spacing: uniform_trailing
+ guidance_rescale: 0.7
diff --git a/examples/base_model_prompts/0.png b/examples/base_model_prompts/0.png
new file mode 100644
index 0000000..468c434
Binary files /dev/null and b/examples/base_model_prompts/0.png differ
diff --git a/examples/base_model_prompts/1.png b/examples/base_model_prompts/1.png
new file mode 100644
index 0000000..899edfb
Binary files /dev/null and b/examples/base_model_prompts/1.png differ
diff --git a/examples/base_model_prompts/10.png b/examples/base_model_prompts/10.png
new file mode 100644
index 0000000..e195de4
Binary files /dev/null and b/examples/base_model_prompts/10.png differ
diff --git a/examples/base_model_prompts/11.png b/examples/base_model_prompts/11.png
new file mode 100644
index 0000000..fca3c47
Binary files /dev/null and b/examples/base_model_prompts/11.png differ
diff --git a/examples/base_model_prompts/12.png b/examples/base_model_prompts/12.png
new file mode 100644
index 0000000..53dda0e
Binary files /dev/null and b/examples/base_model_prompts/12.png differ
diff --git a/examples/base_model_prompts/13.png b/examples/base_model_prompts/13.png
new file mode 100644
index 0000000..cbf5623
Binary files /dev/null and b/examples/base_model_prompts/13.png differ
diff --git a/examples/base_model_prompts/14.png b/examples/base_model_prompts/14.png
new file mode 100644
index 0000000..317843c
Binary files /dev/null and b/examples/base_model_prompts/14.png differ
diff --git a/examples/base_model_prompts/15.png b/examples/base_model_prompts/15.png
new file mode 100644
index 0000000..3a1a4b2
Binary files /dev/null and b/examples/base_model_prompts/15.png differ
diff --git a/examples/base_model_prompts/2.png b/examples/base_model_prompts/2.png
new file mode 100644
index 0000000..b6e58fb
Binary files /dev/null and b/examples/base_model_prompts/2.png differ
diff --git a/examples/base_model_prompts/3.png b/examples/base_model_prompts/3.png
new file mode 100644
index 0000000..39611c4
Binary files /dev/null and b/examples/base_model_prompts/3.png differ
diff --git a/examples/base_model_prompts/4.png b/examples/base_model_prompts/4.png
new file mode 100644
index 0000000..8ec01a0
Binary files /dev/null and b/examples/base_model_prompts/4.png differ
diff --git a/examples/base_model_prompts/5.png b/examples/base_model_prompts/5.png
new file mode 100644
index 0000000..53a5a49
Binary files /dev/null and b/examples/base_model_prompts/5.png differ
diff --git a/examples/base_model_prompts/6.png b/examples/base_model_prompts/6.png
new file mode 100644
index 0000000..b3a6312
Binary files /dev/null and b/examples/base_model_prompts/6.png differ
diff --git a/examples/base_model_prompts/7.png b/examples/base_model_prompts/7.png
new file mode 100644
index 0000000..f4366f8
Binary files /dev/null and b/examples/base_model_prompts/7.png differ
diff --git a/examples/base_model_prompts/8.png b/examples/base_model_prompts/8.png
new file mode 100644
index 0000000..efaaa93
Binary files /dev/null and b/examples/base_model_prompts/8.png differ
diff --git a/examples/base_model_prompts/9.png b/examples/base_model_prompts/9.png
new file mode 100644
index 0000000..c86cda7
Binary files /dev/null and b/examples/base_model_prompts/9.png differ
diff --git a/examples/base_model_prompts/prompts.csv b/examples/base_model_prompts/prompts.csv
new file mode 100644
index 0000000..d6df917
--- /dev/null
+++ b/examples/base_model_prompts/prompts.csv
@@ -0,0 +1,17 @@
+videoid,instruction,fps,start_idx,fs,num_gen
+0,wash the pan,16.0,152.0,2.0,6.0
+1,pick up the blue cup and put it into the brown cup. ,5.0,0.0,2.0,4.0
+2,close top drawer,3.0,0.0,2.0,1.0
+3,Close the laptop.,10.0,0.0,2.0,4.0
+4,destack cube,5.0,0.0,2.0,3.0
+5,arrange plate and fork,20.0,40.0,2.0,4.0
+6,Place the lid on the teapot,15.0,30.0,1.0,4.0
+7,Pick up the green object and insert it.,10.0,0.0,2.0,4.0
+8,place the burger meat in the oven,10.0,0.0,2.0,2.0
+9,make a cup of coffee with the keurig machine,10.0,0.0,2.0,4.0
+10,assemble one_leg,10.0,0.0,2.0,7.0
+11,get the cloth and wipe up the spill under the wine glass,8.0,669.0,2.0,3.0
+12,palce dishes in the dish rack,10.0,0.0,2.0,4.0
+13,move redbull can near green can,3.0,3.0,2.0,1.0
+14,open the drawer,5.0,5.0,1.0,2.0
+15,sweep the green cloth to the left side of the table,5.0,0.0,2.0,3.0
diff --git a/examples/world_model_interaction_prompts/images/unitree_g1_pack_camera/0.png b/examples/world_model_interaction_prompts/images/unitree_g1_pack_camera/0.png
new file mode 100644
index 0000000..b295e63
Binary files /dev/null and b/examples/world_model_interaction_prompts/images/unitree_g1_pack_camera/0.png differ
diff --git a/examples/world_model_interaction_prompts/images/unitree_z1_dual_arm_cleanup_pencils/0.png b/examples/world_model_interaction_prompts/images/unitree_z1_dual_arm_cleanup_pencils/0.png
new file mode 100644
index 0000000..d919c6f
Binary files /dev/null and b/examples/world_model_interaction_prompts/images/unitree_z1_dual_arm_cleanup_pencils/0.png differ
diff --git a/examples/world_model_interaction_prompts/images/unitree_z1_dual_arm_stackbox/0.png b/examples/world_model_interaction_prompts/images/unitree_z1_dual_arm_stackbox/0.png
new file mode 100644
index 0000000..ebcc502
Binary files /dev/null and b/examples/world_model_interaction_prompts/images/unitree_z1_dual_arm_stackbox/0.png differ
diff --git a/examples/world_model_interaction_prompts/images/unitree_z1_dual_arm_stackbox_v2/0.png b/examples/world_model_interaction_prompts/images/unitree_z1_dual_arm_stackbox_v2/0.png
new file mode 100644
index 0000000..17008a2
Binary files /dev/null and b/examples/world_model_interaction_prompts/images/unitree_z1_dual_arm_stackbox_v2/0.png differ
diff --git a/examples/world_model_interaction_prompts/images/unitree_z1_stackbox/0.png b/examples/world_model_interaction_prompts/images/unitree_z1_stackbox/0.png
new file mode 100644
index 0000000..8ec19a0
Binary files /dev/null and b/examples/world_model_interaction_prompts/images/unitree_z1_stackbox/0.png differ
diff --git a/examples/world_model_interaction_prompts/transitions/unitree_g1_pack_camera/0.h5 b/examples/world_model_interaction_prompts/transitions/unitree_g1_pack_camera/0.h5
new file mode 100644
index 0000000..3505d37
Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_g1_pack_camera/0.h5 differ
diff --git a/examples/world_model_interaction_prompts/transitions/unitree_g1_pack_camera/meta_data/stats.safetensors b/examples/world_model_interaction_prompts/transitions/unitree_g1_pack_camera/meta_data/stats.safetensors
new file mode 100644
index 0000000..3dafbcb
Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_g1_pack_camera/meta_data/stats.safetensors differ
diff --git a/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_cleanup_pencils/0.h5 b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_cleanup_pencils/0.h5
new file mode 100644
index 0000000..7835773
Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_cleanup_pencils/0.h5 differ
diff --git a/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_cleanup_pencils/meta_data/stats.safetensors b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_cleanup_pencils/meta_data/stats.safetensors
new file mode 100644
index 0000000..e3194ab
Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_cleanup_pencils/meta_data/stats.safetensors differ
diff --git a/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox/0.h5 b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox/0.h5
new file mode 100755
index 0000000..28184eb
Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox/0.h5 differ
diff --git a/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox/meta_data/episode_data_index.safetensors b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox/meta_data/episode_data_index.safetensors
new file mode 100755
index 0000000..2bfc77f
Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox/meta_data/episode_data_index.safetensors differ
diff --git a/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox/meta_data/stats.safetensors b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox/meta_data/stats.safetensors
new file mode 100755
index 0000000..fa7fd40
Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox/meta_data/stats.safetensors differ
diff --git a/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox_v2/0.h5 b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox_v2/0.h5
new file mode 100644
index 0000000..292bcfa
Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox_v2/0.h5 differ
diff --git a/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox_v2/meta_data/stats.safetensors b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox_v2/meta_data/stats.safetensors
new file mode 100644
index 0000000..6ef7a6c
Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_z1_dual_arm_stackbox_v2/meta_data/stats.safetensors differ
diff --git a/examples/world_model_interaction_prompts/transitions/unitree_z1_stackbox/0.h5 b/examples/world_model_interaction_prompts/transitions/unitree_z1_stackbox/0.h5
new file mode 100755
index 0000000..9cdae1b
Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_z1_stackbox/0.h5 differ
diff --git a/examples/world_model_interaction_prompts/transitions/unitree_z1_stackbox/meta_data/episode_data_index.safetensors b/examples/world_model_interaction_prompts/transitions/unitree_z1_stackbox/meta_data/episode_data_index.safetensors
new file mode 100644
index 0000000..62bde44
Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_z1_stackbox/meta_data/episode_data_index.safetensors differ
diff --git a/examples/world_model_interaction_prompts/transitions/unitree_z1_stackbox/meta_data/stats.safetensors b/examples/world_model_interaction_prompts/transitions/unitree_z1_stackbox/meta_data/stats.safetensors
new file mode 100644
index 0000000..1918ea0
Binary files /dev/null and b/examples/world_model_interaction_prompts/transitions/unitree_z1_stackbox/meta_data/stats.safetensors differ
diff --git a/examples/world_model_interaction_prompts/unitree_g1_pack_camera.csv b/examples/world_model_interaction_prompts/unitree_g1_pack_camera.csv
new file mode 100644
index 0000000..0883835
--- /dev/null
+++ b/examples/world_model_interaction_prompts/unitree_g1_pack_camera.csv
@@ -0,0 +1,2 @@
+videoid,contentUrl,duration,data_dir,instruction,dynamic_confidence,dynamic_wording,dynamic_source_category,embodiment,fps
+0,x,x,unitree_g1_pack_camera,Pack black camera into box.,x,x,x,Unitree G1 Robot with Gripper,30
diff --git a/examples/world_model_interaction_prompts/unitree_z1_dual_arm_cleanup_pencils.csv b/examples/world_model_interaction_prompts/unitree_z1_dual_arm_cleanup_pencils.csv
new file mode 100644
index 0000000..ca39924
--- /dev/null
+++ b/examples/world_model_interaction_prompts/unitree_z1_dual_arm_cleanup_pencils.csv
@@ -0,0 +1,2 @@
+videoid,contentUrl,duration,data_dir,instruction,dynamic_confidence,dynamic_wording,dynamic_source_category,embodiment,fps
+0,x,x,unitree_z1_dual_arm_cleanup_pencils,clean up eraser and pencils,x,x,x,Unitree Z1 Robot Dual-Arm,30
diff --git a/examples/world_model_interaction_prompts/unitree_z1_dual_arm_stackbox.csv b/examples/world_model_interaction_prompts/unitree_z1_dual_arm_stackbox.csv
new file mode 100644
index 0000000..08b30a5
--- /dev/null
+++ b/examples/world_model_interaction_prompts/unitree_z1_dual_arm_stackbox.csv
@@ -0,0 +1,2 @@
+videoid,contentUrl,duration,data_dir,instruction,dynamic_confidence,dynamic_wording,dynamic_source_category,embodiment,fps
+0,x,x,unitree_z1_dual_arm_stackbox,"Stack the blocks in the rectangular block: red at the bottom, yellow in the middle, green on top.",x,x,x,Unitree Z1 Robot Dual-Arm,30
diff --git a/examples/world_model_interaction_prompts/unitree_z1_dual_arm_stackbox_v2.csv b/examples/world_model_interaction_prompts/unitree_z1_dual_arm_stackbox_v2.csv
new file mode 100644
index 0000000..2581ded
--- /dev/null
+++ b/examples/world_model_interaction_prompts/unitree_z1_dual_arm_stackbox_v2.csv
@@ -0,0 +1,2 @@
+videoid,contentUrl,duration,data_dir,instruction,dynamic_confidence,dynamic_wording,dynamic_source_category,embodiment,fps
+0,x,x,unitree_z1_dual_arm_stackbox_v2,"Stack the blocks in the rectangular block: red at the bottom, yellow in the middle, green on top",x,x,x,Unitree Z1 Robot Dual-Arm,30
diff --git a/examples/world_model_interaction_prompts/unitree_z1_stackbox.csv b/examples/world_model_interaction_prompts/unitree_z1_stackbox.csv
new file mode 100755
index 0000000..a4505c1
--- /dev/null
+++ b/examples/world_model_interaction_prompts/unitree_z1_stackbox.csv
@@ -0,0 +1,2 @@
+videoid,contentUrl,duration,data_dir,instruction,dynamic_confidence,dynamic_wording,dynamic_source_category,embodiment,fps
+0,x,x,unitree_z1_stackbox,"Stack the blocks in the rectangular block: red at the bottom, yellow in the middle, green on top.",x,x,x,Unitree Z1 Robot Arm,30
diff --git a/external/dlimp b/external/dlimp
new file mode 160000
index 0000000..5edaa46
--- /dev/null
+++ b/external/dlimp
@@ -0,0 +1 @@
+Subproject commit 5edaa4691567873d495633f2708982b42edf1972
diff --git a/model_architecture_analysis.md b/model_architecture_analysis.md
new file mode 100644
index 0000000..4114c2c
--- /dev/null
+++ b/model_architecture_analysis.md
@@ -0,0 +1,1204 @@
+# UnifoLM World Model Action - 模型架构详细分析
+
+## 目录
+1. [整体架构概览](#整体架构概览)
+2. [推理流程分析](#推理流程分析)
+3. [核心组件详解](#核心组件详解)
+4. [性能瓶颈分析](#性能瓶颈分析)
+5. [内核融合优化建议](#内核融合优化建议)
+
+---
+
+## 1. 整体架构概览
+
+### 1.1 模型层次结构
+
+```
+DDPM (顶层模型)
+├── DiffusionWrapper (条件包装器)
+│ └── UNet3D (核心扩散模型)
+│ ├── 时间嵌入 (Time Embedding)
+│ ├── 下采样块 (Downsampling Blocks)
+│ ├── 中间块 (Middle Blocks)
+│ └── 上采样块 (Upsampling Blocks)
+├── VAE (变分自编码器)
+│ ├── Encoder (编码器)
+│ └── Decoder (解码器)
+├── CLIP Image Encoder (图像编码器)
+├── Text Encoder (文本编码器)
+├── State Projector (状态投影器)
+└── Action Projector (动作投影器)
+```
+
+### 1.2 推理阶段数据流
+
+```
+输入观测 (Observation)
+ ↓
+[1] 条件编码阶段
+ ├── 图像 → CLIP Encoder → Image Embedding
+ ├── 图像 → VAE Encoder → Latent Condition
+ ├── 文本 → Text Encoder → Text Embedding
+ ├── 状态 → State Projector → State Embedding
+ └── 动作 → Action Projector → Action Embedding
+ ↓
+[2] DDIM采样阶段 (n步迭代)
+ ├── 初始化噪声 x_T
+ └── For step in [0, n]:
+ ├── 模型前向传播 (UNet3D)
+ │ ├── 时间步嵌入
+ │ ├── 条件注入 (CrossAttention)
+ │ └── 噪声预测
+ ├── DDIM更新公式
+ └── x_{t-1} = f(x_t, noise_pred)
+ ↓
+[3] VAE解码阶段
+ └── Latent → VAE Decoder → 视频帧
+```
+
+---
+
+## 2. 推理流程分析
+
+### 2.1 阶段1: 生成动作 (sim_mode=False)
+
+**目的**: 根据观测和指令生成动作序列
+
+**输入**:
+- `observation.images.top`: 历史图像观测 [B, C, T_obs, H, W]
+- `observation.state`: 历史状态 [B, T_obs, state_dim]
+- `action`: 历史动作 [B, T_action, action_dim]
+- `instruction`: 文本指令
+
+**输出**:
+- `pred_videos`: 预测视频 [B, C, T, H, W]
+- `pred_actions`: 预测动作序列 [B, T, action_dim]
+
+**关键特点**:
+- 动作条件被置零 (`cond_action_emb = torch.zeros_like(...)`)
+- 使用文本指令作为主要引导
+
+### 2.2 阶段2: 世界模型交互 (sim_mode=True)
+
+**目的**: 根据动作预测未来观测
+
+**输入**:
+- `observation.images.top`: 当前图像
+- `observation.state`: 当前状态
+- `action`: 阶段1生成的动作序列
+
+**输出**:
+- `pred_videos`: 预测的未来视频帧
+- `pred_states`: 预测的未来状态
+
+**关键特点**:
+- 不使用文本指令 (`text_input=False`)
+- 动作条件被实际使用
+
+---
+
+## 3. 核心组件详解
+
+### 3.1 DDIM采样器 (DDIMSampler)
+
+**代码位置**: [src/unifolm_wma/models/samplers/ddim.py](src/unifolm_wma/models/samplers/ddim.py)
+
+**核心方法**: `ddim_sampling()` (第168-300行)
+
+**实际代码结构**:
+```python
+def ddim_sampling(self, cond, shape, x_T=None, ddim_steps=50, ...):
+ # 初始化
+ timesteps = self.ddim_timesteps[:ddim_steps]
+ x = x_T if x_T is not None else torch.randn(shape, device=device)
+
+ # 主循环
+ for i, step in enumerate(iterator):
+ # 获取时间步
+ index = total_steps - i - 1
+ ts = torch.full((b,), step, device=device, dtype=torch.long)
+
+ # 模型前向传播 (核心瓶颈)
+ outs = self.p_sample_ddim(x, cond, ts, index=index, ...)
+ x, pred_x0 = outs
+
+ return x
+```
+
+**性能数据** (来自profiling):
+- 单步去噪总耗时: 10.71s - 11.06s (22次调用)
+- 模型前向: 325.30s (660次调用, 平均0.493s/次)
+- DDIM更新: 0.21s (660次调用, 平均0.0003s/次)
+
+### 3.2 DiffusionWrapper (条件路由器)
+
+**代码位置**: [src/unifolm_wma/models/ddpms.py:2413-2524](src/unifolm_wma/models/ddpms.py)
+
+**作用**: 将输入和条件路由到内部扩散模型
+
+**实际代码** (第2469-2479行):
+```python
+elif self.conditioning_key == 'hybrid':
+ xc = torch.cat([x] + c_concat, dim=1) # 拼接latent条件
+ cc = torch.cat(c_crossattn, 1) # 拼接cross-attention条件
+ cc_action = c_crossattn_action
+ out = self.diffusion_model(xc, x_action, x_state, t,
+ context=cc, context_action=cc_action, **kwargs)
+```
+
+**条件类型**:
+1. **c_concat**: 通道拼接条件 (VAE编码的图像)
+2. **c_crossattn**: 交叉注意力条件 (文本、图像、状态、动作embedding)
+3. **c_crossattn_action**: 动作头专用条件
+
+### 3.3 WMAModel (核心扩散模型)
+
+**代码位置**: [src/unifolm_wma/modules/networks/wma_model.py:326-849](src/unifolm_wma/modules/networks/wma_model.py)
+
+**配置文件**: [configs/inference/world_model_interaction.yaml:69-104](configs/inference/world_model_interaction.yaml)
+
+**实际配置参数**:
+```yaml
+in_channels: 8 # 输入通道 (4 latent + 4 VAE条件)
+out_channels: 4 # 输出通道
+model_channels: 320 # 基础通道数
+channel_mult: [1, 2, 4, 4] # 通道倍增: [320, 640, 1280, 1280]
+num_res_blocks: 2 # 每个分辨率2个ResBlock
+attention_resolutions: [4, 2, 1] # 在这些分辨率启用注意力
+num_head_channels: 64 # 每个注意力头64通道
+transformer_depth: 1 # Transformer深度
+context_dim: 1024 # 交叉注意力上下文维度
+temporal_length: 16 # 时间序列长度
+```
+
+**架构层次** (详见附录A.1):
+- 4个下采样阶段 (每阶段2个ResBlock + Attention)
+- 1个中间块 (2个ResBlock + Attention)
+- 3个上采样阶段 (每阶段2个ResBlock + Attention)
+- 总计: 16个ResBlock + 32个Transformer
+
+### 3.4 VAE (变分自编码器)
+
+**代码位置**: [src/unifolm_wma/models/autoencoder.py](src/unifolm_wma/models/autoencoder.py)
+
+**配置文件**: [configs/inference/world_model_interaction.yaml:159-180](configs/inference/world_model_interaction.yaml)
+
+**实际配置参数**:
+```yaml
+AutoencoderKL:
+ embed_dim: 4 # Latent维度
+ z_channels: 4 # Latent通道数
+ in_channels: 3 # RGB输入
+ out_ch: 3 # RGB输出
+ ch: 128 # 基础通道数
+ ch_mult: [1, 2, 4, 4] # 通道倍增: [128, 256, 512, 512]
+ num_res_blocks: 2 # 每层2个ResBlock
+ attn_resolutions: [] # VAE中不使用注意力
+```
+
+**编码/解码过程**:
+```python
+# 编码: [B, 3, 320, 512] → [B, 4, 40, 64] (8×8下采样)
+z = model.encode_first_stage(img)
+
+# 解码: [B, 4, 40, 64] → [B, 3, 320, 512]
+video = model.decode_first_stage(samples)
+```
+
+**性能数据**:
+- VAE编码: 1.03s (22次, 平均0.047s/次)
+- VAE解码: 15.53s (22次, 平均0.706s/次)
+- 压缩比: 8×8 = 64倍空间压缩
+
+**详细架构**: 见附录A.4
+
+### 3.5 条件编码器
+
+#### 3.5.1 CLIP图像编码器
+
+**代码位置**: [src/unifolm_wma/modules/encoders/condition.py](src/unifolm_wma/modules/encoders/condition.py) - `FrozenOpenCLIPImageEmbedderV2`
+
+**配置文件**: [configs/inference/world_model_interaction.yaml:188-204](configs/inference/world_model_interaction.yaml)
+
+**实际配置**:
+```yaml
+FrozenOpenCLIPImageEmbedderV2:
+ freeze: true
+ # 使用OpenCLIP ViT-H/14
+ # 输出: [B, 1280]
+
+Resampler (图像投影器):
+ dim: 1024 # 输出维度
+ depth: 4 # Transformer深度
+ heads: 12 # 12个注意力头
+ num_queries: 16 # 16个查询token
+ embedding_dim: 1280 # CLIP输出维度
+```
+
+**数据流**: 图像 [B, 3, H, W] → CLIP → [B, 1280] → Resampler → [B, 16, 1024]
+
+**性能**: 0.71s (22次, 平均0.032s/次)
+
+#### 3.5.2 文本编码器
+
+**代码位置**: [src/unifolm_wma/modules/encoders/condition.py](src/unifolm_wma/modules/encoders/condition.py) - `FrozenOpenCLIPEmbedder`
+
+**配置文件**: [configs/inference/world_model_interaction.yaml:182-186](configs/inference/world_model_interaction.yaml)
+
+**实际配置**:
+```yaml
+FrozenOpenCLIPEmbedder:
+ freeze: True
+ layer: "penultimate" # 使用倒数第二层
+ # 输出: [B, seq_len, 1024]
+```
+
+**性能**: 0.13s (22次, 平均0.006s/次)
+
+#### 3.5.3 状态投影器
+
+**代码位置**: [src/unifolm_wma/models/ddpms.py:2014-2026](src/unifolm_wma/models/ddpms.py) - `MLPProjector`
+
+**MLPProjector实现** (src/unifolm_wma/utils/projector.py:14-37):
+```python
+class MLPProjector(nn.Module):
+ def __init__(self, input_dim: int, output_dim: int, mlp_type: str = "gelu-mlp"):
+ if mlp_type == "gelu-mlp":
+ self.projector = nn.Sequential(
+ nn.Linear(input_dim, output_dim, bias=True),
+ nn.GELU(approximate='tanh'),
+ nn.Linear(output_dim, output_dim, bias=True),
+ )
+```
+
+**数据流**: 状态 [B, T_obs, 16] → MLPProjector → [B, T_obs, 1024] + agent_state_pos_emb
+
+**性能**: 0.006s (22次, 平均0.0003s/次)
+
+#### 3.5.4 动作投影器
+
+**代码位置**: [src/unifolm_wma/models/ddpms.py:2020-2024](src/unifolm_wma/models/ddpms.py) - `MLPProjector`
+
+**数据流**: 动作 [B, T_action, 16] → MLPProjector → [B, T_action, 1024] + agent_action_pos_emb
+
+**位置嵌入定义**:
+```python
+# ddpms.py:2023-2026
+self.agent_action_pos_emb = nn.Parameter(torch.randn(1, 16, 1024))
+self.agent_state_pos_emb = nn.Parameter(torch.randn(1, n_obs_steps, 1024))
+```
+
+**性能**: 0.003s (22次, 平均0.0001s/次)
+
+---
+
+## 4. 性能瓶颈分析
+
+### 4.1 时间分布 (总计412.39s)
+
+根据性能分析数据,时间分布如下:
+
+| 阶段 | 总耗时 | 占比 | 说明 |
+|------|--------|------|------|
+| 阶段1: 生成动作 | 171.52s | 41.6% | DDIM采样30步 |
+| 阶段2: 世界模型交互 | 171.65s | 41.6% | DDIM采样30步 |
+| 模型加载 | 47.56s | 11.5% | 一次性开销 |
+| 保存视频 | 13.91s | 3.4% | I/O操作 |
+| 保存完整视频 | 7.22s | 1.8% | I/O操作 |
+| 数据集加载 | 0.51s | 0.1% | 一次性开销 |
+
+### 4.2 DDIM采样详细分析
+
+**DDIM采样是绝对瓶颈,占总时间的94.9%**
+
+```
+DDIM采样总耗时: 325.74s
+├── 模型前向传播: 325.30s (99.86%) ← 核心瓶颈
+├── DDIM更新公式: 0.21s (0.06%)
+└── Action/State调度: 0.13s (0.04%)
+```
+
+**每步耗时分析**:
+- 30个去噪步骤,每步平均耗时: 10.86s
+- 每步调用模型前向2次 (阶段1和阶段2各1次)
+- 每次前向传播: ~0.493s
+
+### 4.3 瓶颈总结
+
+**关键发现**:
+1. **模型前向传播占99.86%的DDIM时间** - 这是优化的核心目标
+2. VAE解码占4.5%总时间 - 次要优化目标
+3. 其他操作(条件编码、DDIM更新)耗时可忽略
+
+---
+
+## 5. 内核融合优化建议
+
+### 5.1 优化策略概览
+
+基于性能分析,优化应聚焦于:
+1. **UNet3D模型前向传播** (最高优先级)
+2. **VAE解码器** (次要优先级)
+3. **批处理和并行化** (辅助优化)
+
+### 5.2 WMAModel内核融合机会
+
+#### 5.2.1 时间步嵌入融合
+
+**代码位置**: [src/unifolm_wma/utils/diffusion.py](src/unifolm_wma/utils/diffusion.py) - `timestep_embedding()`
+
+**当前实现** (实际代码):
+```python
+# 1. 正弦位置编码
+def timestep_embedding(timesteps, dim, max_period=10000):
+ half = dim // 2
+ freqs = torch.exp(-math.log(max_period) * torch.arange(0, half) / half)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ return embedding
+
+# 2. 时间嵌入网络 (在WMAModel.__init__中)
+self.time_embed = nn.Sequential(
+ nn.Linear(model_channels, time_embed_dim), # 320 → 1280
+ nn.SiLU(),
+ nn.Linear(time_embed_dim, time_embed_dim), # 1280 → 1280
+)
+```
+
+**融合机会**:
+- `Linear + SiLU + Linear` 可融合为单个kernel
+- 正弦编码计算可与第一个Linear融合
+
+**预期收益**: 减少2-3次kernel启动开销
+
+#### 5.2.2 ResBlock内核融合
+
+**代码位置**: [src/unifolm_wma/modules/networks/wma_model.py:130-263](src/unifolm_wma/modules/networks/wma_model.py) - `class ResBlock`
+
+**当前实现** (实际代码):
+```python
+# in_layers: GroupNorm + SiLU + Conv
+self.in_layers = nn.Sequential(
+ normalization(channels), # GroupNorm
+ nn.SiLU(),
+ conv_nd(dims, channels, out_channels, 3, padding=1)
+)
+
+# emb_layers: SiLU + Linear
+self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(emb_channels, out_channels)
+)
+
+# out_layers: GroupNorm + SiLU + Dropout + Conv
+self.out_layers = nn.Sequential(
+ normalization(out_channels), # GroupNorm
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(nn.Conv2d(out_channels, out_channels, 3, padding=1))
+)
+```
+
+**融合机会**:
+1. `GroupNorm + SiLU` 可融合 (in_layers和out_layers各一次)
+2. `emb_layers` 的 `SiLU + Linear` 可融合
+3. 残差加法可与下一层的GroupNorm融合
+
+**实际瓶颈**: 16个ResBlock × 30步 × 2次 = **960次ResBlock调用**
+
+**预期收益**: 每个ResBlock节省50-60%的kernel启动开销
+
+#### 5.2.3 注意力机制优化
+
+**代码位置**: [src/unifolm_wma/modules/attention.py](src/unifolm_wma/modules/attention.py)
+
+**实际配置**:
+- SpatialTransformer: 空间维度注意力
+- TemporalTransformer: 时间维度注意力
+- 总计: 32个Transformer × 30步 × 2次 = **1920次注意力调用**
+
+**优化方案**:
+使用 PyTorch 内置的 Flash Attention:
+```python
+from torch.nn.functional import scaled_dot_product_attention
+
+# 替换标准注意力计算
+out = scaled_dot_product_attention(Q, K, V, is_causal=False)
+```
+
+**预期收益**: 注意力层加速2-3倍,整体加速30-40%
+
+### 5.3 VAE解码器优化
+
+**代码位置**: [src/unifolm_wma/models/autoencoder.py](src/unifolm_wma/models/autoencoder.py)
+
+**当前性能**: 15.53s (22次调用, 平均0.706s/次)
+
+**优化方案**:
+1. **混合精度**: 使用FP16进行解码
+ ```python
+ with torch.cuda.amp.autocast():
+ video = vae.decode(latent)
+ ```
+
+2. **批处理优化**: 确保VAE解码使用批处理而非逐帧
+
+**预期收益**: 加速20-30%
+
+### 5.4 实施建议
+
+#### 5.4.1 使用 torch.compile() (最简单)
+
+**代码位置**: [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py)
+
+**实际实施位置**:
+在模型加载后添加:
+
+```python
+# 在模型加载并移动到GPU后添加
+config = OmegaConf.load(args.config)
+model = instantiate_from_config(config.model)
+model = load_model_checkpoint(model, args.ckpt_path)
+model.eval()
+model = model.cuda()
+
+# 添加 torch.compile() 优化
+model.model.diffusion_model = torch.compile(
+ model.model.diffusion_model,
+ mode='max-autotune', # 或 'reduce-overhead'
+ fullgraph=True
+)
+```
+
+**优点**:
+- 无需修改模型代码
+- 自动融合操作
+- 支持动态shape
+
+**预期收益**: 20-40%加速
+
+#### 5.4.2 使用 Flash Attention
+
+**代码位置**: [src/unifolm_wma/modules/attention.py](src/unifolm_wma/modules/attention.py)
+
+**当前实现分析**:
+代码已经支持 xformers (`xformers.ops.memory_efficient_attention`)。当 xformers 可用时,`CrossAttention` 类会自动使用 `efficient_forward` 方法:
+
+```python
+# attention.py:90-91
+if XFORMERS_IS_AVAILBLE and temporal_length is None:
+ self.forward = self.efficient_forward
+```
+
+**进一步优化方案**:
+如果 xformers 不可用,可以使用 PyTorch 内置的 Flash Attention:
+```python
+from torch.nn.functional import scaled_dot_product_attention
+out = scaled_dot_product_attention(q, k, v, is_causal=False)
+```
+
+**预期收益**: 注意力层加速2-3倍
+
+
+#### 5.4.3 混合精度推理
+
+**代码位置**: [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py)
+
+**实际实施位置**:
+在推理调用处添加混合精度上下文:
+
+```python
+# 在 image_guided_synthesis_sim_mode 调用处添加
+with torch.cuda.amp.autocast():
+ pred_videos, pred_actions, pred_states = image_guided_synthesis_sim_mode(
+ model, sample['instruction'], observation, noise_shape,
+ action_cond_step=args.exe_steps,
+ ddim_steps=args.ddim_steps, ddim_eta=args.ddim_eta,
+ unconditional_guidance_scale=args.unconditional_guidance_scale,
+ fs=model_input_fs, timestep_spacing=args.timestep_spacing,
+ guidance_rescale=args.guidance_rescale, sim_mode=False)
+```
+
+**注意事项**:
+- 模型会自动在FP16和FP32之间切换
+- 某些操作(如LayerNorm)会自动保持FP32精度
+- 无需手动转换模型权重
+
+**预期收益**: 30-50%加速 + 减少50%显存
+
+### 5.5 优化路线图
+
+#### 阶段1: 快速优化
+
+**目标**: 获得20-40%加速,无需修改模型代码
+
+**实施步骤**:
+1. 启用 `torch.compile()` - 在模型加载后添加
+2. 启用 `torch.backends.cudnn.benchmark = True` - 在推理开始前设置
+3. 使用混合精度推理 (FP16) - 在推理调用处添加
+
+**实施代码**:
+```python
+# 在推理函数开始处添加
+torch.backends.cudnn.benchmark = True
+
+# 在模型加载后添加 torch.compile()
+model = model.cuda()
+model.model.diffusion_model = torch.compile(
+ model.model.diffusion_model,
+ mode='max-autotune'
+)
+
+# 在推理循环中使用混合精度
+with torch.cuda.amp.autocast():
+ pred_videos, pred_actions, pred_states = image_guided_synthesis_sim_mode(...)
+```
+
+#### 阶段2: 中级优化
+
+**目标**: 获得50-70%加速
+
+**实施步骤**:
+1. 确保 xformers 已安装并启用 - 检查 `XFORMERS_IS_AVAILBLE` 标志
+2. 优化VAE解码器批处理 - 检查 [src/unifolm_wma/models/autoencoder.py](src/unifolm_wma/models/autoencoder.py) 中的 `decode()` 方法
+3. 分析并优化内存访问模式 - 使用 `torch.cuda.memory_stats()` 分析
+
+**关键修改点**:
+- 确认 xformers 已正确安装: `pip install xformers`
+- 在 `CrossAttention` 类中,当 xformers 可用时会自动使用 `efficient_forward`
+- 确保VAE解码使用批处理而非逐帧处理
+
+#### 阶段3: 深度优化
+
+**目标**: 获得2-3倍加速
+
+**实施步骤**:
+1. 自定义CUDA kernel融合关键操作
+ - 融合 GroupNorm + SiLU + Conv (在 [src/unifolm_wma/modules/networks/wma_model.py:130-263](src/unifolm_wma/modules/networks/wma_model.py) 的 ResBlock 中)
+2. 优化卷积操作
+ - 分析 Conv2D 操作的性能 (模型实际使用 Conv2D 而非 Conv3D)
+3. 优化数据加载和预处理pipeline
+ - 检查 [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py) 的数据准备部分
+
+**需要的技能**:
+- CUDA编程
+- PyTorch C++扩展
+- 性能分析工具 (Nsight Systems, nvprof)
+
+
+### 5.6 预期总体收益
+
+基于以上优化策略和实际代码分析,预期性能提升:
+
+| 优化阶段 | 预期加速比 | 实施难度 | 主要修改文件 |
+|---------|-----------|---------|-------------|
+| 阶段1: 快速优化 | 1.2-1.4x | 低 | [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py) |
+| 阶段2: 中级优化 | 1.5-1.7x | 中 | [src/unifolm_wma/modules/attention.py](src/unifolm_wma/modules/attention.py), [src/unifolm_wma/models/autoencoder.py](src/unifolm_wma/models/autoencoder.py) |
+| 阶段3: 深度优化 | 2.0-3.0x | 高 | [src/unifolm_wma/modules/networks/wma_model.py](src/unifolm_wma/modules/networks/wma_model.py) + 自定义CUDA kernel |
+
+**总体目标**: 通过系统性优化实现 2-3倍加速
+
+---
+
+## 6. 关键代码位置索引
+
+为方便内核融合实施,以下是关键代码位置:
+
+### 6.1 核心模型文件
+
+| 组件 | 文件路径 | 关键类/函数 |
+|------|---------|-----------|
+| DDPM主模型 | `src/unifolm_wma/models/ddpms.py` | `class DDPM` |
+| 条件包装器 | `src/unifolm_wma/models/ddpms.py:2413` | `class DiffusionWrapper` |
+| DDIM采样器 | `src/unifolm_wma/models/samplers/ddim.py` | `class DDIMSampler` |
+| VAE编解码 | `src/unifolm_wma/models/autoencoder.py` | `encode_first_stage`, `decode_first_stage` |
+
+### 6.2 推理脚本
+
+| 文件 | 说明 |
+|------|------|
+| `scripts/evaluation/world_model_interaction.py` | 推理脚本 |
+
+### 6.3 配置文件
+
+| 文件 | 说明 |
+|------|------|
+| `configs/inference/world_model_interaction.yaml` | 推理配置 |
+| `unitree_g1_pack_camera/case1/run_world_model_interaction.sh` | 运行脚本 |
+
+
+---
+
+## 7. 下一步行动建议
+
+### 7.1 立即可执行的优化
+
+**最小改动,最大收益**:
+
+1. **启用 torch.compile()**
+ - **代码位置**: [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py)
+ - **修改**: 在模型加载后添加
+ ```python
+ # 在模型加载并移动到GPU后添加
+ model.model.diffusion_model = torch.compile(
+ model.model.diffusion_model,
+ mode='max-autotune'
+ )
+ ```
+
+2. **启用 cuDNN benchmark**
+ - **代码位置**: [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py)
+ - **修改**: 在推理函数开始处添加
+ ```python
+ torch.backends.cudnn.benchmark = True
+ ```
+
+3. **混合精度推理**
+ - **代码位置**: [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py)
+ - **修改**: 在 `image_guided_synthesis_sim_mode` 调用处添加
+ ```python
+ with torch.cuda.amp.autocast():
+ pred_videos, pred_actions, pred_states = image_guided_synthesis_sim_mode(...)
+ ```
+
+**预期收益**: 20-40%加速,无风险
+
+
+### 7.2 需要深入探索的部分
+
+为了更精确的优化,建议进一步分析:
+
+1. **注意力层的具体实现**
+ - **代码位置**: [src/unifolm_wma/modules/attention.py](src/unifolm_wma/modules/attention.py)
+ - **分析目标**:
+ - `CrossAttention` 类 (第48-398行) - 核心注意力实现
+ - `BasicTransformerBlock` 类 (第400-469行) - Transformer块
+ - 确认 xformers 是否已启用 (`XFORMERS_IS_AVAILBLE` 标志)
+
+2. **ResBlock的详细结构**
+ - **代码位置**: [src/unifolm_wma/modules/networks/wma_model.py:130-263](src/unifolm_wma/modules/networks/wma_model.py)
+ - **分析目标**:
+ - 确认 GroupNorm + SiLU + Conv 的调用顺序
+ - 识别可以融合的操作序列
+ - 评估自定义 CUDA kernel 的可行性
+
+3. **内存瓶颈分析**
+ - **分析工具**: 使用 `torch.cuda.memory_stats()` 和 `torch.profiler`
+ - **分析位置**: [scripts/evaluation/world_model_interaction.py](scripts/evaluation/world_model_interaction.py)
+ - **分析目标**:
+ - 识别内存拷贝热点
+ - 优化中间张量的生命周期
+ - 减少不必要的内存分配
+
+4. **计算瓶颈定位**
+ - **分析工具**: Nsight Systems 或 PyTorch Profiler
+ - **分析目标**:
+ - 识别 kernel 启动开销
+ - 分析 GPU 利用率
+ - 找到计算密集型操作
+
+---
+
+## 8. 参考资料
+
+### 8.1 优化技术文档
+
+- [PyTorch 2.0 torch.compile()](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)
+- [Flash Attention](https://github.com/Dao-AILab/flash-attention)
+- [CUDA Kernel Fusion](https://developer.nvidia.com/blog/cuda-pro-tip-increase-performance-with-vectorized-memory-access/)
+
+### 8.2 相关论文
+
+- DDIM: Denoising Diffusion Implicit Models
+- Flash Attention: Fast and Memory-Efficient Exact Attention
+- Efficient Diffusion Models for Vision
+
+---
+
+## 9. 总结
+
+### 9.1 关键发现
+
+1. **模型前向传播占99.86%的DDIM采样时间** - 这是优化的绝对核心
+2. **30步DDIM采样占总时间的83%** - 减少步数或加速单步是关键
+3. **VAE解码占4.5%** - 次要优化目标
+
+### 9.2 优化优先级
+
+**高优先级** (立即实施):
+- ✅ torch.compile()
+- ✅ cuDNN benchmark
+- ✅ 混合精度推理
+
+**中优先级** (1周内):
+- Flash Attention集成
+- VAE批处理优化
+
+**低优先级** (长期):
+- 自定义CUDA kernel
+- 模型架构改进
+
+### 9.3 预期成果
+
+通过系统性优化,预期可以将推理时间从 **412s 降低到 140-200s**,实现 **2-3倍加速**。
+
+---
+
+**文档版本**: v1.1
+**创建日期**: 2026-01-17
+**最后更新**: 2026-01-17
+**更新内容**: 根据实际代码验证并修正了文件路径、行号、组件位置和实现细节
+
+
+---
+
+## 附录A: 实际模型架构详解
+
+基于代码分析,以下是真实的模型实现细节。
+
+### A.1 WMAModel 实际配置
+
+**配置文件**: `configs/inference/world_model_interaction.yaml:69-104`
+
+```yaml
+WMAModel参数:
+ in_channels: 8 # 输入通道 (4 latent + 4 concat条件)
+ out_channels: 4 # 输出通道 (latent空间)
+ model_channels: 320 # 基础通道数
+ channel_mult: [1, 2, 4, 4] # 通道倍增: [320, 640, 1280, 1280]
+ num_res_blocks: 2 # 每个分辨率2个ResBlock
+ attention_resolutions: [4, 2, 1] # 在这些分辨率启用注意力
+ num_head_channels: 64 # 每个注意力头64通道
+ transformer_depth: 1 # Transformer深度
+ context_dim: 1024 # 交叉注意力上下文维度
+ temporal_length: 16 # 时间序列长度
+ dropout: 0.1
+```
+
+**架构层次**:
+```
+输入: [B, 8, 16, 40, 64] (8通道 = 4 latent + 4 VAE条件)
+ ↓
+下采样路径 (4个阶段):
+ Stage 0: [B, 320, 16, 40, 64] - 2个ResBlock + SpatialTransformer + TemporalTransformer
+ Stage 1: [B, 640, 16, 20, 32] - Downsample + 2个ResBlock + Attention
+ Stage 2: [B, 1280, 16, 10, 16] - Downsample + 2个ResBlock + Attention
+ Stage 3: [B, 1280, 16, 5, 8] - Downsample + 2个ResBlock + Attention
+ ↓
+中间块: [B, 1280, 16, 5, 8]
+ - ResBlock + SpatialTransformer + TemporalTransformer + ResBlock
+ ↓
+上采样路径 (3个阶段):
+ Stage 2: [B, 1280, 16, 10, 16] - Upsample + 2个ResBlock + Attention
+ Stage 1: [B, 640, 16, 20, 32] - Upsample + 2个ResBlock + Attention
+ Stage 0: [B, 320, 16, 40, 64] - Upsample + 2个ResBlock + Attention
+ ↓
+输出: [B, 4, 16, 40, 64] (预测的噪声或速度)
+```
+
+
+### A.2 ResBlock 实际实现
+
+**位置**: `src/unifolm_wma/modules/networks/wma_model.py:130-263`
+
+**实际代码结构**:
+```python
+class ResBlock:
+ def __init__(self, channels, emb_channels, dropout, ...):
+ # 输入层: GroupNorm + SiLU + Conv
+ self.in_layers = nn.Sequential(
+ normalization(channels), # GroupNorm
+ nn.SiLU(), # 激活函数
+ conv_nd(dims, channels, out_channels, 3, padding=1)
+ )
+
+ # 时间步嵌入层: SiLU + Linear
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(emb_channels, out_channels)
+ )
+
+ # 输出层: GroupNorm + SiLU + Dropout + Conv
+ self.out_layers = nn.Sequential(
+ normalization(out_channels), # GroupNorm
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(nn.Conv2d(out_channels, out_channels, 3, padding=1))
+ )
+
+ # 残差连接
+ self.skip_connection = ...
+
+ # 可选的时间卷积
+ if use_temporal_conv:
+ self.temporal_conv = TemporalConvBlock(...)
+
+ def forward(self, x, emb):
+ h = self.in_layers(x) # GroupNorm + SiLU + Conv
+ emb_out = self.emb_layers(emb) # 时间步嵌入
+ h = h + emb_out # 加入时间步信息
+ h = self.out_layers(h) # GroupNorm + SiLU + Dropout + Conv
+ h = self.skip_connection(x) + h # 残差连接
+
+ if use_temporal_conv:
+ h = self.temporal_conv(h) # 时间维度卷积
+ return h
+```
+
+**内核融合机会**:
+1. `GroupNorm + SiLU` 可融合 (in_layers和out_layers各一次)
+2. `emb_layers(SiLU + Linear)` 可融合
+3. `残差加法 + 下一层的GroupNorm` 可融合
+
+
+### A.3 注意力机制实际实现
+
+**SpatialTransformer** (空间注意力):
+- 位置: `src/unifolm_wma/modules/attention.py:472-558`
+- 在特征图的空间维度(H×W)上执行自注意力和交叉注意力
+- 使用 `transformer_depth=1`,即每个位置1层Transformer
+- 当 xformers 可用时,使用 `efficient_forward` 方法进行高效注意力计算
+
+**TemporalTransformer** (时间注意力):
+- 位置: `src/unifolm_wma/modules/attention.py:561-680`
+- 在时间维度(T=16帧)上执行自注意力
+- 配置: `temporal_selfatt_only=True` (仅时间自注意力,不做交叉注意力)
+- 使用相对位置编码: `use_relative_position=False` (实际未启用)
+
+**CrossAttention** (核心注意力层):
+- 位置: `src/unifolm_wma/modules/attention.py:48-398`
+- 支持多种交叉注意力: 图像、文本、状态、动作
+- 当 xformers 可用时自动使用 `xformers.ops.memory_efficient_attention`
+
+**注意力头配置**:
+- `num_head_channels=64`: 每个头64通道
+- 对于320通道: 320/64 = 5个注意力头
+- 对于640通道: 640/64 = 10个注意力头
+- 对于1280通道: 1280/64 = 20个注意力头
+
+
+### A.4 VAE 实际配置
+
+**配置文件**: `configs/inference/world_model_interaction.yaml:159-180`
+
+```yaml
+AutoencoderKL:
+ embed_dim: 4 # Latent维度
+ z_channels: 4 # Latent通道数
+ resolution: 256 # 基础分辨率
+ in_channels: 3 # RGB输入
+ out_ch: 3 # RGB输出
+ ch: 128 # 基础通道数
+ ch_mult: [1, 2, 4, 4] # 通道倍增
+ num_res_blocks: 2 # 每层2个ResBlock
+ attn_resolutions: [] # VAE中不使用注意力
+ dropout: 0.0
+```
+
+**编码器架构**:
+```
+输入: [B, 3, 320, 512]
+ ↓ Conv 3→128
+ ↓ ResBlock×2 [128, 320, 512]
+ ↓ Downsample [128, 160, 256]
+ ↓ ResBlock×2 [256, 160, 256]
+ ↓ Downsample [256, 80, 128]
+ ↓ ResBlock×2 [512, 80, 128]
+ ↓ Downsample [512, 40, 64]
+ ↓ ResBlock×2 [512, 40, 64]
+ ↓ ResBlock + Conv
+输出: [B, 4, 40, 64] (8×8下采样)
+```
+
+**解码器架构** (编码器的镜像):
+```
+输入: [B, 4, 40, 64]
+ ↓ Conv + ResBlock
+ ↓ ResBlock×2 [512, 40, 64]
+ ↓ Upsample [512, 80, 128]
+ ↓ ResBlock×2 [512, 80, 128]
+ ↓ Upsample [256, 160, 256]
+ ↓ ResBlock×2 [256, 160, 256]
+ ↓ Upsample [128, 320, 512]
+ ↓ ResBlock×2 [128, 320, 512]
+ ↓ Conv 128→3
+输出: [B, 3, 320, 512]
+```
+
+
+### A.5 条件编码器实际配置
+
+#### CLIP图像编码器
+
+**配置**: `configs/inference/world_model_interaction.yaml:188-191`
+```yaml
+FrozenOpenCLIPImageEmbedderV2:
+ freeze: true
+ # 使用OpenCLIP的ViT-H/14模型
+ # 输出维度: 1280
+```
+
+**图像投影器 (Resampler)**:
+```yaml
+Resampler:
+ dim: 1024 # 输出维度
+ depth: 4 # Transformer深度
+ dim_head: 64 # 注意力头维度
+ heads: 12 # 12个注意力头
+ num_queries: 16 # 16个查询token
+ embedding_dim: 1280 # CLIP输出维度
+ output_dim: 1024 # 最终输出维度
+ video_length: 16 # 视频长度
+```
+
+**数据流**:
+```
+图像 [B, 3, H, W]
+ ↓ CLIP Encoder
+ ↓ [B, 1280]
+ ↓ Resampler (Perceiver-style)
+ ↓ [B, 16, 1024] (16个token,每个1024维)
+```
+
+
+#### 文本编码器
+
+**配置**: `configs/inference/world_model_interaction.yaml:182-186`
+```yaml
+FrozenOpenCLIPEmbedder:
+ freeze: True
+ layer: "penultimate" # 使用倒数第二层
+ # 输出维度: 1024
+```
+
+**数据流**:
+```
+文本指令 "pick up the box"
+ ↓ OpenCLIP Text Encoder
+ ↓ [B, seq_len, 1024]
+```
+
+#### 动作/状态投影器
+
+**代码位置**: [src/unifolm_wma/models/ddpms.py:2014-2026](src/unifolm_wma/models/ddpms.py)
+
+**MLPProjector实现** (src/unifolm_wma/utils/projector.py:14-37):
+```python
+class MLPProjector(nn.Module):
+ def __init__(self, input_dim: int, output_dim: int, mlp_type: str = "gelu-mlp"):
+ if mlp_type == "gelu-mlp":
+ self.projector = nn.Sequential(
+ nn.Linear(input_dim, output_dim, bias=True),
+ nn.GELU(approximate='tanh'),
+ nn.Linear(output_dim, output_dim, bias=True),
+ )
+```
+
+**初始化代码** (ddpms.py:2014-2026):
+```python
+# 状态投影器
+self.state_projector = MLPProjector(agent_state_dim, 1024) # 16 → 1024
+self.action_projector = MLPProjector(agent_action_dim, 1024) # 16 → 1024
+
+# 位置嵌入
+self.agent_action_pos_emb = nn.Parameter(torch.randn(1, 16, 1024))
+self.agent_state_pos_emb = nn.Parameter(torch.randn(1, n_obs_steps, 1024))
+```
+
+**数据流**:
+```
+状态 [B, T_obs, 16]
+ ↓ MLPProjector (Linear + GELU + Linear)
+ ↓ [B, T_obs, 1024]
+ ↓ + agent_state_pos_emb
+ ↓ [B, T_obs, 1024]
+
+动作 [B, T_action, 16]
+ ↓ MLPProjector (Linear + GELU + Linear)
+ ↓ [B, T_action, 1024]
+ ↓ + agent_action_pos_emb
+ ↓ [B, T_action, 1024]
+```
+
+
+### A.6 时间步嵌入实际实现
+
+**位置**: `src/unifolm_wma/utils/diffusion.py:timestep_embedding`
+
+**实际代码**:
+```python
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """
+ 创建正弦位置编码
+ """
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half) / half
+ )
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ return embedding
+```
+
+**在WMAModel中的使用**:
+```python
+# 时间步 t ∈ [0, 999]
+t_emb = timestep_embedding(t, model_channels) # [B, 320]
+t_emb = self.time_embed(t_emb) # Linear(320 → 1280)
+# 输出: [B, 1280]
+```
+
+**时间嵌入网络**:
+```
+t ∈ [0, 999]
+ ↓ timestep_embedding (正弦编码)
+ ↓ [B, 320]
+ ↓ Linear(320 → 1280)
+ ↓ SiLU
+ ↓ Linear(1280 → 1280)
+ ↓ [B, 1280]
+```
+
+
+### A.7 动作头 (ConditionalUnet1D) 实际配置
+
+**配置**: `configs/inference/world_model_interaction.yaml:106-127`
+
+```yaml
+ConditionalUnet1D:
+ input_dim: 16 # 动作维度
+ n_obs_steps: 2 # 观测步数
+ diffusion_step_embed_dim: 128 # 扩散步嵌入维度
+ down_dims: [256, 512, 1024, 2048] # 下采样通道
+ kernel_size: 5 # 卷积核大小
+ n_groups: 8 # GroupNorm分组数
+ horizon: 16 # 预测时间范围
+ use_linear_attn: true # 使用线性注意力
+ imagen_cond_gradient: true # 使用图像条件梯度
+```
+
+**架构**:
+```
+输入: 噪声动作 [B, 16, 16] (16维动作 × 16步)
+条件:
+ - 图像特征 [B, C, H, W] 来自WMAModel中间层
+ - 观测编码 [B, n_obs, obs_dim]
+
+ ↓ Conv1D + ResBlock
+ ↓ [B, 256, 16]
+ ↓ Downsample + ResBlock
+ ↓ [B, 512, 8]
+ ↓ Downsample + ResBlock
+ ↓ [B, 1024, 4]
+ ↓ Downsample + ResBlock
+ ↓ [B, 2048, 2]
+ ↓ Middle Block (with attention)
+ ↓ Upsample + ResBlock
+ ↓ [B, 1024, 4]
+ ↓ Upsample + ResBlock
+ ↓ [B, 512, 8]
+ ↓ Upsample + ResBlock
+ ↓ [B, 256, 16]
+ ↓ Conv1D
+输出: 预测噪声 [B, 16, 16]
+```
+
+
+### A.8 完整前向传播流程
+
+基于实际代码,完整的前向传播流程如下:
+
+```python
+# 1. 条件编码 (一次性完成,可缓存)
+cond_img_emb = clip_encoder(img) → resampler → [B, 16, 1024]
+cond_text_emb = text_encoder(text) → [B, seq_len, 1024]
+cond_state_emb = state_projector(state) + pos_emb → [B, T_obs, 1024]
+cond_action_emb = action_projector(action) + pos_emb → [B, T_action, 1024]
+cond_latent = vae.encode(img) → [B, 4, T, 40, 64]
+
+# 2. 拼接条件
+c_concat = [cond_latent] # 通道拼接
+c_crossattn = [cond_text_emb, cond_img_emb, cond_state_emb, cond_action_emb]
+c_crossattn = torch.cat(c_crossattn, dim=1) # [B, total_tokens, 1024]
+
+# 3. DDIM采样循环 (30步)
+x = torch.randn([B, 4, 16, 40, 64]) # 初始噪声
+for step in range(30):
+ # 3.1 时间步嵌入
+ t_emb = timestep_embedding(t, 320) → Linear → [B, 1280]
+
+ # 3.2 拼接输入
+ x_in = torch.cat([x, cond_latent], dim=1) # [B, 8, 16, 40, 64]
+
+ # 3.3 UNet前向传播 (核心瓶颈)
+ noise_pred = wma_model(x_in, t_emb, c_crossattn)
+ # 包含: 4个下采样阶段 + 中间块 + 3个上采样阶段
+ # 每个阶段: 2个ResBlock + SpatialTransformer + TemporalTransformer
+
+ # 3.4 DDIM更新
+ x = ddim_update(x, noise_pred, t, t_prev)
+
+# 4. VAE解码
+video = vae.decode(x) → [B, 3, 16, 320, 512]
+```
+
+
+### A.9 基于实际架构的优化建议更新
+
+#### 优化点1: ResBlock融合 (高优先级)
+
+**实际瓶颈**:
+- 每个DDIM步骤调用UNet一次
+- UNet包含: 4个下采样阶段 + 1个中间块 + 3个上采样阶段 = 8个阶段
+- 每个阶段有2个ResBlock
+- 总计: 16个ResBlock × 30步 × 2次(阶段1+2) = **960次ResBlock调用**
+
+**融合机会**:
+```python
+# 当前: 6次kernel启动
+h = group_norm(x) # kernel 1
+h = silu(h) # kernel 2
+h = conv2d(h) # kernel 3
+h = group_norm(h) # kernel 4
+h = silu(h) # kernel 5
+h = conv2d(h) # kernel 6
+out = x + h # kernel 7
+
+# 优化后: 2-3次kernel启动
+h = fused_norm_silu_conv(x) # kernel 1 (融合)
+h = fused_norm_silu_conv(h) # kernel 2 (融合)
+out = fused_residual_add(x, h) # kernel 3 (融合)
+```
+
+**预期收益**: 每个ResBlock节省50-60%的kernel启动开销
+
+
+#### 优化点2: 注意力机制优化 (高优先级)
+
+**实际配置**:
+- SpatialTransformer: 在每个阶段的每个ResBlock后
+- TemporalTransformer: 在每个阶段的每个ResBlock后
+- 总计: 16个Spatial + 16个Temporal = **32个Transformer × 30步 × 2次 = 1920次注意力调用**
+
+**当前实现已支持xformers**:
+代码在 `attention.py:8-13` 检测 xformers 可用性:
+```python
+try:
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+```
+
+当 xformers 可用时,`CrossAttention` 会自动使用 `efficient_forward` 方法 (attention.py:90-91)。
+
+**进一步优化方案** (如果xformers不可用):
+```python
+# 使用 PyTorch 内置 Flash Attention
+from torch.nn.functional import scaled_dot_product_attention
+out = scaled_dot_product_attention(Q, K, V, is_causal=False)
+```
+
+**预期收益**: 如果xformers已启用,注意力层已经是优化的;否则使用Flash Attention可加速2-3倍
+
diff --git a/prepare_data/prepare_training_data.py b/prepare_data/prepare_training_data.py
new file mode 100644
index 0000000..3d899e3
--- /dev/null
+++ b/prepare_data/prepare_training_data.py
@@ -0,0 +1,199 @@
+import json
+import os
+import shutil
+import h5py
+import argparse
+import pandas as pd
+import torch
+import subprocess
+
+from pathlib import Path
+from safetensors.torch import save_file
+from tqdm import tqdm
+
+
+def flatten_dict(d, parent_key="", sep="/"):
+ """Flatten a nested dictionary structure by collapsing nested keys into one key with a separator.
+
+ For example:
+ ```
+ >>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}`
+ >>> print(flatten_dict(dct))
+ {"a/b": 1, "a/c/d": 2, "e": 3}
+ """
+ items = []
+ for k, v in d.items():
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
+ if isinstance(v, dict):
+ items.extend(flatten_dict(v, new_key, sep=sep).items())
+ else:
+ items.append((new_key, v))
+ return dict(items)
+
+
+def is_av1(file_path):
+ try:
+ result = subprocess.run([
+ "ffprobe", "-v", "error", "-select_streams", "v:0",
+ "-show_entries", "stream=codec_name", "-of", "csv=p=0",
+ str(file_path)
+ ],
+ capture_output=True,
+ text=True,
+ check=True)
+ return result.stdout.strip() == "av1"
+ except subprocess.CalledProcessError:
+ return False
+
+
+def convert_to_h264(input_path, output_path):
+ subprocess.run([
+ "ffmpeg", "-i",
+ str(input_path), "-c:v", "libx264", "-preset", "slow", "-crf", "23",
+ "-c:a", "copy",
+ str(output_path)
+ ],
+ check=True)
+
+
+def main(args):
+ source_dir = Path(args.source_dir)
+ source_data_dir = source_dir / args.dataset_name / "data" / "chunk-000"
+ source_meta_dir = source_dir / args.dataset_name / "meta"
+ source_videos_dir = source_dir / args.dataset_name / "videos" / "chunk-000"
+
+ target_dir = Path(args.target_dir)
+ target_videos_dir = target_dir / "videos" / args.dataset_name
+ target_transitions_dir = target_dir / "transitions" / args.dataset_name
+ target_meta_dir = target_dir / "transitions" / args.dataset_name / "meta_data"
+
+ target_dir.mkdir(parents=True, exist_ok=True)
+ target_videos_dir.mkdir(parents=True, exist_ok=True)
+ target_transitions_dir.mkdir(parents=True, exist_ok=True)
+ target_meta_dir.mkdir(parents=True, exist_ok=True)
+
+ csv_file = target_dir / f"{args.dataset_name}.csv"
+ COLUMNS = [
+ 'videoid', 'contentUrl', 'duration', 'data_dir', 'instruction',
+ 'dynamic_confidence', 'dynamic_wording', 'dynamic_source_category',
+ 'embodiment'
+ ]
+ df = pd.DataFrame(columns=COLUMNS)
+
+ # Load info.json from source dir
+ info_json_path = source_meta_dir / "info.json"
+ with open(str(info_json_path), "r") as f:
+ info = json.load(f)
+ total_episodes = info['total_episodes']
+
+ # Load task.jsonl to get lanugage ins
+ tasks_jsonl_path = source_meta_dir / "tasks.jsonl"
+ with open(str(tasks_jsonl_path), "r") as f:
+ tasks = [json.loads(line) for line in f]
+ instruction = tasks[0]['task']
+
+ source_video_views = [d for d in source_videos_dir.iterdir()]
+ for v_idx, source_view_dir in enumerate(source_video_views):
+
+ view_name = source_view_dir.name
+ target_videos_view_dir = target_videos_dir / view_name
+ target_videos_view_dir.mkdir(parents=True, exist_ok=True)
+
+ if v_idx == 0:
+ all_actions = []
+ all_states = []
+
+ for idx in tqdm(range(total_episodes)):
+ # Copy source video to target vidoe dir
+ source_video = source_view_dir / f"episode_{idx:06d}.mp4"
+ if is_av1(source_video):
+ output_video = str(target_videos_view_dir / f"{idx}.mp4")
+ print(f"Converting episode_{idx:06d}.mp4 to H.264...")
+ convert_to_h264(source_video, output_video)
+ else:
+ print(f"Skipping episode_{idx:06d}.mp4: not AV1 encoded.")
+
+ # Load parquet file
+ episode_parquet_file = source_data_dir / f"episode_{idx:06d}.parquet"
+ episode_data = pd.read_parquet(episode_parquet_file)
+ actions = torch.tensor(episode_data['action'].tolist())
+ states = torch.tensor(episode_data['observation.state'].tolist())
+
+ # Save action and state into a h5 file
+ if v_idx == 0:
+ target_h5_file = target_transitions_dir / f"{idx}.h5"
+ with h5py.File(str(target_h5_file), 'w') as h5f:
+ h5f.create_dataset('observation.state', data=states)
+ h5f.create_dataset('action', data=actions)
+ h5f.attrs['action_type'] = 'joint position'
+ h5f.attrs['state_type'] = 'joint position'
+ h5f.attrs['robot_type'] = args.robot_name
+
+ # Updata df
+ df = pd.concat([
+ df,
+ pd.DataFrame([{
+ 'videoid': idx,
+ 'contentUrl': 'x',
+ 'duration': 'x',
+ 'data_dir': args.dataset_name + f"/{view_name}",
+ 'instruction': instruction,
+ 'dynamic_confidence': 'x',
+ 'dynamic_wording': 'x',
+ 'dynamic_source_category': 'x',
+ 'embodiment': args.robot_name
+ }])
+ ],
+ ignore_index=True)
+
+ # Collect action and state
+ if v_idx == 0:
+ all_actions.append(actions)
+ all_states.append(states)
+
+ # Create satas.safetensors
+ actions = torch.cat(all_actions, dim=0)
+ states = torch.cat(all_states, dim=0)
+
+ stats = {'action': {}, 'observation.state': {}}
+ stats['action']['max'] = actions.max(dim=0).values
+ stats['action']['min'] = actions.min(dim=0).values
+ stats['action']['mean'] = actions.mean(dim=0)
+ stats['action']['std'] = actions.std(dim=0)
+
+ stats['observation.state']['max'] = states.max(dim=0).values
+ stats['observation.state']['min'] = states.min(dim=0).values
+ stats['observation.state']['mean'] = states.mean(dim=0)
+ stats['observation.state']['std'] = states.std(dim=0)
+
+ flattened_stats = flatten_dict(stats)
+ target_stats_file = target_meta_dir / "stats.safetensors"
+ save_file(flattened_stats, target_stats_file)
+
+ df.to_csv(csv_file, index=False)
+ print(f">>> Finished create {args.dataset_name} dataset ...")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--source_dir',
+ action='store',
+ type=str,
+ help='The dataset dir under lerobot 2.0 data format.',
+ required=True)
+ parser.add_argument('--target_dir',
+ action='store',
+ type=str,
+ default='./data',
+ help='The target dir to save new formatted dataset.')
+ parser.add_argument('--dataset_name',
+ action='store',
+ type=str,
+ help='dataset name',
+ required=True)
+ parser.add_argument('--robot_name',
+ action='store',
+ type=str,
+ help='robot name',
+ required=True)
+ main(parser.parse_args())
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100755
index 0000000..e08d9e6
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,53 @@
+[project]
+name = "unifolm_wma"
+version = "0.0.1"
+description = "UnifoLM-WMA-0"
+license = { text = "BSD-3-Clause" }
+authors = [
+ {name="Unitree Embodied AI R&D Team", email="rd_xyc@unitree.com" }
+]
+requires-python = "==3.10.18"
+dependencies = [
+ "decord==0.6.0",
+ "einops==0.8.0",
+ "imageio==2.35.1",
+ "numpy==1.24.2",
+ "omegaconf==2.3.0",
+ "opencv-python==4.10.0.84",
+ "pandas==2.0.0",
+ "pillow==9.5.0",
+ "pytorch-lightning==1.9.3",
+ "pyyaml==6.0",
+ "setuptools==65.6.3",
+ "torch==2.3.1",
+ "torchvision==0.18.1",
+ "tqdm==4.66.5",
+ "transformers==4.40.1",
+ "moviepy==1.0.3",
+ "av==12.3.0",
+ "xformers==0.0.27",
+ "gradio==4.39.0",
+ "timm==0.9.10",
+ "scikit-learn==1.5.1",
+ "open-clip-torch==2.22.0",
+ "kornia==0.7.3",
+ "diffusers==0.30.2",
+ "termcolor==2.4.0",
+ "draccus==0.11.5",
+ "accelerate==1.7.0",
+ "tensorflow-metadata==1.16.1",
+ "protobuf==3.20.3",
+ "datasets==3.6.0",
+ "tensorflow-graphics==2021.12.3",
+ "fairscale==0.4.13"
+]
+
+[build-system]
+requires = ["setuptools>=65.6.3", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[tool.setuptools]
+package-dir = { "" = "src" }
+
+[tool.setuptools.packages.find]
+where = ["src"]
diff --git a/run_all_cases.sh b/run_all_cases.sh
new file mode 100755
index 0000000..6252554
--- /dev/null
+++ b/run_all_cases.sh
@@ -0,0 +1,114 @@
+#!/bin/bash
+
+# 自动执行所有场景的所有case
+# 总共5个场景,每个场景4个case,共20个case
+# 设置环境变量(离线模式)
+export HF_HUB_OFFLINE=1
+export TRANSFORMERS_OFFLINE=1
+
+# 颜色定义
+RED='\033[0;31m'
+GREEN='\033[0;32m'
+YELLOW='\033[1;33m'
+BLUE='\033[0;34m'
+NC='\033[0m' # No Color
+
+# 定义所有场景
+SCENARIOS=(
+ "unitree_g1_pack_camera"
+ "unitree_z1_dual_arm_cleanup_pencils"
+ "unitree_z1_dual_arm_stackbox"
+ "unitree_z1_dual_arm_stackbox_v2"
+ "unitree_z1_stackbox"
+)
+
+# 定义case数量
+CASES=(1 2 3 4)
+
+# 记录开始时间
+START_TIME=$(date +%s)
+LOG_FILE="run_all_cases_$(date +%Y%m%d_%H%M%S).log"
+
+echo -e "${BLUE}========================================${NC}"
+echo -e "${BLUE}开始执行所有场景的case${NC}"
+echo -e "${BLUE}总共: ${#SCENARIOS[@]} 个场景 x ${#CASES[@]} 个case = $((${#SCENARIOS[@]} * ${#CASES[@]})) 个任务${NC}"
+echo -e "${BLUE}日志文件: ${LOG_FILE}${NC}"
+echo -e "${BLUE}========================================${NC}"
+echo ""
+
+# 初始化计数器
+TOTAL_CASES=$((${#SCENARIOS[@]} * ${#CASES[@]}))
+CURRENT_CASE=0
+SUCCESS_COUNT=0
+FAIL_COUNT=0
+
+# 记录失败的case
+declare -a FAILED_CASES
+
+# 遍历所有场景
+for scenario in "${SCENARIOS[@]}"; do
+ echo -e "${YELLOW}>>> 场景: ${scenario}${NC}"
+
+ # 遍历所有case
+ for case_num in "${CASES[@]}"; do
+ CURRENT_CASE=$((CURRENT_CASE + 1))
+ case_dir="${scenario}/case${case_num}"
+ script_path="${case_dir}/run_world_model_interaction.sh"
+
+ echo -e "${BLUE}[${CURRENT_CASE}/${TOTAL_CASES}] 执行: ${case_dir}${NC}"
+
+ # 检查脚本是否存在
+ if [ ! -f "${script_path}" ]; then
+ echo -e "${RED}错误: 脚本不存在 ${script_path}${NC}"
+ FAIL_COUNT=$((FAIL_COUNT + 1))
+ FAILED_CASES+=("${case_dir} (脚本不存在)")
+ continue
+ fi
+
+ # 执行脚本
+ echo "开始时间: $(date '+%Y-%m-%d %H:%M:%S')"
+
+ if bash "${script_path}" >> "${LOG_FILE}" 2>&1; then
+ echo -e "${GREEN}✓ 成功: ${case_dir}${NC}"
+ SUCCESS_COUNT=$((SUCCESS_COUNT + 1))
+ else
+ echo -e "${RED}✗ 失败: ${case_dir}${NC}"
+ FAIL_COUNT=$((FAIL_COUNT + 1))
+ FAILED_CASES+=("${case_dir}")
+ fi
+
+ echo "结束时间: $(date '+%Y-%m-%d %H:%M:%S')"
+ echo ""
+ done
+
+ echo ""
+done
+
+# 计算总耗时
+END_TIME=$(date +%s)
+DURATION=$((END_TIME - START_TIME))
+HOURS=$((DURATION / 3600))
+MINUTES=$(((DURATION % 3600) / 60))
+SECONDS=$((DURATION % 60))
+
+# 输出总结
+echo -e "${BLUE}========================================${NC}"
+echo -e "${BLUE}执行完成!${NC}"
+echo -e "${BLUE}========================================${NC}"
+echo -e "总任务数: ${TOTAL_CASES}"
+echo -e "${GREEN}成功: ${SUCCESS_COUNT}${NC}"
+echo -e "${RED}失败: ${FAIL_COUNT}${NC}"
+echo -e "总耗时: ${HOURS}小时 ${MINUTES}分钟 ${SECONDS}秒"
+echo -e "详细日志: ${LOG_FILE}"
+echo ""
+
+# 如果有失败的case,列出来
+if [ ${FAIL_COUNT} -gt 0 ]; then
+ echo -e "${RED}失败的case列表:${NC}"
+ for failed_case in "${FAILED_CASES[@]}"; do
+ echo -e "${RED} - ${failed_case}${NC}"
+ done
+ echo ""
+fi
+
+echo -e "${BLUE}========================================${NC}"
diff --git a/scripts/evaluation/base_model_inference.py b/scripts/evaluation/base_model_inference.py
new file mode 100644
index 0000000..42945a7
--- /dev/null
+++ b/scripts/evaluation/base_model_inference.py
@@ -0,0 +1,541 @@
+import argparse, os, glob
+import datetime, time
+import pandas as pd
+import torch
+import torchvision
+import torchvision.transforms as transforms
+import random
+
+from pytorch_lightning import seed_everything
+from PIL import Image
+from omegaconf import OmegaConf
+from tqdm import tqdm
+from einops import rearrange, repeat
+from collections import OrderedDict
+
+from unifolm_wma.models.samplers.ddim import DDIMSampler
+from unifolm_wma.utils.utils import instantiate_from_config
+
+
+def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
+ """
+ Get list of files in `data_dir` with extensions in `postfixes`.
+
+ Args:
+ data_dir (str): Directory path.
+ postfixes (list[str]): List of file extensions (e.g., ['csv', 'jpg']).
+
+ Returns:
+ list[str]: Sorted list of matched file paths.
+ """
+ patterns = [
+ os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes
+ ]
+ file_list = []
+ for pattern in patterns:
+ file_list.extend(glob.glob(pattern))
+ file_list.sort()
+ return file_list
+
+
+def load_model_checkpoint(model: torch.nn.Module,
+ ckpt: str) -> torch.nn.Module:
+ """
+ Load model weights from checkpoint file.
+
+ Args:
+ model (torch.nn.Module): The model to load weights into.
+ ckpt (str): Path to the checkpoint file.
+
+ Returns:
+ torch.nn.Module: Model with weights loaded.
+ """
+ state_dict = torch.load(ckpt, map_location="cpu")
+ if "state_dict" in list(state_dict.keys()):
+ state_dict = state_dict["state_dict"]
+ try:
+ loaded = model.load_state_dict(state_dict, strict=False)
+ print("Missing keys:")
+ for k in loaded.missing_keys:
+ print(f" {k}")
+ print("Unexpected keys:")
+ for k in loaded.unexpected_keys:
+ print(f" {k}")
+
+ except:
+ # Rename the keys for 256x256 model
+ new_pl_sd = OrderedDict()
+ for k, v in state_dict.items():
+ new_pl_sd[k] = v
+
+ for k in list(new_pl_sd.keys()):
+ if "framestride_embed" in k:
+ new_key = k.replace("framestride_embed", "fps_embedding")
+ new_pl_sd[new_key] = new_pl_sd[k]
+ del new_pl_sd[k]
+ model.load_state_dict(new_pl_sd, strict=False)
+ else:
+ new_pl_sd = OrderedDict()
+ for key in state_dict['module'].keys():
+ new_pl_sd[key[16:]] = state_dict['module'][key]
+ model.load_state_dict(new_pl_sd)
+ print('>>> model checkpoint loaded.')
+ return model
+
+
+def load_prompts(prompt_file: str) -> list[str]:
+ """
+ Load prompts from a text file, one per line.
+
+ Args:
+ prompt_file (str): Path to the prompt file.
+
+ Returns:
+ list[str]: List of prompt strings.
+ """
+ f = open(prompt_file, 'r')
+ prompt_list = []
+ for idx, line in enumerate(f.readlines()):
+ l = line.strip()
+ if len(l) != 0:
+ prompt_list.append(l)
+ f.close()
+ return prompt_list
+
+
+def load_data_prompts(
+ data_dir: str,
+ savedir: str,
+ video_size: tuple[int, int] = (256, 256),
+ video_frames: int = 16
+) -> tuple[list[str], list[torch.Tensor], list[str], list[float], list[float],
+ list[int]]:
+ """
+ Load image prompts, process them into video format, and retrieve metadata.
+
+ Args:
+ data_dir (str): Directory containing images and CSV file.
+ savedir (str): Output directory to check if inference was already done.
+ video_size (tuple[int, int], optional): Target size of video frames.
+ video_frames (int, optional): Number of frames in each video.
+
+ Returns:
+ tuple: (filenames, video tensors, prompts, fps values, fs values, num_generations)
+ """
+
+ transform = transforms.Compose([
+ transforms.Resize(min(video_size)),
+ transforms.CenterCrop(video_size),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
+ ])
+
+ # Load prompt csv
+ prompt_file = get_filelist(data_dir, ['csv'])
+ assert len(prompt_file) > 0, "Error: found NO image prompt file!"
+
+ # Load image prompts
+ file_list = get_filelist(data_dir, ['jpg', 'png', 'jpeg', 'JPEG', 'PNG'])
+ data_list = []
+ filename_list = []
+ prompt_list = []
+ fps_list = []
+ fs_list = []
+ num_gen_list = []
+ prompt_csv = pd.read_csv(prompt_file[0])
+ n_samples = len(file_list)
+
+ for idx in range(n_samples):
+ image = Image.open(file_list[idx]).convert('RGB')
+ image_tensor = transform(image).unsqueeze(1)
+ frame_tensor = repeat(image_tensor,
+ 'c t h w -> c (repeat t) h w',
+ repeat=video_frames)
+ _, filename = os.path.split(file_list[idx])
+
+ if not is_inferenced(savedir, filename):
+ video_id = filename[:-4]
+ prompt_csv['videoid'] = prompt_csv['videoid'].map(str)
+ if not (prompt_csv['videoid'] == video_id).any():
+ continue
+ data_list.append(frame_tensor)
+ filename_list.append(filename)
+ ins = prompt_csv[prompt_csv['videoid'] ==
+ video_id]['instruction'].values[0]
+ prompt_list.append(ins)
+ fps = prompt_csv[prompt_csv['videoid'] ==
+ video_id]['fps'].values[0]
+ fps_list.append(fps)
+ fs = prompt_csv[prompt_csv['videoid'] == video_id]['fs'].values[0]
+ fs_list.append(fs)
+ num_gen = prompt_csv[prompt_csv['videoid'] ==
+ video_id]['num_gen'].values[0]
+ num_gen_list.append(int(num_gen))
+
+ return filename_list, data_list, prompt_list, fps_list, fs_list, num_gen_list
+
+
+def is_inferenced(save_dir: str, filename: str) -> bool:
+ """
+ Check if a result video already exists.
+
+ Args:
+ save_dir (str): Directory where results are saved.
+ filename (str): Base filename to check.
+
+ Returns:
+ bool: True if file exists, else False.
+ """
+ video_file = os.path.join(save_dir, f"{filename[:-4]}.mp4")
+ return os.path.exists(video_file)
+
+
+def save_results_seperate(prompt: str | list[str],
+ samples: torch.Tensor,
+ filename: str,
+ fakedir: str,
+ fps: int = 8) -> None:
+ """
+ Save generated video samples as .mp4 files.
+
+ Args:
+ prompt (str | list[str]): The prompt text.
+ samples (torch.Tensor): Generated video tensor of shape [B, C, T, H, W].
+ filename (str): Output filename.
+ fakedir (str): Directory to save output videos.
+ fps (int, optional): Frames per second.
+
+ Returns:
+ None
+ """
+ prompt = prompt[0] if isinstance(prompt, list) else prompt
+
+ # Save video
+ videos = [samples]
+ savedirs = [fakedir]
+ for idx, video in enumerate(videos):
+ if video is None:
+ continue
+ video = video.detach().cpu()
+ video = torch.clamp(video.float(), -1., 1.)
+ n = video.shape[0]
+ for i in range(n):
+ grid = video[i, ...]
+ grid = (grid + 1.0) / 2.0
+ grid = (grid * 255).to(torch.uint8).permute(1, 2, 3, 0)
+ path = os.path.join(savedirs[idx], f'{filename.split(".")[0]}.mp4')
+ torchvision.io.write_video(path,
+ grid,
+ fps=fps,
+ video_codec='h264',
+ options={'crf': '0'})
+
+
+def get_latent_z(model: torch.nn.Module, videos: torch.Tensor) -> torch.Tensor:
+ """
+ Encode videos to latent space.
+
+ Args:
+ model (torch.nn.Module): Model with encode_first_stage function.
+ videos (torch.Tensor): Video tensor of shape [B, C, T, H, W].
+
+ Returns:
+ torch.Tensor: Latent representation of shape [B, C, T, H, W].
+ """
+ b, c, t, h, w = videos.shape
+ x = rearrange(videos, 'b c t h w -> (b t) c h w')
+ z = model.encode_first_stage(x)
+ z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
+ return z
+
+
+def image_guided_synthesis(model: torch.nn.Module,
+ prompts: list[str],
+ videos: torch.Tensor,
+ noise_shape: list[int],
+ ddim_steps: int = 50,
+ ddim_eta: float = 1.0,
+ unconditional_guidance_scale: float = 1.0,
+ fs: int | None = None,
+ text_input: bool = False,
+ timestep_spacing: str = 'uniform',
+ guidance_rescale: float = 0.0,
+ **kwargs) -> torch.Tensor:
+ """
+ Run DDIM-based image-to-video synthesis with hybrid/text+image guidance.
+
+ Args:
+ model (torch.nn.Module): Diffusion model.
+ prompts (list[str]): Text prompts.
+ videos (torch.Tensor): Input images/videos of shape [B, C, T, H, W].
+ noise_shape (list[int]): Latent noise shape [B, C, T, H, W].
+ ddim_steps (int, optional): Number of DDIM steps.
+ ddim_eta (float, optional): Eta value for DDIM.
+ unconditional_guidance_scale (float, optional): Guidance scale.
+ fs (int | None, optional): FPS input for sampler.
+ text_input (bool, optional): If True, use text guidance.
+ timestep_spacing (str, optional): Timestep schedule spacing.
+ guidance_rescale (float, optional): Rescale guidance effect.
+ **kwargs: Additional sampler args.
+
+ Returns:
+ torch.Tensor: Synthesized videos of shape [B, 1, C, T, H, W].
+ """
+
+ ddim_sampler = DDIMSampler(model)
+ batch_size = noise_shape[0]
+ fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
+
+ if not text_input:
+ prompts = [""] * batch_size
+
+ b, c, t, h, w = videos.shape
+ img = videos[:, :, 0]
+ img_emb = model.embedder(img)
+ img_emb = model.image_proj_model(img_emb)
+ img_emb = rearrange(img_emb, 'b (t l) c -> (b t) l c', t=t)
+ cond_emb = model.get_learned_conditioning(prompts)
+ cond_emb = cond_emb.repeat_interleave(repeats=t, dim=0)
+
+ cond = {"c_crossattn": [torch.cat([cond_emb, img_emb], dim=1)]}
+ if model.model.conditioning_key == 'hybrid':
+ z = get_latent_z(model, videos)
+ img_cat_cond = z[:, :, :1, :, :]
+ img_cat_cond = repeat(img_cat_cond,
+ 'b c t h w -> b c (repeat t) h w',
+ repeat=z.shape[2])
+ cond["c_concat"] = [img_cat_cond]
+
+ uc = None
+ cond_mask = None
+ kwargs.update({"unconditional_conditioning_img_nonetext": None})
+
+ batch_variants = []
+ if ddim_sampler is not None:
+ samples, _, _, _ = ddim_sampler.sample(
+ S=ddim_steps,
+ batch_size=batch_size,
+ shape=noise_shape[1:],
+ conditioning=cond,
+ eta=ddim_eta,
+ mask=cond_mask,
+ x0=None,
+ verbose=False,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ fs=fs,
+ timestep_spacing=timestep_spacing,
+ guidance_rescale=guidance_rescale,
+ **kwargs)
+
+ # Reconstruct from latent to pixel space
+ batch_images = model.decode_first_stage(samples)
+ batch_variants.append(batch_images)
+
+ batch_variants = torch.stack(batch_variants)
+ return batch_variants.permute(1, 0, 2, 3, 4, 5)
+
+
+def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
+ """
+ Run inference pipeline on prompts and image inputs.
+
+ Args:
+ args (argparse.Namespace): Parsed command-line arguments.
+ gpu_num (int): Number of GPUs.
+ gpu_no (int): Index of the current GPU.
+
+ Returns:
+ None
+ """
+ # Load config
+ config = OmegaConf.load(args.config)
+ # Set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set"
+ config['model']['params']['wma_config']['params'][
+ 'use_checkpoint'] = False
+ model = instantiate_from_config(config.model)
+ model = model.cuda(gpu_no)
+ model.perframe_ae = args.perframe_ae
+ assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
+ model = load_model_checkpoint(model, args.ckpt_path)
+ model.eval()
+
+ # Run over data
+ assert (args.height % 16 == 0) and (
+ args.width % 16
+ == 0), "Error: image size [h,w] should be multiples of 16!"
+ assert args.bs == 1, "Current implementation only support [batch size = 1]!"
+
+ # Get latent noise shape
+ h, w = args.height // 8, args.width // 8
+ channels = model.model.diffusion_model.out_channels
+ n_frames = args.video_length
+ print(f'>>> Generate {n_frames} frames under each generation ...')
+ noise_shape = [args.bs, channels, n_frames, h, w]
+
+ fakedir = os.path.join(args.savedir, "samples")
+ os.makedirs(fakedir, exist_ok=True)
+
+ # Prompt file setting
+ assert os.path.exists(args.prompt_dir), "Error: prompt file Not Found!"
+ filename_list, data_list, prompt_list, fps_list, fs_list, num_gen_list = load_data_prompts(
+ args.prompt_dir,
+ args.savedir,
+ video_size=(args.height, args.width),
+ video_frames=n_frames)
+
+ num_samples = len(prompt_list)
+ samples_split = num_samples // gpu_num
+ print('>>> Prompts testing [rank:%d] %d/%d samples loaded.' %
+ (gpu_no, samples_split, num_samples))
+
+ indices = list(range(samples_split * gpu_no, samples_split * (gpu_no + 1)))
+ fps_list_rank = [fps_list[i] for i in indices]
+ fs_list_rank = [fs_list[i] for i in indices]
+ prompt_list_rank = [prompt_list[i] for i in indices]
+ data_list_rank = [data_list[i] for i in indices]
+ filename_list_rank = [filename_list[i] for i in indices]
+
+ with torch.no_grad(), torch.cuda.amp.autocast():
+ # Create a new result csv
+ for idx, indice in enumerate(
+ tqdm(range(0, len(prompt_list_rank), args.bs),
+ desc=f'Sample batch')):
+ fps = fps_list_rank[indice:indice + args.bs]
+ fs = fs_list_rank[indice:indice + args.bs]
+ prompts = prompt_list_rank[indice:indice + args.bs]
+ num_gen = num_gen_list[indice:indice + args.bs]
+ videos = data_list_rank[indice:indice + args.bs]
+ filenames = filename_list_rank[indice:indice + args.bs]
+ if isinstance(videos, list):
+ videos = torch.stack(videos, dim=0).to("cuda")
+ else:
+ videos = videos.unsqueeze(0).to("cuda")
+
+ results = []
+ print(
+ f">>> {prompts[0]}, frame_stride:{fs[0]}, and {num_gen[0]} generation ..."
+ )
+ for _ in range(num_gen[0]):
+ batch_samples = image_guided_synthesis(
+ model, prompts, videos, noise_shape, args.ddim_steps,
+ args.ddim_eta, args.unconditional_guidance_scale,
+ fps[0] // fs[0], args.text_input, args.timestep_spacing,
+ args.guidance_rescale)
+ results.extend(batch_samples)
+ videos = repeat(batch_samples[0][:, :, -1, :, :].unsqueeze(2),
+ 'b c t h w -> b c (repeat t) h w',
+ repeat=batch_samples[0].shape[2])
+ batch_samples = [torch.concat(results, axis=2)]
+
+ # Save each example individually
+ for nn, samples in enumerate(batch_samples):
+ prompt = prompts[nn]
+ filename = filenames[nn]
+ save_results_seperate(prompt,
+ samples,
+ filename,
+ fakedir,
+ fps=8)
+
+
+def get_parser() -> argparse.ArgumentParser:
+ """
+ Create and return the argument parser.
+
+ Returns:
+ argparse.ArgumentParser: Parser for command-line arguments.
+ """
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--savedir",
+ type=str,
+ default=None,
+ help="Path to save the results.")
+ parser.add_argument("--ckpt_path",
+ type=str,
+ default=None,
+ help="Path to the model checkpoint.")
+ parser.add_argument("--config",
+ type=str,
+ help="Path to the YAML configuration file.")
+ parser.add_argument(
+ "--prompt_dir",
+ type=str,
+ default=None,
+ help="Directory containing videos and corresponding prompts.")
+ parser.add_argument(
+ "--ddim_steps",
+ type=int,
+ default=50,
+ help="Number of DDIM steps. If non-positive, DDPM is used instead.")
+ parser.add_argument(
+ "--ddim_eta",
+ type=float,
+ default=1.0,
+ help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
+ parser.add_argument("--bs",
+ type=int,
+ default=1,
+ help="Batch size for inference. Must be 1.")
+ parser.add_argument("--height",
+ type=int,
+ default=320,
+ help="Height of the generated images in pixels.")
+ parser.add_argument("--width",
+ type=int,
+ default=512,
+ help="Width of the generated images in pixels.")
+ parser.add_argument(
+ "--unconditional_guidance_scale",
+ type=float,
+ default=1.0,
+ help="Scale for classifier-free guidance during sampling.")
+ parser.add_argument("--seed",
+ type=int,
+ default=123,
+ help="Random seed for reproducibility.")
+ parser.add_argument("--video_length",
+ type=int,
+ default=16,
+ help="Number of frames in the generated video.")
+ parser.add_argument(
+ "--text_input",
+ action='store_true',
+ default=False,
+ help=
+ "Whether to provide a text prompt as input to the image-to-video model."
+ )
+ parser.add_argument(
+ "--timestep_spacing",
+ type=str,
+ default="uniform",
+ help=
+ "Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
+ )
+ parser.add_argument(
+ "--guidance_rescale",
+ type=float,
+ default=0.0,
+ help=
+ "Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
+ )
+ parser.add_argument(
+ "--perframe_ae",
+ action='store_true',
+ default=False,
+ help=
+ "Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
+ )
+ return parser
+
+
+if __name__ == '__main__':
+ parser = get_parser()
+ args = parser.parse_args()
+ seed = args.seed
+ if seed < 0:
+ seed = random.randint(0, 2**31)
+ seed_everything(seed)
+ rank, gpu_num = 0, 1
+ run_inference(args, gpu_num, rank)
diff --git a/scripts/evaluation/eval_utils.py b/scripts/evaluation/eval_utils.py
new file mode 100644
index 0000000..366335a
--- /dev/null
+++ b/scripts/evaluation/eval_utils.py
@@ -0,0 +1,77 @@
+import torch
+import warnings
+import torchvision
+import sys
+import pyarrow as pa
+import logging
+
+from dataclasses import dataclass, field
+from typing import Dict, Any, ClassVar, Deque, Mapping, Union
+from datasets.features.features import register_feature
+from torch.utils.tensorboard.writer import SummaryWriter
+
+logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
+
+
+@dataclass
+class VideoFrame:
+ """
+ Provides a type for a dataset containing video frames.
+
+ Example:
+
+ ```python
+ data_dict = [{"image": {"path": "videos/episode_0.mp4", "timestamp": 0.3}}]
+ features = {"image": VideoFrame()}
+ Dataset.from_dict(data_dict, features=Features(features))
+ ```
+ """
+
+ pa_type: ClassVar[Any] = pa.struct({
+ "path": pa.string(),
+ "timestamp": pa.float32()
+ })
+ _type: str = field(default="VideoFrame", init=False, repr=False)
+
+ def __call__(self):
+ return self.pa_type
+
+
+with warnings.catch_warnings():
+ warnings.filterwarnings(
+ "ignore",
+ "'register_feature' is experimental and might be subject to breaking changes in the future.",
+ category=UserWarning,
+ )
+ register_feature(VideoFrame, "VideoFrame")
+
+
+def populate_queues(
+ queues: Dict[str, Deque[Any]],
+ batch: Mapping[str, Any]) -> Dict[str, Deque[Any]]:
+
+ for key in batch:
+ if key not in queues:
+ continue
+ if len(queues[key]) != queues[key].maxlen:
+ while len(queues[key]) != queues[key].maxlen:
+ queues[key].append(batch[key])
+ else:
+ queues[key].append(batch[key])
+ return queues
+
+
+def log_to_tensorboard(
+ writer: SummaryWriter,
+ data: Union[torch.Tensor, Any],
+ tag: str,
+ fps: int = 10) -> None:
+ if isinstance(data, torch.Tensor) and data.dim() == 5:
+ video = data
+ n = video.shape[0]
+ video = video.permute(2, 0, 1, 3, 4)
+ frame_grids = [torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0) for framesheet in video]
+ grid = torch.stack(frame_grids, dim=0)
+ grid = (grid + 1.0) / 2.0
+ grid = grid.unsqueeze(dim=0)
+ writer.add_video(tag, grid, fps=fps)
diff --git a/scripts/evaluation/real_eval_server.py b/scripts/evaluation/real_eval_server.py
new file mode 100644
index 0000000..d780b5d
--- /dev/null
+++ b/scripts/evaluation/real_eval_server.py
@@ -0,0 +1,463 @@
+import argparse, os, sys
+import torch
+import torchvision
+import warnings
+import imageio
+import logging
+import matplotlib.pyplot as plt
+plt.switch_backend('agg')
+import traceback
+import uvicorn
+
+from omegaconf import OmegaConf
+from einops import rearrange, repeat
+from collections import OrderedDict
+from pytorch_lightning import seed_everything
+from torch import nn
+from fastapi import FastAPI
+from fastapi.responses import JSONResponse
+from typing import Any, Dict, Optional, Tuple, List
+from datetime import datetime
+
+from unifolm_wma.utils.utils import instantiate_from_config
+from unifolm_wma.models.samplers.ddim import DDIMSampler
+
+
+def get_device_from_parameters(module: nn.Module) -> torch.device:
+ """Get a module's device by checking one of its parameters.
+
+ Args:
+ module (nn.Module): PyTorch module.
+
+ Returns:
+ torch.device: The device where the module's parameters are stored.
+ """
+ return next(iter(module.parameters())).device
+
+
+def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
+ """Load model weights from checkpoint file.
+
+ Args:
+ model (nn.Module): Model to load weights into.
+ ckpt (str): Path to checkpoint file.
+
+ Returns:
+ nn.Module: Model with loaded weights.
+ """
+
+ state_dict = torch.load(ckpt, map_location="cpu")
+ if "state_dict" in list(state_dict.keys()):
+ state_dict = state_dict["state_dict"]
+ try:
+ model.load_state_dict(state_dict, strict=False)
+ except:
+ new_pl_sd = OrderedDict()
+ for k, v in state_dict.items():
+ new_pl_sd[k] = v
+
+ for k in list(new_pl_sd.keys()):
+ if "framestride_embed" in k:
+ new_key = k.replace("framestride_embed", "fps_embedding")
+ new_pl_sd[new_key] = new_pl_sd[k]
+ del new_pl_sd[k]
+ model.load_state_dict(new_pl_sd, strict=False)
+ else:
+ new_pl_sd = OrderedDict()
+ for key in state_dict['module'].keys():
+ new_pl_sd[key[16:]] = state_dict['module'][key]
+ model.load_state_dict(new_pl_sd)
+ print('>>> model checkpoint loaded.')
+ return model
+
+
+def write_video(video_path: str, stacked_frames: List[Any], fps: int) -> None:
+ """Write a video to disk using imageio.
+
+ Args:
+ video_path (str): Path to save the video.
+ stacked_frames (List[Any]): Frames to write.
+ fps (int): Frames per second.
+ """
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore",
+ "pkg_resources is deprecated as an API",
+ category=DeprecationWarning)
+ imageio.mimsave(video_path, stacked_frames, fps=fps)
+
+
+def save_results(video: torch.Tensor, filename: str, fps: int = 8) -> None:
+ """Save a video tensor as an MP4 file.
+
+ Args:
+ video (torch.Tensor): Video tensor of shape (B, C, T, H, W).
+ filename (str): Path to save video.
+ fps (int, optional): Frame rate. Defaults to 8.
+
+ """
+ video = video.detach().cpu()
+ video = torch.clamp(video.float(), -1., 1.)
+ n = video.shape[0]
+ video = video.permute(2, 0, 1, 3, 4)
+
+ frame_grids = [
+ torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
+ for framesheet in video
+ ]
+ grid = torch.stack(frame_grids, dim=0)
+ grid = (grid + 1.0) / 2.0
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
+ torchvision.io.write_video(filename,
+ grid,
+ fps=fps,
+ video_codec='h264',
+ options={'crf': '10'})
+
+
+def get_latent_z(model: nn.Module, videos: torch.Tensor) -> torch.Tensor:
+ """Encode videos into latent space.
+
+ Args:
+ model (nn.Module): Model with `encode_first_stage` method.
+ videos (torch.Tensor): Input videos (B, C, T, H, W).
+
+ Returns:
+ torch.Tensor: Latent representation (B, C, T, H, W).
+
+ """
+ b, c, t, h, w = videos.shape
+ x = rearrange(videos, 'b c t h w -> (b t) c h w')
+ z = model.encode_first_stage(x)
+ z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
+ return z
+
+
+def image_guided_synthesis(
+ model: torch.nn.Module,
+ prompts: list[str],
+ observation: Dict[str, torch.Tensor],
+ noise_shape: tuple[int, int, int, int, int],
+ ddim_steps: int = 50,
+ ddim_eta: float = 1.0,
+ unconditional_guidance_scale: float = 1.0,
+ fs: int | None = None,
+ timestep_spacing: str = 'uniform',
+ guidance_rescale: float = 0.0,
+ **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Run inference with DDIM sampling.
+
+ Args:
+ model (nn.Module): Diffusion model.
+ prompts (Any): Conditioning text prompts.
+ observation (Dict[str, torch.Tensor]): Observation dictionary.
+ noise_shape (List[int]): Shape of noise tensor.
+ ddim_steps (int, optional): Number of DDIM steps. Defaults to 50.
+ ddim_eta (float, optional): Sampling eta. Defaults to 1.0.
+ unconditional_guidance_scale (float, optional): Guidance scale. Defaults to 1.0.
+ fs (Optional[int], optional): Frame stride or FPS. Defaults to None.
+ timestep_spacing (str, optional): Spacing strategy. Defaults to "uniform".
+ guidance_rescale (float, optional): Guidance rescale. Defaults to 0.0.
+ **kwargs (Any): Additional arguments.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+
+ b, _, t, _, _ = noise_shape
+ ddim_sampler = DDIMSampler(model)
+ batch_size = noise_shape[0]
+ fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
+
+ img = observation['observation.images.top']
+ cond_img = img[:, -1, ...]
+ cond_img_emb = model.embedder(cond_img)
+ cond_img_emb = model.image_proj_model(cond_img_emb)
+
+ if model.model.conditioning_key == 'hybrid':
+ z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
+ img_cat_cond = z[:, :, -1:, :, :]
+ img_cat_cond = repeat(img_cat_cond,
+ 'b c t h w -> b c (repeat t) h w',
+ repeat=noise_shape[2])
+ cond = {"c_concat": [img_cat_cond]}
+
+ cond_ins_emb = model.get_learned_conditioning(prompts)
+ cond_state = model.state_projector(observation['observation.state'])
+ cond_state_emb = model.agent_state_pos_emb + cond_state
+
+ cond_action = model.action_projector(observation['action'])
+ cond_action_emb = model.agent_action_pos_emb + cond_action
+ cond_action_emb = torch.zeros_like(cond_action_emb)
+
+ cond["c_crossattn"] = [
+ torch.cat([cond_state_emb, cond_ins_emb, cond_img_emb], dim=1)
+ ]
+ cond["c_crossattn_action"] = [
+ observation['observation.images.top'].permute(
+ 0, 2, 1, 3, 4)[:, :, -model.n_obs_steps_acting:],
+ observation['observation.state'][:, -model.n_obs_steps_acting:]
+ ]
+
+ uc = None
+
+ kwargs.update({"unconditional_conditioning_img_nonetext": None})
+
+ cond_mask = None
+ cond_z0 = None
+
+ if ddim_sampler is not None:
+
+ samples, actions, states, intermedia = ddim_sampler.sample(
+ S=ddim_steps,
+ conditioning=cond,
+ batch_size=batch_size,
+ shape=noise_shape[1:],
+ verbose=False,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ eta=ddim_eta,
+ cfg_img=None,
+ mask=cond_mask,
+ x0=cond_z0,
+ fs=fs,
+ timestep_spacing=timestep_spacing,
+ guidance_rescale=guidance_rescale,
+ **kwargs)
+
+ # Reconstruct from latent to pixel space
+ batch_images = model.decode_first_stage(samples)
+ batch_variants = batch_images
+
+ return batch_variants, actions, states
+
+
+def run_inference(args: argparse.Namespace, gpu_num: int,
+ gpu_no: int) -> Tuple[nn.Module, List[int], Any]:
+ """
+ Run inference pipeline on prompts and image inputs.
+
+ Args:
+ args (argparse.Namespace): Parsed command-line arguments.
+ gpu_num (int): Number of GPUs.
+ gpu_no (int): Index of the current GPU.
+
+ Returns:
+ None
+ """
+ # Load config
+ config = OmegaConf.load(args.config)
+ # Set use_checkpoint as False as when using deepspeed, it encounters an error "deepspeed backend not set"
+ config['model']['params']['wma_config']['params']['use_checkpoint'] = False
+ model = instantiate_from_config(config.model)
+ model.perframe_ae = args.perframe_ae
+ assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
+ model = load_model_checkpoint(model, args.ckpt_path)
+ model = model.cuda(gpu_no)
+ model.eval()
+ print(">>> Model is successfully loaded ...")
+
+ # Build unnomalizer
+ logging.info("***** Configing Data *****")
+ data = instantiate_from_config(config.data)
+ data.setup()
+ print(">>> Dataset is successfully loaded ...")
+
+ ## Run over data
+ assert (args.height % 16 == 0) and (
+ args.width % 16
+ == 0), "Error: image size [h,w] should be multiples of 16!"
+ assert args.bs == 1, "Current implementation only support [batch size = 1]!"
+
+ ## Get latent noise shape
+ h, w = args.height // 8, args.width // 8
+ channels = model.model.diffusion_model.out_channels
+ n_frames = args.video_length
+ print(f'>>> Generate {n_frames} frames under each generation ...')
+ noise_shape = [args.bs, channels, n_frames, h, w]
+
+ return model, noise_shape, data
+
+
+def get_parser() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--savedir",
+ type=str,
+ default=None,
+ help="Path to save the results.")
+ parser.add_argument("--ckpt_path",
+ type=str,
+ default=None,
+ help="Path to the model checkpoint.")
+ parser.add_argument("--config", type=str, help="Path to the config file.")
+ parser.add_argument(
+ "--ddim_steps",
+ type=int,
+ default=50,
+ help="Number of DDIM steps. If non-positive, DDPM is used instead.")
+ parser.add_argument(
+ "--ddim_eta",
+ type=float,
+ default=1.0,
+ help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
+ parser.add_argument("--bs",
+ type=int,
+ default=1,
+ help="Batch size for inference. Must be 1.")
+ parser.add_argument("--height",
+ type=int,
+ default=320,
+ help="Height of the generated images in pixels.")
+ parser.add_argument("--width",
+ type=int,
+ default=512,
+ help="Width of the generated images in pixels.")
+ parser.add_argument(
+ "--frame_stride",
+ type=int,
+ default=3,
+ help=
+ "frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)"
+ )
+ parser.add_argument(
+ "--unconditional_guidance_scale",
+ type=float,
+ default=1.0,
+ help="Scale for classifier-free guidance during sampling.")
+ parser.add_argument("--seed",
+ type=int,
+ default=123,
+ help="Random seed for reproducibility.")
+ parser.add_argument("--video_length",
+ type=int,
+ default=16,
+ help="Number of frames in the generated video.")
+ parser.add_argument(
+ "--timestep_spacing",
+ type=str,
+ default="uniform",
+ help=
+ "Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
+ )
+ parser.add_argument(
+ "--guidance_rescale",
+ type=float,
+ default=0.0,
+ help=
+ "Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
+ )
+ parser.add_argument(
+ "--perframe_ae",
+ action='store_true',
+ default=False,
+ help=
+ "Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
+ )
+ return parser
+
+
+class Server:
+
+ def __init__(self, args: argparse.Namespace) -> None:
+ self.model_, self.noise_shape_, self.data_ = run_inference(args, 1, 0)
+ self.args_ = args
+ self.dataset_name = self.data_.dataset_configs['test']['params'][
+ 'dataset_name']
+ self.device_ = get_device_from_parameters(self.model_)
+
+ def normalize_image(self, image: torch.Tensor) -> torch.Tensor:
+ return (image / 255 - 0.5) * 2
+
+ def predict_action(self, payload: Dict[str, Any]) -> Any:
+ try:
+ images = payload['observation.images.top']
+ states = payload['observation.state']
+ actions = payload['action'] # Should be all zeros
+ language_instruction = payload['language_instruction']
+
+ images = torch.tensor(images).cuda()
+ images = self.data_.test_datasets[
+ self.dataset_name].spatial_transform(images).unsqueeze(0)
+ images = self.normalize_image(images)
+ print(f"images shape: {images.shape} ...")
+ states = torch.tensor(states)
+ states = self.data_.test_datasets[self.dataset_name].normalizer(
+ {'observation.state': states})['observation.state']
+ states, _ = self.data_.test_datasets[
+ self.dataset_name]._map_to_uni_state(states, "joint position")
+ print(f"states shape: {states.shape} ...")
+ actions = torch.tensor(actions)
+ actions, action_mask = self.data_.test_datasets[
+ self.dataset_name]._map_to_uni_action(actions,
+ "joint position")
+ print(f"actions shape: {actions.shape} ...")
+ print("=" * 20)
+ states = states.unsqueeze(0).cuda()
+ actions = actions.unsqueeze(0).cuda()
+
+ observation = {
+ 'observation.images.top': images,
+ 'observation.state': states,
+ 'action': actions
+ }
+ observation = {
+ key: observation[key].to(self.device_, non_blocking=True)
+ for key in observation
+ }
+
+ args = self.args_
+ pred_videos, pred_action, _ = image_guided_synthesis(
+ self.model_,
+ language_instruction,
+ observation,
+ self.noise_shape_,
+ ddim_steps=args.ddim_steps,
+ ddim_ets=args.ddim_eta,
+ unconditional_guidance_scale=args.unconditional_guidance_scale,
+ fs=30 / args.frame_stride,
+ timestep_spacing=args.timestep_spacing,
+ guidance_rescale=args.guidance_rescale)
+
+ pred_action = pred_action[..., action_mask[0] == 1.0][0].cpu()
+ pred_action = self.data_.test_datasets[
+ self.dataset_name].unnormalizer({'action':
+ pred_action})['action']
+
+ os.makedirs(args.savedir, exist_ok=True)
+ current_time = datetime.now().strftime("%H:%M:%S")
+ video_file = f'{args.savedir}/{current_time}.mp4'
+ save_results(pred_videos.cpu(), video_file)
+
+ response = {
+ 'result': 'ok',
+ 'action': pred_action.tolist(),
+ 'desc': 'success'
+ }
+ return JSONResponse(response)
+
+ except:
+ logging.error(traceback.format_exc())
+ logging.warning(
+ "Your request threw an error; make sure your request complies with the expected format:\n"
+ "{'image': np.ndarray, 'instruction': str}\n"
+ "You can optionally an `unnorm_key: str` to specific the dataset statistics you want to use for "
+ "de-normalizing the output actions.")
+ return {'result': 'error', 'desc': traceback.format_exc()}
+
+ def run(self, host: str = "127.0.0.1", port: int = 8000) -> None:
+ self.app = FastAPI()
+ self.app.post("/predict_action")(self.predict_action)
+ print(">>> Inference server is ready ... ")
+ uvicorn.run(self.app, host=host, port=port)
+ print(">>> Inference server stops ... ")
+ return
+
+
+if __name__ == '__main__':
+ parser = get_parser()
+ args = parser.parse_args()
+ seed = args.seed
+ seed_everything(seed)
+ rank, gpu_num = 0, 1
+ print(">>> Launch inference server ... ")
+ server = Server(args)
+ server.run()
diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py
new file mode 100644
index 0000000..c87bae5
--- /dev/null
+++ b/scripts/evaluation/world_model_interaction.py
@@ -0,0 +1,1220 @@
+import argparse, os, glob
+import pandas as pd
+import random
+import torch
+import torchvision
+import h5py
+import numpy as np
+import logging
+import einops
+import warnings
+import imageio
+import time
+import json
+from contextlib import contextmanager, nullcontext
+from dataclasses import dataclass, field, asdict
+from typing import Optional, Dict, List, Any
+
+from pytorch_lightning import seed_everything
+from omegaconf import OmegaConf
+from tqdm import tqdm
+from einops import rearrange, repeat
+from collections import OrderedDict
+from torch import nn
+from eval_utils import populate_queues, log_to_tensorboard
+from collections import deque
+from torch import Tensor
+from torch.utils.tensorboard import SummaryWriter
+from PIL import Image
+
+from unifolm_wma.models.samplers.ddim import DDIMSampler
+from unifolm_wma.utils.utils import instantiate_from_config
+
+
+# ========== Profiling Infrastructure ==========
+@dataclass
+class TimingRecord:
+ """Record for a single timing measurement."""
+ name: str
+ start_time: float = 0.0
+ end_time: float = 0.0
+ cuda_time_ms: float = 0.0
+ count: int = 0
+ children: List['TimingRecord'] = field(default_factory=list)
+
+ @property
+ def cpu_time_ms(self) -> float:
+ return (self.end_time - self.start_time) * 1000
+
+ def to_dict(self) -> dict:
+ return {
+ 'name': self.name,
+ 'cpu_time_ms': self.cpu_time_ms,
+ 'cuda_time_ms': self.cuda_time_ms,
+ 'count': self.count,
+ 'children': [c.to_dict() for c in self.children]
+ }
+
+
+class ProfilerManager:
+ """Manages macro and micro-level profiling."""
+
+ def __init__(self, enabled: bool = False, output_dir: str = "./profile_output"):
+ self.enabled = enabled
+ self.output_dir = output_dir
+ self.macro_timings: Dict[str, List[float]] = {}
+ self.cuda_events: Dict[str, List[tuple]] = {}
+ self.memory_snapshots: List[Dict] = []
+ self.pytorch_profiler = None
+ self.current_iteration = 0
+ self.operator_stats: Dict[str, Dict] = {}
+
+ if enabled:
+ os.makedirs(output_dir, exist_ok=True)
+
+ @contextmanager
+ def profile_section(self, name: str, sync_cuda: bool = True):
+ """Context manager for profiling a code section."""
+ if not self.enabled:
+ yield
+ return
+
+ if sync_cuda and torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ start_event = None
+ end_event = None
+ if torch.cuda.is_available():
+ start_event = torch.cuda.Event(enable_timing=True)
+ end_event = torch.cuda.Event(enable_timing=True)
+ start_event.record()
+
+ start_time = time.perf_counter()
+
+ try:
+ yield
+ finally:
+ if sync_cuda and torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ end_time = time.perf_counter()
+ cpu_time_ms = (end_time - start_time) * 1000
+
+ cuda_time_ms = 0.0
+ if start_event is not None and end_event is not None:
+ end_event.record()
+ torch.cuda.synchronize()
+ cuda_time_ms = start_event.elapsed_time(end_event)
+
+ if name not in self.macro_timings:
+ self.macro_timings[name] = []
+ self.macro_timings[name].append(cpu_time_ms)
+
+ if name not in self.cuda_events:
+ self.cuda_events[name] = []
+ self.cuda_events[name].append((cpu_time_ms, cuda_time_ms))
+
+ def record_memory(self, tag: str = ""):
+ """Record current GPU memory state."""
+ if not self.enabled or not torch.cuda.is_available():
+ return
+
+ snapshot = {
+ 'tag': tag,
+ 'iteration': self.current_iteration,
+ 'allocated_mb': torch.cuda.memory_allocated() / 1024**2,
+ 'reserved_mb': torch.cuda.memory_reserved() / 1024**2,
+ 'max_allocated_mb': torch.cuda.max_memory_allocated() / 1024**2,
+ }
+ self.memory_snapshots.append(snapshot)
+
+ def start_pytorch_profiler(self, wait: int = 1, warmup: int = 1, active: int = 3):
+ """Start PyTorch profiler for operator-level analysis."""
+ if not self.enabled:
+ return nullcontext()
+
+ self.pytorch_profiler = torch.profiler.profile(
+ activities=[
+ torch.profiler.ProfilerActivity.CPU,
+ torch.profiler.ProfilerActivity.CUDA,
+ ],
+ schedule=torch.profiler.schedule(
+ wait=wait, warmup=warmup, active=active, repeat=1
+ ),
+ on_trace_ready=self._trace_handler,
+ record_shapes=True,
+ profile_memory=True,
+ with_stack=True,
+ with_flops=True,
+ with_modules=True,
+ )
+ return self.pytorch_profiler
+
+ def _trace_handler(self, prof):
+ """Handle profiler trace output."""
+ trace_path = os.path.join(
+ self.output_dir,
+ f"trace_iter_{self.current_iteration}.json"
+ )
+ prof.export_chrome_trace(trace_path)
+
+ # Extract operator statistics
+ key_averages = prof.key_averages(group_by_input_shape=True)
+ for evt in key_averages:
+ op_name = evt.key
+ if op_name not in self.operator_stats:
+ self.operator_stats[op_name] = {
+ 'count': 0,
+ 'cpu_time_total_us': 0,
+ 'cuda_time_total_us': 0,
+ 'self_cpu_time_total_us': 0,
+ 'self_cuda_time_total_us': 0,
+ 'cpu_memory_usage': 0,
+ 'cuda_memory_usage': 0,
+ 'flops': 0,
+ }
+ stats = self.operator_stats[op_name]
+ stats['count'] += evt.count
+ stats['cpu_time_total_us'] += evt.cpu_time_total
+ stats['cuda_time_total_us'] += evt.cuda_time_total
+ stats['self_cpu_time_total_us'] += evt.self_cpu_time_total
+ stats['self_cuda_time_total_us'] += evt.self_cuda_time_total
+ if hasattr(evt, 'cpu_memory_usage'):
+ stats['cpu_memory_usage'] += evt.cpu_memory_usage
+ if hasattr(evt, 'cuda_memory_usage'):
+ stats['cuda_memory_usage'] += evt.cuda_memory_usage
+ if hasattr(evt, 'flops') and evt.flops:
+ stats['flops'] += evt.flops
+
+ def step_profiler(self):
+ """Step the PyTorch profiler."""
+ if self.pytorch_profiler is not None:
+ self.pytorch_profiler.step()
+
+ def generate_report(self) -> str:
+ """Generate comprehensive profiling report."""
+ if not self.enabled:
+ return "Profiling disabled."
+
+ report_lines = []
+ report_lines.append("=" * 80)
+ report_lines.append("PERFORMANCE PROFILING REPORT")
+ report_lines.append("=" * 80)
+ report_lines.append("")
+
+ # Macro-level timing summary
+ report_lines.append("-" * 40)
+ report_lines.append("MACRO-LEVEL TIMING SUMMARY")
+ report_lines.append("-" * 40)
+ report_lines.append(f"{'Section':<40} {'Count':>8} {'Total(ms)':>12} {'Avg(ms)':>12} {'CUDA Avg(ms)':>14}")
+ report_lines.append("-" * 86)
+
+ total_time = 0
+ timing_data = []
+ for name, times in sorted(self.macro_timings.items()):
+ cuda_times = [ct for _, ct in self.cuda_events.get(name, [])]
+ avg_time = np.mean(times)
+ avg_cuda = np.mean(cuda_times) if cuda_times else 0
+ total = sum(times)
+ total_time += total
+ timing_data.append({
+ 'name': name,
+ 'count': len(times),
+ 'total_ms': total,
+ 'avg_ms': avg_time,
+ 'cuda_avg_ms': avg_cuda,
+ 'times': times,
+ 'cuda_times': cuda_times,
+ })
+ report_lines.append(f"{name:<40} {len(times):>8} {total:>12.2f} {avg_time:>12.2f} {avg_cuda:>14.2f}")
+
+ report_lines.append("-" * 86)
+ report_lines.append(f"{'TOTAL':<40} {'':<8} {total_time:>12.2f}")
+ report_lines.append("")
+
+ # Memory summary
+ if self.memory_snapshots:
+ report_lines.append("-" * 40)
+ report_lines.append("GPU MEMORY SUMMARY")
+ report_lines.append("-" * 40)
+ max_alloc = max(s['max_allocated_mb'] for s in self.memory_snapshots)
+ avg_alloc = np.mean([s['allocated_mb'] for s in self.memory_snapshots])
+ report_lines.append(f"Peak allocated: {max_alloc:>10.2f} MB")
+ report_lines.append(f"Average allocated: {avg_alloc:>10.2f} MB")
+ report_lines.append("")
+
+ # Top operators by CUDA time
+ if self.operator_stats:
+ report_lines.append("-" * 40)
+ report_lines.append("TOP 30 OPERATORS BY CUDA TIME")
+ report_lines.append("-" * 40)
+ sorted_ops = sorted(
+ self.operator_stats.items(),
+ key=lambda x: x[1]['cuda_time_total_us'],
+ reverse=True
+ )[:30]
+
+ report_lines.append(f"{'Operator':<50} {'Count':>8} {'CUDA(ms)':>12} {'CPU(ms)':>12} {'Self CUDA(ms)':>14}")
+ report_lines.append("-" * 96)
+
+ for op_name, stats in sorted_ops:
+ # Truncate long operator names
+ display_name = op_name[:47] + "..." if len(op_name) > 50 else op_name
+ report_lines.append(
+ f"{display_name:<50} {stats['count']:>8} "
+ f"{stats['cuda_time_total_us']/1000:>12.2f} "
+ f"{stats['cpu_time_total_us']/1000:>12.2f} "
+ f"{stats['self_cuda_time_total_us']/1000:>14.2f}"
+ )
+ report_lines.append("")
+
+ # Compute category breakdown
+ report_lines.append("-" * 40)
+ report_lines.append("OPERATOR CATEGORY BREAKDOWN")
+ report_lines.append("-" * 40)
+
+ categories = {
+ 'Attention': ['attention', 'softmax', 'bmm', 'baddbmm'],
+ 'Convolution': ['conv', 'cudnn'],
+ 'Normalization': ['norm', 'layer_norm', 'batch_norm', 'group_norm'],
+ 'Activation': ['relu', 'gelu', 'silu', 'sigmoid', 'tanh'],
+ 'Linear/GEMM': ['linear', 'addmm', 'mm', 'gemm'],
+ 'Memory': ['copy', 'contiguous', 'view', 'reshape', 'permute', 'transpose'],
+ 'Elementwise': ['add', 'mul', 'div', 'sub', 'pow', 'exp', 'sqrt'],
+ }
+
+ category_times = {cat: 0.0 for cat in categories}
+ category_times['Other'] = 0.0
+
+ for op_name, stats in self.operator_stats.items():
+ op_lower = op_name.lower()
+ categorized = False
+ for cat, keywords in categories.items():
+ if any(kw in op_lower for kw in keywords):
+ category_times[cat] += stats['cuda_time_total_us']
+ categorized = True
+ break
+ if not categorized:
+ category_times['Other'] += stats['cuda_time_total_us']
+
+ total_op_time = sum(category_times.values())
+ report_lines.append(f"{'Category':<30} {'CUDA Time(ms)':>15} {'Percentage':>12}")
+ report_lines.append("-" * 57)
+ for cat, time_us in sorted(category_times.items(), key=lambda x: -x[1]):
+ pct = (time_us / total_op_time * 100) if total_op_time > 0 else 0
+ report_lines.append(f"{cat:<30} {time_us/1000:>15.2f} {pct:>11.1f}%")
+ report_lines.append("")
+
+ report = "\n".join(report_lines)
+ return report
+
+ def save_results(self):
+ """Save all profiling results to files."""
+ if not self.enabled:
+ return
+
+ # Save report
+ report = self.generate_report()
+ report_path = os.path.join(self.output_dir, "profiling_report.txt")
+ with open(report_path, 'w') as f:
+ f.write(report)
+ print(f">>> Profiling report saved to: {report_path}")
+
+ # Save detailed JSON data
+ data = {
+ 'macro_timings': {
+ name: {
+ 'times': times,
+ 'cuda_times': [ct for _, ct in self.cuda_events.get(name, [])]
+ }
+ for name, times in self.macro_timings.items()
+ },
+ 'memory_snapshots': self.memory_snapshots,
+ 'operator_stats': self.operator_stats,
+ }
+ json_path = os.path.join(self.output_dir, "profiling_data.json")
+ with open(json_path, 'w') as f:
+ json.dump(data, f, indent=2)
+ print(f">>> Detailed profiling data saved to: {json_path}")
+
+ # Print summary to console
+ print("\n" + report)
+
+
+# Global profiler instance
+_profiler: Optional[ProfilerManager] = None
+
+def get_profiler() -> ProfilerManager:
+ """Get the global profiler instance."""
+ global _profiler
+ if _profiler is None:
+ _profiler = ProfilerManager(enabled=False)
+ return _profiler
+
+def init_profiler(enabled: bool, output_dir: str) -> ProfilerManager:
+ """Initialize the global profiler."""
+ global _profiler
+ _profiler = ProfilerManager(enabled=enabled, output_dir=output_dir)
+ return _profiler
+
+
+# ========== Original Functions ==========
+def get_device_from_parameters(module: nn.Module) -> torch.device:
+ """Get a module's device by checking one of its parameters.
+
+ Args:
+ module (nn.Module): The model whose device is to be inferred.
+
+ Returns:
+ torch.device: The device of the model's parameters.
+ """
+ return next(iter(module.parameters())).device
+
+
+def write_video(video_path: str, stacked_frames: list, fps: int) -> None:
+ """Save a list of frames to a video file.
+
+ Args:
+ video_path (str): Output path for the video.
+ stacked_frames (list): List of image frames.
+ fps (int): Frames per second for the video.
+ """
+ with warnings.catch_warnings():
+ warnings.filterwarnings("ignore",
+ "pkg_resources is deprecated as an API",
+ category=DeprecationWarning)
+ imageio.mimsave(video_path, stacked_frames, fps=fps)
+
+
+def get_filelist(data_dir: str, postfixes: list[str]) -> list[str]:
+ """Return sorted list of files in a directory matching specified postfixes.
+
+ Args:
+ data_dir (str): Directory path to search in.
+ postfixes (list[str]): List of file extensions to match.
+
+ Returns:
+ list[str]: Sorted list of file paths.
+ """
+ patterns = [
+ os.path.join(data_dir, f"*.{postfix}") for postfix in postfixes
+ ]
+ file_list = []
+ for pattern in patterns:
+ file_list.extend(glob.glob(pattern))
+ file_list.sort()
+ return file_list
+
+
+def load_model_checkpoint(model: nn.Module, ckpt: str) -> nn.Module:
+ """Load model weights from checkpoint file.
+
+ Args:
+ model (nn.Module): Model instance.
+ ckpt (str): Path to the checkpoint file.
+
+ Returns:
+ nn.Module: Model with loaded weights.
+ """
+ state_dict = torch.load(ckpt, map_location="cpu")
+ if "state_dict" in list(state_dict.keys()):
+ state_dict = state_dict["state_dict"]
+ try:
+ model.load_state_dict(state_dict, strict=True)
+ except:
+ new_pl_sd = OrderedDict()
+ for k, v in state_dict.items():
+ new_pl_sd[k] = v
+
+ for k in list(new_pl_sd.keys()):
+ if "framestride_embed" in k:
+ new_key = k.replace("framestride_embed", "fps_embedding")
+ new_pl_sd[new_key] = new_pl_sd[k]
+ del new_pl_sd[k]
+ model.load_state_dict(new_pl_sd, strict=True)
+ else:
+ new_pl_sd = OrderedDict()
+ for key in state_dict['module'].keys():
+ new_pl_sd[key[16:]] = state_dict['module'][key]
+ model.load_state_dict(new_pl_sd)
+ print('>>> model checkpoint loaded.')
+ return model
+
+
+def is_inferenced(save_dir: str, filename: str) -> bool:
+ """Check if a given filename has already been processed and saved.
+
+ Args:
+ save_dir (str): Directory where results are saved.
+ filename (str): Name of the file to check.
+
+ Returns:
+ bool: True if processed file exists, False otherwise.
+ """
+ video_file = os.path.join(save_dir, "samples_separate",
+ f"{filename[:-4]}_sample0.mp4")
+ return os.path.exists(video_file)
+
+
+def save_results(video: Tensor, filename: str, fps: int = 8) -> None:
+ """Save video tensor to file using torchvision.
+
+ Args:
+ video (Tensor): Tensor of shape (B, C, T, H, W).
+ filename (str): Output file path.
+ fps (int, optional): Frames per second. Defaults to 8.
+ """
+ video = video.detach().cpu()
+ video = torch.clamp(video.float(), -1., 1.)
+ n = video.shape[0]
+ video = video.permute(2, 0, 1, 3, 4)
+
+ frame_grids = [
+ torchvision.utils.make_grid(framesheet, nrow=int(n), padding=0)
+ for framesheet in video
+ ]
+ grid = torch.stack(frame_grids, dim=0)
+ grid = (grid + 1.0) / 2.0
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
+ torchvision.io.write_video(filename,
+ grid,
+ fps=fps,
+ video_codec='h264',
+ options={'crf': '10'})
+
+
+def get_init_frame_path(data_dir: str, sample: dict) -> str:
+ """Construct the init_frame path from directory and sample metadata.
+
+ Args:
+ data_dir (str): Base directory containing videos.
+ sample (dict): Dictionary containing 'data_dir' and 'videoid'.
+
+ Returns:
+ str: Full path to the video file.
+ """
+ rel_video_fp = os.path.join(sample['data_dir'],
+ str(sample['videoid']) + '.png')
+ full_image_fp = os.path.join(data_dir, 'images', rel_video_fp)
+ return full_image_fp
+
+
+def get_transition_path(data_dir: str, sample: dict) -> str:
+ """Construct the full transition file path from directory and sample metadata.
+
+ Args:
+ data_dir (str): Base directory containing transition files.
+ sample (dict): Dictionary containing 'data_dir' and 'videoid'.
+
+ Returns:
+ str: Full path to the HDF5 transition file.
+ """
+ rel_transition_fp = os.path.join(sample['data_dir'],
+ str(sample['videoid']) + '.h5')
+ full_transition_fp = os.path.join(data_dir, 'transitions',
+ rel_transition_fp)
+ return full_transition_fp
+
+
+def prepare_init_input(start_idx: int,
+ init_frame_path: str,
+ transition_dict: dict[str, torch.Tensor],
+ frame_stride: int,
+ wma_data,
+ video_length: int = 16,
+ n_obs_steps: int = 2) -> dict[str, Tensor]:
+ """
+ Extracts a structured sample from a video sequence including frames, states, and actions,
+ along with properly padded observations and pre-processed tensors for model input.
+
+ Args:
+ start_idx (int): Starting frame index for the current clip.
+ video: decord video instance.
+ transition_dict (Dict[str, Tensor]): Dictionary containing tensors for 'action',
+ 'observation.state', 'action_type', 'state_type'.
+ frame_stride (int): Temporal stride between sampled frames.
+ wma_data: Object that holds configuration and utility functions like normalization,
+ transformation, and resolution info.
+ video_length (int, optional): Number of frames to sample from the video. Default is 16.
+ n_obs_steps (int, optional): Number of historical steps for observations. Default is 2.
+ """
+
+ indices = [start_idx + frame_stride * i for i in range(video_length)]
+ init_frame = Image.open(init_frame_path).convert('RGB')
+ init_frame = torch.tensor(np.array(init_frame)).unsqueeze(0).permute(
+ 3, 0, 1, 2).float()
+
+ if start_idx < n_obs_steps - 1:
+ state_indices = list(range(0, start_idx + 1))
+ states = transition_dict['observation.state'][state_indices, :]
+ num_padding = n_obs_steps - 1 - start_idx
+ first_slice = states[0:1, :] # (t, d)
+ padding = first_slice.repeat(num_padding, 1)
+ states = torch.cat((padding, states), dim=0)
+ else:
+ state_indices = list(range(start_idx - n_obs_steps + 1, start_idx + 1))
+ states = transition_dict['observation.state'][state_indices, :]
+
+ actions = transition_dict['action'][indices, :]
+
+ ori_state_dim = states.shape[-1]
+ ori_action_dim = actions.shape[-1]
+
+ frames_action_state_dict = {
+ 'action': actions,
+ 'observation.state': states,
+ }
+ frames_action_state_dict = wma_data.normalizer(frames_action_state_dict)
+ frames_action_state_dict = wma_data.get_uni_vec(
+ frames_action_state_dict,
+ transition_dict['action_type'],
+ transition_dict['state_type'],
+ )
+
+ if wma_data.spatial_transform is not None:
+ init_frame = wma_data.spatial_transform(init_frame)
+ init_frame = (init_frame / 255 - 0.5) * 2
+
+ data = {
+ 'observation.image': init_frame,
+ }
+ data.update(frames_action_state_dict)
+ return data, ori_state_dim, ori_action_dim
+
+
+def get_latent_z(model, videos: Tensor) -> Tensor:
+ """
+ Extracts latent features from a video batch using the model's first-stage encoder.
+
+ Args:
+ model: the world model.
+ videos (Tensor): Input videos of shape [B, C, T, H, W].
+
+ Returns:
+ Tensor: Latent video tensor of shape [B, C, T, H, W].
+ """
+ profiler = get_profiler()
+ with profiler.profile_section("get_latent_z/encode"):
+ b, c, t, h, w = videos.shape
+ x = rearrange(videos, 'b c t h w -> (b t) c h w')
+ z = model.encode_first_stage(x)
+ z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
+ return z
+
+
+def preprocess_observation(
+ model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
+ """Convert environment observation to LeRobot format observation.
+ Args:
+ observation: Dictionary of observation batches from a Gym vector environment.
+ Returns:
+ Dictionary of observation batches with keys renamed to LeRobot format and values as tensors.
+ """
+ # Map to expected inputs for the policy
+ return_observations = {}
+
+ if isinstance(observations["pixels"], dict):
+ imgs = {
+ f"observation.images.{key}": img
+ for key, img in observations["pixels"].items()
+ }
+ else:
+ imgs = {"observation.images.top": observations["pixels"]}
+
+ for imgkey, img in imgs.items():
+ img = torch.from_numpy(img)
+
+ # Sanity check that images are channel last
+ _, h, w, c = img.shape
+ assert c < h and c < w, f"expect channel first images, but instead {img.shape}"
+
+ # Sanity check that images are uint8
+ assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}"
+
+ # Convert to channel first of type float32 in range [0,1]
+ img = einops.rearrange(img, "b h w c -> b c h w").contiguous()
+ img = img.type(torch.float32)
+
+ return_observations[imgkey] = img
+
+ return_observations["observation.state"] = torch.from_numpy(
+ observations["agent_pos"]).float()
+ return_observations['observation.state'] = model.normalize_inputs({
+ 'observation.state':
+ return_observations['observation.state'].to(model.device)
+ })['observation.state']
+
+ return return_observations
+
+
+def image_guided_synthesis_sim_mode(
+ model: torch.nn.Module,
+ prompts: list[str],
+ observation: dict,
+ noise_shape: tuple[int, int, int, int, int],
+ action_cond_step: int = 16,
+ n_samples: int = 1,
+ ddim_steps: int = 50,
+ ddim_eta: float = 1.0,
+ unconditional_guidance_scale: float = 1.0,
+ fs: int | None = None,
+ text_input: bool = True,
+ timestep_spacing: str = 'uniform',
+ guidance_rescale: float = 0.0,
+ sim_mode: bool = True,
+ **kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Performs image-guided video generation in a simulation-style mode with optional multimodal guidance (image, state, action, text).
+
+ Args:
+ model (torch.nn.Module): The diffusion-based generative model with multimodal conditioning.
+ prompts (list[str]): A list of textual prompts to guide the synthesis process.
+ observation (dict): A dictionary containing observed inputs including:
+ - 'observation.images.top': Tensor of shape [B, O, C, H, W] (top-down images)
+ - 'observation.state': Tensor of shape [B, O, D] (state vector)
+ - 'action': Tensor of shape [B, T, D] (action sequence)
+ noise_shape (tuple[int, int, int, int, int]): Shape of the latent variable to generate,
+ typically (B, C, T, H, W).
+ action_cond_step (int): Number of time steps where action conditioning is applied. Default is 16.
+ n_samples (int): Number of samples to generate (unused here, always generates 1). Default is 1.
+ ddim_steps (int): Number of DDIM sampling steps. Default is 50.
+ ddim_eta (float): DDIM eta parameter controlling the stochasticity. Default is 1.0.
+ unconditional_guidance_scale (float): Scale for classifier-free guidance. If 1.0, guidance is off.
+ fs (int | None): Frame index to condition on, broadcasted across the batch if specified. Default is None.
+ text_input (bool): Whether to use text prompt as conditioning. If False, uses empty strings. Default is True.
+ timestep_spacing (str): Timestep sampling method in DDIM sampler. Typically "uniform" or "linspace".
+ guidance_rescale (float): Guidance rescaling factor to mitigate overexposure from classifier-free guidance.
+ sim_mode (bool): Whether to perform world-model interaction or decision-making using the world-model.
+ **kwargs: Additional arguments passed to the DDIM sampler.
+
+ Returns:
+ batch_variants (torch.Tensor): Predicted pixel-space video frames [B, C, T, H, W].
+ actions (torch.Tensor): Predicted action sequences [B, T, D] from diffusion decoding.
+ states (torch.Tensor): Predicted state sequences [B, T, D] from diffusion decoding.
+ """
+ profiler = get_profiler()
+
+ b, _, t, _, _ = noise_shape
+ ddim_sampler = DDIMSampler(model)
+ batch_size = noise_shape[0]
+
+ fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device)
+
+ with profiler.profile_section("synthesis/conditioning_prep"):
+ img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
+ cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
+ cond_img_emb = model.embedder(cond_img)
+ cond_img_emb = model.image_proj_model(cond_img_emb)
+
+ if model.model.conditioning_key == 'hybrid':
+ z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
+ img_cat_cond = z[:, :, -1:, :, :]
+ img_cat_cond = repeat(img_cat_cond,
+ 'b c t h w -> b c (repeat t) h w',
+ repeat=noise_shape[2])
+ cond = {"c_concat": [img_cat_cond]}
+
+ if not text_input:
+ prompts = [""] * batch_size
+ cond_ins_emb = model.get_learned_conditioning(prompts)
+
+ cond_state_emb = model.state_projector(observation['observation.state'])
+ cond_state_emb = cond_state_emb + model.agent_state_pos_emb
+
+ cond_action_emb = model.action_projector(observation['action'])
+ cond_action_emb = cond_action_emb + model.agent_action_pos_emb
+
+ if not sim_mode:
+ cond_action_emb = torch.zeros_like(cond_action_emb)
+
+ cond["c_crossattn"] = [
+ torch.cat(
+ [cond_state_emb, cond_action_emb, cond_ins_emb, cond_img_emb],
+ dim=1)
+ ]
+ cond["c_crossattn_action"] = [
+ observation['observation.images.top'][:, :,
+ -model.n_obs_steps_acting:],
+ observation['observation.state'][:, -model.n_obs_steps_acting:],
+ sim_mode,
+ False,
+ ]
+
+ uc = None
+ kwargs.update({"unconditional_conditioning_img_nonetext": None})
+ cond_mask = None
+ cond_z0 = None
+
+ if ddim_sampler is not None:
+ with profiler.profile_section("synthesis/ddim_sampling"):
+ samples, actions, states, intermedia = ddim_sampler.sample(
+ S=ddim_steps,
+ conditioning=cond,
+ batch_size=batch_size,
+ shape=noise_shape[1:],
+ verbose=False,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ eta=ddim_eta,
+ cfg_img=None,
+ mask=cond_mask,
+ x0=cond_z0,
+ fs=fs,
+ timestep_spacing=timestep_spacing,
+ guidance_rescale=guidance_rescale,
+ **kwargs)
+
+ # Reconstruct from latent to pixel space
+ with profiler.profile_section("synthesis/decode_first_stage"):
+ batch_images = model.decode_first_stage(samples)
+ batch_variants = batch_images
+
+ return batch_variants, actions, states
+
+
+def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
+ """
+ Run inference pipeline on prompts and image inputs.
+
+ Args:
+ args (argparse.Namespace): Parsed command-line arguments.
+ gpu_num (int): Number of GPUs.
+ gpu_no (int): Index of the current GPU.
+
+ Returns:
+ None
+ """
+ profiler = get_profiler()
+
+ # Create inference and tensorboard dirs
+ os.makedirs(args.savedir + '/inference', exist_ok=True)
+ log_dir = args.savedir + f"/tensorboard"
+ os.makedirs(log_dir, exist_ok=True)
+ writer = SummaryWriter(log_dir=log_dir)
+
+ # Load prompt
+ csv_path = os.path.join(args.prompt_dir, f"{args.dataset}.csv")
+ df = pd.read_csv(csv_path)
+
+ # Load config
+ with profiler.profile_section("model_loading/config"):
+ config = OmegaConf.load(args.config)
+ config['model']['params']['wma_config']['params'][
+ 'use_checkpoint'] = False
+ model = instantiate_from_config(config.model)
+ model.perframe_ae = args.perframe_ae
+
+ assert os.path.exists(args.ckpt_path), "Error: checkpoint Not Found!"
+
+ with profiler.profile_section("model_loading/checkpoint"):
+ model = load_model_checkpoint(model, args.ckpt_path)
+ model.eval()
+ print(f'>>> Load pre-trained model ...')
+
+ # Build unnomalizer
+ logging.info("***** Configing Data *****")
+ with profiler.profile_section("data_loading"):
+ data = instantiate_from_config(config.data)
+ data.setup()
+ print(">>> Dataset is successfully loaded ...")
+
+ with profiler.profile_section("model_to_cuda"):
+ model = model.cuda(gpu_no)
+ device = get_device_from_parameters(model)
+
+ profiler.record_memory("after_model_load")
+
+ # Run over data
+ assert (args.height % 16 == 0) and (
+ args.width % 16
+ == 0), "Error: image size [h,w] should be multiples of 16!"
+ assert args.bs == 1, "Current implementation only support [batch size = 1]!"
+
+ # Get latent noise shape
+ h, w = args.height // 8, args.width // 8
+ channels = model.model.diffusion_model.out_channels
+ n_frames = args.video_length
+ print(f'>>> Generate {n_frames} frames under each generation ...')
+ noise_shape = [args.bs, channels, n_frames, h, w]
+
+ # Determine profiler iterations
+ profile_active_iters = getattr(args, 'profile_iterations', 3)
+ use_pytorch_profiler = profiler.enabled and profile_active_iters > 0
+
+ # Start inference
+ for idx in range(0, len(df)):
+ sample = df.iloc[idx]
+
+ # Got initial frame path
+ init_frame_path = get_init_frame_path(args.prompt_dir, sample)
+ ori_fps = float(sample['fps'])
+
+ video_save_dir = args.savedir + f"/inference/sample_{sample['videoid']}"
+ os.makedirs(video_save_dir, exist_ok=True)
+ os.makedirs(video_save_dir + '/dm', exist_ok=True)
+ os.makedirs(video_save_dir + '/wm', exist_ok=True)
+
+ # Load transitions to get the initial state later
+ transition_path = get_transition_path(args.prompt_dir, sample)
+ with profiler.profile_section("load_transitions"):
+ with h5py.File(transition_path, 'r') as h5f:
+ transition_dict = {}
+ for key in h5f.keys():
+ transition_dict[key] = torch.tensor(h5f[key][()])
+ for key in h5f.attrs.keys():
+ transition_dict[key] = h5f.attrs[key]
+
+ # If many, test various frequence control and world-model generation
+ for fs in args.frame_stride:
+
+ # For saving imagens in policy
+ sample_save_dir = f'{video_save_dir}/dm/{fs}'
+ os.makedirs(sample_save_dir, exist_ok=True)
+ # For saving environmental changes in world-model
+ sample_save_dir = f'{video_save_dir}/wm/{fs}'
+ os.makedirs(sample_save_dir, exist_ok=True)
+ # For collecting interaction videos
+ wm_video = []
+ # Initialize observation queues
+ cond_obs_queues = {
+ "observation.images.top":
+ deque(maxlen=model.n_obs_steps_imagen),
+ "observation.state": deque(maxlen=model.n_obs_steps_imagen),
+ "action": deque(maxlen=args.video_length),
+ }
+
+ # Obtain initial frame and state
+ with profiler.profile_section("prepare_init_input"):
+ start_idx = 0
+ model_input_fs = ori_fps // fs
+ batch, ori_state_dim, ori_action_dim = prepare_init_input(
+ start_idx,
+ init_frame_path,
+ transition_dict,
+ fs,
+ data.test_datasets[args.dataset],
+ n_obs_steps=model.n_obs_steps_imagen)
+ observation = {
+ 'observation.images.top':
+ batch['observation.image'].permute(1, 0, 2,
+ 3)[-1].unsqueeze(0),
+ 'observation.state':
+ batch['observation.state'][-1].unsqueeze(0),
+ 'action':
+ torch.zeros_like(batch['action'][-1]).unsqueeze(0)
+ }
+ observation = {
+ key: observation[key].to(device, non_blocking=True)
+ for key in observation
+ }
+ # Update observation queues
+ cond_obs_queues = populate_queues(cond_obs_queues, observation)
+
+ # Setup PyTorch profiler context if enabled
+ pytorch_prof_ctx = nullcontext()
+ if use_pytorch_profiler:
+ pytorch_prof_ctx = profiler.start_pytorch_profiler(
+ wait=1, warmup=1, active=profile_active_iters
+ )
+
+ # Multi-round interaction with the world-model
+ with pytorch_prof_ctx:
+ for itr in tqdm(range(args.n_iter)):
+ profiler.current_iteration = itr
+ profiler.record_memory(f"iter_{itr}_start")
+
+ with profiler.profile_section("iteration_total"):
+ # Get observation
+ with profiler.profile_section("prepare_observation"):
+ observation = {
+ 'observation.images.top':
+ torch.stack(list(
+ cond_obs_queues['observation.images.top']),
+ dim=1).permute(0, 2, 1, 3, 4),
+ 'observation.state':
+ torch.stack(list(cond_obs_queues['observation.state']),
+ dim=1),
+ 'action':
+ torch.stack(list(cond_obs_queues['action']), dim=1),
+ }
+ observation = {
+ key: observation[key].to(device, non_blocking=True)
+ for key in observation
+ }
+
+ # Use world-model in policy to generate action
+ print(f'>>> Step {itr}: generating actions ...')
+ with profiler.profile_section("action_generation"):
+ pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode(
+ model,
+ sample['instruction'],
+ observation,
+ noise_shape,
+ action_cond_step=args.exe_steps,
+ ddim_steps=args.ddim_steps,
+ ddim_eta=args.ddim_eta,
+ unconditional_guidance_scale=args.
+ unconditional_guidance_scale,
+ fs=model_input_fs,
+ timestep_spacing=args.timestep_spacing,
+ guidance_rescale=args.guidance_rescale,
+ sim_mode=False)
+
+ # Update future actions in the observation queues
+ with profiler.profile_section("update_action_queues"):
+ for act_idx in range(len(pred_actions[0])):
+ obs_update = {'action': pred_actions[0][act_idx:act_idx + 1]}
+ obs_update['action'][:, ori_action_dim:] = 0.0
+ cond_obs_queues = populate_queues(cond_obs_queues,
+ obs_update)
+
+ # Collect data for interacting the world-model using the predicted actions
+ with profiler.profile_section("prepare_wm_observation"):
+ observation = {
+ 'observation.images.top':
+ torch.stack(list(
+ cond_obs_queues['observation.images.top']),
+ dim=1).permute(0, 2, 1, 3, 4),
+ 'observation.state':
+ torch.stack(list(cond_obs_queues['observation.state']),
+ dim=1),
+ 'action':
+ torch.stack(list(cond_obs_queues['action']), dim=1),
+ }
+ observation = {
+ key: observation[key].to(device, non_blocking=True)
+ for key in observation
+ }
+
+ # Interaction with the world-model
+ print(f'>>> Step {itr}: interacting with world model ...')
+ with profiler.profile_section("world_model_interaction"):
+ pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode(
+ model,
+ "",
+ observation,
+ noise_shape,
+ action_cond_step=args.exe_steps,
+ ddim_steps=args.ddim_steps,
+ ddim_eta=args.ddim_eta,
+ unconditional_guidance_scale=args.
+ unconditional_guidance_scale,
+ fs=model_input_fs,
+ text_input=False,
+ timestep_spacing=args.timestep_spacing,
+ guidance_rescale=args.guidance_rescale)
+
+ with profiler.profile_section("update_state_queues"):
+ for step_idx in range(args.exe_steps):
+ obs_update = {
+ 'observation.images.top':
+ pred_videos_1[0][:, step_idx:step_idx + 1].permute(1, 0, 2, 3),
+ 'observation.state':
+ torch.zeros_like(pred_states[0][step_idx:step_idx + 1]) if
+ args.zero_pred_state else pred_states[0][step_idx:step_idx + 1],
+ 'action':
+ torch.zeros_like(pred_actions[0][-1:])
+ }
+ obs_update['observation.state'][:, ori_state_dim:] = 0.0
+ cond_obs_queues = populate_queues(cond_obs_queues,
+ obs_update)
+
+ # Save the imagen videos for decision-making
+ with profiler.profile_section("save_results"):
+ sample_tag = f"{args.dataset}-vid{sample['videoid']}-dm-fs-{fs}/itr-{itr}"
+ log_to_tensorboard(writer,
+ pred_videos_0,
+ sample_tag,
+ fps=args.save_fps)
+ # Save videos environment changes via world-model interaction
+ sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/itr-{itr}"
+ log_to_tensorboard(writer,
+ pred_videos_1,
+ sample_tag,
+ fps=args.save_fps)
+
+ # Save the imagen videos for decision-making
+ sample_video_file = f'{video_save_dir}/dm/{fs}/itr-{itr}.mp4'
+ save_results(pred_videos_0.cpu(),
+ sample_video_file,
+ fps=args.save_fps)
+ # Save videos environment changes via world-model interaction
+ sample_video_file = f'{video_save_dir}/wm/{fs}/itr-{itr}.mp4'
+ save_results(pred_videos_1.cpu(),
+ sample_video_file,
+ fps=args.save_fps)
+
+ print('>' * 24)
+ # Collect the result of world-model interactions
+ wm_video.append(pred_videos_1[:, :, :args.exe_steps].cpu())
+
+ profiler.record_memory(f"iter_{itr}_end")
+ profiler.step_profiler()
+
+ full_video = torch.cat(wm_video, dim=2)
+ sample_tag = f"{args.dataset}-vid{sample['videoid']}-wd-fs-{fs}/full"
+ log_to_tensorboard(writer,
+ full_video,
+ sample_tag,
+ fps=args.save_fps)
+ sample_full_video_file = f"{video_save_dir}/../{sample['videoid']}_full_fs{fs}.mp4"
+ save_results(full_video, sample_full_video_file, fps=args.save_fps)
+
+ # Save profiling results
+ profiler.save_results()
+
+
+def get_parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--savedir",
+ type=str,
+ default=None,
+ help="Path to save the results.")
+ parser.add_argument("--ckpt_path",
+ type=str,
+ default=None,
+ help="Path to the model checkpoint.")
+ parser.add_argument("--config",
+ type=str,
+ help="Path to the model checkpoint.")
+ parser.add_argument(
+ "--prompt_dir",
+ type=str,
+ default=None,
+ help="Directory containing videos and corresponding prompts.")
+ parser.add_argument("--dataset",
+ type=str,
+ default=None,
+ help="the name of dataset to test")
+ parser.add_argument(
+ "--ddim_steps",
+ type=int,
+ default=50,
+ help="Number of DDIM steps. If non-positive, DDPM is used instead.")
+ parser.add_argument(
+ "--ddim_eta",
+ type=float,
+ default=1.0,
+ help="Eta for DDIM sampling. Set to 0.0 for deterministic results.")
+ parser.add_argument("--bs",
+ type=int,
+ default=1,
+ help="Batch size for inference. Must be 1.")
+ parser.add_argument("--height",
+ type=int,
+ default=320,
+ help="Height of the generated images in pixels.")
+ parser.add_argument("--width",
+ type=int,
+ default=512,
+ help="Width of the generated images in pixels.")
+ parser.add_argument(
+ "--frame_stride",
+ type=int,
+ nargs='+',
+ required=True,
+ help=
+ "frame stride control for 256 model (larger->larger motion), FPS control for 512 or 1024 model (smaller->larger motion)"
+ )
+ parser.add_argument(
+ "--unconditional_guidance_scale",
+ type=float,
+ default=1.0,
+ help="Scale for classifier-free guidance during sampling.")
+ parser.add_argument("--seed",
+ type=int,
+ default=123,
+ help="Random seed for reproducibility.")
+ parser.add_argument("--video_length",
+ type=int,
+ default=16,
+ help="Number of frames in the generated video.")
+ parser.add_argument("--num_generation",
+ type=int,
+ default=1,
+ help="seed for seed_everything")
+ parser.add_argument(
+ "--timestep_spacing",
+ type=str,
+ default="uniform",
+ help=
+ "Strategy for timestep scaling. See Table 2 in the paper: 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
+ )
+ parser.add_argument(
+ "--guidance_rescale",
+ type=float,
+ default=0.0,
+ help=
+ "Rescale factor for guidance as discussed in 'Common Diffusion Noise Schedules and Sample Steps are Flawed' (https://huggingface.co/papers/2305.08891)."
+ )
+ parser.add_argument(
+ "--perframe_ae",
+ action='store_true',
+ default=False,
+ help=
+ "Use per-frame autoencoder decoding to reduce GPU memory usage. Recommended for models with resolutions like 576x1024."
+ )
+ parser.add_argument(
+ "--n_action_steps",
+ type=int,
+ default=16,
+ help="num of samples per prompt",
+ )
+ parser.add_argument(
+ "--exe_steps",
+ type=int,
+ default=16,
+ help="num of samples to execute",
+ )
+ parser.add_argument(
+ "--n_iter",
+ type=int,
+ default=40,
+ help="num of iteration to interact with the world model",
+ )
+ parser.add_argument("--zero_pred_state",
+ action='store_true',
+ default=False,
+ help="not using the predicted states as comparison")
+ parser.add_argument("--save_fps",
+ type=int,
+ default=8,
+ help="fps for the saving video")
+ # Profiling arguments
+ parser.add_argument(
+ "--profile",
+ action='store_true',
+ default=False,
+ help="Enable performance profiling (macro and operator-level analysis)."
+ )
+ parser.add_argument(
+ "--profile_output_dir",
+ type=str,
+ default=None,
+ help="Directory to save profiling results. Defaults to {savedir}/profile_output."
+ )
+ parser.add_argument(
+ "--profile_iterations",
+ type=int,
+ default=3,
+ help="Number of iterations to run PyTorch profiler's active phase for operator-level analysis."
+ )
+ return parser
+
+
+if __name__ == '__main__':
+ parser = get_parser()
+ args = parser.parse_args()
+ seed = args.seed
+ if seed < 0:
+ seed = random.randint(0, 2**31)
+ seed_everything(seed)
+
+ # Initialize profiler
+ profile_output_dir = args.profile_output_dir
+ if profile_output_dir is None:
+ profile_output_dir = os.path.join(args.savedir, "profile_output")
+ init_profiler(enabled=args.profile, output_dir=profile_output_dir)
+
+ rank, gpu_num = 0, 1
+ run_inference(args, gpu_num, rank)
diff --git a/scripts/run_base_model_inference.sh b/scripts/run_base_model_inference.sh
new file mode 100644
index 0000000..7740d14
--- /dev/null
+++ b/scripts/run_base_model_inference.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+
+model_name=base_model
+ckpt=/path/to/base/model
+config=configs/inference/base_model_inference.yaml
+res_dir="/path/to/result/directory"
+seed=123
+
+CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/base_model_inference.py \
+--seed ${seed} \
+--ckpt_path $ckpt \
+--config $config \
+--savedir "${res_dir}/videos" \
+--bs 1 --height 320 --width 512 \
+--unconditional_guidance_scale 1.0 \
+--ddim_steps 16 \
+--ddim_eta 1.0 \
+--prompt_dir "/path/to/examples/base_model_prompts" \
+--text_input \
+--video_length 16 \
+--timestep_spacing 'uniform_trailing' \
+--guidance_rescale 0.7 \
+--perframe_ae
diff --git a/scripts/run_real_eval_server.sh b/scripts/run_real_eval_server.sh
new file mode 100644
index 0000000..f3c1947
--- /dev/null
+++ b/scripts/run_real_eval_server.sh
@@ -0,0 +1,26 @@
+model_name=testing
+ckpt=/path/to/model/checkpoint
+config=configs/inference/world_model_decision_making.yaml
+seed=123
+res_dir="path/to/results/directory"
+datasets=(
+ "unitree_g1_pack_camera"
+)
+
+
+for dataset in "${datasets[@]}"; do
+ CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/real_eval_server.py \
+ --seed ${seed} \
+ --ckpt_path $ckpt \
+ --config $config \
+ --savedir "${res_dir}/${dataset}/${model_name}/videos" \
+ --bs 1 --height 320 --width 512 \
+ --unconditional_guidance_scale 1.0 \
+ --ddim_steps 16 \
+ --ddim_eta 1.0 \
+ --video_length 16 \
+ --frame_stride 2 \
+ --timestep_spacing 'uniform_trailing' \
+ --guidance_rescale 0.7 \
+ --perframe_ae
+done
diff --git a/scripts/run_world_model_interaction.sh b/scripts/run_world_model_interaction.sh
new file mode 100644
index 0000000..e8fc586
--- /dev/null
+++ b/scripts/run_world_model_interaction.sh
@@ -0,0 +1,42 @@
+model_name=testing
+ckpt=/path/to/model/checkpoint
+config=configs/inference/world_model_interaction.yaml
+seed=123
+res_dir="/path/to/result/directory"
+
+datasets=(
+ "unitree_z1_stackbox"
+ "unitree_z1_dual_arm_stackbox"
+ "unitree_z1_dual_arm_stackbox_v2"
+ "unitree_z1_dual_arm_cleanup_pencils"
+ "unitree_g1_pack_camera"
+)
+
+n_iters=(12 7 11 8 11)
+fses=(4 4 4 4 6)
+
+for i in "${!datasets[@]}"; do
+ dataset=${datasets[$i]}
+ n_iter=${n_iters[$i]}
+ fs=${fses[$i]}
+
+ CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
+ --seed ${seed} \
+ --ckpt_path $ckpt \
+ --config $config \
+ --savedir "${res_dir}/${model_name}/${dataset}" \
+ --bs 1 --height 320 --width 512 \
+ --unconditional_guidance_scale 1.0 \
+ --ddim_steps 50 \
+ --ddim_eta 1.0 \
+ --prompt_dir "/path/to/unifolm-world-model-action/examples/world_model_interaction_prompts" \
+ --dataset ${dataset} \
+ --video_length 16 \
+ --frame_stride ${fs} \
+ --n_action_steps 16 \
+ --exe_steps 16 \
+ --n_iter ${n_iter} \
+ --timestep_spacing 'uniform_trailing' \
+ --guidance_rescale 0.7 \
+ --perframe_ae
+done
diff --git a/scripts/train.sh b/scripts/train.sh
new file mode 100644
index 0000000..d37ee79
--- /dev/null
+++ b/scripts/train.sh
@@ -0,0 +1,32 @@
+# NCCL configuration
+# export NCCL_DEBUG=debug
+# export NCCL_IB_DISABLE=0
+# export NCCL_IB_GID_INDEX=3
+# export NCCL_NET_GDR_LEVEL=3
+# export CUDA_LAUNCH_BLOCKING=1
+
+# export NCCL_TOPO_FILE=/tmp/topo.txt
+# export MASTER_ADDR="master.ip."
+# export MASTER_PROT=12366
+
+
+# args
+name="experiment_name"
+config_file=configs/train/config.yaml
+
+# save root dir for logs, checkpoints, tensorboard record, etc.
+save_root="/path/to/savedir"
+
+mkdir -p $save_root/$name
+
+## run
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch \
+--nproc_per_node=8 --nnodes=1 --master_addr=127.0.0.1 --master_port=12366 --node_rank=0 \
+./scripts/trainer.py \
+--base $config_file \
+--train \
+--name $name \
+--logdir $save_root \
+--devices 8 \
+--total_gpus=8 \
+lightning.trainer.num_nodes=1
diff --git a/scripts/trainer.py b/scripts/trainer.py
new file mode 100644
index 0000000..87c6820
--- /dev/null
+++ b/scripts/trainer.py
@@ -0,0 +1,214 @@
+import argparse, os, datetime
+import pytorch_lightning as pl
+import torch
+
+from omegaconf import OmegaConf
+from transformers import logging as transf_logging
+from pytorch_lightning import seed_everything
+from pytorch_lightning.trainer import Trainer
+
+from unifolm_wma.utils.utils import instantiate_from_config
+from unifolm_wma.utils.train import get_trainer_callbacks, get_trainer_logger, get_trainer_strategy
+from unifolm_wma.utils.train import set_logger, init_workspace, load_checkpoints, get_num_parameters
+
+
+def get_parser(**parser_kwargs):
+ parser = argparse.ArgumentParser(**parser_kwargs)
+ parser.add_argument("--seed",
+ "-s",
+ type=int,
+ default=20250912,
+ help="seed for seed_everything")
+ parser.add_argument("--name",
+ "-n",
+ type=str,
+ default="",
+ help="experiment name, as saving folder")
+ parser.add_argument(
+ "--base",
+ "-b",
+ nargs="*",
+ metavar="base_config.yaml",
+ help="paths to base configs. Loaded from left-to-right.",
+ default=list())
+ parser.add_argument("--train",
+ "-t",
+ action='store_true',
+ default=False,
+ help='train')
+ parser.add_argument("--val",
+ "-v",
+ action='store_true',
+ default=False,
+ help='val')
+ parser.add_argument("--test",
+ action='store_true',
+ default=False,
+ help='test')
+ parser.add_argument("--logdir",
+ "-l",
+ type=str,
+ default="logs",
+ help="directory for logging dat shit")
+ parser.add_argument("--auto_resume",
+ action='store_true',
+ default=False,
+ help="resume from full-info checkpoint")
+ parser.add_argument("--auto_resume_weight_only",
+ action='store_true',
+ default=False,
+ help="resume from weight-only checkpoint")
+ parser.add_argument("--debug",
+ "-d",
+ action='store_true',
+ default=False,
+ help="enable post-mortem debugging")
+ return parser
+
+
+def get_nondefault_trainer_args(args):
+ parser = argparse.ArgumentParser()
+ parser = Trainer.add_argparse_args(parser)
+ default_trainer_args = parser.parse_args([])
+ return sorted(k for k in vars(default_trainer_args)
+ if getattr(args, k) != getattr(default_trainer_args, k))
+
+
+if __name__ == "__main__":
+ now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
+ local_rank = int(os.environ.get('LOCAL_RANK'))
+ global_rank = int(os.environ.get('RANK'))
+ num_rank = int(os.environ.get('WORLD_SIZE'))
+
+ parser = get_parser()
+ # Extends existing argparse by default Trainer attributes
+ parser = Trainer.add_argparse_args(parser)
+ args, unknown = parser.parse_known_args()
+ transf_logging.set_verbosity_error()
+ seed_everything(args.seed)
+
+ configs = [OmegaConf.load(cfg) for cfg in args.base]
+ cli = OmegaConf.from_dotlist(unknown)
+ config = OmegaConf.merge(*configs, cli)
+ lightning_config = config.pop("lightning", OmegaConf.create())
+ trainer_config = lightning_config.get("trainer", OmegaConf.create())
+
+ # Setup workspace directories
+ workdir, ckptdir, cfgdir, loginfo = init_workspace(args.name, args.logdir,
+ config,
+ lightning_config,
+ global_rank)
+ logger = set_logger(
+ logfile=os.path.join(loginfo, 'log_%d:%s.txt' % (global_rank, now)))
+ logger.info("@lightning version: %s [>=1.8 required]" % (pl.__version__))
+ logger.info("***** Configing Model *****")
+ config.model.params.logdir = workdir
+ model = instantiate_from_config(config.model)
+ # Load checkpoints
+ model = load_checkpoints(model, config.model)
+
+ # Register_schedule again to make ZTSNR work
+ if model.rescale_betas_zero_snr:
+ model.register_schedule(given_betas=model.given_betas,
+ beta_schedule=model.beta_schedule,
+ timesteps=model.timesteps,
+ linear_start=model.linear_start,
+ linear_end=model.linear_end,
+ cosine_s=model.cosine_s)
+
+ # Update trainer config
+ for k in get_nondefault_trainer_args(args):
+ trainer_config[k] = getattr(args, k)
+
+ num_nodes = trainer_config.num_nodes
+ ngpu_per_node = trainer_config.devices
+ logger.info(f"Running on {num_rank}={num_nodes}x{ngpu_per_node} GPUs")
+
+ # Setup learning rate
+ base_lr = config.model.base_learning_rate
+ bs = config.data.params.batch_size
+ if getattr(config.model, 'scale_lr', True):
+ model.learning_rate = num_rank * bs * base_lr
+ else:
+ model.learning_rate = base_lr
+
+ logger.info("***** Configing Data *****")
+ data = instantiate_from_config(config.data)
+ data.setup()
+ for k in data.train_datasets:
+ logger.info(
+ f"{k}, {data.train_datasets[k].__class__.__name__}, {len(data.train_datasets[k])}"
+ )
+ if hasattr(data, 'val_datasets'):
+ for k in data.val_datasets:
+ logger.info(
+ f"{k}, {data.val_datasets[k].__class__.__name__}, {len(data.val_datasets[k])}"
+ )
+
+ for item in unknown:
+ if item.startswith('--total_gpus'):
+ num_gpus = int(item.split('=')[-1])
+ break
+ model.datasets_len = len(data)
+
+ logger.info("***** Configing Trainer *****")
+ if "accelerator" not in trainer_config:
+ trainer_config["accelerator"] = "gpu"
+
+ # Setup trainer args: pl-logger and callbacks
+ trainer_kwargs = dict()
+ trainer_kwargs["num_sanity_val_steps"] = 0
+ logger_cfg = get_trainer_logger(lightning_config, workdir, args.debug)
+ trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
+
+ # Setup callbacks
+ callbacks_cfg = get_trainer_callbacks(lightning_config, config, workdir,
+ ckptdir, logger)
+ trainer_kwargs["callbacks"] = [
+ instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
+ ]
+ strategy_cfg = get_trainer_strategy(lightning_config)
+ trainer_kwargs["strategy"] = strategy_cfg if type(
+ strategy_cfg) == str else instantiate_from_config(strategy_cfg)
+ trainer_kwargs['precision'] = lightning_config.get('precision', 32)
+ trainer_kwargs["sync_batchnorm"] = False
+
+ # Trainer config: others
+ trainer_args = argparse.Namespace(**trainer_config)
+ trainer = Trainer.from_argparse_args(trainer_args, **trainer_kwargs)
+
+ # Allow checkpointing via USR1
+ def melk(*args, **kwargs):
+ if trainer.global_rank == 0:
+ print("Summoning checkpoint.")
+ ckpt_path = os.path.join(ckptdir, "last_summoning.ckpt")
+ trainer.save_checkpoint(ckpt_path)
+
+ def divein(*args, **kwargs):
+ if trainer.global_rank == 0:
+ import pudb
+ pudb.set_trace()
+
+ import signal
+ signal.signal(signal.SIGUSR1, melk)
+ signal.signal(signal.SIGUSR2, divein)
+
+ # List the key model sizes
+ total_params = get_num_parameters(model)
+
+ logger.info("***** Running the Loop *****")
+ if args.train:
+ try:
+ if "strategy" in lightning_config and lightning_config[
+ 'strategy'].startswith('deepspeed'):
+ logger.info("")
+ if trainer_kwargs['precision'] == 16:
+ with torch.cuda.amp.autocast():
+ trainer.fit(model, data)
+ else:
+ trainer.fit(model, data)
+ else:
+ logger.info("")
+ trainer.fit(model, data)
+ except Exception:
+ raise
diff --git a/src/unifolm_wma/__init__.py b/src/unifolm_wma/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/unifolm_wma/data/base.py b/src/unifolm_wma/data/base.py
new file mode 100644
index 0000000..b10b13a
--- /dev/null
+++ b/src/unifolm_wma/data/base.py
@@ -0,0 +1,26 @@
+from abc import abstractmethod
+from torch.utils.data import IterableDataset
+
+
+class Txt2ImgIterableBaseDataset(IterableDataset):
+ '''
+ Define an interface to make the IterableDatasets for text2img data chainable
+ '''
+
+ def __init__(self, num_records=0, valid_ids=None, size=256):
+ super().__init__()
+ self.num_records = num_records
+ self.valid_ids = valid_ids
+ self.sample_ids = valid_ids
+ self.size = size
+
+ print(
+ f'{self.__class__.__name__} dataset contains {self.__len__()} examples.'
+ )
+
+ def __len__(self):
+ return self.num_records
+
+ @abstractmethod
+ def __iter__(self):
+ pass
diff --git a/src/unifolm_wma/data/normolize.py b/src/unifolm_wma/data/normolize.py
new file mode 100644
index 0000000..98f8d7e
--- /dev/null
+++ b/src/unifolm_wma/data/normolize.py
@@ -0,0 +1,230 @@
+#!/usr/bin/env python
+
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from torch import Tensor, nn
+from typing import Dict, List
+
+
+def create_stats_buffers(
+ shapes: Dict[str, List[int]],
+ modes: Dict[str, str],
+ stats: Dict[str, Dict[str, Tensor]] = None,
+) -> Dict[str, Dict[str, nn.ParameterDict]]:
+ """
+ Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
+ statistics.
+
+ Args: (see Normalize and Unnormalize)
+
+ Returns:
+ Dict: A Dictionary where keys are modalities and values are `nn.ParameterDict` containing
+ `nn.Parameters` set to `requires_grad=False`, suitable to not be updated during backpropagation.
+ """
+ stats_buffers = {}
+
+ for key, mode in modes.items():
+ assert mode in ["mean_std", "min_max"]
+
+ shape = tuple(shapes[key])
+
+ if "image" in key:
+ # sanity checks
+ assert len(
+ shape) == 3, f"number of dimensions of {key} != 3 ({shape=}"
+ c, h, w = shape
+ assert c < h and c < w, f"{key} is not channel first ({shape=})"
+ # override image shape to be invariant to height and width
+ shape = (c, 1, 1)
+
+ # Note: we initialize mean, std, min, max to infinity. They should be overwritten
+ # downstream by `stats` or `policy.load_state_Dict`, as expected. During forward,
+ # we assert they are not infinity anymore.
+
+ if "action" in key:
+ target_key = "action"
+ elif "state" in key:
+ target_key = 'observation.state'
+ else:
+ target_key = key
+
+ buffer = {}
+ if mode == "mean_std":
+ mean = torch.ones(shape, dtype=torch.float32) * torch.inf
+ std = torch.ones(shape, dtype=torch.float32) * torch.inf
+ buffer = nn.ParameterDict({
+ "mean":
+ nn.Parameter(mean, requires_grad=False),
+ "std":
+ nn.Parameter(std, requires_grad=False),
+ })
+ elif mode == "min_max":
+ min = torch.ones(shape, dtype=torch.float32) * torch.inf
+ max = torch.ones(shape, dtype=torch.float32) * torch.inf
+ buffer = nn.ParameterDict({
+ "min":
+ nn.Parameter(min, requires_grad=False),
+ "max":
+ nn.Parameter(max, requires_grad=False),
+ })
+
+ if stats is not None:
+ # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
+ # tensors anywhere (for example, when we use the same stats for normalization and
+ # unnormalization). See the logic here
+ if mode == "mean_std":
+ buffer["mean"].data = stats[target_key]["mean"].clone()
+ buffer["std"].data = stats[target_key]["std"].clone()
+ elif mode == "min_max":
+ buffer["min"].data = stats[target_key]["min"].clone()
+ buffer["max"].data = stats[target_key]["max"].clone()
+
+ stats_buffers[key] = buffer
+ return stats_buffers
+
+
+def _no_stats_error_str(name: str) -> str:
+ return (
+ f"`{name}` is infinity. You should either initialize with `stats` as an argument, or use a "
+ "pretrained model.")
+
+
+class Normalize(nn.Module):
+ """Normalizes data (e.g. "observation.image") for more stable and faster convergence during training."""
+
+ def __init__(
+ self,
+ shapes: Dict[str, List[int]],
+ modes: Dict[str, str],
+ stats: Dict[str, Dict[str, Tensor]] = None,
+ ):
+ """
+ Args:
+ shapes (Dict): A Dictionary where keys are input modalities (e.g. "observation.image") and values
+ are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
+ mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
+ is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
+ modes (Dict): A Dictionary where keys are output modalities (e.g. "observation.image") and values
+ are their normalization modes among:
+ - "mean_std": subtract the mean and divide by standard deviation.
+ - "min_max": map to [-1, 1] range.
+ stats (Dict, optional): A Dictionary where keys are output modalities (e.g. "observation.image")
+ and values are Dictionaries of statistic types and their values (e.g.
+ `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
+ training the model for the first time, these statistics will overwrite the default buffers. If
+ not provided, as expected for finetuning or evaluation, the default buffers should to be
+ overwritten by a call to `policy.load_state_Dict(state_Dict)`. That way, initializing the
+ dataset is not needed to get the stats, since they are already in the policy state_Dict.
+ """
+ super().__init__()
+ self.shapes = shapes
+ self.modes = modes
+ self.stats = stats
+ stats_buffers = create_stats_buffers(shapes, modes, stats)
+ for key, buffer in stats_buffers.items():
+ setattr(self, "buffer_" + key.replace(".", "_"), buffer)
+
+ @torch.no_grad()
+ def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
+ for key, mode in self.modes.items():
+ if key not in batch:
+ continue
+
+ buffer = getattr(self, "buffer_" + key.replace(".", "_"))
+
+ if mode == "mean_std":
+ mean = buffer["mean"]
+ std = buffer["std"]
+
+ assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
+ assert not torch.isinf(std).any(), _no_stats_error_str("std")
+ batch[key] = (batch[key] - mean) / (std + 1e-8)
+ elif mode == "min_max":
+ min = buffer["min"]
+ max = buffer["max"]
+
+ assert not torch.isinf(min).any(), _no_stats_error_str("min")
+ assert not torch.isinf(max).any(), _no_stats_error_str("max")
+ # normalize to [0,1]
+ batch[key] = (batch[key] - min) / (max - min + 1e-8)
+ # normalize to [-1, 1]
+ batch[key] = batch[key] * 2 - 1
+ else:
+ raise ValueError(mode)
+ return batch
+
+
+class Unnormalize(nn.Module):
+ """
+ Similar to `Normalize` but unnormalizes output data (e.g. `{"action": torch.randn(b,c)}`) in their
+ original range used by the environment.
+ """
+
+ def __init__(
+ self,
+ shapes: Dict[str, List[int]],
+ modes: Dict[str, str],
+ stats: Dict[str, Dict[str, Tensor]] = None,
+ ):
+ """
+ Args:
+ shapes (Dict): A Dictionary where keys are input modalities (e.g. "observation.image") and values
+ are their shapes (e.g. `[3,96,96]`]). These shapes are used to create the tensor buffer containing
+ mean, std, min, max statistics. If the provided `shapes` contain keys related to images, the shape
+ is adjusted to be invariant to height and width, assuming a channel-first (c, h, w) format.
+ modes (Dict): A Dictionary where keys are output modalities (e.g. "observation.image") and values
+ are their normalization modes among:
+ - "mean_std": subtract the mean and divide by standard deviation.
+ - "min_max": map to [-1, 1] range.
+ stats (Dict, optional): A Dictionary where keys are output modalities (e.g. "observation.image")
+ and values are Dictionaries of statistic types and their values (e.g.
+ `{"mean": torch.randn(3,1,1)}, "std": torch.randn(3,1,1)}`). If provided, as expected for
+ training the model for the first time, these statistics will overwrite the default buffers. If
+ not provided, as expected for finetuning or evaluation, the default buffers should to be
+ overwritten by a call to `policy.load_state_Dict(state_Dict)`. That way, initializing the
+ dataset is not needed to get the stats, since they are already in the policy state_Dict.
+ """
+ super().__init__()
+ self.shapes = shapes
+ self.modes = modes
+ self.stats = stats
+ stats_buffers = create_stats_buffers(shapes, modes, stats)
+ for key, buffer in stats_buffers.items():
+ setattr(self, "buffer_" + key.replace(".", "_"), buffer)
+
+ @torch.no_grad()
+ def forward(self, batch: Dict[str, Tensor]) -> Dict[str, Tensor]:
+ for key, mode in self.modes.items():
+ if key not in batch:
+ continue
+
+ buffer = getattr(self, "buffer_" + key.replace(".", "_"))
+
+ if mode == "mean_std":
+ mean = buffer["mean"]
+ std = buffer["std"]
+ assert not torch.isinf(mean).any(), _no_stats_error_str("mean")
+ assert not torch.isinf(std).any(), _no_stats_error_str("std")
+ batch[key] = batch[key] * std + mean
+ elif mode == "min_max":
+ min = buffer["min"]
+ max = buffer["max"]
+ assert not torch.isinf(min).any(), _no_stats_error_str("min")
+ assert not torch.isinf(max).any(), _no_stats_error_str("max")
+ batch[key] = (batch[key] + 1) / 2
+ batch[key] = batch[key] * (max - min) + min
+ else:
+ raise ValueError(mode)
+ return batch
diff --git a/src/unifolm_wma/data/utils.py b/src/unifolm_wma/data/utils.py
new file mode 100644
index 0000000..55675ae
--- /dev/null
+++ b/src/unifolm_wma/data/utils.py
@@ -0,0 +1,60 @@
+import torch
+
+from huggingface_hub import hf_hub_download, snapshot_download
+from typing import Dict, List, Union
+from pathlib import Path
+from safetensors.torch import load_file
+
+def unflatten_dict(d, sep="/"):
+ outdict = {}
+ for key, value in d.items():
+ parts = key.split(sep)
+ d = outdict
+ for part in parts[:-1]:
+ if part not in d:
+ d[part] = {}
+ d = d[part]
+ d[parts[-1]] = value
+ return outdict
+
+
+def load_episode_data_index(repo_id, version, root) -> Dict[str, torch.Tensor]:
+ """episode_data_index contains the range of indices for each episode
+
+ Example:
+ ```python
+ from_id = episode_data_index["from"][episode_id].item()
+ to_id = episode_data_index["to"][episode_id].item()
+ episode_frames = [dataset[i] for i in range(from_id, to_id)]
+ ```
+ """
+ if root is not None:
+ path = Path(
+ root) / repo_id / "meta_data" / "episode_data_index.safetensors"
+ else:
+ path = hf_hub_download(repo_id,
+ "meta_data/episode_data_index.safetensors",
+ repo_type="dataset",
+ revision=version)
+
+ return load_file(path)
+
+
+def load_stats(repo_id, version, root) -> Dict[str, Dict[str, torch.Tensor]]:
+ """stats contains the statistics per modality computed over the full dataset, such as max, min, mean, std
+
+ Example:
+ ```python
+ normalized_action = (action - stats["action"]["mean"]) / stats["action"]["std"]
+ ```
+ """
+ if root is not None:
+ path = Path(root) / repo_id / "meta_data" / "stats.safetensors"
+ else:
+ path = hf_hub_download(repo_id,
+ "meta_data/stats.safetensors",
+ repo_type="dataset",
+ revision=version)
+
+ stats = load_file(path)
+ return unflatten_dict(stats)
diff --git a/src/unifolm_wma/data/wma_data.py b/src/unifolm_wma/data/wma_data.py
new file mode 100644
index 0000000..bf64b32
--- /dev/null
+++ b/src/unifolm_wma/data/wma_data.py
@@ -0,0 +1,408 @@
+import torch
+import os
+import random
+import pandas as pd
+import h5py
+
+from decord import VideoReader, cpu
+from torch.utils.data import Dataset
+from torchvision import transforms
+from pathlib import Path
+
+from unifolm_wma.data.utils import load_stats
+from unifolm_wma.data.normolize import Normalize, Unnormalize
+
+
+class WMAData(Dataset):
+ """
+ Assuming the following dataset structure:
+ dataset_dir/
+ ├── videos
+ │ ├──dataset_name
+ │ │ ├──camera_view_dir
+ │ │ ├── 0.mp4
+ │ │ ├── 1.mp4
+ │ │ └── ...
+ │ └── ...
+ ├── transitions
+ │ ├── dataset_name
+ │ ├── meta_data
+ │ ├── 0.h5
+ │ ├── 1.h5
+ │ └── ...
+ └── dataset_name.csv
+ """
+
+ def __init__(
+ self,
+ meta_path,
+ data_dir,
+ subsample=None,
+ video_length=16,
+ resolution=[256, 512],
+ frame_stride=1,
+ frame_stride_min=1,
+ spatial_transform=None,
+ crop_resolution=None,
+ fps_max=None,
+ load_raw_resolution=False,
+ fixed_fps=None,
+ random_fs=False,
+ cond_robot_label_prob=0.0,
+ transition_dir=None,
+ dataset_name=None,
+ normalization_mode='min_max',
+ individual_normalization=False,
+ n_obs_steps=1,
+ max_action_dim=7,
+ max_state_dim=7,
+ ):
+ self.meta_path = meta_path
+ self.data_dir = data_dir
+ self.subsample = subsample
+ self.video_length = video_length
+ self.resolution = [resolution, resolution] if isinstance(
+ resolution, int) else resolution
+ self.fps_max = fps_max
+ self.frame_stride = frame_stride
+ self.frame_stride_min = frame_stride_min
+ self.fixed_fps = fixed_fps
+ self.load_raw_resolution = load_raw_resolution
+ self.random_fs = random_fs
+ self.cond_robot_label_prob = cond_robot_label_prob
+ self.transition_dir = transition_dir
+ self.dataset_name = dataset_name
+ self.max_action_dim = max_action_dim
+ self.max_state_dim = max_state_dim
+
+ self._load_metadata()
+ if spatial_transform is not None:
+ if spatial_transform == "random_crop":
+ self.spatial_transform = transforms.RandomCrop(crop_resolution)
+ elif spatial_transform == "center_crop":
+ self.spatial_transform = transforms.Compose([
+ transforms.CenterCrop(resolution),
+ ])
+ elif spatial_transform == "resize_center_crop":
+ self.spatial_transform = transforms.Compose([
+ transforms.Resize(min(self.resolution)),
+ transforms.CenterCrop(self.resolution),
+ ])
+ elif spatial_transform == "resize":
+ self.spatial_transform = transforms.Resize(self.resolution)
+ else:
+ raise NotImplementedError
+ else:
+ self.spatial_transform = None
+
+ self.normalization_mode = normalization_mode
+ self.individual_normalization = individual_normalization
+ self.n_obs_steps = n_obs_steps
+ self._load_stats()
+ if individual_normalization:
+ self._init_normalizers()
+
+ def _load_metadata(self):
+ metadata = pd.read_csv(self.meta_path, dtype=str)
+ if self.subsample is not None:
+ metadata = metadata.sample(self.subsample, random_state=0)
+
+ self.metadata = metadata
+ # drop the rows contain NaN values
+ self.metadata.dropna(inplace=True)
+ print(
+ f">>> {metadata['data_dir'].iloc[0]}: {len(metadata)} data samples loaded."
+ )
+
+ def _load_stats(self):
+ self.stats = load_stats(self.dataset_name, None, self.transition_dir)
+ print(f">>> {self.metadata['data_dir'].iloc[0]}: data stats loaded.")
+
+ def _init_normalizers(self):
+ shape_dict = {
+ 'pre_action': [self.stats['action']['max'].shape[-1]],
+ 'action': [self.stats['action']['max'].shape[-1]],
+ 'observation.state':
+ [self.stats['observation.state']['max'].shape[-1]],
+ 'next.state': [self.stats['observation.state']['max'].shape[-1]]
+ }
+ normalization_mode_dict = {
+ 'pre_action': self.normalization_mode,
+ 'action': self.normalization_mode,
+ 'observation.state': self.normalization_mode,
+ 'next.state': self.normalization_mode
+ }
+ self.normalizer = Normalize(shape_dict, normalization_mode_dict,
+ self.stats)
+ self.unnormalizer = Unnormalize(shape_dict, normalization_mode_dict,
+ self.stats)
+ print(
+ f">>> {self.metadata['data_dir'].iloc[0]}: normalizer initiated.")
+
+ def _get_video_path(self, sample):
+ rel_video_fp = os.path.join(sample['data_dir'],
+ str(sample['videoid']) + '.mp4')
+ full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp)
+ return full_video_fp
+
+ def _get_transition_path(self, sample):
+ data_dir = Path(sample['data_dir'])
+ if self.dataset_name == data_dir.name:
+ rel_transition_fp = os.path.join(str(data_dir),
+ str(sample['videoid']) + '.h5')
+ else:
+ rel_transition_fp = os.path.join(str(data_dir.parent),
+ str(sample['videoid']) + '.h5')
+ full_transition_fp = os.path.join(self.data_dir, 'transitions',
+ rel_transition_fp)
+ return full_transition_fp
+
+ def get_uni_vec(self, action_state_dict, action_type, state_type):
+ if 'pre_action' in action_state_dict:
+ action_state_dict['pre_action'], _ = self._map_to_uni_action(
+ action_state_dict['pre_action'], action_type)
+ if 'action' in action_state_dict:
+ action_state_dict['action'], action_state_dict[
+ 'action_mask'] = self._map_to_uni_action(
+ action_state_dict['action'], action_type)
+ if 'observation.state' in action_state_dict:
+ action_state_dict['observation.state'], _ = self._map_to_uni_state(
+ action_state_dict['observation.state'], state_type)
+ if 'next.state' in action_state_dict:
+ action_state_dict['next.state'], action_state_dict[
+ 'state_mask'] = self._map_to_uni_state(
+ action_state_dict['next.state'], state_type)
+ return action_state_dict
+
+ def _map_to_uni_action(self, action, action_type):
+ action_dim = action.shape[-1]
+ uni_action = torch.nn.functional.pad(
+ action, (0, self.max_action_dim - action_dim),
+ mode='constant',
+ value=0)
+ uni_action_mask = torch.zeros_like(uni_action)
+ uni_action_mask[:, :action_dim] = 1
+ return uni_action, uni_action_mask
+
+ def _map_to_uni_state(self, state, state_type):
+ state_dim = state.shape[-1]
+ uni_state = torch.nn.functional.pad(
+ state, (0, self.max_state_dim - state_dim),
+ mode='constant',
+ value=0)
+ uni_state_mask = torch.zeros_like(uni_state)
+ uni_state_mask[:, :state_dim] = 1
+ return uni_state, uni_state_mask
+
+ def __getitem__(self, index):
+
+ if self.random_fs:
+ frame_stride = random.randint(self.frame_stride_min,
+ self.frame_stride)
+ else:
+ frame_stride = self.frame_stride
+
+ # Get frames until success
+ while True:
+ index = index % len(self.metadata)
+ sample = self.metadata.iloc[index]
+ video_path = self._get_video_path(sample)
+
+ instruction = sample['instruction']
+ if self.cond_robot_label_prob > 0.0 and random.random(
+ ) < self.cond_robot_label_prob:
+ if sample['embodiment'] != 'x':
+ instruction = sample['embodiment'] + ' [SEP] ' + sample[
+ 'instruction']
+ try:
+ if self.load_raw_resolution:
+ video_reader = VideoReader(video_path, ctx=cpu(0))
+ else:
+ video_reader = VideoReader(video_path,
+ ctx=cpu(0),
+ width=530,
+ height=300)
+ if len(video_reader) < self.video_length:
+ print(
+ f">>> Video length ({len(video_reader)}) is smaller than target length({self.video_length})"
+ )
+ index += 1
+ continue
+ else:
+ pass
+ except:
+ index += 1
+ print(f">>> Error: load video failed! path = {video_path}")
+ continue
+
+ fps_ori = video_reader.get_avg_fps()
+ if self.fixed_fps is not None:
+ frame_stride = int(frame_stride *
+ (1.0 * fps_ori / self.fixed_fps))
+
+ # To avoid extreme cases when fixed_fps is used
+ frame_stride = max(frame_stride, 1)
+
+ # Get valid range (adapting case by case)
+ required_frame_num = frame_stride * (self.video_length - 1) + 1
+ frame_num = len(video_reader)
+ if frame_num < required_frame_num:
+ # Drop extra samples if fixed fps is required
+ if self.fixed_fps is not None and frame_num < required_frame_num * 0.5:
+ index += 1
+ continue
+ else:
+ frame_stride = frame_num // self.video_length
+ required_frame_num = frame_stride * (self.video_length -
+ 1) + 1
+
+ # Select a random clip
+ random_range = frame_num - required_frame_num
+ start_idx = random.randint(
+ 0, random_range -
+ frame_stride) if random_range - frame_stride > 0 else 0
+
+ # Calculate frame indices
+ frame_indices = [
+ start_idx + frame_stride * i for i in range(self.video_length)
+ ]
+ try:
+ next_frame_indices = [
+ idx + frame_stride for idx in frame_indices
+ ]
+ frames = video_reader.get_batch(next_frame_indices)
+ break
+ except:
+ print(
+ f">>> Error: Get frames failed! path = {video_path}; [max_ind vs frame_total:{max(frame_indices)} / {frame_num}]"
+ )
+ index += 1
+ continue
+
+ # Load transition data
+ transition_path = self._get_transition_path(sample)
+ with h5py.File(transition_path, 'r') as h5f:
+ transition_dict = {}
+ for key in h5f.keys():
+ transition_dict[key] = torch.tensor(h5f[key][()])
+ for key in h5f.attrs.keys():
+ transition_dict[key] = h5f.attrs[key]
+
+ # Load observable states
+ if start_idx < self.n_obs_steps - 1:
+ state_indices = list(range(0, start_idx + 1))
+ states = transition_dict['observation.state'][state_indices, :]
+ num_padding = self.n_obs_steps - 1 - start_idx
+ first_slice = states[0:1, :] # (t, d)
+ padding = first_slice.repeat(num_padding, 1)
+ states = torch.cat((padding, states), dim=0)
+ else:
+ state_indices = list(
+ range(start_idx - self.n_obs_steps + 1, start_idx + 1))
+ states = transition_dict['observation.state'][state_indices, :]
+ assert states.shape[
+ 0] == self.n_obs_steps, '>>> Do not have enough previous states as observation.'
+
+ # Load observable actions
+ if start_idx < self.n_obs_steps:
+ pre_action_indices = list(range(0, start_idx))
+ pre_actions = transition_dict['action'][pre_action_indices, :]
+ num_padding = self.n_obs_steps - start_idx
+ first_slice = torch.zeros_like(transition_dict['action'][:1, :])
+ padding = first_slice.repeat(num_padding, 1)
+ pre_actions = torch.cat((padding, pre_actions), dim=0)
+ else:
+ pre_action_indices = list(
+ range(start_idx - self.n_obs_steps, start_idx))
+ pre_actions = transition_dict['action'][pre_action_indices, :]
+ assert pre_actions.shape[
+ 0] == self.n_obs_steps, ">>> Do not have enough previous actions as observation"
+
+ # Load future actions
+ actions = transition_dict['action'][frame_indices, :]
+ # Load future states
+ next_state_indices = [idx + frame_stride for idx in frame_indices]
+ next_states = transition_dict['observation.state'][
+ next_state_indices, :]
+ frames_action_state_dict = {
+ 'pre_action': pre_actions,
+ 'action': actions,
+ 'observation.state': states,
+ 'next.state': next_states
+ }
+ if self.individual_normalization:
+ frames_action_state_dict = self.normalizer(
+ frames_action_state_dict)
+
+ # Update action and states to unified vector
+ frames_action_state_dict = self.get_uni_vec(
+ frames_action_state_dict,
+ transition_dict['action_type'],
+ transition_dict['state_type'],
+ )
+
+ # Load observable images
+ if start_idx < self.n_obs_steps - 1:
+ action_net_frame_indices = list(range(0, start_idx + 1))
+ action_net_frames = video_reader.get_batch(
+ action_net_frame_indices)
+ action_net_frames = torch.tensor(
+ action_net_frames.asnumpy()).permute(0, 3, 1, 2).float()
+ first_slice = action_net_frames[0:1, :]
+ num_padding = self.n_obs_steps - 1 - start_idx
+ padding = first_slice.repeat(num_padding, 1, 1, 1)
+ action_net_frames = torch.cat((padding, action_net_frames), dim=0)
+ assert (
+ action_net_frames.shape[0] == self.n_obs_steps
+ ), f'{len(action_net_frames)}, self.n_obs_steps={self.n_obs_steps}'
+ action_net_frames = action_net_frames.permute(1, 0, 2, 3)
+ else:
+ action_net_frame_indices = list(
+ range(start_idx - self.n_obs_steps + 1, start_idx + 1))
+ action_net_frames = video_reader.get_batch(
+ action_net_frame_indices)
+ assert (
+ action_net_frames.shape[0] == self.n_obs_steps
+ ), f'{len(action_net_frames)}, self.n_obs_steps={self.n_obs_steps}'
+ action_net_frames = torch.tensor(
+ action_net_frames.asnumpy()).permute(3, 0, 1, 2).float()
+
+ assert (frames.shape[0] == self.video_length
+ ), f'{len(frames)}, self.video_length={self.video_length}'
+ frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float()
+
+ if self.spatial_transform is not None:
+ frames = self.spatial_transform(frames)
+ action_net_frames = self.spatial_transform(action_net_frames)
+
+ if self.resolution is not None:
+ assert (frames.shape[2], frames.shape[3]) == (
+ self.resolution[0], self.resolution[1]
+ ), f'frames={frames.shape}, self.resolution={self.resolution}'
+ assert (
+ action_net_frames.shape[2], action_net_frames.shape[3]
+ ) == (
+ self.resolution[0], self.resolution[1]
+ ), f'action_net_frames={action_net_frames.shape}, self.resolution={self.resolution}'
+
+ # Normalize frames tensors to [-1,1]
+ frames = (frames / 255 - 0.5) * 2
+ action_net_frames = (action_net_frames / 255 - 0.5) * 2
+ fps_clip = fps_ori // frame_stride
+ if self.fps_max is not None and fps_clip > self.fps_max:
+ fps_clip = self.fps_max
+
+ data = {
+ 'video': frames,
+ 'instruction': instruction,
+ 'path': video_path,
+ 'fps': fps_clip,
+ 'frame_stride': frame_stride,
+ 'observation.image': action_net_frames,
+ }
+ data.update(frames_action_state_dict)
+
+ return data
+
+ def __len__(self):
+ return len(self.metadata)
diff --git a/src/unifolm_wma/models/__init__.py b/src/unifolm_wma/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/unifolm_wma/models/autoencoder.py b/src/unifolm_wma/models/autoencoder.py
new file mode 100644
index 0000000..2a3b521
--- /dev/null
+++ b/src/unifolm_wma/models/autoencoder.py
@@ -0,0 +1,267 @@
+import os
+import torch
+import torch.nn.functional as F
+import pytorch_lightning as pl
+
+from einops import rearrange
+from unifolm_wma.modules.networks.ae_modules import Encoder, Decoder
+from unifolm_wma.utils.distributions import DiagonalGaussianDistribution
+from unifolm_wma.utils.utils import instantiate_from_config
+
+
+class AutoencoderKL(pl.LightningModule):
+
+ def __init__(
+ self,
+ ddconfig,
+ lossconfig,
+ embed_dim,
+ ckpt_path=None,
+ ignore_keys=[],
+ image_key="image",
+ colorize_nlabels=None,
+ monitor=None,
+ test=False,
+ logdir=None,
+ input_dim=4,
+ test_args=None,
+ ):
+ super().__init__()
+ self.image_key = image_key
+ self.encoder = Encoder(**ddconfig)
+ self.decoder = Decoder(**ddconfig)
+ self.loss = instantiate_from_config(lossconfig)
+ assert ddconfig["double_z"]
+ self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"],
+ 2 * embed_dim, 1)
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim,
+ ddconfig["z_channels"], 1)
+ self.embed_dim = embed_dim
+ self.input_dim = input_dim
+ self.test = test
+ self.test_args = test_args
+ self.logdir = logdir
+ if colorize_nlabels is not None:
+ assert type(colorize_nlabels) == int
+ self.register_buffer("colorize",
+ torch.randn(3, colorize_nlabels, 1, 1))
+ if monitor is not None:
+ self.monitor = monitor
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
+ if self.test:
+ self.init_test()
+
+ def init_test(self, ):
+ self.test = True
+ save_dir = os.path.join(self.logdir, "test")
+ if 'ckpt' in self.test_args:
+ ckpt_name = os.path.basename(self.test_args.ckpt).split(
+ '.ckpt')[0] + f'_epoch{self._cur_epoch}'
+ self.root = os.path.join(save_dir, ckpt_name)
+ else:
+ self.root = save_dir
+ if 'test_subdir' in self.test_args:
+ self.root = os.path.join(save_dir, self.test_args.test_subdir)
+
+ self.root_zs = os.path.join(self.root, "zs")
+ self.root_dec = os.path.join(self.root, "reconstructions")
+ self.root_inputs = os.path.join(self.root, "inputs")
+ os.makedirs(self.root, exist_ok=True)
+
+ if self.test_args.save_z:
+ os.makedirs(self.root_zs, exist_ok=True)
+ if self.test_args.save_reconstruction:
+ os.makedirs(self.root_dec, exist_ok=True)
+ if self.test_args.save_input:
+ os.makedirs(self.root_inputs, exist_ok=True)
+ assert (self.test_args is not None)
+ self.test_maximum = getattr(self.test_args, 'test_maximum', None)
+ self.count = 0
+ self.eval_metrics = {}
+ self.decodes = []
+ self.save_decode_samples = 2048
+
+ def init_from_ckpt(self, path, ignore_keys=list()):
+ sd = torch.load(path, map_location="cpu")
+ try:
+ self._cur_epoch = sd['epoch']
+ sd = sd["state_dict"]
+ except:
+ self._cur_epoch = 'null'
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ print("Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ self.load_state_dict(sd, strict=False)
+ print(f"Restored from {path}")
+
+ def encode(self, x, **kwargs):
+
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+ return posterior
+
+ def decode(self, z, **kwargs):
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, input, sample_posterior=True):
+ posterior = self.encode(input)
+ if sample_posterior:
+ z = posterior.sample()
+ else:
+ z = posterior.mode()
+ dec = self.decode(z)
+ return dec, posterior
+
+ def get_input(self, batch, k):
+ x = batch[k]
+ if x.dim() == 5 and self.input_dim == 4:
+ b, c, t, h, w = x.shape
+ self.b = b
+ self.t = t
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
+
+ return x
+
+ def training_step(self, batch, batch_idx, optimizer_idx):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+
+ if optimizer_idx == 0:
+ # train encoder+decoder+logvar
+ aeloss, log_dict_ae = self.loss(inputs,
+ reconstructions,
+ posterior,
+ optimizer_idx,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="train")
+ self.log("aeloss",
+ aeloss,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True)
+ self.log_dict(log_dict_ae,
+ prog_bar=False,
+ logger=True,
+ on_step=True,
+ on_epoch=False)
+ return aeloss
+
+ if optimizer_idx == 1:
+ # train the discriminator
+ discloss, log_dict_disc = self.loss(
+ inputs,
+ reconstructions,
+ posterior,
+ optimizer_idx,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="train")
+
+ self.log("discloss",
+ discloss,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True)
+ self.log_dict(log_dict_disc,
+ prog_bar=False,
+ logger=True,
+ on_step=True,
+ on_epoch=False)
+ return discloss
+
+ def validation_step(self, batch, batch_idx):
+ inputs = self.get_input(batch, self.image_key)
+ reconstructions, posterior = self(inputs)
+ aeloss, log_dict_ae = self.loss(inputs,
+ reconstructions,
+ posterior,
+ 0,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="val")
+
+ discloss, log_dict_disc = self.loss(inputs,
+ reconstructions,
+ posterior,
+ 1,
+ self.global_step,
+ last_layer=self.get_last_layer(),
+ split="val")
+
+ self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
+ self.log_dict(log_dict_ae)
+ self.log_dict(log_dict_disc)
+ return self.log_dict
+
+ def configure_optimizers(self):
+ lr = self.learning_rate
+ opt_ae = torch.optim.Adam(list(self.encoder.parameters()) +
+ list(self.decoder.parameters()) +
+ list(self.quant_conv.parameters()) +
+ list(self.post_quant_conv.parameters()),
+ lr=lr,
+ betas=(0.5, 0.9))
+ opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
+ lr=lr,
+ betas=(0.5, 0.9))
+ return [opt_ae, opt_disc], []
+
+ def get_last_layer(self):
+ return self.decoder.conv_out.weight
+
+ @torch.no_grad()
+ def log_images(self, batch, only_inputs=False, **kwargs):
+ log = dict()
+ x = self.get_input(batch, self.image_key)
+ x = x.to(self.device)
+ if not only_inputs:
+ xrec, posterior = self(x)
+ if x.shape[1] > 3:
+ # colorize with random projection
+ assert xrec.shape[1] > 3
+ x = self.to_rgb(x)
+ xrec = self.to_rgb(xrec)
+ log["samples"] = self.decode(torch.randn_like(posterior.sample()))
+ log["reconstructions"] = xrec
+ log["inputs"] = x
+ return log
+
+ def to_rgb(self, x):
+ assert self.image_key == "segmentation"
+ if not hasattr(self, "colorize"):
+ self.register_buffer("colorize",
+ torch.randn(3, x.shape[1], 1, 1).to(x))
+ x = F.conv2d(x, weight=self.colorize)
+ x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
+ return x
+
+
+class IdentityFirstStage(torch.nn.Module):
+
+ def __init__(self, *args, vq_interface=False, **kwargs):
+ self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
+ super().__init__()
+
+ def encode(self, x, *args, **kwargs):
+ return x
+
+ def decode(self, x, *args, **kwargs):
+ return x
+
+ def quantize(self, x, *args, **kwargs):
+ if self.vq_interface:
+ return x, None, [None, None, None]
+ return x
+
+ def forward(self, x, *args, **kwargs):
+ return x
diff --git a/src/unifolm_wma/models/ddpms.py b/src/unifolm_wma/models/ddpms.py
new file mode 100644
index 0000000..fbf2042
--- /dev/null
+++ b/src/unifolm_wma/models/ddpms.py
@@ -0,0 +1,2524 @@
+"""
+wild mixture of
+https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
+https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+https://github.com/CompVis/taming-transformers
+-- merci
+"""
+
+import random
+import torch
+import torch.nn as nn
+import copy
+import numpy as np
+import pytorch_lightning as pl
+import torch.nn.functional as F
+import logging
+
+mainlogger = logging.getLogger('mainlogger')
+
+from functools import partial
+from contextlib import contextmanager
+from tqdm import tqdm
+from einops import rearrange, repeat, reduce
+from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR
+from torchvision.utils import make_grid
+from pytorch_lightning.utilities import rank_zero_only
+from omegaconf import OmegaConf
+from typing import Optional, Sequence, Any, Tuple, Union, List, Dict
+from collections.abc import Mapping, Iterable, Callable
+from torch import Tensor
+
+from unifolm_wma.utils.utils import instantiate_from_config
+from unifolm_wma.utils.ema import LitEma
+from unifolm_wma.utils.distributions import DiagonalGaussianDistribution
+from unifolm_wma.utils.diffusion import make_beta_schedule, rescale_zero_terminal_snr
+from unifolm_wma.utils.basics import disabled_train
+from unifolm_wma.utils.common import (extract_into_tensor, noise_like, exists,
+ default)
+
+from unifolm_wma.models.samplers.ddim import DDIMSampler
+from unifolm_wma.models.diffusion_head.common.lr_scheduler import get_scheduler, SelectiveLRScheduler
+from unifolm_wma.models.diffusion_head.ema_model import EMAModel
+from unifolm_wma.models.diffusion_head.positional_embedding import SinusoidalPosEmb
+from unifolm_wma.modules.encoders.condition import MLPProjector
+from unifolm_wma.data.normolize import Normalize, Unnormalize
+
+__conditioning_keys__ = {
+ 'concat': 'c_concat',
+ 'crossattn': 'c_crossattn',
+ 'adm': 'y'
+}
+
+
+class DDPM(pl.LightningModule):
+ """
+ Denoising Diffusion Probabilistic Model (DDPM) LightningModule.
+ """
+
+ def __init__(
+ self,
+ wma_config: OmegaConf,
+ timesteps: int = 1000,
+ beta_schedule: str = "linear",
+ loss_type: str = "l2",
+ ckpt_path: Optional[str] = None,
+ ignore_keys: Optional[Sequence[str]] = [],
+ load_only_unet: bool = False,
+ monitor: str = None,
+ use_ema: bool = True,
+ first_stage_key: str = "image",
+ image_size: int = 256,
+ channels: int = 3,
+ log_every_t: int = 100,
+ clip_denoised: bool = True,
+ linear_start: float = 1e-4,
+ linear_end: float = 2e-2,
+ cosine_s: float = 8e-3,
+ given_betas: Optional[np.ndarray] = None,
+ original_elbo_weight: float = 0.0,
+ v_posterior: float = 0.0,
+ l_simple_weight: float = 1.0,
+ conditioning_key: Optional[str] = None,
+ parameterization: str = "eps",
+ scheduler_config: Optional[Mapping[str, Any]] = None,
+ use_positional_encodings: bool = False,
+ learn_logvar: bool = False,
+ logvar_init: float = 0.0,
+ rescale_betas_zero_snr: bool = False,
+ ):
+ """
+ wma_config: Config object used to build the underlying model.
+ timesteps: Number of diffusion steps.
+ beta_schedule: Schedule type for betas (e.g., 'linear', 'cosine').
+ loss_type: Loss type.
+ ckpt_path: Optional checkpoint path to restore weights.
+ ignore_keys: Keys to ignore when loading the checkpoint.
+ load_only_unet: If True, load the backbone into self.model only.
+ monitor: Metric key for monitoring.
+ use_ema: If True, maintain EMA weights.
+ first_stage_key: Key in batch dict for inputs.
+ image_size: Image size.
+ channels: Number of channels.
+ log_every_t: Interval of timesteps to log intermediates during sampling.
+ clip_denoised: Clamp x_0 predictions or not.
+ linear_start: Linear schedule start.
+ linear_end: Linear schedule end.
+ cosine_s: Cosine schedule s parameter.
+ given_betas: Externally provided betas; overrides schedule if set.
+ original_elbo_weight: Weight for VLB term.
+ v_posterior: Interpolation weight for posterior variance.
+ l_simple_weight: Weight for simple loss term.
+ conditioning_key: Conditioning mechanism key (if used by wrapper).
+ parameterization: One of {'eps','x0','v'}.
+ scheduler_config: Optional LR scheduler config.
+ use_positional_encodings: Whether to inject positional encodings.
+ learn_logvar: If True, learn per-timestep log-variance.
+ logvar_init: Initial value for log-variance.
+ rescale_betas_zero_snr: If True, apply zero-SNR rescaling to betas.
+ """
+ super().__init__()
+ assert parameterization in [
+ "eps", "x0", "v"
+ ], 'currently only supporting "eps" and "x0" and "v"'
+ self.parameterization = parameterization
+ mainlogger.info(
+ f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode"
+ )
+ self.cond_stage_model = None
+ self.clip_denoised = clip_denoised
+ self.log_every_t = log_every_t
+ self.first_stage_key = first_stage_key
+ self.channels = channels
+ self.temporal_length = wma_config.params.temporal_length
+ self.image_size = image_size
+ if isinstance(self.image_size, int):
+ self.image_size = [self.image_size, self.image_size]
+ self.use_positional_encodings = use_positional_encodings
+ self.model = DiffusionWrapper(wma_config, conditioning_key)
+ self.use_ema = use_ema
+ self.rescale_betas_zero_snr = rescale_betas_zero_snr
+ if self.use_ema:
+ self.model_ema = LitEma(self.model)
+ mainlogger.info(
+ f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
+ self.v_posterior = v_posterior
+ self.original_elbo_weight = original_elbo_weight
+ self.l_simple_weight = l_simple_weight
+
+ if monitor is not None:
+ self.monitor = monitor
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path,
+ ignore_keys=ignore_keys,
+ only_model=load_only_unet)
+ self.register_schedule(given_betas=given_betas,
+ beta_schedule=beta_schedule,
+ timesteps=timesteps,
+ linear_start=linear_start,
+ linear_end=linear_end,
+ cosine_s=cosine_s)
+
+ # For reschedule
+ self.given_betas = given_betas
+ self.beta_schedule = beta_schedule
+ self.timesteps = timesteps
+ self.cosine_s = cosine_s
+ self.loss_type = loss_type
+ self.learn_logvar = learn_logvar
+ self.logvar = torch.full(fill_value=logvar_init,
+ size=(self.num_timesteps, ))
+ if self.learn_logvar:
+ self.logvar = nn.Parameter(self.logvar, requires_grad=True)
+
+ def register_schedule(self,
+ given_betas: Optional[np.ndarray] = None,
+ beta_schedule: str = "linear",
+ timesteps: int = 1000,
+ linear_start: float = 1e-4,
+ linear_end: float = 2e-2,
+ cosine_s: float = 8e-3) -> None:
+ """
+ Create and register diffusion buffers (betas, alphas, posteriors, weights).
+
+ Args:
+ given_betas: If provided, use these instead of building a schedule.
+ beta_schedule: Name of schedule to create if betas not given.
+ timesteps: Number of diffusion steps.
+ linear_start: Linear schedule start.
+ linear_end: Linear schedule end.
+ cosine_s: Cosine schedule parameter
+ """
+ if exists(given_betas):
+ betas = given_betas
+ else:
+ betas = make_beta_schedule(beta_schedule,
+ timesteps,
+ linear_start=linear_start,
+ linear_end=linear_end,
+ cosine_s=cosine_s)
+ if self.rescale_betas_zero_snr:
+ betas = rescale_zero_terminal_snr(betas)
+
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[
+ 0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev',
+ to_torch(alphas_cumprod_prev))
+
+ self.register_buffer('sqrt_alphas_cumprod',
+ to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod',
+ to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod',
+ to_torch(np.log(1. - alphas_cumprod)))
+
+ if self.parameterization != 'v':
+ self.register_buffer('sqrt_recip_alphas_cumprod',
+ to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod',
+ to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+ else:
+ self.register_buffer('sqrt_recip_alphas_cumprod',
+ torch.zeros_like(to_torch(alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod',
+ torch.zeros_like(to_torch(alphas_cumprod)))
+
+ posterior_variance = (1 - self.v_posterior) * betas * (
+ 1. - alphas_cumprod_prev) / (
+ 1. - alphas_cumprod) + self.v_posterior * betas
+ # Above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+ self.register_buffer('posterior_variance',
+ to_torch(posterior_variance))
+ # Below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.register_buffer(
+ 'posterior_log_variance_clipped',
+ to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
+ self.register_buffer(
+ 'posterior_mean_coef1',
+ to_torch(betas * np.sqrt(alphas_cumprod_prev) /
+ (1. - alphas_cumprod)))
+ self.register_buffer(
+ 'posterior_mean_coef2',
+ to_torch((1. - alphas_cumprod_prev) * np.sqrt(alphas) /
+ (1. - alphas_cumprod)))
+ if self.parameterization == "eps":
+ lvlb_weights = self.betas**2 / (2 * self.posterior_variance *
+ to_torch(alphas) *
+ (1 - self.alphas_cumprod))
+ elif self.parameterization == "x0":
+ lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (
+ 2. * 1 - torch.Tensor(alphas_cumprod))
+ elif self.parameterization == "v":
+ lvlb_weights = torch.ones_like(
+ self.betas**2 /
+ (2 * self.posterior_variance * to_torch(alphas) *
+ (1 - self.alphas_cumprod)))
+ else:
+ raise NotImplementedError("mu not supported")
+ lvlb_weights[0] = lvlb_weights[1]
+ self.register_buffer('lvlb_weights', lvlb_weights, persistent=False)
+ assert not torch.isnan(self.lvlb_weights).all()
+
+ @contextmanager
+ def ema_scope(self, context: Optional[str] = None) -> Iterable[None]:
+ """
+ Context manager that temporarily swaps to EMA weights (if enabled).
+
+ Args:
+ context: Optional label for logging.
+
+ """
+ if self.use_ema:
+ self.model_ema.store(self.model.parameters())
+ self.model_ema.copy_to(self.model)
+ if context is not None:
+ mainlogger.info(f"{context}: Switched to EMA weights")
+ try:
+ yield None
+ finally:
+ if self.use_ema:
+ self.model_ema.restore(self.model.parameters())
+ if context is not None:
+ mainlogger.info(f"{context}: Restored training weights")
+
+ def init_from_ckpt(self,
+ path: str,
+ ignore_keys: Sequence[str] = tuple(),
+ only_model: bool = False) -> None:
+ """
+ Load a checkpoint, optionally filtering keys or loading only the inner model.
+
+ Args:
+ path: Path to checkpoint.
+ ignore_keys: State-dict keys (prefix match) to drop.
+ only_model: If True, load only into self.model.
+ """
+ sd = torch.load(path, map_location="cpu")
+ if "state_dict" in list(sd.keys()):
+ sd = sd["state_dict"]
+ keys = list(sd.keys())
+ for k in keys:
+ for ik in ignore_keys:
+ if k.startswith(ik):
+ mainlogger.info(
+ "Deleting key {} from state_dict.".format(k))
+ del sd[k]
+ missing, unexpected = self.load_state_dict(
+ sd,
+ strict=False) if not only_model else self.model.load_state_dict(
+ sd, strict=False)
+ mainlogger.info(
+ f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
+ )
+ if len(missing) > 0:
+ mainlogger.info(f"Missing Keys: {missing}")
+ if len(unexpected) > 0:
+ mainlogger.info(f"Unexpected Keys: {unexpected}")
+
+ def q_mean_variance(self, x_start: Tensor,
+ t: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
+ """
+ Compute q(x_t | x_0): mean, variance, and log-variance.
+
+ Args:
+ x_start: the [N x C x ...] tensor of noiseless inputs..
+ t: the number of diffusion steps (minus 1). Here, 0 means one step..
+
+ Returns:
+ (mean, variance, log_variance), each shaped like x_start.
+ """
+ mean = (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) *
+ x_start)
+ variance = extract_into_tensor(1.0 - self.alphas_cumprod, t,
+ x_start.shape)
+ log_variance = extract_into_tensor(self.log_one_minus_alphas_cumprod,
+ t, x_start.shape)
+ return mean, variance, log_variance
+
+ def predict_start_from_noise(self, x_t: Tensor, t: Tensor,
+ noise: Tensor) -> Tensor:
+ """
+ Predict x_0 from x_t and noise.
+ """
+ return (
+ extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) *
+ x_t - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t,
+ x_t.shape) * noise)
+
+ def predict_start_from_z_and_v(self, x_t: Tensor, t: Tensor,
+ v: Tensor) -> Tensor:
+ """
+ Predict x_0 from z and v (v-parameterization).
+ """
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t,
+ x_t.shape) * v)
+
+ def predict_eps_from_z_and_v(self, x_t: Tensor, t: Tensor,
+ v: Tensor) -> Tensor:
+ """
+ Predict epsilon from z and v (v-parameterization).
+ """
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t,
+ x_t.shape) * x_t)
+
+ def q_posterior(self, x_start: Tensor, x_t: Tensor,
+ t: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
+ """
+ Compute posterior q(x_{t-1} | x_t, x_0): mean and (log-)variance.
+ """
+ posterior_mean = (
+ extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) *
+ x_start +
+ extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t)
+ posterior_variance = extract_into_tensor(self.posterior_variance, t,
+ x_t.shape)
+ posterior_log_variance_clipped = extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape)
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, x: Tensor, t: Tensor,
+ clip_denoised: bool) -> Tuple[Tensor, Tensor, Tensor]:
+ """
+ Predict mean and variance of p(x_{t-1} | x_t) using the model.
+ """
+ model_out = self.model(x, t)
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
+ x_start=x_recon, x_t=x, t=t)
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self,
+ x: Tensor,
+ t: Tensor,
+ clip_denoised: bool = True,
+ repeat_noise: bool = False) -> Tensor:
+ """
+ Draw a single reverse-diffusion sample step: x_{t-1} from x_t.
+
+ Args:
+ x: Current noisy sample (B, C, ...).
+ t: Current timestep indices (B,).
+ clip_denoised: Clamp x_0 predictions or not.
+ repeat_noise: Reuse the same noise across the batch.
+
+ Returns:
+ Next sample x_{t-1}.
+ """
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance = self.p_mean_variance(
+ x=x, t=t, clip_denoised=clip_denoised)
+ noise = noise_like(x.shape, device, repeat_noise)
+ # No noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(
+ b, *((1, ) * (len(x.shape) - 1)))
+ return model_mean + nonzero_mask * (0.5 *
+ model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def p_sample_loop(
+ self,
+ shape: Sequence[int],
+ return_intermediates: bool = False
+ ) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]:
+ """
+ Run the full reverse process starting from Gaussian noise.
+
+ Args:
+ shape: Output tensor shape (B, C, ...).
+ return_intermediates: If True, also return intermediate frames.
+
+ Returns:
+ Final sample, and optionally the intermediate list.
+ """
+ device = self.betas.device
+ b = shape[0]
+ img = torch.randn(shape, device=device)
+ intermediates = [img]
+ for i in tqdm(reversed(range(0, self.num_timesteps)),
+ desc='Sampling t',
+ total=self.num_timesteps):
+ img = self.p_sample(img,
+ torch.full((b, ),
+ i,
+ device=device,
+ dtype=torch.long),
+ clip_denoised=self.clip_denoised)
+ if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
+ intermediates.append(img)
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(
+ self,
+ batch_size: int = 16,
+ return_intermediates: bool = False
+ ) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]:
+ """
+ Convenience wrapper to sample square images of configured size.
+
+ Args:
+ batch_size: Number of samples.
+ return_intermediates: If True, also return intermediate frames.
+
+ Returns:
+ Final sample (and optionally intermediates).
+ """
+ image_size = self.image_size
+ channels = self.channels
+ return self.p_sample_loop(
+ (batch_size, channels, image_size, image_size),
+ return_intermediates=return_intermediates)
+
+ def q_sample(self,
+ x_start: Tensor,
+ t: Tensor,
+ noise: Optional[Tensor] = None) -> Tensor:
+ """
+ Forward noising step: sample x_t ~ q(x_t | x_0).
+ """
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) *
+ x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod,
+ t, x_start.shape) * noise)
+
+ def get_v(self, x: Tensor, noise: Tensor, t: Tensor) -> Tensor:
+ """
+ Compute v-target given x and epsilon.
+ """
+ return (
+ extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t,
+ x.shape) * x)
+
+ def get_loss(self,
+ pred: Tensor,
+ target: Tensor,
+ mean: bool = True) -> Tensor:
+ """
+ Compute training loss between prediction and target.
+
+ Args:
+ pred: Model output.
+ target: Target tensor.
+ mean: If True, reduce to mean.
+
+ Returns:
+ Loss tensor (scalar if reduced).
+ """
+ if self.loss_type == 'l1':
+ loss = (target - pred).abs()
+ if mean:
+ loss = loss.mean()
+ elif self.loss_type == 'l2':
+ if mean:
+ loss = torch.nn.functional.mse_loss(target, pred)
+ else:
+ loss = torch.nn.functional.mse_loss(target,
+ pred,
+ reduction='none')
+ else:
+ raise NotImplementedError("unknown loss type '{loss_type}'")
+
+ return loss
+
+ def p_losses(self,
+ x_start: Tensor,
+ t: Tensor,
+ noise: Optional[Tensor] = None
+ ) -> Tuple[Tensor, Dict[str, Tensor]]:
+ """
+ Compute the per-step training loss for a batch.
+
+ Args:
+ x_start: Clean inputs (B, C, ...).
+ t: Timesteps (B,).
+ noise: Optional pre-sampled epsilon.
+
+ Returns:
+ (loss, log_dict)
+ """
+
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_out = self.model(x_noisy, t)
+
+ loss_dict = {}
+ if self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError(
+ f"Paramterization {self.parameterization} not yet supported")
+
+ loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])
+
+ log_prefix = 'train' if self.training else 'val'
+
+ loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
+ loss_simple = loss.mean() * self.l_simple_weight
+
+ loss_vlb = (self.lvlb_weights[t] * loss).mean()
+ loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})
+
+ loss = loss_simple + self.original_elbo_weight * loss_vlb
+
+ loss_dict.update({f'{log_prefix}/loss': loss})
+
+ return loss, loss_dict
+
+ def forward(self, x: Tensor, *args: Any,
+ **kwargs: Any) -> Tuple[Tensor, Dict[str, Tensor]]:
+ """
+ Lightning forward: sample random timesteps and compute losses.
+
+ Args:
+ x: Clean batch (B, C, ...).
+
+ Returns:
+ (loss, log_dict)
+ """
+ t = torch.randint(0,
+ self.num_timesteps, (x.shape[0], ),
+ device=self.device).long()
+ return self.p_losses(x, t, *args, **kwargs)
+
+ def get_input(self, batch: Mapping[str, Tensor], k: str) -> Tensor:
+ """
+ Fetch and format the network input from batch.
+
+ Args:
+ batch: Batch mapping.
+ k: Key for the tensor to use.
+
+ Returns:
+ (B, C, ...) float32 contiguous tensor.
+ """
+ x = batch[k]
+ '''
+ if len(x.shape) == 3:
+ x = x[..., None]
+ x = rearrange(x, 'b h w c -> b c h w')
+ '''
+ x = x.to(memory_format=torch.contiguous_format).float()
+ return x
+
+ def shared_step(
+ self, batch: Mapping[str,
+ Tensor]) -> Tuple[Tensor, Dict[str, Tensor]]:
+ """
+ Common train/val step computing loss and logs.
+ """
+ x = self.get_input(batch, self.first_stage_key)
+ loss, loss_dict = self(x)
+ return loss, loss_dict
+
+ def training_step(self, batch: Mapping[str, Tensor],
+ batch_idx: int) -> Tensor:
+ """
+ PyTorch Lightning training step.
+ """
+ loss, loss_dict = self.shared_step(batch)
+
+ self.log_dict(loss_dict,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True)
+
+ self.log("global_step",
+ self.global_step,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=False)
+
+ if self.use_scheduler:
+ lr = self.optimizers().param_groups[0]['lr']
+ self.log('lr_abs',
+ lr,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=False)
+ return loss
+
+ @torch.no_grad()
+ def validation_step(self, batch: Mapping[str, Tensor],
+ batch_idx: int) -> None:
+ """
+ PyTorch Lightning validation step with and without EMA.
+ """
+ _, loss_dict_no_ema = self.shared_step(batch)
+ with self.ema_scope():
+ _, loss_dict_ema = self.shared_step(batch)
+ loss_dict_ema = {
+ key + '_ema': loss_dict_ema[key]
+ for key in loss_dict_ema
+ }
+ self.log_dict(loss_dict_no_ema,
+ prog_bar=False,
+ logger=True,
+ on_step=False,
+ on_epoch=True)
+ self.log_dict(loss_dict_ema,
+ prog_bar=False,
+ logger=True,
+ on_step=False,
+ on_epoch=True)
+
+ def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None:
+ """
+ Update EMA after each train batch (if enabled).
+ """
+ if self.use_ema:
+ self.model_ema(self.model)
+
+ def _get_rows_from_list(self, samples: List[Tensor]) -> Tensor:
+ """
+ Arrange a list of (B, C, ...) tensors into a grid for logging.
+
+ Args:
+ samples: List of tensors at different timesteps.
+
+ Returns:
+ Grid image tensor suitable for visualization.
+ """
+ n_imgs_per_row = len(samples)
+ denoise_grid = rearrange(samples, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_imgs_per_row)
+ return denoise_grid
+
+ @torch.no_grad()
+ def log_images(
+ self,
+ batch: Mapping[str, Tensor],
+ N: int = 8,
+ n_row: int = 2,
+ sample: bool = True,
+ return_keys: Optional[Sequence[str]] = None,
+ **kwargs: Any,
+ ) -> Dict[str, Tensor]:
+ """
+ Create tensors for image logging: inputs, diffusion row, (optional) samples.
+
+ Args:
+ batch: Batch mapping.
+ N: Number of examples to visualize.
+ n_row: Number of examples for diffusion-row visualization.
+ sample: If True, also run reverse diffusion to produce samples.
+ return_keys: If provided, filter the returned dict to these keys.
+
+ Returns:
+ Dict of image tensors.
+ """
+ log = dict()
+ x = self.get_input(batch, self.first_stage_key)
+ N = min(x.shape[0], N)
+ n_row = min(x.shape[0], n_row)
+ x = x.to(self.device)[:N]
+ log["inputs"] = x
+
+ # Get diffusion row
+ diffusion_row = list()
+ x_start = x[:n_row]
+
+ for t in range(self.num_timesteps):
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
+ t = t.to(self.device).long()
+ noise = torch.randn_like(x_start)
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ diffusion_row.append(x_noisy)
+
+ log["diffusion_row"] = self._get_rows_from_list(diffusion_row)
+
+ if sample:
+ # Get denoise row
+ with self.ema_scope("Plotting"):
+ samples, denoise_row = self.sample(batch_size=N,
+ return_intermediates=True)
+
+ log["samples"] = samples
+ log["denoise_row"] = self._get_rows_from_list(denoise_row)
+
+ if return_keys:
+ if np.intersect1d(list(log.keys()), return_keys).shape[0] == 0:
+ return log
+ else:
+ return {key: log[key] for key in return_keys}
+ return log
+
+ def configure_optimizers(self) -> torch.optim.Optimizer:
+ """
+ Build the optimizer (AdamW) over model parameters (+ logvar if learned).
+ """
+ lr = self.learning_rate
+ params = list(self.model.parameters())
+ if self.learn_logvar:
+ params = params + [self.logvar]
+ opt = torch.optim.AdamW(params, lr=lr)
+ return opt
+
+
+class LatentDiffusion(DDPM):
+ """
+ Main Class: Latent-diffusion model on top of DDPM (first/cond stages + guidance).
+ """
+
+ def __init__(self,
+ first_stage_config: OmegaConf,
+ cond_stage_config: OmegaConf,
+ num_timesteps_cond: int | None = None,
+ cond_stage_key: str = "instruction",
+ cond_stage_trainable: bool = False,
+ cond_stage_forward: str | None = None,
+ conditioning_key: str | None = None,
+ uncond_prob: float = 0.2,
+ uncond_type: str = "empty_seq",
+ scale_factor: str = 1.0,
+ scale_by_std: bool = False,
+ encoder_type: str = "2d",
+ only_model: bool = False,
+ noise_strength: float = 0.0,
+ use_dynamic_rescale: bool = False,
+ base_scale: float = 0.7,
+ turning_step: int = 400,
+ interp_mode: bool = False,
+ fps_condition_type: str = 'fs',
+ perframe_ae: bool = False,
+ logdir: str | None = None,
+ rand_cond_frame: bool = False,
+ en_and_decode_n_samples_a_time: int | None = None,
+ *args,
+ **kwargs):
+ """
+ Args:
+ first_stage_config: OmegaConf config for the first-stage autoencoder.
+ cond_stage_config: OmegaConf config for the conditioning encoder/module.
+ num_timesteps_cond: Number of condition timesteps used for schedule shortening.
+ cond_stage_key: Batch key for conditioning input (e.g., "instruction").
+ cond_stage_trainable: Whether the conditioning module is trainable.
+ cond_stage_forward: Optional method name to call on cond model instead of default.
+ conditioning_key: Conditioning mode (e.g., "crossattn", "concat").
+ uncond_prob: Probability to drop/zero the condition for classifier-free guidance.
+ uncond_type: Strategy for unconditional condition ("zero_embed" or "empty_seq").
+ scale_factor: Fixed latent scale multiplier if not using std-scaling.
+ scale_by_std: If True, compute scale as 1/std of latents at first batch.
+ encoder_type: "2d" (per-frame) or "3d" (volumetric) first-stage behavior.
+ only_model: If True, load only inner model weights when restoring from ckpt.
+ noise_strength: Extra offset noise strength for inputs (when > 0).
+ use_dynamic_rescale: If True, apply time-dependent rescaling array.
+ base_scale: Target base scale used by dynamic rescaling after turning_step.
+ turning_step: Steps to transition from 1.0 to base_scale in dynamic rescaling.
+ interp_mode: Flag for interpolation-specific behaviors (reserved).
+ fps_condition_type: Frame-per-second conditioning mode label.
+ perframe_ae: If True, encode/decode one frame at a time.
+ logdir: Optional directory for logs.
+ rand_cond_frame: If True, randomly select conditioning frames.
+ en_and_decode_n_samples_a_time: Optional per-step batch size for (en|de)code loops.
+ """
+
+ self.num_timesteps_cond = default(num_timesteps_cond, 1)
+ self.scale_by_std = scale_by_std
+ assert self.num_timesteps_cond <= kwargs['timesteps']
+ # For backwards compatibility after implementation of DiffusionWrapper
+ ckpt_path = kwargs.pop("ckpt_path", None)
+ ignore_keys = kwargs.pop("ignore_keys", [])
+ conditioning_key = default(conditioning_key, 'crossattn')
+ super().__init__(conditioning_key=conditioning_key, *args, **kwargs)
+
+ self.cond_stage_trainable = cond_stage_trainable
+ self.cond_stage_key = cond_stage_key
+ self.noise_strength = noise_strength
+ self.use_dynamic_rescale = use_dynamic_rescale
+ self.interp_mode = interp_mode
+ self.fps_condition_type = fps_condition_type
+ self.perframe_ae = perframe_ae
+
+ self.logdir = logdir
+ self.rand_cond_frame = rand_cond_frame
+ self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
+
+ try:
+ self.num_downs = len(
+ first_stage_config.params.ddconfig.ch_mult) - 1
+ except:
+ self.num_downs = 0
+ if not scale_by_std:
+ self.scale_factor = scale_factor
+ else:
+ self.register_buffer('scale_factor', torch.tensor(scale_factor))
+
+ if use_dynamic_rescale:
+ scale_arr1 = np.linspace(1.0, base_scale, turning_step)
+ scale_arr2 = np.full(self.num_timesteps, base_scale)
+ scale_arr = np.concatenate((scale_arr1, scale_arr2))
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+ self.register_buffer('scale_arr', to_torch(scale_arr))
+
+ self.instantiate_first_stage(first_stage_config)
+ self.instantiate_cond_stage(cond_stage_config)
+ self.first_stage_config = first_stage_config
+ self.cond_stage_config = cond_stage_config
+ self.clip_denoised = False
+
+ self.cond_stage_forward = cond_stage_forward
+ self.encoder_type = encoder_type
+ assert (encoder_type in ["2d", "3d"])
+ self.uncond_prob = uncond_prob
+ self.classifier_free_guidance = True if uncond_prob > 0 else False
+ assert (uncond_type in ["zero_embed", "empty_seq"])
+ self.uncond_type = uncond_type
+
+ self.restarted_from_ckpt = False
+ if ckpt_path is not None:
+ self.init_from_ckpt(ckpt_path, ignore_keys, only_model=only_model)
+ self.restarted_from_ckpt = True
+
+ def make_cond_schedule(self) -> None:
+ """
+ Build the condition timestep schedule.
+ """
+ self.cond_ids = torch.full(size=(self.num_timesteps, ),
+ fill_value=self.num_timesteps - 1,
+ dtype=torch.long)
+ ids = torch.round(
+ torch.linspace(0, self.num_timesteps - 1,
+ self.num_timesteps_cond)).long()
+ self.cond_ids[:self.num_timesteps_cond] = ids
+
+ @rank_zero_only
+ @torch.no_grad()
+ def on_train_batch_start(self,
+ batch: Mapping[str, Any],
+ batch_idx: int,
+ dataloader_idx: int | None = None) -> None:
+ """
+ Args:
+ batch: Current batch mapping.
+ batch_idx: Index of the batch within the epoch.
+ dataloader_idx: Optional dataloader index in multi-loader setups.
+ """
+ # Only for very first batch, reset the self.scale_factor
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and \
+ not self.restarted_from_ckpt:
+ assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
+ # set rescale weight to 1./std of encodings
+ mainlogger.info("### USING STD-RESCALING ###")
+ x = super().get_input(batch, self.first_stage_key)
+ x = x.to(self.device)
+ encoder_posterior = self.encode_first_stage(x)
+ z = self.get_first_stage_encoding(encoder_posterior).detach()
+ del self.scale_factor
+ self.register_buffer('scale_factor', 1. / z.flatten().std())
+ mainlogger.info(
+ f"setting self.scale_factor to {self.scale_factor}")
+ mainlogger.info("### USING STD-RESCALING ###")
+ mainlogger.info(f"std={z.flatten().std()}")
+
+ def register_schedule(self,
+ given_betas: np.ndarray | None = None,
+ beta_schedule: str = "linear",
+ timesteps: int = 1000,
+ linear_start: float = 1e-4,
+ linear_end: float = 2e-2,
+ cosine_s: float = 8e-3) -> None:
+ """
+ Extend base schedule registration and optionally shorten conditioning schedule.
+
+ Args:
+ given_betas: Optional precomputed beta schedule.
+ beta_schedule: Name of schedule function ("linear", "cosine", etc.).
+ timesteps: Number of diffusion steps.
+ linear_start: Linear schedule start beta (if used).
+ linear_end: Linear schedule end beta (if used).
+ cosine_s: Cosine schedule parameter (if used).
+ """
+ super().register_schedule(given_betas, beta_schedule, timesteps,
+ linear_start, linear_end, cosine_s)
+
+ self.shorten_cond_schedule = self.num_timesteps_cond > 1
+ if self.shorten_cond_schedule:
+ self.make_cond_schedule()
+
+ def instantiate_first_stage(self, config: OmegaConf) -> None:
+ """
+ Build and freeze the first-stage (autoencoder) model.
+
+ Args:
+ config: OmegaConf config describing the first-stage model to instantiate.
+ """
+ model = instantiate_from_config(config)
+ self.first_stage_model = model.eval()
+ self.first_stage_model.train = disabled_train
+ for param in self.first_stage_model.parameters():
+ param.requires_grad = False
+
+ def instantiate_cond_stage(self, config: OmegaConf) -> None:
+ """
+ Build the conditioning stage model.
+
+ Args:
+ config: OmegaConf config describing the conditioning model to instantiate.
+
+ """
+ if not self.cond_stage_trainable:
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model.eval()
+ self.cond_stage_model.train = disabled_train
+ for param in self.cond_stage_model.parameters():
+ param.requires_grad = False
+ else:
+ model = instantiate_from_config(config)
+ self.cond_stage_model = model
+
+ def get_learned_conditioning(self, c: Any) -> Tensor:
+ """
+ Encode conditioning input into an embedding tensor.
+
+ Args:
+ c: Raw conditioning input (tensor, list/dict of strings, tokens, etc.).
+
+ Returns:
+ Conditioning embedding as a tensor (shape depends on cond model).
+ """
+ if self.cond_stage_forward is None:
+ if hasattr(self.cond_stage_model, 'encode') and callable(
+ self.cond_stage_model.encode):
+ c = self.cond_stage_model.encode(c)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ else:
+ c = self.cond_stage_model(c)
+ else:
+ assert hasattr(self.cond_stage_model, self.cond_stage_forward)
+ c = getattr(self.cond_stage_model, self.cond_stage_forward)(c)
+ return c
+
+ def get_first_stage_encoding(
+ self,
+ encoder_posterior: DiagonalGaussianDistribution | Tensor,
+ noise: Tensor | None = None) -> Tensor:
+ """
+ Convert encoder posterior to latent z and apply scaling.
+
+ Args:
+ encoder_posterior: First-stage output; either a Gaussian posterior or a latent tensor.
+ noise: Optional noise for sampling if posterior is Gaussian.
+
+ Returns:
+ Scaled latent tensor z.
+ """
+ if isinstance(encoder_posterior, DiagonalGaussianDistribution):
+ z = encoder_posterior.sample(noise=noise)
+ elif isinstance(encoder_posterior, torch.Tensor):
+ z = encoder_posterior
+ else:
+ raise NotImplementedError(
+ f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented"
+ )
+ return self.scale_factor * z
+
+ @torch.no_grad()
+ def encode_first_stage(self, x: Tensor) -> Tensor:
+ """
+ Encode input frames/images into latent space.
+
+ Args:
+ x: Input tensor, either (B, C, ...).
+
+ Returns:
+ Latent tensor with shape matched to input.
+ """
+ if self.encoder_type == "2d" and x.dim() == 5:
+ b, _, t, _, _ = x.shape
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
+ reshape_back = True
+ else:
+ reshape_back = False
+
+ ## Consume more GPU memory but faster
+ if not self.perframe_ae:
+ encoder_posterior = self.first_stage_model.encode(x)
+ results = self.get_first_stage_encoding(encoder_posterior).detach()
+ else: ## Consume less GPU memory but slower
+ results = []
+ for index in range(x.shape[0]):
+ frame_batch = self.first_stage_model.encode(x[index:index +
+ 1, :, :, :])
+ frame_result = self.get_first_stage_encoding(
+ frame_batch).detach()
+ results.append(frame_result)
+ results = torch.cat(results, dim=0)
+
+ if reshape_back:
+ results = rearrange(results, '(b t) c h w -> b c t h w', b=b, t=t)
+
+ return results
+
+ def decode_core(self, z: Tensor, **kwargs: Any) -> Tensor:
+ """
+ Decode latent z back to pixel space (2D or per-frame).
+
+ Args:
+ z: Latent tensor (B, C, ...).
+
+ Returns:
+ Decoded tensor in pixel space with shape matching the input layout.
+ """
+ if self.encoder_type == "2d" and z.dim() == 5:
+ b, _, t, _, _ = z.shape
+ z = rearrange(z, 'b c t h w -> (b t) c h w')
+ reshape_back = True
+ else:
+ reshape_back = False
+
+ if not self.perframe_ae:
+ z = 1. / self.scale_factor * z
+ results = self.first_stage_model.decode(z, **kwargs)
+ else:
+ results = []
+ for index in range(z.shape[0]):
+ frame_z = 1. / self.scale_factor * z[index:index + 1, :, :, :]
+ frame_result = self.first_stage_model.decode(frame_z, **kwargs)
+ results.append(frame_result)
+ results = torch.cat(results, dim=0)
+
+ if reshape_back:
+ results = rearrange(results, '(b t) c h w -> b c t h w', b=b, t=t)
+ return results
+
+ @torch.no_grad()
+ def decode_first_stage(self, z: Tensor, **kwargs: Any) -> Tensor:
+ """
+ Decode latent with no gradient.
+
+ Args:
+ z: Latent tensor to decode.
+ **kwargs: Extra args for the decoder.
+
+ Returns:
+ Decoded tensor in pixel space.
+
+ """
+ return self.decode_core(z, **kwargs)
+
+ # Same as above but without decorator
+ def differentiable_decode_first_stage(self, z: Tensor,
+ **kwargs: Any) -> Tensor:
+ """
+ Decode latent with gradients enabled.
+
+ Args:
+ z: Latent tensor to decode.
+ Returns:
+ ecoded tensor in pixel space.
+
+ """
+ return self.decode_core(z, **kwargs)
+
+ @torch.no_grad()
+ def get_batch_input(self,
+ batch: Mapping[str, Any],
+ random_uncond: bool,
+ return_first_stage_outputs: bool = False,
+ return_original_cond: bool = False) -> list[Any]:
+ """
+ Prepare batch: encode inputs to latents and produce conditioning embeddings.
+
+ Args:
+ batch: Batch mapping containing first-stage inputs and conditioning.
+ random_uncond: If True and `uncond_type` allows, randomly drop conditions.
+ return_first_stage_outputs: If True, also decode z to xrec for logging.
+ return_original_cond: If True, also return the raw condition object.
+ """
+ x = super().get_input(batch, self.first_stage_key)
+
+ # Encode video frames x to z via a 2D encoder
+ z = self.encode_first_stage(x)
+
+ # Get instruction condition
+ cond = batch[self.cond_stage_key]
+ if random_uncond and self.uncond_type == 'empty_seq':
+ for i, ci in enumerate(cond):
+ if random.random() < self.uncond_prob:
+ cond[i] = ""
+ if isinstance(cond, dict) or isinstance(cond, list):
+ cond_emb = self.get_learned_conditioning(cond)
+ else:
+ cond_emb = self.get_learned_conditioning(cond.to(self.device))
+ if random_uncond and self.uncond_type == 'zero_embed':
+ for i, ci in enumerate(cond):
+ if random.random() < self.uncond_prob:
+ cond_emb[i] = torch.zeros_like(cond_emb[i])
+
+ out = [z, cond_emb]
+ if return_first_stage_outputs:
+ xrec = self.decode_first_stage(z)
+ out.extend([xrec])
+
+ if return_original_cond:
+ out.append(cond)
+
+ return out
+
+ def forward(
+ self,
+ x: Tensor,
+ x_action: Tensor,
+ x_state: Tensor,
+ c: Any,
+ **kwargs: Any,
+ ) -> tuple[Tensor, dict[str, Tensor]]:
+ """
+ Args:
+ x: Input latent (or pixel) tensor for the primary stream.
+ x_action: Action tensor associated with the batch.
+ x_state: State tensor associated with the batch.
+ c: Conditioning object (tensor/list/dict) consumed by `apply_model`.
+
+ Returns:
+ (loss, loss_dict) tuple.
+ """
+
+ t = torch.randint(0,
+ self.num_timesteps, (x.shape[0], ),
+ device=self.device).long()
+ if self.use_dynamic_rescale:
+ x = x * extract_into_tensor(self.scale_arr, t, x.shape)
+
+ return self.p_losses(x, x_action, x_state, c, t, **kwargs)
+
+ def shared_step(self, batch: Mapping[str, Any], random_uncond: bool,
+ **kwargs: Any) -> tuple[Tensor, dict[str, Tensor]]:
+ """
+ Common train/val step: build inputs, run forward, return loss/logs.
+
+ Args:
+ batch: Input batch mapping.
+ random_uncond: Whether to apply classifier-free guidance dropout to cond.
+ **kwargs: Extra args forwarded to `forward`.
+
+ Returns:
+ (loss, loss_dict) tuple.
+ """
+ x, c = self.get_batch_input(batch, random_uncond=random_uncond)
+ loss, loss_dict = self(x, c, **kwargs)
+
+ return loss, loss_dict
+
+ def apply_model(self, x_noisy: Tensor, x_action_noisy: Tensor,
+ x_state_noisy: Tensor, t: Tensor, cond: Any,
+ **kwargs: Any) -> Tensor | tuple[Tensor, Tensor, Tensor]:
+ """
+ Apply inner diffusion model with standardized conditioning dict.
+
+ Args:
+ x_noisy: Noisy latent input for the primary stream.
+ x_action_noisy: Noisy action tensor aligned with t.
+ x_state_noisy: Noisy state tensor aligned with t.
+ t: Timestep indices (B,).
+ cond: Raw conditioning; will be wrapped into the proper key if not a dict.
+ **kwargs: Extra args forwarded to the inner model call.
+
+ Returns:
+ Either a single tensor or a tuple of tensors (x, action, state) depending on model.
+ """
+ if isinstance(cond, dict):
+ pass
+ else:
+ if not isinstance(cond, list):
+ cond = [cond]
+ key = 'c_concat' if self.model.conditioning_key == 'concat' else 'c_crossattn'
+ cond = {key: cond}
+
+ x_recon, x_action_recon, x_state_recon = self.model(
+ x_noisy, x_action_noisy, x_state_noisy, t, **cond, **kwargs)
+
+ if isinstance(x_recon, tuple):
+ return x_recon[0]
+ else:
+ return x_recon, x_action_recon, x_state_recon
+
+ def p_losses(
+ self,
+ x_start: Tensor,
+ x_action_start: Tensor,
+ x_state_start: Tensor,
+ cond: Any,
+ t: Tensor,
+ noise: Tensor | None = None,
+ action_noise: Tensor | None = None,
+ **kwargs: Any,
+ ) -> tuple[Tensor, dict[str, Tensor]]:
+ """
+ Compute the per-step training losses for latent diffusion.
+
+ Args:
+ x_start: Clean primary latent (or pixel) tensor.
+ x_action_start: Clean action tensor aligned with x_start.
+ x_state_start: Clean state tensor aligned with x_start.
+ cond: Conditioning object; may include masks for action/state losses.
+ t: Timestep indices (B,).
+ noise: Optional epsilon noise for the primary stream (else sampled).
+ action_noise: Optional epsilon noise for the action stream (else sampled).
+ **kwargs: Extra args forwarded into `apply_model` (and logged if needed).
+
+ Returns:
+ (loss, loss_dict)
+ """
+ if self.noise_strength > 0:
+ b, c, f, _, _ = x_start.shape
+ offset_noise = torch.randn(b, c, f, 1, 1, device=x_start.device)
+ noise = default(
+ noise, lambda: torch.randn_like(x_start) + self.noise_strength
+ * offset_noise)
+ else:
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ action_noise = torch.randn(x_action_start.shape,
+ device=x_action_start.device)
+ action_noise_new = action_noise + self.input_pertub * torch.randn(
+ x_action_start.shape, device=x_action_start.device)
+
+ state_noise = torch.randn(x_state_start.shape,
+ device=x_state_start.device)
+ state_noise_new = state_noise + self.input_pertub * torch.randn(
+ x_state_start.shape, device=x_state_start.device)
+
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ x_action_noisy = self.dp_noise_scheduler_action.add_noise(
+ x_action_start, action_noise_new, t[:x_action_start.shape[0]])
+ x_state_noisy = self.dp_noise_scheduler_state.add_noise(
+ x_state_start, state_noise_new, t[:x_state_start.shape[0]])
+
+ kwargs['x_start'] = x_start
+ model_output, model_action_output, model_state_output = self.apply_model(
+ x_noisy, x_action_noisy, x_state_noisy, t, cond, **kwargs)
+
+ loss_dict = {}
+ prefix = 'train' if self.training else 'val'
+
+ if self.parameterization == "x0":
+ target = x_start
+ elif self.parameterization == "eps":
+ target = noise
+ elif self.parameterization == "v":
+ target = self.get_v(x_start, noise, t)
+ else:
+ raise NotImplementedError()
+
+ target_action = action_noise
+ target_state = state_noise
+
+ loss_simple = self.get_loss(model_output, target,
+ mean=False).mean([1, 2, 3, 4])
+ loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
+ loss_action_simple = F.mse_loss(model_action_output,
+ target_action,
+ reduction='none')
+ action_mask = cond['c_crossattn_action'][-1]
+ loss_action_simple *= action_mask
+ loss_action_simple = loss_action_simple.type(loss_action_simple.dtype)
+ loss_action_simple = reduce(loss_action_simple, 'b ... -> b (...)',
+ 'mean')
+ loss_action_simple = loss_action_simple.sum() / action_mask.sum()
+ loss_dict.update({f'{prefix}/loss_action_simple': loss_action_simple})
+
+ loss_state_simple = F.mse_loss(model_state_output,
+ target_state,
+ reduction='none')
+ state_mask = cond['c_crossattn_action'][-2]
+ loss_state_simple *= state_mask
+ loss_state_simple = loss_state_simple.type(loss_state_simple.dtype)
+ loss_state_simple = reduce(loss_state_simple, 'b ... -> b (...)',
+ 'mean')
+ loss_state_simple = loss_state_simple.sum() / state_mask.sum()
+ loss_dict.update({f'{prefix}/loss_state_simple': loss_state_simple})
+
+ if self.logvar.device is not self.device:
+ self.logvar = self.logvar.to(self.device)
+ logvar_t = self.logvar[t]
+ loss = loss_simple / torch.exp(logvar_t) + logvar_t
+ loss_action = loss_action_simple
+ loss_state = loss_state_simple
+
+ if self.learn_logvar:
+ loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
+ loss_dict.update({'logvar': self.logvar.data.mean()})
+
+ loss = self.l_simple_weight * loss.mean()
+ loss_vlb = self.get_loss(model_output, target,
+ mean=False).mean(dim=(1, 2, 3, 4))
+ loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
+ loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
+
+ loss_dict.update({f'{prefix}/loss_action_vlb': loss_action})
+ loss_dict.update({f'{prefix}/loss_state_vlb': loss_state})
+
+ loss += (self.original_elbo_weight * loss_vlb)
+ loss_dict.update({f'{prefix}/loss': loss})
+
+ loss_dict.update({f'{prefix}/loss_action': loss_action})
+ loss_dict.update({f'{prefix}/loss_state': loss_state})
+
+ if cond['c_crossattn_action'][2]:
+ return loss + loss_state + loss_action * 0.0, loss_dict
+ else:
+ return loss + loss_action + loss_state * 0.0, loss_dict
+
+ def training_step(self, batch: Mapping[str, Any],
+ batch_idx: int) -> Tensor:
+ """
+ Lightning training step: compute loss and log metrics.
+
+ Args:
+ batch: Training batch mapping.
+ batch_idx: Batch index within current epoch.
+
+ Returns:
+ Scalar loss tensor for optimization.
+ """
+ loss, loss_dict = self.shared_step(
+ batch, random_uncond=self.classifier_free_guidance)
+ loss_dict.update(
+ {'lr': self.trainer.optimizers[0].param_groups[0]['lr']})
+ loss_dict.update({
+ 'lr_action_unet':
+ self.trainer.optimizers[0].param_groups[1]['lr']
+ })
+ # Sync_dist | rank_zero_only
+ self.log_dict(loss_dict,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ sync_dist=False)
+ if (batch_idx + 1) % self.log_every_t == 0:
+ mainlogger.info(
+ f"batch:{batch_idx}|epoch:{self.current_epoch} [globalstep:{self.global_step}]: loss={loss}"
+ )
+ return loss
+
+ @torch.no_grad()
+ def validation_step(self, batch: Mapping[str, Any],
+ batch_idx: int) -> None:
+ """
+ Lightning validation step: compute loss and log metrics.
+
+ Args:
+ batch: Validation batch mapping.
+ batch_idx: Batch index in validation loop.
+ """
+ _, loss_dict_no_ema = self.shared_step(batch, random_uncond=False)
+ self.log_dict(loss_dict_no_ema,
+ prog_bar=False,
+ logger=True,
+ on_step=False,
+ on_epoch=True)
+
+ def _get_denoise_row_from_list(self,
+ samples: Sequence[Tensor],
+ desc: str = '') -> Tensor:
+ """
+ Decode a list of latents and pack into a grid for visualization.
+
+ Args:
+ samples: Sequence of latent tensors to decode and tile.
+ desc: Optional tqdm description string.
+
+ Returns:
+ Grid image tensor suitable for logging.
+
+ """
+ denoise_row = []
+ for zd in tqdm(samples, desc=desc):
+ denoise_row.append(self.decode_first_stage(zd.to(self.device)))
+ n_log_timesteps = len(denoise_row)
+
+ denoise_row = torch.stack(denoise_row)
+
+ if denoise_row.dim() == 5:
+ denoise_grid = rearrange(denoise_row, 'n b c h w -> b n c h w')
+ denoise_grid = rearrange(denoise_grid, 'b n c h w -> (b n) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=n_log_timesteps)
+ elif denoise_row.dim() == 6:
+ video_length = denoise_row.shape[3]
+ denoise_grid = rearrange(denoise_row, 'n b c t h w -> b n c t h w')
+ denoise_grid = rearrange(denoise_grid,
+ 'b n c t h w -> (b n) c t h w')
+ denoise_grid = rearrange(denoise_grid, 'n c t h w -> (n t) c h w')
+ denoise_grid = make_grid(denoise_grid, nrow=video_length)
+ else:
+ raise ValueError
+
+ return denoise_grid
+
+ @torch.no_grad()
+ def log_images(self,
+ batch: Mapping[str, Any],
+ sample: bool = True,
+ ddim_steps: int = 200,
+ ddim_eta: float = 1.0,
+ plot_denoise_rows: bool = False,
+ unconditional_guidance_scale: float = 1.0,
+ **kwargs: Any) -> dict[str, Tensor]:
+ """ Log images for LatentDiffusion """
+ # Control sampled imgae for logging, larger value may cause OOM
+ sampled_img_num = 2
+ for key in batch.keys():
+ batch[key] = batch[key][:sampled_img_num]
+
+ # TBD: currently, classifier_free_guidance sampling is only supported by DDIM
+ use_ddim = ddim_steps is not None
+ log = dict()
+ z, c, xrec, xc = self.get_batch_input(batch,
+ random_uncond=False,
+ return_first_stage_outputs=True,
+ return_original_cond=True,
+ logging=True)
+
+ N = xrec.shape[0]
+ log["reconst"] = xrec
+ log["condition"] = xc
+
+ if sample:
+ uc = None
+ with self.ema_scope("Plotting"):
+ samples, z_denoise_row = self.sample_log(
+ cond=c,
+ batch_size=N,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ x0=z,
+ **kwargs)
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ return log
+
+ def p_mean_variance(
+ self,
+ x: Tensor,
+ c: Any,
+ t: Tensor,
+ clip_denoised: bool,
+ return_x0: bool = False,
+ score_corrector: Any = None,
+ corrector_kwargs: Mapping[str, Any] | None = None,
+ **kwargs: Any
+ ) -> tuple[Tensor, Tensor, Tensor] | tuple[Tensor, Tensor, Tensor, Tensor]:
+ """
+ Predict posterior parameters (and optionally x0) at timestep t.
+
+ Args:
+ x: Current latent at timestep t.
+ c: Conditioning object passed to the inner model/score corrector.
+ t: Timestep indices (B,).
+ clip_denoised: If True, clamp predicted x0 to [-1, 1].
+ return_x0: If True, also return predicted x0.
+ score_corrector: Optional score-corrector object with `modify_score`.
+ corrector_kwargs: Extra kwargs for the score corrector.
+ **kwargs: Forwarded to `apply_model`.
+
+ Returns:
+ (mean, var, log_var) or (mean, var, log_var, x0) tensors.
+ """
+
+ t_in = t
+ model_out = self.apply_model(x, t_in, c, **kwargs)
+
+ if score_corrector is not None:
+ assert self.parameterization == "eps"
+ model_out = score_corrector.modify_score(self, model_out, x, t, c,
+ **corrector_kwargs)
+
+ if self.parameterization == "eps":
+ x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
+ elif self.parameterization == "x0":
+ x_recon = model_out
+ else:
+ raise NotImplementedError()
+
+ if clip_denoised:
+ x_recon.clamp_(-1., 1.)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(
+ x_start=x_recon, x_t=x, t=t)
+
+ if return_x0:
+ return model_mean, posterior_variance, posterior_log_variance, x_recon
+ else:
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self,
+ x: Tensor,
+ c: Any,
+ t: Tensor,
+ clip_denoised: bool = False,
+ repeat_noise: bool = False,
+ return_x0: bool = False,
+ temperature: float = 1.0,
+ noise_dropout: float = 0.0,
+ score_corrector: Any = None,
+ corrector_kwargs: Mapping[str, Any] | None = None,
+ **kwargs: Any) -> Tensor | tuple[Tensor, Tensor]:
+ """
+ Draw a single reverse-diffusion step (optionally return x0).
+
+ Args:
+ x: Current latent at timestep t.
+ c: Conditioning object for the model.
+ t: Timestep indices (B,).
+ clip_denoised: Clamp predicted x0 to [-1, 1] when forming the mean.
+ repeat_noise: If True, reuse the same noise across batch.
+ return_x0: If True, also return the predicted x0.
+ temperature: Temperature for sampling noise scale.
+ noise_dropout: Dropout probability applied to the sampled noise.
+ score_corrector: Optional score-corrector to adjust model outputs.
+ corrector_kwargs: Extra kwargs for the corrector.
+ **kwargs: Forwarded to `p_mean_variance`.
+
+ Returns:
+ Next latent (and optionally x0).
+ """
+
+ b, *_, device = *x.shape, x.device
+ outputs = self.p_mean_variance(x=x, c=c, t=t, clip_denoised=clip_denoised, return_x0=return_x0, \
+ score_corrector=score_corrector, corrector_kwargs=corrector_kwargs, **kwargs)
+ if return_x0:
+ model_mean, _, model_log_variance, x0 = outputs
+ else:
+ model_mean, _, model_log_variance = outputs
+
+ noise = noise_like(x.shape, device, repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+ # No noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(
+ b, *((1, ) * (len(x.shape) - 1)))
+
+ if return_x0:
+ return model_mean + nonzero_mask * (
+ 0.5 * model_log_variance).exp() * noise, x0
+ else:
+ return model_mean + nonzero_mask * (
+ 0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def p_sample_loop(self,
+ cond: Any,
+ shape: Sequence[int],
+ return_intermediates: bool = False,
+ x_T: Tensor | None = None,
+ verbose: bool = True,
+ callback: Callable[[int], Any] | None = None,
+ timesteps: int | None = None,
+ mask: Tensor | None = None,
+ x0: Tensor | None = None,
+ img_callback: Callable[[Tensor, int], Any] | None = None,
+ start_T: int | None = None,
+ log_every_t: int | None = None,
+ **kwargs: Any) -> Tensor | tuple[Tensor, list[Tensor]]:
+ """
+ Run the full reverse process from noise to sample(s).
+
+ Args:
+ cond: Conditioning object (tensor/dict/list), optionally noised when cond schedule is shortened.
+ shape: Output latent shape (B, C, ...).
+ return_intermediates: If True, also return intermediate latents.
+ x_T: Optional starting noise latent (else sampled from N(0, I)).
+ verbose: If True, show tqdm progress.
+ callback: Optional function called with the current timestep i.
+ timesteps: Number of reverse steps to perform (default: self.num_timesteps).
+ mask: Optional inpainting mask; ones keep original x0 regions.
+ x0: Optional original latent for masked regions (when using `mask`).
+ img_callback: Optional function called with (img, i) every step.
+ start_T: Optional cap to limit starting step (min(timesteps, start_T)).
+ log_every_t: Logging frequency for collecting intermediates (defaults to self.log_every_t).
+
+ Returns:
+ Final latent sample (and optionally the list of intermediates).
+ """
+
+ if not log_every_t:
+ log_every_t = self.log_every_t
+ device = self.betas.device
+ b = shape[0]
+ # Sample an initial noise
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ else:
+ img = x_T
+
+ intermediates = [img]
+ if timesteps is None:
+ timesteps = self.num_timesteps
+ if start_T is not None:
+ timesteps = min(timesteps, start_T)
+
+ iterator = tqdm(
+ reversed(range(0, timesteps)), desc='Sampling t',
+ total=timesteps) if verbose else reversed(range(0, timesteps))
+
+ if mask is not None:
+ assert x0 is not None
+ assert x0.shape[2:3] == mask.shape[2:3]
+
+ for i in iterator:
+ ts = torch.full((b, ), i, device=device, dtype=torch.long)
+ if self.shorten_cond_schedule:
+ assert self.model.conditioning_key != 'hybrid'
+ tc = self.cond_ids[ts].to(cond.device)
+ cond = self.q_sample(x_start=cond,
+ t=tc,
+ noise=torch.randn_like(cond))
+
+ img = self.p_sample(img,
+ cond,
+ ts,
+ clip_denoised=self.clip_denoised,
+ **kwargs)
+ if mask is not None:
+ img_orig = self.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ if i % log_every_t == 0 or i == timesteps - 1:
+ intermediates.append(img)
+ if callback: callback(i)
+ if img_callback: img_callback(img, i)
+
+ if return_intermediates:
+ return img, intermediates
+ return img
+
+ @torch.no_grad()
+ def sample(self,
+ cond,
+ batch_size: int = 16,
+ return_intermediates: bool = False,
+ x_T: Tensor | None = None,
+ verbose: bool = True,
+ timesteps: int | None = None,
+ mask: Tensor | None = None,
+ x0: Tensor | None = None,
+ shape: Sequence[int] | None = None,
+ **kwargs: Any) -> Tensor | tuple[Tensor, list[Tensor]]:
+ """
+ Convenience wrapper to run `p_sample_loop` with a full batch.
+
+ Args:
+ cond: Conditioning object; dict/list items are truncated to batch_size.
+ batch_size: Number of samples to generate.
+ return_intermediates: If True, return intermediates as well.
+ x_T: Optional starting noise latent (else sampled).
+ verbose: Whether to print sampling progress.
+ timesteps: Number of reverse steps (default: self.num_timesteps).
+ mask: Optional mask for partial generation/inpainting.
+ x0: Optional original latent used with `mask` during sampling.
+ shape: Optional output shape; if None, uses (B, C, T, H, W) from model config.
+
+ Returns:
+ Final latent (and optionally intermediates).
+ """
+ if shape is None:
+ shape = (batch_size, self.channels, self.temporal_length,
+ *self.image_size)
+ if cond is not None:
+ if isinstance(cond, dict):
+ cond = {
+ key:
+ cond[key][:batch_size] if not isinstance(cond[key], list)
+ else list(map(lambda x: x[:batch_size], cond[key]))
+ for key in cond
+ }
+ else:
+ cond = [c[:batch_size] for c in cond] if isinstance(
+ cond, list) else cond[:batch_size]
+ return self.p_sample_loop(cond,
+ shape,
+ return_intermediates=return_intermediates,
+ x_T=x_T,
+ verbose=verbose,
+ timesteps=timesteps,
+ mask=mask,
+ x0=x0,
+ **kwargs)
+
+ @torch.no_grad()
+ def sample_log(self, cond: Any, batch_size: int, ddim: bool,
+ ddim_steps: int,
+ **kwargs: Any) -> tuple[Any, Any, Any, Any]:
+ """
+ Produce samples (and intermediates), optionally via DDIM sampler.
+
+ Args:
+ cond: Conditioning object passed to the sampler.
+ batch_size: Number of samples to generate.
+ ddim: If True, use DDIM sampler; otherwise use ancestral sampling.
+ ddim_steps: Number of DDIM steps when `ddim` is True.
+
+ """
+ if ddim:
+ ddim_sampler = DDIMSampler(self)
+ shape = (self.channels, self.temporal_length, *self.image_size)
+ samples, actions, states, intermediates = ddim_sampler.sample(
+ ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
+
+ else:
+ samples, intermediates = self.sample(cond=cond,
+ batch_size=batch_size,
+ return_intermediates=True,
+ **kwargs)
+
+ return samples, actions, states, intermediates
+
+ def configure_schedulers(
+ self, optimizer: torch.optim.Optimizer) -> dict[str, Any]:
+ """
+ Build LR scheduler dict compatible with PyTorch Lightning.
+
+ Args:
+ optimizer: Optimizer instance for which to build the scheduler dict.
+
+ Returns:
+ Dict with keys {'scheduler', 'interval', 'frequency'} per Lightning API.
+ """
+ assert 'target' in self.scheduler_config
+ scheduler_name = self.scheduler_config.target.split('.')[-1]
+ interval = self.scheduler_config.interval
+ frequency = self.scheduler_config.frequency
+ if scheduler_name == "LambdaLRScheduler":
+ scheduler = instantiate_from_config(self.scheduler_config)
+ scheduler.start_step = self.global_step
+ lr_scheduler = {
+ 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
+ 'interval': interval,
+ 'frequency': frequency
+ }
+ elif scheduler_name == "CosineAnnealingLRScheduler":
+ scheduler = instantiate_from_config(self.scheduler_config)
+ decay_steps = scheduler.decay_steps
+ last_step = -1 if self.global_step == 0 else scheduler.start_step
+ lr_scheduler = {
+ 'scheduler':
+ CosineAnnealingLR(optimizer,
+ T_max=decay_steps,
+ last_epoch=last_step),
+ 'interval':
+ interval,
+ 'frequency':
+ frequency
+ }
+ else:
+ raise NotImplementedError
+ return lr_scheduler
+
+
+class LatentVisualDiffusion(LatentDiffusion):
+ """
+ Visual-conditioned latent diffusion with action/state heads and schedulers.
+
+ """
+
+ def __init__(self,
+ img_cond_stage_config: OmegaConf,
+ image_proj_stage_config: OmegaConf,
+ noise_scheduler_config: OmegaConf,
+ dp_optimizer_config: OmegaConf,
+ dp_ema_config: OmegaConf,
+ freeze_embedder: bool = True,
+ image_proj_model_trainable: bool = True,
+ n_obs_steps_imagen: int = 2,
+ n_obs_steps_acting: int = 2,
+ agent_state_dim: int = 14,
+ agent_action_dim: int = 14,
+ global_emb_dim: int = 1024,
+ input_pertub: float = 0.1,
+ lr_scheduler: str = 'cosine',
+ lr_warmup_steps: int = 500,
+ num_epochs: int = 15000,
+ gradient_accumulate_every: int = 1,
+ use_scheduler: bool = False,
+ dp_use_ema: bool = False,
+ pretrained_checkpoint: str | None = None,
+ decision_making_only: bool = True,
+ *args,
+ **kwargs):
+ """
+ Args:
+ img_cond_stage_config: OmegaConf for the *image* conditioning encoder.
+ image_proj_stage_config: OmegaConf for the image feature projector.
+ noise_scheduler_config: OmegaConf for DP noise schedulers (state/action).
+ dp_optimizer_config: OmegaConf for optimizer params of the UNet heads.
+ dp_ema_config: Optional EMA config for the action UNet.
+ freeze_embedder: If True, freeze the image embedder params.
+ image_proj_model_trainable: If True, train the image projector.
+ n_obs_steps_imagen: Number of observed steps for image conditions.
+ n_obs_steps_acting: Number of observed steps for acting head.
+ agent_state_dim: Dimension of agent state vector.
+ agent_action_dim: Dimension of agent action vector.
+ global_emb_dim: Embedding size for state/action/text/image fusion.
+ input_pertub: Perturbation scale added to action/state noises.
+ lr_scheduler: Name of LR scheduler (for SelectiveLRScheduler wrapper).
+ lr_warmup_steps: Warmup steps for scheduler creation.
+ num_epochs: Total training epochs.
+ gradient_accumulate_every: Gradient accumulation steps.
+ use_scheduler: If True, enable LR scheduling.
+ dp_use_ema: If True, maintain EMA for action UNet head.
+ pretrained_checkpoint: Optional path to a pretrained checkpoint.
+ decision_making_only: If True, use decision-only augmentation path.
+ """
+
+ super().__init__(*args, **kwargs)
+ self.image_proj_model_trainable = image_proj_model_trainable
+ self.agent_state_dim = agent_state_dim
+ self.agent_action_dim = agent_action_dim
+ self.global_emb_dim = global_emb_dim
+ self.n_obs_steps_imagen = n_obs_steps_imagen
+ self.n_obs_steps_acting = n_obs_steps_acting
+ self.decision_making_only = decision_making_only
+
+ self._init_embedder(img_cond_stage_config, freeze_embedder)
+ self._init_img_ctx_projector(image_proj_stage_config,
+ image_proj_model_trainable)
+ self._init_dp_noise_scheduler(noise_scheduler_config)
+ self._init_projectors()
+ if dp_use_ema:
+ self._init_dp_ema(dp_ema_config)
+
+ # Create a pos_embedder for state and action info, our state and action have an unified vector space
+ self.pos_embedder = SinusoidalPosEmb(self.global_emb_dim)
+ self.register_buffer('cond_pos_emb',
+ self.pos_embedder(torch.arange(
+ 0, 16))) #NOTE HAND-CODE 16
+
+ self.input_pertub = input_pertub
+ self.dp_optimizer_config = dp_optimizer_config
+ self.lr_scheduler = lr_scheduler
+ self.lr_warmup_steps = lr_warmup_steps
+ self.num_epochs = num_epochs
+ self.gradient_accumulate_every = gradient_accumulate_every
+ self.use_scheduler = use_scheduler
+ self.dp_use_ema = dp_use_ema
+ self.pretrained_checkpoint = pretrained_checkpoint
+
+ def _init_img_ctx_projector(self, config: OmegaConf,
+ trainable: bool) -> None:
+ """
+ Instantiate image context projector; optionally freeze.
+
+ Args:
+ config: OmegaConf for the projector module to instantiate.
+ trainable: If False, freeze the projector.
+ """
+ self.image_proj_model = instantiate_from_config(config)
+ if not trainable:
+ self.image_proj_model.eval()
+ self.image_proj_model.train = disabled_train
+ for param in self.image_proj_model.parameters():
+ param.requires_grad = False
+
+ def _init_embedder(self, config: OmegaConf, freeze: bool = True) -> None:
+ """
+ Instantiate the image embedder; optionally freeze.
+
+ Args:
+ config: OmegaConf for the embedder to instantiate.
+ freeze: If True, set to eval/disable grads.
+ """
+ self.embedder = instantiate_from_config(config)
+ if freeze:
+ self.embedder.eval()
+ self.embedder.train = disabled_train
+ for param in self.embedder.parameters():
+ param.requires_grad = False
+
+ def init_normalizers(self, normalize_config: OmegaConf,
+ dataset_stats: Mapping[str, Any]) -> None:
+ """
+ Create normalization and unnormalization utilities.
+
+ Args:
+ normalize_config: Config with shapes and normalization modes.
+ dataset_stats: Statistics dict used to compute normalization.
+ """
+ self.normalize_inputs = Normalize(
+ normalize_config.input_shapes,
+ normalize_config.input_normalization_modes, dataset_stats)
+ self.unnormalize_outputs = Unnormalize(
+ normalize_config.output_shapes,
+ normalize_config.output_normalization_modes, dataset_stats)
+
+ def _init_dp_noise_scheduler(self, config: OmegaConf) -> None:
+ """
+ Instantiate separate DP noise schedulers for action and state.
+
+ Args:
+ config: OmegaConf used to create scheduler instances.
+ """
+ self.dp_noise_scheduler_action = instantiate_from_config(config)
+ self.dp_noise_scheduler_state = instantiate_from_config(config)
+
+ def _init_dp_ema(self, config: OmegaConf | None) -> None:
+ """
+ Initialize EMA for UNet head.
+
+ Args:
+ config: EMA config, must contain 'params' sub-dict.
+ """
+ self.dp_ema_model = copy.deepcopy(
+ self.model.diffusion_model.action_unet)
+ self.dp_ema_model_on_device = False
+ self.dp_ema = EMAModel(**config['params'], model=self.dp_ema_model)
+
+ def _init_projectors(self):
+ """
+ Build small MLP projectors and positional embeddings for state/action.
+ """
+ self.state_projector = MLPProjector(self.agent_state_dim,
+ 1024) # NOTE HAND CODE
+ self.action_projector = MLPProjector(self.agent_action_dim,
+ 1024) # NOTE HAND CODE
+
+ self.agent_action_pos_emb = nn.Parameter(
+ torch.randn(1, 16, self.global_emb_dim))
+ self.agent_state_pos_emb = nn.Parameter(
+ torch.randn(1, self.n_obs_steps_imagen, self.global_emb_dim))
+
+ def _get_augmented_batch(
+ self,
+ z: Tensor,
+ state: Tensor,
+ obs_state: Tensor,
+ action: Tensor,
+ ins: Tensor,
+ null_ins: Tensor,
+ img: Tensor,
+ sim_mode: bool = False,
+ pre_action: Tensor | None = None,
+ logging: bool = False) -> tuple[Tensor, Tensor, list[Tensor]]:
+ """
+ Construct augmented conditioning batch for decision/simulation modes.
+
+ Args:
+ z: Latent video tensor (B, C, ...).
+ state: Full state tensor (B, T, D_s).
+ obs_state: Observed state embeddings (B, T, E).
+ action: Action embeddings (B, T, E).
+ ins: Instruction/text embeddings (B, L, E) after projector.
+ null_ins: Null/empty instruction embedding for CFG.
+ img: Image conditioning embedding (B, E_img) or batched equivalent.
+ sim_mode: If True, build simulated-mode batch; else decision-making.
+ pre_action: Optional previous action(s). (unused here; reserved)
+ logging: If True, may include extra returns for logs. (unused)
+
+ Returns:
+ Tuple of (z, state, [mode_batch]) where mode_batch is a single tensor combining the selected conditioning streams.
+ """
+
+ b, _, t, _, _ = z.shape
+ if self.decision_making_only:
+ mode_batch = torch.cat([obs_state, ins, img], dim=1)
+ return z, state, [mode_batch]
+
+ if not sim_mode:
+ zero_action = torch.zeros_like(action)
+ mode_batch = torch.cat([obs_state, zero_action, ins, img], dim=1)
+ else:
+ null_ins_batch = null_ins.repeat_interleave(repeats=ins.shape[0],
+ dim=0)
+ mode_batch = torch.cat([obs_state, action, null_ins_batch, img],
+ dim=1)
+ return z, state, [mode_batch]
+
+ def on_train_batch_end(self, outputs: Any, batch: Mapping[str, Any],
+ batch_idx: int) -> None:
+ """
+ Update EMA for action UNet after each train batch (if enabled).
+
+ Args:
+ batch: Current training batch mapping.
+ batch_idx: Batch index within the epoch.
+ """
+ if self.dp_use_ema:
+ if self.dp_ema_model is not None and not self.dp_ema_model_on_device:
+ device = self.model.device
+ self.dp_ema_model.to(device)
+ self.dp_ema_model_on_device = True
+ self.dp_ema.step(self.model.diffusion_model.action_unet)
+
+ def shared_step(self, batch: Mapping[str, Any], random_uncond: bool,
+ **kwargs: Any) -> tuple[Tensor, dict[str, Tensor]]:
+ """
+ Common train/val step for visual diffusion.
+
+ Args:
+ batch: Input batch mapping.
+ random_uncond: Whether to apply classifier-free guidance dropout.
+
+ Returns:
+ (loss, loss_dict) tuple.
+ """
+ x, x_action, x_state, c, fs = self.get_batch_input(
+ batch, random_uncond=random_uncond, return_fs=True)
+ kwargs.update({"fs": fs.long()})
+ loss, loss_dict = self(x, x_action, x_state, c, **kwargs)
+ return loss, loss_dict
+
+ def get_batch_input(self,
+ batch: Mapping[str, Any],
+ random_uncond: bool,
+ return_first_stage_outputs: bool = False,
+ return_original_cond: bool = False,
+ return_fs: bool = False,
+ return_cond_frame: bool = False,
+ return_original_input: bool = False,
+ logging: bool = False,
+ **kwargs: Any) -> list[Any]:
+ """
+ Prepare model inputs & conditioning from a raw training batch.
+
+ Args:
+ batch: Batch mapping with keys like image/state/action/obs/etc.
+ random_uncond: Apply stochastic condition dropout for CFG.
+ return_first_stage_outputs: If True, also return xrec (decoded z).
+ return_original_cond: If True, also return raw instruction text.
+ return_fs: If True, return fps or frame_stride per config.
+ return_cond_frame: If True, return conditioning frames (obs images).
+ return_original_input: If True, return original x (pre-encoding).
+ logging: If True, append sim_mode flag at the end.
+
+ Returns:
+ A list of inputs
+ """
+ # x: b c t h w
+ x = super().get_input(batch, self.first_stage_key)
+ b, _, t, _, _ = x.shape
+ # Get actions: b t d
+ action = super().get_input(batch, 'action')
+ # Get states: b t d
+ state = super().get_input(batch, 'next.state')
+ # Get observable states: b t d
+ obs_state = super().get_input(batch, 'observation.state')
+ # Get observable images: b c t h w
+ obs = super().get_input(batch, 'observation.image')
+
+ # Encode video frames x to z via a 2D encoder
+ z = self.encode_first_stage(x)
+
+ cond = {}
+ # Get instruction condition
+ cond_ins_input = batch[self.cond_stage_key]
+ if isinstance(cond_ins_input, dict) or isinstance(
+ cond_ins_input, list):
+ cond_ins_emb = self.get_learned_conditioning(cond_ins_input)
+ else:
+ cond_ins_emb = self.get_learned_conditioning(
+ cond_ins_input.to(self.device))
+ # To support classifier-free guidance, randomly drop out only text conditioning
+ # 5%, only image conditioning 5%, and both 5%.
+ if random_uncond:
+ random_num = torch.rand(b, device=x.device)
+ else:
+ random_num = torch.ones(b, device=x.device)
+ prompt_mask = rearrange(random_num < 2 * self.uncond_prob,
+ "n -> n 1 1")
+ null_prompt = self.get_learned_conditioning([""])
+ cond_ins_emb = torch.where(prompt_mask, null_prompt,
+ cond_ins_emb.detach())
+
+ # Get conditioning frames
+ cond_frame_index = 0
+ img = obs[:, :, -1, ...]
+ input_mask = 1 - rearrange(
+ (random_num >= self.uncond_prob).float() *
+ (random_num < 3 * self.uncond_prob).float(), "n -> n 1 1 1")
+
+ cond_img = input_mask * img
+ cond_img_emb = self.embedder(cond_img)
+ cond_img_emb = self.image_proj_model(cond_img_emb)
+
+ if self.model.conditioning_key == 'hybrid':
+ if self.interp_mode:
+ img_cat_cond = torch.zeros_like(z)
+ img_cat_cond[:, :, 0, :, :] = z[:, :, 0, :, :]
+ img_cat_cond[:, :, -1, :, :] = z[:, :, -1, :, :]
+ else:
+ img_cat_cond = z[:, :, cond_frame_index, :, :]
+ img_cat_cond = img_cat_cond.unsqueeze(2)
+ img_cat_cond = repeat(img_cat_cond,
+ 'b c t h w -> b c (repeat t) h w',
+ repeat=z.shape[2])
+ cond["c_concat"] = [img_cat_cond]
+
+ cond_action = self.action_projector(action)
+ cond_action_emb = self.agent_action_pos_emb + cond_action
+ # Get conditioning states
+ cond_state = self.state_projector(obs_state)
+ cond_state_emb = self.agent_state_pos_emb + cond_state
+
+ if self.decision_making_only:
+ is_sim_mode = False
+ else:
+ is_sim_mode = torch.rand(1) < 0.5
+ z, state, cond["c_crossattn"] = self._get_augmented_batch(
+ z,
+ state,
+ cond_state_emb,
+ cond_action_emb,
+ cond_ins_emb,
+ null_prompt,
+ cond_img_emb,
+ sim_mode=is_sim_mode,
+ logging=logging)
+
+ cond["c_crossattn_action"] = [
+ obs[:, :, -self.n_obs_steps_acting:],
+ state[:, -self.n_obs_steps_acting:], is_sim_mode,
+ batch['state_mask'], batch['action_mask']
+ ]
+
+ out = [z, action, state, cond]
+ if return_first_stage_outputs:
+ xrec = self.decode_first_stage(z)
+ out.extend([xrec])
+ if return_original_cond:
+ out.append(cond_ins_input)
+ if return_fs:
+ if self.fps_condition_type == 'fs':
+ fs = super().get_input(batch, 'frame_stride')
+ elif self.fps_condition_type == 'fps':
+ fs = super().get_input(batch, 'fps')
+ out.append(fs)
+ if return_cond_frame:
+ out.append(obs)
+ if return_original_input:
+ out.append(x)
+ if logging:
+ out.append(is_sim_mode)
+ return out
+
+ @torch.no_grad()
+ def log_images(self,
+ batch: Mapping[str, Any],
+ sample: bool = True,
+ ddim_steps: int = 50,
+ ddim_eta: float = 1.0,
+ plot_denoise_rows: bool = False,
+ unconditional_guidance_scale: float = 1.0,
+ mask: Tensor | None = None,
+ **kwargs) -> dict[str, Tensor]:
+ """
+ Log images for LatentVisualDiffusion
+
+ Args:
+ batch: Batch mapping used to form inputs/conditions.
+ sample: If True, also run sampling for visualization.
+ ddim_steps: Number of DDIM steps when using DDIM.
+ ddim_eta: DDIM eta parameter (stochasticity).
+ plot_denoise_rows: If True, include denoise progression grid.
+ unconditional_guidance_scale: Guidance scale for CFG sampling.
+ mask: Optional mask for sampling-time inpainting.
+
+ Returns:
+ Dict of visualization tensors (images/actions/states/progress).
+ """
+
+ ##### sampled_img_num: control sampled imgae for logging, larger value may cause OOM
+ sampled_img_num = 1
+ for key in batch.keys():
+ batch[key] = batch[key][:sampled_img_num]
+
+ ## TBD: currently, classifier_free_guidance sampling is only supported by DDIM
+ use_ddim = ddim_steps is not None
+ log = dict()
+
+ z, act, state, c, xrec, xc, fs, cond_x, is_sim_mode = self.get_batch_input(
+ batch,
+ random_uncond=False,
+ return_first_stage_outputs=True,
+ return_original_cond=True,
+ return_fs=True,
+ return_cond_frame=True,
+ logging=True)
+
+ kwargs['x_start'] = z
+
+ N = xrec.shape[0]
+ log["image_condition"] = cond_x
+ log["reconst"] = xrec
+ if is_sim_mode:
+ xc = ["NULL"]
+ xc_with_fs = []
+ for idx, content in enumerate(xc):
+ xc_with_fs.append(content + '_fs=' + str(fs[idx].item()))
+ log['instruction'] = xc
+ log["condition"] = xc_with_fs
+ kwargs.update({"fs": fs.long()})
+
+ if sample:
+ uc = None
+ with self.ema_scope("Plotting"):
+ samples, action_samples, state_samples, z_denoise_row = self.sample_log(
+ cond=c,
+ batch_size=N,
+ ddim=use_ddim,
+ ddim_steps=ddim_steps,
+ eta=ddim_eta,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=uc,
+ x0=z,
+ **kwargs)
+
+ x_samples = self.decode_first_stage(samples)
+ log["samples"] = x_samples
+
+ # Log actions
+ mb, mt, _ = batch['action_mask'].shape
+ act_mask = batch['action_mask'] == 1.0
+ action_target = act[act_mask].reshape(mb, mt, -1)
+ action_samples = action_samples[act_mask].reshape(mb, mt, -1)
+ log["action"] = torch.cat((action_target, action_samples), dim=0)
+
+ # Log states
+ mb, mt, _ = batch['state_mask'].shape
+ state_mask = batch['state_mask'] == 1.0
+ state_target = state[state_mask].reshape(mb, mt, -1)
+ state_samples = state_samples[state_mask].reshape(mb, mt, -1)
+ log["state"] = torch.cat((state_target, state_samples), dim=0)
+
+ if plot_denoise_rows:
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
+ log["denoise_row"] = denoise_grid
+
+ log["video_idx"] = batch["path"][0].split('/')[-1][:-4]
+ return log
+
+ def configure_optimizers(self):
+ """ configure_optimizers for LatentDiffusion """
+ lr = self.learning_rate
+
+ params = [
+ param for name, param in self.model.named_parameters()
+ if not name.startswith("diffusion_model.action_unet")
+ and not name.startswith("diffusion_model.state_unet")
+ ]
+ params_unet_head = list(
+ self.model.diffusion_model.action_unet.parameters()) + list(
+ self.model.diffusion_model.state_unet.parameters())
+
+ mainlogger.info(f"@Training [{len(params)}] Full Paramters.")
+
+ if self.cond_stage_trainable:
+ params_cond_stage = [
+ p for p in self.cond_stage_model.parameters()
+ if p.requires_grad == True
+ ]
+ mainlogger.info(
+ f"@Training [{len(params_cond_stage)}] Paramters for Cond_stage_model."
+ )
+ params.extend(params_cond_stage)
+
+ if self.image_proj_model_trainable:
+ mainlogger.info(
+ f"@Training [{len(list(self.image_proj_model.parameters()))}] Paramters for Image_proj_model."
+ )
+ params.extend(list(self.image_proj_model.parameters()))
+
+ if self.learn_logvar:
+ mainlogger.info('Diffusion model optimizing logvar')
+ if isinstance(params[0], dict):
+ params.append({"params": [self.logvar]})
+ else:
+ params.append(self.logvar)
+
+ params_group = [{
+ 'params': params,
+ 'lr': lr
+ }, {
+ 'params':
+ params_unet_head,
+ 'lr':
+ self.dp_optimizer_config['params']['lr'],
+ 'betas':
+ self.dp_optimizer_config['params']['betas'],
+ 'eps':
+ self.dp_optimizer_config['params']['eps'],
+ 'weight_decay':
+ self.dp_optimizer_config['params']['weight_decay']
+ }]
+ optimizer = torch.optim.AdamW(params_group, lr=lr)
+
+ if self.use_scheduler:
+
+ # mainlogger.info("Setting up scheduler...")
+ lr_scheduler = get_scheduler(
+ self.lr_scheduler,
+ optimizer=optimizer,
+ num_warmup_steps=self.lr_warmup_steps,
+ num_training_steps=(self.datasets_len * self.num_epochs) //
+ self.gradient_accumulate_every, # 50 is handcode
+ last_epoch=-1)
+
+ scheduler = SelectiveLRScheduler(
+ optimizer=optimizer,
+ base_scheduler=lr_scheduler,
+ group_indices=[1],
+ default_lr=[lr, self.dp_optimizer_config['params']['lr']])
+ return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}]
+
+ return [optimizer]
+
+
+class DiffusionWrapper(pl.LightningModule):
+ """Thin wrapper that routes inputs/conditions to the underlying diffusion model."""
+
+ def __init__(self, diff_model_config: OmegaConf,
+ conditioning_key: str | None) -> None:
+ """
+ Args:
+ diff_model_config: OmegaConf describing the inner diffusion model to instantiate.
+ conditioning_key: How conditioning is applied.
+ """
+ super().__init__()
+ self.diffusion_model = instantiate_from_config(diff_model_config)
+ self.conditioning_key = conditioning_key
+
+ def forward(
+ self,
+ x: Tensor,
+ x_action: Tensor | None,
+ x_state: Tensor | None,
+ t: Tensor,
+ c_concat: Sequence[Tensor] | None = None,
+ c_crossattn: Sequence[Tensor] | None = None,
+ c_crossattn_action: list[Any] | None = None,
+ c_adm: Tensor | None = None,
+ s: Tensor | None = None,
+ mask: Tensor | None = None,
+ **kwargs: Any,
+ ) -> Any:
+ """
+ Route input(s) and condition(s) into the inner diffusion model based on `conditioning_key`.
+
+ Args:
+ x: Primary input tensor (e.g., latent/image) at timestep `t`.
+ x_action: Action stream tensor (used by 'hybrid' variants).
+ x_state: State stream tensor (used by 'hybrid' variants).
+ t: Timestep indices (B,).
+ c_concat: List of tensors to be concatenated channel-wise with `x` (for 'concat' / 'hybrid' modes).
+ c_crossattn: List of context tensors concatenated along sequence/channel dim for cross-attention.
+ c_crossattn_action: Mixed list used by action/state heads.
+ c_adm: Class/ADM conditioning (e.g., labels) when required by '*adm*' modes.
+ s: Optional additional time-like / scalar conditioning (e.g., fps/frame-stride) for '*time*' modes.
+ mask: Optional spatial/temporal mask (e.g., inpainting) for '*mask*' modes.
+ **kwargs: Extra keyword arguments forwarded to the inner diffusion model.
+
+ Returns:
+ Output from the inner diffusion model (tensor or tuple, depending on the model).
+ """
+
+ if self.conditioning_key is None:
+ out = self.diffusion_model(x, t)
+ elif self.conditioning_key == 'concat':
+ xc = torch.cat([x] + c_concat, dim=1)
+ out = self.diffusion_model(xc, t, **kwargs)
+ elif self.conditioning_key == 'crossattn':
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(x, t, context=cc, **kwargs)
+ elif self.conditioning_key == 'hybrid':
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ cc_action = c_crossattn_action
+ out = self.diffusion_model(xc,
+ x_action,
+ x_state,
+ t,
+ context=cc,
+ context_action=cc_action,
+ **kwargs)
+ elif self.conditioning_key == 'resblockcond':
+ cc = c_crossattn[0]
+ out = self.diffusion_model(x, t, context=cc)
+ elif self.conditioning_key == 'adm':
+ cc = c_crossattn[0]
+ out = self.diffusion_model(x, t, y=cc)
+ elif self.conditioning_key == 'hybrid-adm':
+ assert c_adm is not None
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc, y=c_adm, **kwargs)
+ elif self.conditioning_key == 'hybrid-time':
+ assert s is not None
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc, s=s)
+ elif self.conditioning_key == 'concat-time-mask':
+ xc = torch.cat([x] + c_concat, dim=1)
+ out = self.diffusion_model(xc, t, context=None, s=s, mask=mask)
+ elif self.conditioning_key == 'concat-adm-mask':
+ if c_concat is not None:
+ xc = torch.cat([x] + c_concat, dim=1)
+ else:
+ xc = x
+ out = self.diffusion_model(xc, t, context=None, y=s, mask=mask)
+ elif self.conditioning_key == 'hybrid-adm-mask':
+ cc = torch.cat(c_crossattn, 1)
+ if c_concat is not None:
+ xc = torch.cat([x] + c_concat, dim=1)
+ else:
+ xc = x
+ out = self.diffusion_model(xc, t, context=cc, y=s, mask=mask)
+ elif self.conditioning_key == 'hybrid-time-adm':
+ assert c_adm is not None
+ xc = torch.cat([x] + c_concat, dim=1)
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(xc, t, context=cc, s=s, y=c_adm)
+ elif self.conditioning_key == 'crossattn-adm':
+ assert c_adm is not None
+ cc = torch.cat(c_crossattn, 1)
+ out = self.diffusion_model(x, t, context=cc, y=c_adm)
+ else:
+ raise NotImplementedError()
+
+ return out
diff --git a/src/unifolm_wma/models/diffusion_head/__init__.py b/src/unifolm_wma/models/diffusion_head/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/unifolm_wma/models/diffusion_head/base_nets.py b/src/unifolm_wma/models/diffusion_head/base_nets.py
new file mode 100644
index 0000000..8e96baf
--- /dev/null
+++ b/src/unifolm_wma/models/diffusion_head/base_nets.py
@@ -0,0 +1,217 @@
+"""
+Contains torch Modules that correspond to basic network building blocks, like
+MLP, RNN, and CNN backbones.
+"""
+
+import abc
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+class Module(torch.nn.Module):
+ """
+ Base class for networks. The only difference from torch.nn.Module is that it
+ requires implementing @output_shape.
+ """
+
+ @abc.abstractmethod
+ def output_shape(self, input_shape=None):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ raise NotImplementedError
+
+
+"""
+================================================
+Visual Backbone Networks
+================================================
+"""
+
+
+class ConvBase(Module):
+ """
+ Base class for ConvNets.
+ """
+
+ def __init__(self):
+ super(ConvBase, self).__init__()
+
+ # dirty hack - re-implement to pass the buck onto subclasses from ABC parent
+ def output_shape(self, input_shape):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ raise NotImplementedError
+
+ def forward(self, inputs):
+ x = self.nets(inputs)
+ if list(self.output_shape(list(inputs.shape)[1:])) != list(
+ x.shape)[1:]:
+ raise ValueError('Size mismatch: expect size %s, but got size %s' %
+ (str(self.output_shape(list(
+ inputs.shape)[1:])), str(list(x.shape)[1:])))
+ return x
+
+
+class SpatialSoftmax(ConvBase):
+ """
+ Spatial Softmax Layer.
+
+ Based on Deep Spatial Autoencoders for Visuomotor Learning by Finn et al.
+ https://rll.berkeley.edu/dsae/dsae.pdf
+ """
+
+ def __init__(
+ self,
+ input_shape,
+ num_kp=32,
+ temperature=1.,
+ learnable_temperature=False,
+ output_variance=False,
+ noise_std=0.0,
+ ):
+ """
+ Args:
+ input_shape (list): shape of the input feature (C, H, W)
+ num_kp (int): number of keypoints (None for not using spatialsoftmax)
+ temperature (float): temperature term for the softmax.
+ learnable_temperature (bool): whether to learn the temperature
+ output_variance (bool): treat attention as a distribution, and compute second-order statistics to return
+ noise_std (float): add random spatial noise to the predicted keypoints
+ """
+ super(SpatialSoftmax, self).__init__()
+ assert len(input_shape) == 3
+ self._in_c, self._in_h, self._in_w = input_shape # (C, H, W)
+
+ if num_kp is not None:
+ self.nets = torch.nn.Conv2d(self._in_c, num_kp, kernel_size=1)
+ self._num_kp = num_kp
+ else:
+ self.nets = None
+ self._num_kp = self._in_c
+ self.learnable_temperature = learnable_temperature
+ self.output_variance = output_variance
+ self.noise_std = noise_std
+
+ if self.learnable_temperature:
+ # temperature will be learned
+ temperature = torch.nn.Parameter(torch.ones(1) * temperature,
+ requires_grad=True)
+ self.register_parameter('temperature', temperature)
+ else:
+ # temperature held constant after initialization
+ temperature = torch.nn.Parameter(torch.ones(1) * temperature,
+ requires_grad=False)
+ self.register_buffer('temperature', temperature)
+
+ pos_x, pos_y = np.meshgrid(np.linspace(-1., 1., self._in_w),
+ np.linspace(-1., 1., self._in_h))
+ pos_x = torch.from_numpy(pos_x.reshape(1, self._in_h *
+ self._in_w)).float()
+ pos_y = torch.from_numpy(pos_y.reshape(1, self._in_h *
+ self._in_w)).float()
+ self.register_buffer('pos_x', pos_x)
+ self.register_buffer('pos_y', pos_y)
+
+ self.kps = None
+
+ def __repr__(self):
+ """Pretty print network."""
+ header = format(str(self.__class__.__name__))
+ return header + '(num_kp={}, temperature={}, noise={})'.format(
+ self._num_kp, self.temperature.item(), self.noise_std)
+
+ def output_shape(self, input_shape):
+ """
+ Function to compute output shape from inputs to this module.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+ assert (len(input_shape) == 3)
+ assert (input_shape[0] == self._in_c)
+ return [self._num_kp, 2]
+
+ def forward(self, feature):
+ """
+ Forward pass through spatial softmax layer. For each keypoint, a 2D spatial
+ probability distribution is created using a softmax, where the support is the
+ pixel locations. This distribution is used to compute the expected value of
+ the pixel location, which becomes a keypoint of dimension 2. K such keypoints
+ are created.
+
+ Returns:
+ out (torch.Tensor or tuple): mean keypoints of shape [B, K, 2], and possibly
+ keypoint variance of shape [B, K, 2, 2] corresponding to the covariance
+ under the 2D spatial softmax distribution
+ """
+ assert (feature.shape[1] == self._in_c)
+ assert (feature.shape[2] == self._in_h)
+ assert (feature.shape[3] == self._in_w)
+ if self.nets is not None:
+ feature = self.nets(feature)
+
+ # [B, K, H, W] -> [B * K, H * W] where K is number of keypoints
+ feature = feature.reshape(-1, self._in_h * self._in_w)
+ # 2d softmax normalization
+ attention = F.softmax(feature / self.temperature, dim=-1)
+ # [1, H * W] x [B * K, H * W] -> [B * K, 1] for spatial coordinate mean in x and y dimensions
+ expected_x = torch.sum(self.pos_x * attention, dim=1, keepdim=True)
+ expected_y = torch.sum(self.pos_y * attention, dim=1, keepdim=True)
+ # stack to [B * K, 2]
+ expected_xy = torch.cat([expected_x, expected_y], 1)
+ # reshape to [B, K, 2]
+ feature_keypoints = expected_xy.view(-1, self._num_kp, 2)
+
+ if self.training:
+ noise = torch.randn_like(feature_keypoints) * self.noise_std
+ feature_keypoints += noise
+
+ if self.output_variance:
+ # treat attention as a distribution, and compute second-order statistics to return
+ expected_xx = torch.sum(self.pos_x * self.pos_x * attention,
+ dim=1,
+ keepdim=True)
+ expected_yy = torch.sum(self.pos_y * self.pos_y * attention,
+ dim=1,
+ keepdim=True)
+ expected_xy = torch.sum(self.pos_x * self.pos_y * attention,
+ dim=1,
+ keepdim=True)
+ var_x = expected_xx - expected_x * expected_x
+ var_y = expected_yy - expected_y * expected_y
+ var_xy = expected_xy - expected_x * expected_y
+ # stack to [B * K, 4] and then reshape to [B, K, 2, 2] where last 2 dims are covariance matrix
+ feature_covar = torch.cat([var_x, var_xy, var_xy, var_y],
+ 1).reshape(-1, self._num_kp, 2, 2)
+ feature_keypoints = (feature_keypoints, feature_covar)
+
+ if isinstance(feature_keypoints, tuple):
+ self.kps = (feature_keypoints[0].detach(),
+ feature_keypoints[1].detach())
+ else:
+ self.kps = feature_keypoints.detach()
+ return feature_keypoints
diff --git a/src/unifolm_wma/models/diffusion_head/common/__init__.py b/src/unifolm_wma/models/diffusion_head/common/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/unifolm_wma/models/diffusion_head/common/lr_scheduler.py b/src/unifolm_wma/models/diffusion_head/common/lr_scheduler.py
new file mode 100644
index 0000000..6f4f60a
--- /dev/null
+++ b/src/unifolm_wma/models/diffusion_head/common/lr_scheduler.py
@@ -0,0 +1,83 @@
+from diffusers.optimization import (Union, SchedulerType, Optional, Optimizer,
+ TYPE_TO_SCHEDULER_FUNCTION)
+
+
+def get_scheduler(name: Union[str, SchedulerType],
+ optimizer: Optimizer,
+ num_warmup_steps: Optional[int] = None,
+ num_training_steps: Optional[int] = None,
+ **kwargs):
+ """
+ Added kwargs vs diffuser's original implementation
+
+ Unified API to get any scheduler from its name.
+
+ Args:
+ name (`str` or `SchedulerType`):
+ The name of the scheduler to use.
+ optimizer (`torch.optim.Optimizer`):
+ The optimizer that will be used during training.
+ num_warmup_steps (`int`, *optional*):
+ The number of warmup steps to do. This is not required by all schedulers (hence the argument being
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
+ num_training_steps (`int``, *optional*):
+ The number of training steps to do. This is not required by all schedulers (hence the argument being
+ optional), the function will raise an error if it's unset and the scheduler type requires it.
+ """
+ name = SchedulerType(name)
+ schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
+ if name == SchedulerType.CONSTANT:
+ return schedule_func(optimizer, **kwargs)
+
+ # All other schedulers require `num_warmup_steps`
+ if num_warmup_steps is None:
+ raise ValueError(
+ f"{name} requires `num_warmup_steps`, please provide that argument."
+ )
+
+ if name == SchedulerType.CONSTANT_WITH_WARMUP:
+ return schedule_func(optimizer,
+ num_warmup_steps=num_warmup_steps,
+ **kwargs)
+
+ # All other schedulers require `num_training_steps`
+ if num_training_steps is None:
+ raise ValueError(
+ f"{name} requires `num_training_steps`, please provide that argument."
+ )
+
+ return schedule_func(optimizer,
+ num_warmup_steps=num_warmup_steps,
+ num_training_steps=num_training_steps,
+ **kwargs)
+
+
+import torch
+from torch.optim.lr_scheduler import _LRScheduler
+import pytorch_lightning as pl
+from diffusers.optimization import TYPE_TO_SCHEDULER_FUNCTION, SchedulerType
+
+
+class SelectiveLRScheduler(_LRScheduler):
+
+ def __init__(self,
+ optimizer,
+ base_scheduler,
+ group_indices,
+ default_lr=[1e-5, 1e-4],
+ last_epoch=-1):
+ self.base_scheduler = base_scheduler
+ self.group_indices = group_indices # Indices of parameter groups to update
+ self.default_lr = default_lr
+ super().__init__(optimizer, last_epoch)
+
+ def step(self, epoch=None):
+ self.base_scheduler.step()
+ base_lrs = self.base_scheduler.get_last_lr()
+
+ for idx, group in enumerate(self.optimizer.param_groups):
+ if idx in self.group_indices:
+ group['lr'] = base_lrs[idx]
+ else:
+ # Reset the learning rate to its initial value
+ group['lr'] = self.default_lr[idx]
diff --git a/src/unifolm_wma/models/diffusion_head/common/module_attr_mixin.py b/src/unifolm_wma/models/diffusion_head/common/module_attr_mixin.py
new file mode 100644
index 0000000..e33efe2
--- /dev/null
+++ b/src/unifolm_wma/models/diffusion_head/common/module_attr_mixin.py
@@ -0,0 +1,16 @@
+import torch.nn as nn
+
+
+class ModuleAttrMixin(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self._dummy_variable = nn.Parameter()
+
+ @property
+ def device(self):
+ return next(iter(self.parameters())).device
+
+ @property
+ def dtype(self):
+ return next(iter(self.parameters())).dtype
diff --git a/src/unifolm_wma/models/diffusion_head/common/pytorch_util.py b/src/unifolm_wma/models/diffusion_head/common/pytorch_util.py
new file mode 100644
index 0000000..a8913e7
--- /dev/null
+++ b/src/unifolm_wma/models/diffusion_head/common/pytorch_util.py
@@ -0,0 +1,91 @@
+import collections
+import torch
+import torch.nn as nn
+from typing import Dict, Callable, List
+
+
+def dict_apply(
+ x: Dict[str, torch.Tensor],
+ func: Callable[[torch.Tensor],
+ torch.Tensor]) -> Dict[str, torch.Tensor]:
+ result = dict()
+ for key, value in x.items():
+ if isinstance(value, dict):
+ result[key] = dict_apply(value, func)
+ else:
+ result[key] = func(value)
+ return result
+
+
+def pad_remaining_dims(x, target):
+ assert x.shape == target.shape[:len(x.shape)]
+ return x.reshape(x.shape + (1, ) * (len(target.shape) - len(x.shape)))
+
+
+def dict_apply_split(
+ x: Dict[str, torch.Tensor], split_func: Callable[[torch.Tensor],
+ Dict[str, torch.Tensor]]
+) -> Dict[str, torch.Tensor]:
+ results = collections.defaultdict(dict)
+ for key, value in x.items():
+ result = split_func(value)
+ for k, v in result.items():
+ results[k][key] = v
+ return results
+
+
+def dict_apply_reduce(
+ x: List[Dict[str,
+ torch.Tensor]], reduce_func: Callable[[List[torch.Tensor]],
+ torch.Tensor]
+) -> Dict[str, torch.Tensor]:
+ result = dict()
+ for key in x[0].keys():
+ result[key] = reduce_func([x_[key] for x_ in x])
+ return result
+
+
+def replace_submodules(root_module: nn.Module, predicate: Callable[[nn.Module],
+ bool],
+ func: Callable[[nn.Module], nn.Module]) -> nn.Module:
+ """
+ predicate: Return true if the module is to be replaced.
+ func: Return new module to use.
+ """
+ if predicate(root_module):
+ return func(root_module)
+
+ bn_list = [
+ k.split('.')
+ for k, m in root_module.named_modules(remove_duplicate=True)
+ if predicate(m)
+ ]
+ for *parent, k in bn_list:
+ parent_module = root_module
+ if len(parent) > 0:
+ parent_module = root_module.get_submodule('.'.join(parent))
+ if isinstance(parent_module, nn.Sequential):
+ src_module = parent_module[int(k)]
+ else:
+ src_module = getattr(parent_module, k)
+ tgt_module = func(src_module)
+ if isinstance(parent_module, nn.Sequential):
+ parent_module[int(k)] = tgt_module
+ else:
+ setattr(parent_module, k, tgt_module)
+ # verify that all BN are replaced
+ bn_list = [
+ k.split('.')
+ for k, m in root_module.named_modules(remove_duplicate=True)
+ if predicate(m)
+ ]
+ assert len(bn_list) == 0
+ return root_module
+
+
+def optimizer_to(optimizer, device):
+ for state in optimizer.state.values():
+ for k, v in state.items():
+ if isinstance(v, torch.Tensor):
+ state[k] = v.to(device=device)
+ return optimizer
diff --git a/src/unifolm_wma/models/diffusion_head/common/tensor_util.py b/src/unifolm_wma/models/diffusion_head/common/tensor_util.py
new file mode 100644
index 0000000..98e962a
--- /dev/null
+++ b/src/unifolm_wma/models/diffusion_head/common/tensor_util.py
@@ -0,0 +1,960 @@
+"""
+A collection of utilities for working with nested tensor structures consisting
+of numpy arrays and torch tensors.
+"""
+import collections
+import numpy as np
+import torch
+
+
+def recursive_dict_list_tuple_apply(x, type_func_dict):
+ """
+ Recursively apply functions to a nested dictionary or list or tuple, given a dictionary of
+ {data_type: function_to_apply}.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ type_func_dict (dict): a mapping from data types to the functions to be
+ applied for each data type.
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ assert (list not in type_func_dict)
+ assert (tuple not in type_func_dict)
+ assert (dict not in type_func_dict)
+
+ if isinstance(x, (dict, collections.OrderedDict)):
+ new_x = collections.OrderedDict() if isinstance(
+ x, collections.OrderedDict) else dict()
+ for k, v in x.items():
+ new_x[k] = recursive_dict_list_tuple_apply(v, type_func_dict)
+ return new_x
+ elif isinstance(x, (list, tuple)):
+ ret = [recursive_dict_list_tuple_apply(v, type_func_dict) for v in x]
+ if isinstance(x, tuple):
+ ret = tuple(ret)
+ return ret
+ else:
+ for t, f in type_func_dict.items():
+ if isinstance(x, t):
+ return f(x)
+ else:
+ raise NotImplementedError('Cannot handle data type %s' %
+ str(type(x)))
+
+
+def map_tensor(x, func):
+ """
+ Apply function @func to torch.Tensor objects in a nested dictionary or
+ list or tuple.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ func (function): function to apply to each tensor
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(x, {
+ torch.Tensor: func,
+ type(None): lambda x: x,
+ })
+
+
+def map_ndarray(x, func):
+ """
+ Apply function @func to np.ndarray objects in a nested dictionary or
+ list or tuple.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ func (function): function to apply to each array
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(x, {
+ np.ndarray: func,
+ type(None): lambda x: x,
+ })
+
+
+def map_tensor_ndarray(x, tensor_func, ndarray_func):
+ """
+ Apply function @tensor_func to torch.Tensor objects and @ndarray_func to
+ np.ndarray objects in a nested dictionary or list or tuple.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ tensor_func (function): function to apply to each tensor
+ ndarray_Func (function): function to apply to each array
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x, {
+ torch.Tensor: tensor_func,
+ np.ndarray: ndarray_func,
+ type(None): lambda x: x,
+ })
+
+
+def clone(x):
+ """
+ Clones all torch tensors and numpy arrays in nested dictionary or list
+ or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x, {
+ torch.Tensor: lambda x: x.clone(),
+ np.ndarray: lambda x: x.copy(),
+ type(None): lambda x: x,
+ })
+
+
+def detach(x):
+ """
+ Detaches all torch tensors in nested dictionary or list
+ or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(x, {
+ torch.Tensor: lambda x: x.detach(),
+ })
+
+
+def to_batch(x):
+ """
+ Introduces a leading batch dimension of 1 for all torch tensors and numpy
+ arrays in nested dictionary or list or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x, {
+ torch.Tensor: lambda x: x[None, ...],
+ np.ndarray: lambda x: x[None, ...],
+ type(None): lambda x: x,
+ })
+
+
+def to_sequence(x):
+ """
+ Introduces a time dimension of 1 at dimension 1 for all torch tensors and numpy
+ arrays in nested dictionary or list or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x, {
+ torch.Tensor: lambda x: x[:, None, ...],
+ np.ndarray: lambda x: x[:, None, ...],
+ type(None): lambda x: x,
+ })
+
+
+def index_at_time(x, ind):
+ """
+ Indexes all torch tensors and numpy arrays in dimension 1 with index @ind in
+ nested dictionary or list or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ ind (int): index
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x, {
+ torch.Tensor: lambda x: x[:, ind, ...],
+ np.ndarray: lambda x: x[:, ind, ...],
+ type(None): lambda x: x,
+ })
+
+
+def unsqueeze(x, dim):
+ """
+ Adds dimension of size 1 at dimension @dim in all torch tensors and numpy arrays
+ in nested dictionary or list or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ dim (int): dimension
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x, {
+ torch.Tensor: lambda x: x.unsqueeze(dim=dim),
+ np.ndarray: lambda x: np.expand_dims(x, axis=dim),
+ type(None): lambda x: x,
+ })
+
+
+def contiguous(x):
+ """
+ Makes all torch tensors and numpy arrays contiguous in nested dictionary or
+ list or tuple and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x, {
+ torch.Tensor: lambda x: x.contiguous(),
+ np.ndarray: lambda x: np.ascontiguousarray(x),
+ type(None): lambda x: x,
+ })
+
+
+def to_device(x, device):
+ """
+ Sends all torch tensors in nested dictionary or list or tuple to device
+ @device, and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ device (torch.Device): device to send tensors to
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x, {
+ torch.Tensor: lambda x, d=device: x.to(d),
+ type(None): lambda x: x,
+ })
+
+
+def to_tensor(x):
+ """
+ Converts all numpy arrays in nested dictionary or list or tuple to
+ torch tensors (and leaves existing torch Tensors as-is), and returns
+ a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x, {
+ torch.Tensor: lambda x: x,
+ np.ndarray: lambda x: torch.from_numpy(x),
+ type(None): lambda x: x,
+ })
+
+
+def to_numpy(x):
+ """
+ Converts all torch tensors in nested dictionary or list or tuple to
+ numpy (and leaves existing numpy arrays as-is), and returns
+ a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+
+ def f(tensor):
+ if tensor.is_cuda:
+ return tensor.detach().cpu().numpy()
+ else:
+ return tensor.detach().numpy()
+
+ return recursive_dict_list_tuple_apply(x, {
+ torch.Tensor: f,
+ np.ndarray: lambda x: x,
+ type(None): lambda x: x,
+ })
+
+
+def to_list(x):
+ """
+ Converts all torch tensors and numpy arrays in nested dictionary or list
+ or tuple to a list, and returns a new nested structure. Useful for
+ json encoding.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+
+ def f(tensor):
+ if tensor.is_cuda:
+ return tensor.detach().cpu().numpy().tolist()
+ else:
+ return tensor.detach().numpy().tolist()
+
+ return recursive_dict_list_tuple_apply(
+ x, {
+ torch.Tensor: f,
+ np.ndarray: lambda x: x.tolist(),
+ type(None): lambda x: x,
+ })
+
+
+def to_float(x):
+ """
+ Converts all torch tensors and numpy arrays in nested dictionary or list
+ or tuple to float type entries, and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x, {
+ torch.Tensor: lambda x: x.float(),
+ np.ndarray: lambda x: x.astype(np.float32),
+ type(None): lambda x: x,
+ })
+
+
+def to_uint8(x):
+ """
+ Converts all torch tensors and numpy arrays in nested dictionary or list
+ or tuple to uint8 type entries, and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x, {
+ torch.Tensor: lambda x: x.byte(),
+ np.ndarray: lambda x: x.astype(np.uint8),
+ type(None): lambda x: x,
+ })
+
+
+def to_torch(x, device):
+ """
+ Converts all numpy arrays and torch tensors in nested dictionary or list or tuple to
+ torch tensors on device @device and returns a new nested structure.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ device (torch.Device): device to send tensors to
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return to_device(to_float(to_tensor(x)), device)
+
+
+def to_one_hot_single(tensor, num_class):
+ """
+ Convert tensor to one-hot representation, assuming a certain number of total class labels.
+
+ Args:
+ tensor (torch.Tensor): tensor containing integer labels
+ num_class (int): number of classes
+
+ Returns:
+ x (torch.Tensor): tensor containing one-hot representation of labels
+ """
+ x = torch.zeros(tensor.size() + (num_class, )).to(tensor.device)
+ x.scatter_(-1, tensor.unsqueeze(-1), 1)
+ return x
+
+
+def to_one_hot(tensor, num_class):
+ """
+ Convert all tensors in nested dictionary or list or tuple to one-hot representation,
+ assuming a certain number of total class labels.
+
+ Args:
+ tensor (dict or list or tuple): a possibly nested dictionary or list or tuple
+ num_class (int): number of classes
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return map_tensor(tensor,
+ func=lambda x, nc=num_class: to_one_hot_single(x, nc))
+
+
+def flatten_single(x, begin_axis=1):
+ """
+ Flatten a tensor in all dimensions from @begin_axis onwards.
+
+ Args:
+ x (torch.Tensor): tensor to flatten
+ begin_axis (int): which axis to flatten from
+
+ Returns:
+ y (torch.Tensor): flattened tensor
+ """
+ fixed_size = x.size()[:begin_axis]
+ _s = list(fixed_size) + [-1]
+ return x.reshape(*_s)
+
+
+def flatten(x, begin_axis=1):
+ """
+ Flatten all tensors in nested dictionary or list or tuple, from @begin_axis onwards.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ begin_axis (int): which axis to flatten from
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(x, {
+ torch.Tensor:
+ lambda x, b=begin_axis: flatten_single(x, begin_axis=b),
+ })
+
+
+def reshape_dimensions_single(x, begin_axis, end_axis, target_dims):
+ """
+ Reshape selected dimensions in a tensor to a target dimension.
+
+ Args:
+ x (torch.Tensor): tensor to reshape
+ begin_axis (int): begin dimension
+ end_axis (int): end dimension
+ target_dims (tuple or list): target shape for the range of dimensions
+ (@begin_axis, @end_axis)
+
+ Returns:
+ y (torch.Tensor): reshaped tensor
+ """
+ assert (begin_axis <= end_axis)
+ assert (begin_axis >= 0)
+ assert (end_axis < len(x.shape))
+ assert (isinstance(target_dims, (tuple, list)))
+ s = x.shape
+ final_s = []
+ for i in range(len(s)):
+ if i == begin_axis:
+ final_s.extend(target_dims)
+ elif i < begin_axis or i > end_axis:
+ final_s.append(s[i])
+ return x.reshape(*final_s)
+
+
+def reshape_dimensions(x, begin_axis, end_axis, target_dims):
+ """
+ Reshape selected dimensions for all tensors in nested dictionary or list or tuple
+ to a target dimension.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ begin_axis (int): begin dimension
+ end_axis (int): end dimension
+ target_dims (tuple or list): target shape for the range of dimensions
+ (@begin_axis, @end_axis)
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x, {
+ torch.Tensor:
+ lambda x, b=begin_axis, e=end_axis, t=target_dims:
+ reshape_dimensions_single(
+ x, begin_axis=b, end_axis=e, target_dims=t),
+ np.ndarray:
+ lambda x, b=begin_axis, e=end_axis, t=target_dims:
+ reshape_dimensions_single(
+ x, begin_axis=b, end_axis=e, target_dims=t),
+ type(None):
+ lambda x: x,
+ })
+
+
+def join_dimensions(x, begin_axis, end_axis):
+ """
+ Joins all dimensions between dimensions (@begin_axis, @end_axis) into a flat dimension, for
+ all tensors in nested dictionary or list or tuple.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ begin_axis (int): begin dimension
+ end_axis (int): end dimension
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return recursive_dict_list_tuple_apply(
+ x, {
+ torch.Tensor:
+ lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
+ x, begin_axis=b, end_axis=e, target_dims=[-1]),
+ np.ndarray:
+ lambda x, b=begin_axis, e=end_axis: reshape_dimensions_single(
+ x, begin_axis=b, end_axis=e, target_dims=[-1]),
+ type(None):
+ lambda x: x,
+ })
+
+
+def expand_at_single(x, size, dim):
+ """
+ Expand a tensor at a single dimension @dim by @size
+
+ Args:
+ x (torch.Tensor): input tensor
+ size (int): size to expand
+ dim (int): dimension to expand
+
+ Returns:
+ y (torch.Tensor): expanded tensor
+ """
+ assert dim < x.ndimension()
+ assert x.shape[dim] == 1
+ expand_dims = [-1] * x.ndimension()
+ expand_dims[dim] = size
+ return x.expand(*expand_dims)
+
+
+def expand_at(x, size, dim):
+ """
+ Expand all tensors in nested dictionary or list or tuple at a single
+ dimension @dim by @size.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ size (int): size to expand
+ dim (int): dimension to expand
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return map_tensor(x, lambda t, s=size, d=dim: expand_at_single(t, s, d))
+
+
+def unsqueeze_expand_at(x, size, dim):
+ """
+ Unsqueeze and expand a tensor at a dimension @dim by @size.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ size (int): size to expand
+ dim (int): dimension to unsqueeze and expand
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ x = unsqueeze(x, dim)
+ return expand_at(x, size, dim)
+
+
+def repeat_by_expand_at(x, repeats, dim):
+ """
+ Repeat a dimension by combining expand and reshape operations.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ repeats (int): number of times to repeat the target dimension
+ dim (int): dimension to repeat on
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ x = unsqueeze_expand_at(x, repeats, dim + 1)
+ return join_dimensions(x, dim, dim + 1)
+
+
+def named_reduce_single(x, reduction, dim):
+ """
+ Reduce tensor at a dimension by named reduction functions.
+
+ Args:
+ x (torch.Tensor): tensor to be reduced
+ reduction (str): one of ["sum", "max", "mean", "flatten"]
+ dim (int): dimension to be reduced (or begin axis for flatten)
+
+ Returns:
+ y (torch.Tensor): reduced tensor
+ """
+ assert x.ndimension() > dim
+ assert reduction in ["sum", "max", "mean", "flatten"]
+ if reduction == "flatten":
+ x = flatten(x, begin_axis=dim)
+ elif reduction == "max":
+ x = torch.max(x, dim=dim)[0] # [B, D]
+ elif reduction == "sum":
+ x = torch.sum(x, dim=dim)
+ else:
+ x = torch.mean(x, dim=dim)
+ return x
+
+
+def named_reduce(x, reduction, dim):
+ """
+ Reduces all tensors in nested dictionary or list or tuple at a dimension
+ using a named reduction function.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ reduction (str): one of ["sum", "max", "mean", "flatten"]
+ dim (int): dimension to be reduced (or begin axis for flatten)
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return map_tensor(
+ x, func=lambda t, r=reduction, d=dim: named_reduce_single(t, r, d))
+
+
+def gather_along_dim_with_dim_single(x, target_dim, source_dim, indices):
+ """
+ This function indexes out a target dimension of a tensor in a structured way,
+ by allowing a different value to be selected for each member of a flat index
+ tensor (@indices) corresponding to a source dimension. This can be interpreted
+ as moving along the source dimension, using the corresponding index value
+ in @indices to select values for all other dimensions outside of the
+ source and target dimensions. A common use case is to gather values
+ in target dimension 1 for each batch member (target dimension 0).
+
+ Args:
+ x (torch.Tensor): tensor to gather values for
+ target_dim (int): dimension to gather values along
+ source_dim (int): dimension to hold constant and use for gathering values
+ from the other dimensions
+ indices (torch.Tensor): flat index tensor with same shape as tensor @x along
+ @source_dim
+
+ Returns:
+ y (torch.Tensor): gathered tensor, with dimension @target_dim indexed out
+ """
+ assert len(indices.shape) == 1
+ assert x.shape[source_dim] == indices.shape[0]
+
+ # unsqueeze in all dimensions except the source dimension
+ new_shape = [1] * x.ndimension()
+ new_shape[source_dim] = -1
+ indices = indices.reshape(*new_shape)
+
+ # repeat in all dimensions - but preserve shape of source dimension,
+ # and make sure target_dimension has singleton dimension
+ expand_shape = list(x.shape)
+ expand_shape[source_dim] = -1
+ expand_shape[target_dim] = 1
+ indices = indices.expand(*expand_shape)
+
+ out = x.gather(dim=target_dim, index=indices)
+ return out.squeeze(target_dim)
+
+
+def gather_along_dim_with_dim(x, target_dim, source_dim, indices):
+ """
+ Apply @gather_along_dim_with_dim_single to all tensors in a nested
+ dictionary or list or tuple.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ target_dim (int): dimension to gather values along
+ source_dim (int): dimension to hold constant and use for gathering values
+ from the other dimensions
+ indices (torch.Tensor): flat index tensor with same shape as tensor @x along
+ @source_dim
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple
+ """
+ return map_tensor(x,
+ lambda y, t=target_dim, s=source_dim, i=indices:
+ gather_along_dim_with_dim_single(y, t, s, i))
+
+
+def gather_sequence_single(seq, indices):
+ """
+ Given a tensor with leading dimensions [B, T, ...], gather an element from each sequence in
+ the batch given an index for each sequence.
+
+ Args:
+ seq (torch.Tensor): tensor with leading dimensions [B, T, ...]
+ indices (torch.Tensor): tensor indices of shape [B]
+
+ Return:
+ y (torch.Tensor): indexed tensor of shape [B, ....]
+ """
+ return gather_along_dim_with_dim_single(seq,
+ target_dim=1,
+ source_dim=0,
+ indices=indices)
+
+
+def gather_sequence(seq, indices):
+ """
+ Given a nested dictionary or list or tuple, gathers an element from each sequence of the batch
+ for tensors with leading dimensions [B, T, ...].
+
+ Args:
+ seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
+ of leading dimensions [B, T, ...]
+ indices (torch.Tensor): tensor indices of shape [B]
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple with tensors of shape [B, ...]
+ """
+ return gather_along_dim_with_dim(seq,
+ target_dim=1,
+ source_dim=0,
+ indices=indices)
+
+
+def pad_sequence_single(seq,
+ padding,
+ batched=False,
+ pad_same=True,
+ pad_values=None):
+ """
+ Pad input tensor or array @seq in the time dimension (dimension 1).
+
+ Args:
+ seq (np.ndarray or torch.Tensor): sequence to be padded
+ padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
+ batched (bool): if sequence has the batch dimension
+ pad_same (bool): if pad by duplicating
+ pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
+
+ Returns:
+ padded sequence (np.ndarray or torch.Tensor)
+ """
+ assert isinstance(seq, (np.ndarray, torch.Tensor))
+ assert pad_same or pad_values is not None
+ if pad_values is not None:
+ assert isinstance(pad_values, float)
+ repeat_func = np.repeat if isinstance(
+ seq, np.ndarray) else torch.repeat_interleave
+ concat_func = np.concatenate if isinstance(seq, np.ndarray) else torch.cat
+ ones_like_func = np.ones_like if isinstance(
+ seq, np.ndarray) else torch.ones_like
+ seq_dim = 1 if batched else 0
+
+ begin_pad = []
+ end_pad = []
+
+ if padding[0] > 0:
+ pad = seq[[0]] if pad_same else ones_like_func(seq[[0]]) * pad_values
+ begin_pad.append(repeat_func(pad, padding[0], seq_dim))
+ if padding[1] > 0:
+ pad = seq[[-1]] if pad_same else ones_like_func(seq[[-1]]) * pad_values
+ end_pad.append(repeat_func(pad, padding[1], seq_dim))
+
+ return concat_func(begin_pad + [seq] + end_pad, seq_dim)
+
+
+def pad_sequence(seq, padding, batched=False, pad_same=True, pad_values=None):
+ """
+ Pad a nested dictionary or list or tuple of sequence tensors in the time dimension (dimension 1).
+
+ Args:
+ seq (dict or list or tuple): a possibly nested dictionary or list or tuple with tensors
+ of leading dimensions [B, T, ...]
+ padding (tuple): begin and end padding, e.g. [1, 1] pads both begin and end of the sequence by 1
+ batched (bool): if sequence has the batch dimension
+ pad_same (bool): if pad by duplicating
+ pad_values (scalar or (ndarray, Tensor)): values to be padded if not pad_same
+
+ Returns:
+ padded sequence (dict or list or tuple)
+ """
+ return recursive_dict_list_tuple_apply(
+ seq, {
+ torch.Tensor:
+ lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values:
+ pad_sequence_single(x, p, b, ps, pv),
+ np.ndarray:
+ lambda x, p=padding, b=batched, ps=pad_same, pv=pad_values:
+ pad_sequence_single(x, p, b, ps, pv),
+ type(None):
+ lambda x: x,
+ })
+
+
+def assert_size_at_dim_single(x, size, dim, msg):
+ """
+ Ensure that array or tensor @x has size @size in dim @dim.
+
+ Args:
+ x (np.ndarray or torch.Tensor): input array or tensor
+ size (int): size that tensors should have at @dim
+ dim (int): dimension to check
+ msg (str): text to display if assertion fails
+ """
+ assert x.shape[dim] == size, msg
+
+
+def assert_size_at_dim(x, size, dim, msg):
+ """
+ Ensure that arrays and tensors in nested dictionary or list or tuple have
+ size @size in dim @dim.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+ size (int): size that tensors should have at @dim
+ dim (int): dimension to check
+ """
+ map_tensor(
+ x,
+ lambda t, s=size, d=dim, m=msg: assert_size_at_dim_single(t, s, d, m))
+
+
+def get_shape(x):
+ """
+ Get all shapes of arrays and tensors in nested dictionary or list or tuple.
+
+ Args:
+ x (dict or list or tuple): a possibly nested dictionary or list or tuple
+
+ Returns:
+ y (dict or list or tuple): new nested dict-list-tuple that contains each array or
+ tensor's shape
+ """
+ return recursive_dict_list_tuple_apply(
+ x, {
+ torch.Tensor: lambda x: x.shape,
+ np.ndarray: lambda x: x.shape,
+ type(None): lambda x: x,
+ })
+
+
+def list_of_flat_dict_to_dict_of_list(list_of_dict):
+ """
+ Helper function to go from a list of flat dictionaries to a dictionary of lists.
+ By "flat" we mean that none of the values are dictionaries, but are numpy arrays,
+ floats, etc.
+
+ Args:
+ list_of_dict (list): list of flat dictionaries
+
+ Returns:
+ dict_of_list (dict): dictionary of lists
+ """
+ assert isinstance(list_of_dict, list)
+ dic = collections.OrderedDict()
+ for i in range(len(list_of_dict)):
+ for k in list_of_dict[i]:
+ if k not in dic:
+ dic[k] = []
+ dic[k].append(list_of_dict[i][k])
+ return dic
+
+
+def flatten_nested_dict_list(d, parent_key='', sep='_', item_key=''):
+ """
+ Flatten a nested dict or list to a list.
+
+ For example, given a dict
+ {
+ a: 1
+ b: {
+ c: 2
+ }
+ c: 3
+ }
+
+ the function would return [(a, 1), (b_c, 2), (c, 3)]
+
+ Args:
+ d (dict, list): a nested dict or list to be flattened
+ parent_key (str): recursion helper
+ sep (str): separator for nesting keys
+ item_key (str): recursion helper
+ Returns:
+ list: a list of (key, value) tuples
+ """
+ items = []
+ if isinstance(d, (tuple, list)):
+ new_key = parent_key + sep + item_key if len(
+ parent_key) > 0 else item_key
+ for i, v in enumerate(d):
+ items.extend(
+ flatten_nested_dict_list(v, new_key, sep=sep, item_key=str(i)))
+ return items
+ elif isinstance(d, dict):
+ new_key = parent_key + sep + item_key if len(
+ parent_key) > 0 else item_key
+ for k, v in d.items():
+ assert isinstance(k, str)
+ items.extend(
+ flatten_nested_dict_list(v, new_key, sep=sep, item_key=k))
+ return items
+ else:
+ new_key = parent_key + sep + item_key if len(
+ parent_key) > 0 else item_key
+ return [(new_key, d)]
+
+
+def time_distributed(inputs,
+ op,
+ activation=None,
+ inputs_as_kwargs=False,
+ inputs_as_args=False,
+ **kwargs):
+ """
+ Apply function @op to all tensors in nested dictionary or list or tuple @inputs in both the
+ batch (B) and time (T) dimension, where the tensors are expected to have shape [B, T, ...].
+ Will do this by reshaping tensors to [B * T, ...], passing through the op, and then reshaping
+ outputs to [B, T, ...].
+
+ Args:
+ inputs (list or tuple or dict): a possibly nested dictionary or list or tuple with tensors
+ of leading dimensions [B, T, ...]
+ op: a layer op that accepts inputs
+ activation: activation to apply at the output
+ inputs_as_kwargs (bool): whether to feed input as a kwargs dict to the op
+ inputs_as_args (bool) whether to feed input as a args list to the op
+ kwargs (dict): other kwargs to supply to the op
+
+ Returns:
+ outputs (dict or list or tuple): new nested dict-list-tuple with tensors of leading dimension [B, T].
+ """
+ batch_size, seq_len = flatten_nested_dict_list(inputs)[0][1].shape[:2]
+ inputs = join_dimensions(inputs, 0, 1)
+ if inputs_as_kwargs:
+ outputs = op(**inputs, **kwargs)
+ elif inputs_as_args:
+ outputs = op(*inputs, **kwargs)
+ else:
+ outputs = op(inputs, **kwargs)
+
+ if activation is not None:
+ outputs = map_tensor(outputs, activation)
+ outputs = reshape_dimensions(outputs,
+ begin_axis=0,
+ end_axis=0,
+ target_dims=(batch_size, seq_len))
+ return outputs
diff --git a/src/unifolm_wma/models/diffusion_head/conditional_unet1d.py b/src/unifolm_wma/models/diffusion_head/conditional_unet1d.py
new file mode 100644
index 0000000..12378a1
--- /dev/null
+++ b/src/unifolm_wma/models/diffusion_head/conditional_unet1d.py
@@ -0,0 +1,701 @@
+import logging
+import torch
+import torch.nn as nn
+import einops
+
+from einops import rearrange, repeat
+from typing import Union
+
+from unifolm_wma.models.diffusion_head.conv1d_components import (
+ Downsample1d, Upsample1d, Conv1dBlock)
+from unifolm_wma.models.diffusion_head.positional_embedding import SinusoidalPosEmb
+from unifolm_wma.models.diffusion_head.base_nets import SpatialSoftmax
+
+from unifolm_wma.utils.basics import zero_module
+from unifolm_wma.utils.common import (
+ checkpoint,
+ exists,
+ default,
+)
+from unifolm_wma.utils.utils import instantiate_from_config
+
+logger = logging.getLogger(__name__)
+
+
+class GEGLU(nn.Module):
+
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(nn.Linear(
+ dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out))
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class CrossAttention(nn.Module):
+
+ def __init__(self,
+ query_dim,
+ context_dim=None,
+ heads=8,
+ dim_head=64,
+ dropout=0.,
+ relative_position=False):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+ self.dim_head = dim_head
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout))
+
+ def efficient_forward(self, x, context=None):
+ spatial_self_attn = (context is None)
+ k_ip, v_ip, out_ip = None, None, None
+
+ q = self.to_q(x)
+ if spatial_self_attn:
+ context = default(context, x)
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ b, _, _ = q.shape
+ q, k, v = map(
+ lambda t: t.unsqueeze(3).reshape(b, t.shape[
+ 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
+ b * self.heads, t.shape[1], self.dim_head).contiguous(),
+ (q, k, v),
+ )
+ # actually compute the attention, what we cannot get enough of
+ out = xformers.ops.memory_efficient_attention(q,
+ k,
+ v,
+ attn_bias=None,
+ op=None)
+ out = (out.unsqueeze(0).reshape(
+ b, self.heads, out.shape[1],
+ self.dim_head).permute(0, 2, 1,
+ 3).reshape(b, out.shape[1],
+ self.heads * self.dim_head))
+ return self.to_out(out)
+
+
+class BasicTransformerBlock(nn.Module):
+
+ def __init__(self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ disable_self_attn=False,
+ attention_cls=None):
+ super().__init__()
+ attn_cls = CrossAttention if attention_cls is None else attention_cls
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None)
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = attn_cls(query_dim=dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout)
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None, **kwargs):
+ ## implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
+ input_tuple = (
+ x,
+ ) ## should not be (x), otherwise *input_tuple will decouple x into multiple arguments
+ if context is not None:
+ input_tuple = (x, context)
+ return checkpoint(self._forward, input_tuple, self.parameters(),
+ self.checkpoint)
+
+ def _forward(self, x, context=None, mask=None):
+ x = self.attn1(self.norm1(x),
+ context=context if self.disable_self_attn else None,
+ mask=mask) + x
+ x = self.attn2(self.norm2(x), context=context, mask=mask) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class ActionLatentImageCrossAttention(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ in_dim,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.,
+ context_dim=None,
+ use_checkpoint=True,
+ disable_self_attn=False,
+ use_linear=True):
+ super().__init__()
+ """
+ in_channels: action input dim
+
+ """
+ self.in_channels = in_channels
+ self.in_dim = in_dim
+ inner_dim = n_heads * d_head
+ self.norm = torch.nn.GroupNorm(num_groups=8,
+ num_channels=in_channels,
+ eps=1e-6,
+ affine=True)
+
+ self.proj_in_action = nn.Linear(in_dim, inner_dim)
+ self.proj_in_cond = nn.Linear(context_dim, inner_dim)
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_dim))
+ self.use_linear = use_linear
+
+ attention_cls = None
+ self.transformer_blocks = nn.ModuleList([
+ BasicTransformerBlock(inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ context_dim=context_dim,
+ disable_self_attn=disable_self_attn,
+ checkpoint=use_checkpoint,
+ attention_cls=attention_cls)
+ for d in range(depth)
+ ])
+
+ def forward(self, x, context=None, **kwargs):
+ ba, ca, da = x.shape
+ b, t, c, h, w = context.shape
+ context = rearrange(context, 'b t c h w -> b (t h w) c').contiguous()
+
+ x_in = x
+ x = self.norm(x) # ba x ja x d_in
+ if self.use_linear:
+ x = self.proj_in_action(x)
+ context = self.proj_in_cond(context)
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, context=context, **kwargs)
+ if self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
+
+
+class ConditionalResidualBlock1D(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ cond_dim,
+ kernel_size=3,
+ n_groups=8,
+ cond_predict_scale=True,
+ use_linear_act_proj=False):
+ super().__init__()
+
+ self.blocks = nn.ModuleList([
+ Conv1dBlock(in_channels,
+ out_channels,
+ kernel_size,
+ n_groups=n_groups),
+ Conv1dBlock(out_channels,
+ out_channels,
+ kernel_size,
+ n_groups=n_groups),
+ ])
+
+ self.cond_predict_scale = cond_predict_scale
+ self.use_linear_act_proj = use_linear_act_proj
+ self.out_channels = out_channels
+ # FiLM modulation https://arxiv.org/abs/1709.07871
+ # predicts per-channel scale and bias
+ cond_channels = out_channels
+ if cond_predict_scale and use_linear_act_proj:
+ cond_channels = out_channels * 2
+ self.cond_encoder = nn.Sequential(
+ nn.Mish(),
+ nn.Linear(cond_dim, cond_channels),
+ )
+ # make sure dimensions compatible
+ self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
+ if in_channels != out_channels else nn.Identity()
+
+ def forward(self, x, cond=None):
+ '''
+ x : [ batch_size x in_channels x horizon ]
+ cond : [ batch_size x cond_dim]
+
+ returns:
+ out : [ batch_size x out_channels x horizon ]
+ '''
+ B, T, _ = cond.shape
+
+ out = self.blocks[0](x)
+ if self.cond_predict_scale:
+ embed = self.cond_encoder(cond)
+ if self.use_linear_act_proj:
+ embed = embed.reshape(B * T, -1)
+ embed = embed.reshape(-1, 2, self.out_channels, 1)
+ else:
+ embed = embed.reshape(embed.shape[0], 2, self.out_channels, 1)
+ scale = embed[:, 0, ...]
+ bias = embed[:, 1, ...]
+ out = scale * out + bias
+ # else:
+ # out = out + embed
+ out = self.blocks[1](out)
+ out = out + self.residual_conv(x)
+ return out
+
+
+class ConditionalUnet1D(nn.Module):
+
+ def __init__(self,
+ input_dim,
+ n_obs_steps=1,
+ local_cond_dim=None,
+ global_cond_dim=None,
+ diffusion_step_embed_dim=256,
+ down_dims=[256, 512, 1024],
+ kernel_size=3,
+ n_groups=8,
+ cond_predict_scale=False,
+ horizon=16,
+ num_head_channels=64,
+ use_linear_attn=True,
+ use_linear_act_proj=True,
+ act_proj_dim=32,
+ cond_cross_attention=False,
+ context_dims=None,
+ image_size=None,
+ imagen_cond_gradient=False,
+ last_frame_only=False,
+ use_imagen_mid_only=False,
+ use_z_only=False,
+ spatial_num_kp=32,
+ obs_encoder_config=None):
+ super().__init__()
+
+ self.n_obs_steps = n_obs_steps
+ self.obs_encoder = instantiate_from_config(obs_encoder_config)
+
+ all_dims = [input_dim] + list(down_dims)
+ start_dim = down_dims[0]
+
+ dsed = diffusion_step_embed_dim
+ diffusion_step_encoder = nn.Sequential(
+ SinusoidalPosEmb(dsed),
+ nn.Linear(dsed, dsed * 4),
+ nn.Mish(),
+ nn.Linear(dsed * 4, dsed),
+ )
+ cond_dim = dsed + self.obs_encoder.output_shape()[-1] * self.n_obs_steps
+ in_out = list(zip(all_dims[:-1], all_dims[1:]))
+ local_cond_encoder = None
+ down_modules = nn.ModuleList([])
+
+ dim_a_list = []
+ for ind, (dim_in, dim_out) in enumerate(in_out):
+ is_last = ind >= (len(in_out) - 1)
+ if ind == 0:
+ dim_a = horizon
+ else:
+ dim_a = horizon // 2 * ind
+ dim_a_list.append(dim_a)
+
+ # for attention
+ num_heads = dim_out // num_head_channels
+ dim_head = num_head_channels
+ if use_linear_act_proj:
+ if use_imagen_mid_only:
+ cur_cond_dim = cond_dim + 2 * context_dims[-1]
+ elif use_z_only:
+ cur_cond_dim = cond_dim + 2 * spatial_num_kp
+ else:
+ cur_cond_dim = cond_dim + 2 * context_dims[ind]
+ else:
+ cur_cond_dim = cond_dim + horizon * context_dims[ind]
+
+ down_modules.append(
+ nn.ModuleList([
+ ConditionalResidualBlock1D(
+ dim_in,
+ dim_out,
+ cond_dim=cur_cond_dim,
+ kernel_size=kernel_size,
+ n_groups=n_groups,
+ cond_predict_scale=cond_predict_scale,
+ use_linear_act_proj=use_linear_act_proj),
+ ConditionalResidualBlock1D(
+ dim_out,
+ dim_out,
+ cond_dim=cur_cond_dim,
+ kernel_size=kernel_size,
+ n_groups=n_groups,
+ cond_predict_scale=cond_predict_scale,
+ use_linear_act_proj=use_linear_act_proj),
+ ActionLatentImageCrossAttention(
+ dim_out,
+ dim_a,
+ num_heads,
+ dim_head,
+ context_dim=context_dims[ind],
+ use_linear=use_linear_attn)
+ if cond_cross_attention else nn.Identity(),
+ Downsample1d(dim_out) if not is_last else nn.Identity()
+ ]))
+
+ mid_dim = all_dims[-1]
+ self.mid_modules = nn.ModuleList([
+ ConditionalResidualBlock1D(
+ mid_dim,
+ mid_dim,
+ cond_dim=cur_cond_dim,
+ kernel_size=kernel_size,
+ n_groups=n_groups,
+ cond_predict_scale=cond_predict_scale,
+ use_linear_act_proj=use_linear_act_proj),
+ ConditionalResidualBlock1D(
+ mid_dim,
+ mid_dim,
+ cond_dim=cur_cond_dim,
+ kernel_size=kernel_size,
+ n_groups=n_groups,
+ cond_predict_scale=cond_predict_scale,
+ use_linear_act_proj=use_linear_act_proj),
+ ActionLatentImageCrossAttention(mid_dim,
+ dim_a_list[-1],
+ num_heads,
+ dim_head,
+ context_dim=context_dims[-1],
+ use_linear=use_linear_attn)
+ if cond_cross_attention else nn.Identity(),
+ ])
+
+ up_modules = nn.ModuleList([])
+ context_dims = context_dims[::-1]
+ for ind, (dim_in, dim_out) in enumerate(
+ reversed(in_out[1:] + [(down_dims[-1], down_dims[-1])])):
+ is_last = ind >= (len(in_out) - 1)
+ if use_linear_act_proj:
+ if use_imagen_mid_only:
+ cur_cond_dim = cond_dim + 2 * context_dims[0]
+ elif use_z_only:
+ cur_cond_dim = cond_dim + 2 * spatial_num_kp
+ else:
+ cur_cond_dim = cond_dim + 2 * context_dims[ind]
+ else:
+ cur_cond_dim = cond_dim + horizon * context_dims[ind]
+ up_modules.append(
+ nn.ModuleList([
+ ConditionalResidualBlock1D(
+ dim_out + dim_in,
+ dim_in,
+ cond_dim=cur_cond_dim,
+ kernel_size=kernel_size,
+ n_groups=n_groups,
+ cond_predict_scale=cond_predict_scale,
+ use_linear_act_proj=use_linear_act_proj),
+ ConditionalResidualBlock1D(
+ dim_in,
+ dim_in,
+ cond_dim=cur_cond_dim,
+ kernel_size=kernel_size,
+ n_groups=n_groups,
+ cond_predict_scale=cond_predict_scale,
+ use_linear_act_proj=use_linear_act_proj),
+ ActionLatentImageCrossAttention(
+ dim_in,
+ dim_a_list.pop(),
+ num_heads,
+ dim_head,
+ context_dim=context_dims[ind],
+ use_linear=use_linear_attn)
+ if cond_cross_attention else nn.Identity(),
+ Upsample1d(dim_in) if not is_last else nn.Identity()
+ ]))
+
+ final_conv = nn.Sequential(
+ Conv1dBlock(start_dim, start_dim, kernel_size=kernel_size),
+ nn.Conv1d(start_dim, input_dim, 1),
+ )
+
+ if use_z_only:
+ h, w = image_size
+ self.spatial_softmax_blocks = nn.ModuleList(
+ [SpatialSoftmax((4, h, w), spatial_num_kp)])
+ else:
+ self.spatial_softmax_blocks = nn.ModuleList([])
+ context_dims = context_dims[::-1]
+ for ind, context_dim in enumerate(context_dims):
+ h, w = image_size
+ if ind != 0:
+ h //= 2**ind
+ w //= 2**ind
+ net = SpatialSoftmax((context_dim, h, w), context_dim)
+ self.spatial_softmax_blocks.append(net)
+ self.spatial_softmax_blocks.append(net)
+ self.spatial_softmax_blocks += self.spatial_softmax_blocks[
+ 0:4][::-1]
+
+ self.diffusion_step_encoder = diffusion_step_encoder
+ self.local_cond_encoder = local_cond_encoder
+ self.up_modules = up_modules
+ self.down_modules = down_modules
+ self.final_conv = final_conv
+
+ self.cond_cross_attention = cond_cross_attention
+ self.use_linear_act_proj = use_linear_act_proj
+
+ self.proj_in_action = nn.Sequential(nn.Linear(1, act_proj_dim),
+ nn.LayerNorm(act_proj_dim))
+ self.proj_in_horizon = nn.Sequential(nn.Linear(horizon, act_proj_dim),
+ nn.LayerNorm(act_proj_dim))
+ self.proj_out_action = nn.Sequential(nn.LayerNorm(act_proj_dim),
+ nn.Linear(act_proj_dim, 1))
+ self.proj_out_horizon = nn.Sequential(nn.LayerNorm(act_proj_dim),
+ nn.Linear(act_proj_dim, horizon))
+ logger.info("number of parameters: %e",
+ sum(p.numel() for p in self.parameters()))
+
+ self.imagen_cond_gradient = imagen_cond_gradient
+ self.use_imagen_mid_only = use_imagen_mid_only
+ self.use_z_only = use_z_only
+ self.spatial_num_kp = spatial_num_kp
+ self.last_frame_only = last_frame_only
+ self.horizon = horizon
+
+ def forward(self,
+ sample: torch.Tensor,
+ timestep: Union[torch.Tensor, float, int],
+ imagen_cond=None,
+ cond=None,
+ **kwargs):
+ """
+ sample: (B,T,input_dim)
+ timestep: (B,) or int, diffusion step
+ imagen_cond: a list of hidden info from video gen unet
+ cond: dict:
+ image: (B, 3, To, h, w)
+ agent_pos: (B, Ta, d)
+ output: (B,T,input_dim)
+ """
+
+ if not self.imagen_cond_gradient:
+ imagen_cond = [c.detach() for c in imagen_cond]
+
+ cond = {'image': cond[0], 'agent_pos': cond[1]}
+
+ cond['image'] = cond['image'].permute(0, 2, 1, 3,
+ 4)
+ cond['image'] = rearrange(cond['image'], 'b t c h w -> (b t) c h w')
+ cond['agent_pos'] = rearrange(cond['agent_pos'], 'b t d -> (b t) d')
+
+ B, T, D = sample.shape
+ if self.use_linear_act_proj:
+ sample = self.proj_in_action(sample.unsqueeze(-1))
+ global_cond = self.obs_encoder(cond)
+ global_cond = rearrange(global_cond,
+ '(b t) d -> b 1 (t d)',
+ b=B,
+ t=self.n_obs_steps)
+ global_cond = repeat(global_cond,
+ 'b c d -> b (repeat c) d',
+ repeat=T)
+ else:
+ sample = einops.rearrange(sample, 'b h t -> b t h')
+ sample = self.proj_in_horizon(sample)
+ robo_state_cond = rearrange(robo_state_cond, 'b t d -> b 1 (t d)')
+ robo_state_cond = repeat(robo_state_cond,
+ 'b c d -> b (repeat c) d',
+ repeat=2)
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ timesteps = torch.tensor([timesteps],
+ dtype=torch.long,
+ device=sample.device)
+ elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+ # Broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+ global_feature = self.diffusion_step_encoder(timesteps)
+ (imagen_cond_down, imagen_cond_mid, imagen_cond_up
+ ) = imagen_cond[0:4], imagen_cond[4], imagen_cond[5:] #NOTE HAND CODE
+
+ x = sample if not self.use_linear_act_proj else sample.reshape(
+ B * T, D, -1)
+ h = []
+ for idx, modules in enumerate(self.down_modules):
+ if self.cond_cross_attention:
+ (resnet, resnet2, crossatten, downsample) = modules
+ else:
+ (resnet, resnet2, _, downsample) = modules
+
+ # Access the cond from the unet embeds from video unet
+ if self.use_imagen_mid_only:
+ imagen_cond = imagen_cond_mid
+ elif self.use_z_only:
+ imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
+ else:
+ imagen_cond = imagen_cond_down[idx]
+ if self.last_frame_only:
+ imagen_cond = imagen_cond[:, -1].unsqueeze(1)
+ imagen_cond = repeat(imagen_cond,
+ 'b t c h w -> b (repeat t) c h w',
+ repeat=self.horizon)
+ imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
+ if self.use_imagen_mid_only:
+ imagen_cond = self.spatial_softmax_blocks[len(
+ self.spatial_softmax_blocks) // 2](imagen_cond)
+ elif self.use_z_only:
+ imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
+ else:
+ imagen_cond = self.spatial_softmax_blocks[idx](imagen_cond)
+ imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
+
+ if self.use_linear_act_proj:
+ imagen_cond = imagen_cond.reshape(B, T, -1)
+ cur_global_feature = global_feature.unsqueeze(
+ 1).repeat_interleave(repeats=T, dim=1)
+ else:
+ imagen_cond = imagen_cond.permute(0, 3, 1, 2)
+ imagen_cond = imagen_cond.reshape(B, 2, -1)
+ cur_global_feature = global_feature.unsqueeze(
+ 1).repeat_interleave(repeats=2, dim=1)
+ cur_global_feature = torch.cat(
+ [cur_global_feature, global_cond, imagen_cond], axis=-1)
+ x = resnet(x, cur_global_feature)
+ x = resnet2(x, cur_global_feature)
+ h.append(x)
+ x = downsample(x)
+
+ #>>> mide blocks
+ resnet, resnet2, _ = self.mid_modules
+ # Access the cond from the unet embeds from video unet
+ if self.use_z_only:
+ imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
+ else:
+ imagen_cond = imagen_cond_mid
+ if self.last_frame_only:
+ imagen_cond = imagen_cond[:, -1].unsqueeze(1)
+ imagen_cond = repeat(imagen_cond,
+ 'b t c h w -> b (repeat t) c h w',
+ repeat=self.horizon)
+ imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
+ idx += 1
+ if self.use_z_only:
+ imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
+ else:
+ imagen_cond = self.spatial_softmax_blocks[idx](imagen_cond)
+ imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
+ if self.use_linear_act_proj:
+ imagen_cond = imagen_cond.reshape(B, T, -1)
+ cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
+ repeats=T, dim=1)
+ else:
+ imagen_cond = imagen_cond.permute(0, 3, 1, 2)
+ imagen_cond = imagen_cond.reshape(B, 2, -1)
+ cur_global_feature = global_feature.unsqueeze(1).repeat_interleave(
+ repeats=2, dim=1)
+ cur_global_feature = torch.cat(
+ [cur_global_feature, global_cond, imagen_cond], axis=-1)
+ x = resnet(x, cur_global_feature)
+ x = resnet2(x, cur_global_feature)
+
+ #>>> up blocks
+ idx += 1
+ for jdx, modules in enumerate(self.up_modules):
+ if self.cond_cross_attention:
+ (resnet, resnet2, crossatten, upsample) = modules
+ else:
+ (resnet, resnet2, _, upsample) = modules
+
+ # Access the cond from the unet embeds from video unet
+ if self.use_imagen_mid_only:
+ imagen_cond = imagen_cond_mid
+ elif self.use_z_only:
+ imagen_cond = kwargs['x_start'].permute(0, 2, 1, 3, 4)
+ else:
+ imagen_cond = imagen_cond_up[jdx]
+ if self.last_frame_only:
+ imagen_cond = imagen_cond[:, -1].unsqueeze(1)
+ imagen_cond = repeat(imagen_cond,
+ 'b t c h w -> b (repeat t) c h w',
+ repeat=self.horizon)
+ imagen_cond = rearrange(imagen_cond, 'b t c h w -> (b t) c h w')
+ if self.use_imagen_mid_only:
+ imagen_cond = self.spatial_softmax_blocks[len(
+ self.spatial_softmax_blocks) // 2](imagen_cond)
+ elif self.use_z_only:
+ imagen_cond = self.spatial_softmax_blocks[0](imagen_cond)
+ else:
+ imagen_cond = self.spatial_softmax_blocks[jdx +
+ idx](imagen_cond)
+ imagen_cond = rearrange(imagen_cond, '(b t) c d -> b t c d', b=B)
+
+ if self.use_linear_act_proj:
+ imagen_cond = imagen_cond.reshape(B, T, -1)
+ cur_global_feature = global_feature.unsqueeze(
+ 1).repeat_interleave(repeats=T, dim=1)
+ else:
+ imagen_cond = imagen_cond.permute(0, 3, 1, 2)
+ imagen_cond = imagen_cond.reshape(B, 2, -1)
+ cur_global_feature = global_feature.unsqueeze(
+ 1).repeat_interleave(repeats=2, dim=1)
+
+ cur_global_feature = torch.cat(
+ [cur_global_feature, global_cond, imagen_cond], axis=-1)
+
+ x = torch.cat((x, h.pop()), dim=1)
+ x = resnet(x, cur_global_feature)
+ x = resnet2(x, cur_global_feature)
+ x = upsample(x)
+
+ x = self.final_conv(x)
+
+ if self.use_linear_act_proj:
+ x = x.reshape(B, T, D, -1)
+ x = self.proj_out_action(x)
+ x = x.reshape(B, T, D)
+ else:
+ x = self.proj_out_horizon(x)
+ x = einops.rearrange(x, 'b t h -> b h t')
+ return x
diff --git a/src/unifolm_wma/models/diffusion_head/conv1d_components.py b/src/unifolm_wma/models/diffusion_head/conv1d_components.py
new file mode 100644
index 0000000..713ac6c
--- /dev/null
+++ b/src/unifolm_wma/models/diffusion_head/conv1d_components.py
@@ -0,0 +1,52 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class Downsample1d(nn.Module):
+
+ def __init__(self, dim):
+ super().__init__()
+ self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class Upsample1d(nn.Module):
+
+ def __init__(self, dim):
+ super().__init__()
+ self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class Conv1dBlock(nn.Module):
+ '''
+ Conv1d --> GroupNorm --> Mish
+ '''
+
+ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
+ super().__init__()
+
+ self.block = nn.Sequential(
+ nn.Conv1d(inp_channels,
+ out_channels,
+ kernel_size,
+ padding=kernel_size // 2),
+ # Rearrange('batch channels horizon -> batch channels 1 horizon'),
+ nn.GroupNorm(n_groups, out_channels),
+ # Rearrange('batch channels 1 horizon -> batch channels horizon'),
+ nn.Mish(),
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+def test():
+ cb = Conv1dBlock(256, 128, kernel_size=3)
+ x = torch.zeros((1, 256, 16))
+ o = cb(x)
diff --git a/src/unifolm_wma/models/diffusion_head/ema_model.py b/src/unifolm_wma/models/diffusion_head/ema_model.py
new file mode 100644
index 0000000..c83a3b0
--- /dev/null
+++ b/src/unifolm_wma/models/diffusion_head/ema_model.py
@@ -0,0 +1,80 @@
+import copy
+import torch
+from torch.nn.modules.batchnorm import _BatchNorm
+
+
+class EMAModel:
+ """
+ Exponential Moving Average of models weights
+ """
+
+ def __init__(self,
+ model,
+ update_after_step=0,
+ inv_gamma=1.0,
+ power=2 / 3,
+ min_value=0.0,
+ max_value=0.9999):
+ """
+ @crowsonkb's notes on EMA Warmup:
+ If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
+ to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
+ gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
+ at 215.4k steps).
+ Args:
+ inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
+ power (float): Exponential factor of EMA warmup. Default: 2/3.
+ min_value (float): The minimum EMA decay rate. Default: 0.
+ """
+
+ self.averaged_model = model
+ self.averaged_model.eval()
+ self.averaged_model.requires_grad_(False)
+
+ self.update_after_step = update_after_step
+ self.inv_gamma = inv_gamma
+ self.power = power
+ self.min_value = min_value
+ self.max_value = max_value
+
+ self.decay = 0.0
+ self.optimization_step = 0
+
+ def get_decay(self, optimization_step):
+ """
+ Compute the decay factor for the exponential moving average.
+ """
+ step = max(0, optimization_step - self.update_after_step - 1)
+ value = 1 - (1 + step / self.inv_gamma)**-self.power
+
+ if step <= 0:
+ return 0.0
+
+ return max(self.min_value, min(value, self.max_value))
+
+ @torch.no_grad()
+ def step(self, new_model):
+ self.decay = self.get_decay(self.optimization_step)
+
+ all_dataptrs = set()
+ for module, ema_module in zip(new_model.modules(),
+ self.averaged_model.modules()):
+ for param, ema_param in zip(module.parameters(recurse=False),
+ ema_module.parameters(recurse=False)):
+ # iterative over immediate parameters only.
+ if isinstance(param, dict):
+ raise RuntimeError('Dict parameter not supported')
+
+ if isinstance(module, _BatchNorm):
+ # skip batchnorms
+ ema_param.copy_(param.to(dtype=ema_param.dtype).data)
+ elif not param.requires_grad:
+ ema_param.copy_(param.to(dtype=ema_param.dtype).data)
+ else:
+ ema_param.mul_(self.decay)
+ ema_param.add_(param.data.to(dtype=ema_param.dtype),
+ alpha=1 - self.decay)
+
+ # verify that iterating over module and then parameters is identical to parameters recursively.
+ # assert old_all_dataptrs == all_dataptrs
+ self.optimization_step += 1
diff --git a/src/unifolm_wma/models/diffusion_head/positional_embedding.py b/src/unifolm_wma/models/diffusion_head/positional_embedding.py
new file mode 100644
index 0000000..1b1d646
--- /dev/null
+++ b/src/unifolm_wma/models/diffusion_head/positional_embedding.py
@@ -0,0 +1,19 @@
+import math
+import torch
+import torch.nn as nn
+
+
+class SinusoidalPosEmb(nn.Module):
+
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x):
+ device = x.device
+ half_dim = self.dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+ emb = x[:, None] * emb[None, :]
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
+ return emb
diff --git a/src/unifolm_wma/models/diffusion_head/vision/__init__.py b/src/unifolm_wma/models/diffusion_head/vision/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/unifolm_wma/models/diffusion_head/vision/crop_randomizer.py b/src/unifolm_wma/models/diffusion_head/vision/crop_randomizer.py
new file mode 100644
index 0000000..d7b5408
--- /dev/null
+++ b/src/unifolm_wma/models/diffusion_head/vision/crop_randomizer.py
@@ -0,0 +1,322 @@
+import torch
+import torch.nn as nn
+import torchvision.transforms.functional as ttf
+import unifolm_wma.models.diffusion_head.common.tensor_util as tu
+
+
+class CropRandomizer(nn.Module):
+ """
+ Randomly sample crops at input, and then average across crop features at output.
+ """
+
+ def __init__(
+ self,
+ input_shape,
+ crop_height,
+ crop_width,
+ num_crops=1,
+ pos_enc=False,
+ ):
+ """
+ Args:
+ input_shape (tuple, list): shape of input (not including batch dimension)
+ crop_height (int): crop height
+ crop_width (int): crop width
+ num_crops (int): number of random crops to take
+ pos_enc (bool): if True, add 2 channels to the output to encode the spatial
+ location of the cropped pixels in the source image
+ """
+ super().__init__()
+
+ assert len(input_shape) == 3 # (C, H, W)
+ assert crop_height < input_shape[1]
+ assert crop_width < input_shape[2]
+
+ self.input_shape = input_shape
+ self.crop_height = crop_height
+ self.crop_width = crop_width
+ self.num_crops = num_crops
+ self.pos_enc = pos_enc
+
+ def output_shape_in(self, input_shape=None):
+ """
+ Function to compute output shape from inputs to this module. Corresponds to
+ the @forward_in operation, where raw inputs (usually observation modalities)
+ are passed in.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+
+ # outputs are shape (C, CH, CW), or maybe C + 2 if using position encoding, because
+ # the number of crops are reshaped into the batch dimension, increasing the batch
+ # size from B to B * N
+ out_c = self.input_shape[0] + 2 if self.pos_enc else self.input_shape[0]
+ return [out_c, self.crop_height, self.crop_width]
+
+ def output_shape_out(self, input_shape=None):
+ """
+ Function to compute output shape from inputs to this module. Corresponds to
+ the @forward_out operation, where processed inputs (usually encoded observation
+ modalities) are passed in.
+
+ Args:
+ input_shape (iterable of int): shape of input. Does not include batch dimension.
+ Some modules may not need this argument, if their output does not depend
+ on the size of the input, or if they assume fixed size input.
+
+ Returns:
+ out_shape ([int]): list of integers corresponding to output shape
+ """
+
+ # since the forward_out operation splits [B * N, ...] -> [B, N, ...]
+ # and then pools to result in [B, ...], only the batch dimension changes,
+ # and so the other dimensions retain their shape.
+ return list(input_shape)
+
+ def forward_in(self, inputs):
+ """
+ Samples N random crops for each input in the batch, and then reshapes
+ inputs to [B * N, ...].
+ """
+ assert len(
+ inputs.shape) >= 3 # must have at least (C, H, W) dimensions
+ if self.training:
+ # generate random crops
+ out, _ = sample_random_image_crops(
+ images=inputs,
+ crop_height=self.crop_height,
+ crop_width=self.crop_width,
+ num_crops=self.num_crops,
+ pos_enc=self.pos_enc,
+ )
+ # [B, N, ...] -> [B * N, ...]
+ return tu.join_dimensions(out, 0, 1)
+ else:
+ # take center crop during eval
+ out = ttf.center_crop(img=inputs,
+ output_size=(self.crop_height,
+ self.crop_width))
+ if self.num_crops > 1:
+ B, C, H, W = out.shape
+ out = out.unsqueeze(1).expand(B, self.num_crops, C, H,
+ W).reshape(-1, C, H, W)
+ # [B * N, ...]
+ return out
+
+ def forward_out(self, inputs):
+ """
+ Splits the outputs from shape [B * N, ...] -> [B, N, ...] and then average across N
+ to result in shape [B, ...] to make sure the network output is consistent with
+ what would have happened if there were no randomization.
+ """
+ if self.num_crops <= 1:
+ return inputs
+ else:
+ batch_size = (inputs.shape[0] // self.num_crops)
+ out = tu.reshape_dimensions(inputs,
+ begin_axis=0,
+ end_axis=0,
+ target_dims=(batch_size,
+ self.num_crops))
+ return out.mean(dim=1)
+
+ def forward(self, inputs):
+ return self.forward_in(inputs)
+
+ def __repr__(self):
+ """Pretty print network."""
+ header = '{}'.format(str(self.__class__.__name__))
+ msg = header + "(input_shape={}, crop_size=[{}, {}], num_crops={})".format(
+ self.input_shape, self.crop_height, self.crop_width,
+ self.num_crops)
+ return msg
+
+
+def crop_image_from_indices(images, crop_indices, crop_height, crop_width):
+ """
+ Crops images at the locations specified by @crop_indices. Crops will be
+ taken across all channels.
+
+ Args:
+ images (torch.Tensor): batch of images of shape [..., C, H, W]
+
+ crop_indices (torch.Tensor): batch of indices of shape [..., N, 2] where
+ N is the number of crops to take per image and each entry corresponds
+ to the pixel height and width of where to take the crop. Note that
+ the indices can also be of shape [..., 2] if only 1 crop should
+ be taken per image. Leading dimensions must be consistent with
+ @images argument. Each index specifies the top left of the crop.
+ Values must be in range [0, H - CH - 1] x [0, W - CW - 1] where
+ H and W are the height and width of @images and CH and CW are
+ @crop_height and @crop_width.
+
+ crop_height (int): height of crop to take
+
+ crop_width (int): width of crop to take
+
+ Returns:
+ crops (torch.Tesnor): cropped images of shape [..., C, @crop_height, @crop_width]
+ """
+
+ # make sure length of input shapes is consistent
+ assert crop_indices.shape[-1] == 2
+ ndim_im_shape = len(images.shape)
+ ndim_indices_shape = len(crop_indices.shape)
+ assert (ndim_im_shape == ndim_indices_shape +
+ 1) or (ndim_im_shape == ndim_indices_shape + 2)
+
+ # maybe pad so that @crop_indices is shape [..., N, 2]
+ is_padded = False
+ if ndim_im_shape == ndim_indices_shape + 2:
+ crop_indices = crop_indices.unsqueeze(-2)
+ is_padded = True
+
+ # make sure leading dimensions between images and indices are consistent
+ assert images.shape[:-3] == crop_indices.shape[:-2]
+
+ device = images.device
+ image_c, image_h, image_w = images.shape[-3:]
+ num_crops = crop_indices.shape[-2]
+
+ # make sure @crop_indices are in valid range
+ assert (crop_indices[..., 0] >= 0).all().item()
+ assert (crop_indices[..., 0] < (image_h - crop_height)).all().item()
+ assert (crop_indices[..., 1] >= 0).all().item()
+ assert (crop_indices[..., 1] < (image_w - crop_width)).all().item()
+
+ # convert each crop index (ch, cw) into a list of pixel indices that correspond to the entire window.
+
+ # 2D index array with columns [0, 1, ..., CH - 1] and shape [CH, CW]
+ crop_ind_grid_h = torch.arange(crop_height).to(device)
+ crop_ind_grid_h = tu.unsqueeze_expand_at(crop_ind_grid_h,
+ size=crop_width,
+ dim=-1)
+ # 2D index array with rows [0, 1, ..., CW - 1] and shape [CH, CW]
+ crop_ind_grid_w = torch.arange(crop_width).to(device)
+ crop_ind_grid_w = tu.unsqueeze_expand_at(crop_ind_grid_w,
+ size=crop_height,
+ dim=0)
+ # combine into shape [CH, CW, 2]
+ crop_in_grid = torch.cat(
+ (crop_ind_grid_h.unsqueeze(-1), crop_ind_grid_w.unsqueeze(-1)), dim=-1)
+
+ # Add above grid with the offset index of each sampled crop to get 2d indices for each crop.
+ # After broadcasting, this will be shape [..., N, CH, CW, 2] and each crop has a [CH, CW, 2]
+ # shape array that tells us which pixels from the corresponding source image to grab.
+ grid_reshape = [1] * len(crop_indices.shape[:-1]) + [
+ crop_height, crop_width, 2
+ ]
+ all_crop_inds = crop_indices.unsqueeze(-2).unsqueeze(
+ -2) + crop_in_grid.reshape(grid_reshape)
+
+ # For using @torch.gather, convert to flat indices from 2D indices, and also
+ # repeat across the channel dimension. To get flat index of each pixel to grab for
+ # each sampled crop, we just use the mapping: ind = h_ind * @image_w + w_ind
+ all_crop_inds = all_crop_inds[..., 0] * image_w + all_crop_inds[
+ ..., 1] # shape [..., N, CH, CW]
+ all_crop_inds = tu.unsqueeze_expand_at(all_crop_inds, size=image_c,
+ dim=-3) # shape [..., N, C, CH, CW]
+ all_crop_inds = tu.flatten(all_crop_inds,
+ begin_axis=-2) # shape [..., N, C, CH * CW]
+
+ # Repeat and flatten the source images -> [..., N, C, H * W] and then use gather to index with crop pixel inds
+ images_to_crop = tu.unsqueeze_expand_at(images, size=num_crops, dim=-4)
+ images_to_crop = tu.flatten(images_to_crop, begin_axis=-2)
+ crops = torch.gather(images_to_crop, dim=-1, index=all_crop_inds)
+ # [..., N, C, CH * CW] -> [..., N, C, CH, CW]
+ reshape_axis = len(crops.shape) - 1
+ crops = tu.reshape_dimensions(crops,
+ begin_axis=reshape_axis,
+ end_axis=reshape_axis,
+ target_dims=(crop_height, crop_width))
+
+ if is_padded:
+ # undo padding -> [..., C, CH, CW]
+ crops = crops.squeeze(-4)
+ return crops
+
+
+def sample_random_image_crops(images,
+ crop_height,
+ crop_width,
+ num_crops,
+ pos_enc=False):
+ """
+ For each image, randomly sample @num_crops crops of size (@crop_height, @crop_width), from
+ @images.
+
+ Args:
+ images (torch.Tensor): batch of images of shape [..., C, H, W]
+
+ crop_height (int): height of crop to take
+
+ crop_width (int): width of crop to take
+
+ num_crops (n): number of crops to sample
+
+ pos_enc (bool): if True, also add 2 channels to the outputs that gives a spatial
+ encoding of the original source pixel locations. This means that the
+ output crops will contain information about where in the source image
+ it was sampled from.
+
+ Returns:
+ crops (torch.Tensor): crops of shape (..., @num_crops, C, @crop_height, @crop_width)
+ if @pos_enc is False, otherwise (..., @num_crops, C + 2, @crop_height, @crop_width)
+
+ crop_inds (torch.Tensor): sampled crop indices of shape (..., N, 2)
+ """
+ device = images.device
+
+ # maybe add 2 channels of spatial encoding to the source image
+ source_im = images
+ if pos_enc:
+ # spatial encoding [y, x] in [0, 1]
+ h, w = source_im.shape[-2:]
+ pos_y, pos_x = torch.meshgrid(torch.arange(h), torch.arange(w))
+ pos_y = pos_y.float().to(device) / float(h)
+ pos_x = pos_x.float().to(device) / float(w)
+ position_enc = torch.stack((pos_y, pos_x)) # shape [C, H, W]
+
+ # unsqueeze and expand to match leading dimensions -> shape [..., C, H, W]
+ leading_shape = source_im.shape[:-3]
+ position_enc = position_enc[(None, ) * len(leading_shape)]
+ position_enc = position_enc.expand(*leading_shape, -1, -1, -1)
+
+ # concat across channel dimension with input
+ source_im = torch.cat((source_im, position_enc), dim=-3)
+
+ # make sure sample boundaries ensure crops are fully within the images
+ image_c, image_h, image_w = source_im.shape[-3:]
+ max_sample_h = image_h - crop_height
+ max_sample_w = image_w - crop_width
+
+ # Sample crop locations for all tensor dimensions up to the last 3, which are [C, H, W].
+ # Each gets @num_crops samples - typically this will just be the batch dimension (B), so
+ # we will sample [B, N] indices, but this supports having more than one leading dimension,
+ # or possibly no leading dimension.
+ #
+ # Trick: sample in [0, 1) with rand, then re-scale to [0, M) and convert to long to get sampled ints
+ crop_inds_h = (
+ max_sample_h *
+ torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
+ crop_inds_w = (
+ max_sample_w *
+ torch.rand(*source_im.shape[:-3], num_crops).to(device)).long()
+ crop_inds = torch.cat(
+ (crop_inds_h.unsqueeze(-1), crop_inds_w.unsqueeze(-1)),
+ dim=-1) # shape [..., N, 2]
+
+ crops = crop_image_from_indices(
+ images=source_im,
+ crop_indices=crop_inds,
+ crop_height=crop_height,
+ crop_width=crop_width,
+ )
+
+ return crops, crop_inds
diff --git a/src/unifolm_wma/models/diffusion_head/vision/model_getter.py b/src/unifolm_wma/models/diffusion_head/vision/model_getter.py
new file mode 100644
index 0000000..9c33ff0
--- /dev/null
+++ b/src/unifolm_wma/models/diffusion_head/vision/model_getter.py
@@ -0,0 +1,30 @@
+import torch
+import torchvision
+
+
+def get_resnet(name, weights=None, **kwargs):
+ """
+ name: resnet18, resnet34, resnet50
+ weights: "IMAGENET1K_V1", "r3m"
+ """
+ # load r3m weights
+ if (weights == "r3m") or (weights == "R3M"):
+ return get_r3m(name=name, **kwargs)
+
+ func = getattr(torchvision.models, name)
+ resnet = func(weights=weights, **kwargs)
+ resnet.fc = torch.nn.Identity()
+ return resnet
+
+
+def get_r3m(name, **kwargs):
+ """
+ name: resnet18, resnet34, resnet50
+ """
+ import r3m
+ r3m.device = 'cpu'
+ model = r3m.load_r3m(name)
+ r3m_model = model.module
+ resnet_model = r3m_model.convnet
+ resnet_model = resnet_model.to('cpu')
+ return resnet_model
diff --git a/src/unifolm_wma/models/diffusion_head/vision/multi_image_obs_encoder.py b/src/unifolm_wma/models/diffusion_head/vision/multi_image_obs_encoder.py
new file mode 100644
index 0000000..a46b0ed
--- /dev/null
+++ b/src/unifolm_wma/models/diffusion_head/vision/multi_image_obs_encoder.py
@@ -0,0 +1,247 @@
+import copy
+import torch
+import torch.nn as nn
+import torchvision
+import json
+import os
+
+from unifolm_wma.models.diffusion_head.vision.crop_randomizer import CropRandomizer
+from unifolm_wma.models.diffusion_head.base_nets import SpatialSoftmax
+from unifolm_wma.models.diffusion_head.common.module_attr_mixin import ModuleAttrMixin
+from unifolm_wma.models.diffusion_head.common.pytorch_util import dict_apply, replace_submodules
+from unifolm_wma.utils.utils import instantiate_from_config
+from einops import rearrange, repeat
+from typing import Dict, Tuple, Union
+from pathlib import Path
+
+
+class MultiImageObsEncoder(ModuleAttrMixin):
+
+ def __init__(
+ self,
+ rgb_model_config: Dict,
+ shape_meta_path: str | None = None,
+ resize_shape: Union[Tuple[int, int], Dict[str, tuple],
+ None] = None,
+ crop_shape: Union[Tuple[int, int], Dict[str, tuple], None] = None,
+ random_crop: bool = True,
+ # replace BatchNorm with GroupNorm
+ use_group_norm: bool = False,
+ # use single rgb model for all rgb inputs
+ share_rgb_model: bool = False,
+ # renormalize rgb input with imagenet normalization
+ # assuming input in [0,1]
+ imagenet_norm: bool = False,
+ use_spatial_softmax=False,
+ spatial_softmax_kp=32,
+ use_dinoSiglip=False):
+ """
+ Assumes rgb input: B,C,H,W
+ Assumes low_dim input: B,D
+ """
+ super().__init__()
+
+ if not shape_meta_path:
+ shape_meta_path = str(Path(os.getcwd()) / "configs/train/meta.json")
+
+ with open(shape_meta_path, 'r') as file:
+ shape_meta = json.load(file)
+
+ rgb_model = instantiate_from_config(rgb_model_config)
+
+ rgb_keys = list()
+ low_dim_keys = list()
+ key_model_map = nn.ModuleDict()
+ key_transform_map = nn.ModuleDict()
+ key_shape_map = dict()
+
+ # handle sharing vision backbone
+ if share_rgb_model:
+ assert isinstance(rgb_model, nn.Module)
+ key_model_map['rgb'] = rgb_model
+
+ obs_shape_meta = shape_meta['obs']
+ for key, attr in obs_shape_meta.items():
+ shape = tuple(attr['shape'])
+ type = attr.get('type', 'low_dim')
+ key_shape_map[key] = shape
+ if type == 'rgb':
+ rgb_keys.append(key)
+ if not use_dinoSiglip:
+ # configure model for this key
+ this_model = None
+ if not share_rgb_model:
+ if isinstance(rgb_model, dict):
+ # have provided model for each key
+ this_model = rgb_model[key]
+ else:
+ assert isinstance(rgb_model, nn.Module)
+ # have a copy of the rgb model
+ this_model = copy.deepcopy(rgb_model)
+
+ if this_model is not None:
+ if use_group_norm:
+ this_model = replace_submodules(
+ root_module=this_model,
+ predicate=lambda x: isinstance(
+ x, nn.BatchNorm2d),
+ func=lambda x: nn.GroupNorm(
+ num_groups=x.num_features // 16,
+ num_channels=x.num_features))
+ key_model_map[key] = this_model
+
+ # configure resize
+ input_shape = shape
+ this_resizer = nn.Identity()
+ if resize_shape is not None:
+ if isinstance(resize_shape, dict):
+ h, w = resize_shape[key]
+ else:
+ h, w = resize_shape
+ this_resizer = torchvision.transforms.Resize(size=(h,
+ w))
+ input_shape = (shape[0], h, w)
+
+ # configure randomizer
+ this_randomizer = nn.Identity()
+ if crop_shape is not None:
+ if isinstance(crop_shape, dict):
+ h, w = crop_shape[key]
+ else:
+ h, w = crop_shape
+ if random_crop:
+ this_randomizer = CropRandomizer(
+ input_shape=input_shape,
+ crop_height=h,
+ crop_width=w,
+ num_crops=1,
+ pos_enc=False)
+ else:
+ this_normalizer = torchvision.transforms.CenterCrop(
+ size=(h, w))
+ # configure normalizer
+ this_normalizer = nn.Identity()
+ if imagenet_norm:
+ this_normalizer = torchvision.transforms.Normalize(
+ mean=[0.485, 0.456, 0.406],
+ std=[0.229, 0.224, 0.225])
+
+ this_transform = nn.Sequential(this_resizer,
+ this_randomizer,
+ this_normalizer)
+ key_transform_map[key] = this_transform
+ else:
+ key_model_map[key] = rgb_model
+ elif type == 'low_dim':
+ low_dim_keys.append(key)
+ else:
+ raise RuntimeError(f"Unsupported obs type: {type}")
+
+ rgb_keys = sorted(rgb_keys)
+ low_dim_keys = sorted(low_dim_keys)
+
+ self.shape_meta = shape_meta
+ self.key_model_map = key_model_map
+ self.key_transform_map = key_transform_map
+ self.share_rgb_model = share_rgb_model
+ self.rgb_keys = rgb_keys
+ self.low_dim_keys = low_dim_keys
+ self.key_shape_map = key_shape_map
+ self.use_dinoSiglip = use_dinoSiglip
+
+ ##NOTE add spatial softmax
+ self.use_spatial_softmax = use_spatial_softmax
+ if use_spatial_softmax and not use_dinoSiglip:
+ model = nn.Sequential(
+ key_model_map['image'].conv1,
+ key_model_map['image'].bn1,
+ key_model_map['image'].relu,
+ key_model_map['image'].maxpool,
+ key_model_map['image'].layer1,
+ key_model_map['image'].layer2,
+ key_model_map['image'].layer3,
+ key_model_map['image'].layer4,
+ )
+ key_model_map['image'] = model
+ input_shape = self.output_shape(resnet_output_shape=True)
+ self.spatial_softmax = SpatialSoftmax(input_shape,
+ num_kp=spatial_softmax_kp)
+
+ def forward(self, obs_dict, resnet_output_shape=False):
+ batch_size = None
+ features = list()
+ # process rgb input
+ if self.share_rgb_model:
+ # pass all rgb obs to rgb model
+ imgs = list()
+ for key in self.rgb_keys:
+ img = obs_dict[key]
+ if batch_size is None:
+ batch_size = img.shape[0]
+ else:
+ assert batch_size == img.shape[0]
+ assert img.shape[1:] == self.key_shape_map[key]
+ img = self.key_transform_map[key](img)
+ imgs.append(img)
+ # (N*B,C,H,W)
+ imgs = torch.cat(imgs, dim=0)
+ # (N*B,D)
+ feature = self.key_model_map['rgb'](imgs)
+ # (N,B,D)
+ feature = feature.reshape(-1, batch_size, *feature.shape[1:])
+ # (B,N,D)
+ feature = torch.moveaxis(feature, 0, 1)
+ # (B,N*D)
+ feature = feature.reshape(batch_size, -1)
+ features.append(feature)
+ else:
+ # run each rgb obs to independent models
+ for key in self.rgb_keys:
+ img = obs_dict[key]
+ if batch_size is None:
+ batch_size = img.shape[0]
+ else:
+ assert batch_size == img.shape[0]
+ assert img.shape[1:] == self.key_shape_map[key]
+ if not self.use_dinoSiglip:
+ img = self.key_transform_map[key](img)
+ feature = self.key_model_map[key](img)
+ else:
+ feature = self.key_model_map[key](img)[:, :1, :]
+
+ if resnet_output_shape:
+ return feature
+ if not self.use_dinoSiglip and self.use_spatial_softmax:
+ feature = self.spatial_softmax(feature)
+ feature = feature.reshape(batch_size, -1)
+ features.append(feature)
+
+ # process lowdim input
+ for key in self.low_dim_keys:
+ data = obs_dict[key]
+ if batch_size is None:
+ batch_size = data.shape[0]
+ else:
+ assert batch_size == data.shape[0]
+ assert data.shape[1:] == self.key_shape_map[key]
+ features.append(data)
+
+ # concatenate all features
+ result = torch.cat(features, dim=-1)
+ return result
+
+ @torch.no_grad()
+ def output_shape(self, resnet_output_shape=False):
+ example_obs_dict = dict()
+ obs_shape_meta = self.shape_meta['obs']
+ batch_size = 1
+ for key, attr in obs_shape_meta.items():
+ shape = tuple(attr['shape'])
+ this_obs = torch.zeros((batch_size, ) + shape,
+ dtype=self.dtype,
+ device=self.device)
+ example_obs_dict[key] = this_obs
+ example_output = self.forward(example_obs_dict,
+ resnet_output_shape=resnet_output_shape)
+ output_shape = example_output.shape[1:]
+ return output_shape
diff --git a/src/unifolm_wma/models/samplers/ddim.py b/src/unifolm_wma/models/samplers/ddim.py
new file mode 100644
index 0000000..77602a1
--- /dev/null
+++ b/src/unifolm_wma/models/samplers/ddim.py
@@ -0,0 +1,473 @@
+import numpy as np
+import torch
+import copy
+
+from unifolm_wma.utils.diffusion import make_ddim_sampling_parameters, make_ddim_timesteps, rescale_noise_cfg
+from unifolm_wma.utils.common import noise_like
+from unifolm_wma.utils.common import extract_into_tensor
+from tqdm import tqdm
+
+
+class DDIMSampler(object):
+
+ def __init__(self, model, schedule="linear", **kwargs):
+ super().__init__()
+ self.model = model
+ self.ddpm_num_timesteps = model.num_timesteps
+ self.schedule = schedule
+ self.counter = 0
+
+ def register_buffer(self, name, attr):
+ if type(attr) == torch.Tensor:
+ if attr.device != torch.device("cuda"):
+ attr = attr.to(torch.device("cuda"))
+ setattr(self, name, attr)
+
+ def make_schedule(self,
+ ddim_num_steps,
+ ddim_discretize="uniform",
+ ddim_eta=0.,
+ verbose=True):
+ self.ddim_timesteps = make_ddim_timesteps(
+ ddim_discr_method=ddim_discretize,
+ num_ddim_timesteps=ddim_num_steps,
+ num_ddpm_timesteps=self.ddpm_num_timesteps,
+ verbose=verbose)
+ alphas_cumprod = self.model.alphas_cumprod
+ assert alphas_cumprod.shape[
+ 0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
+ to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model
+ .device)
+
+ if self.model.use_dynamic_rescale:
+ self.ddim_scale_arr = self.model.scale_arr[self.ddim_timesteps]
+ self.ddim_scale_arr_prev = torch.cat(
+ [self.ddim_scale_arr[0:1], self.ddim_scale_arr[:-1]])
+
+ self.register_buffer('betas', to_torch(self.model.betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev',
+ to_torch(self.model.alphas_cumprod_prev))
+
+ # Calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod',
+ to_torch(np.sqrt(alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod',
+ to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
+ self.register_buffer('log_one_minus_alphas_cumprod',
+ to_torch(np.log(1. - alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recip_alphas_cumprod',
+ to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod',
+ to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
+
+ # DDIM sampling parameters
+ ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
+ alphacums=alphas_cumprod.cpu(),
+ ddim_timesteps=self.ddim_timesteps,
+ eta=ddim_eta,
+ verbose=verbose)
+ self.register_buffer('ddim_sigmas', ddim_sigmas)
+ self.register_buffer('ddim_alphas', ddim_alphas)
+ self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
+ self.register_buffer('ddim_sqrt_one_minus_alphas',
+ np.sqrt(1. - ddim_alphas))
+ sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
+ (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
+ (1 - self.alphas_cumprod / self.alphas_cumprod_prev))
+ self.register_buffer('ddim_sigmas_for_original_num_steps',
+ sigmas_for_original_sampling_steps)
+
+ @torch.no_grad()
+ def sample(
+ self,
+ S,
+ batch_size,
+ shape,
+ conditioning=None,
+ callback=None,
+ normals_sequence=None,
+ img_callback=None,
+ quantize_x0=False,
+ eta=0.,
+ mask=None,
+ x0=None,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ verbose=True,
+ schedule_verbose=False,
+ x_T=None,
+ log_every_t=100,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ precision=None,
+ fs=None,
+ timestep_spacing='uniform', #uniform_trailing for starting from last timestep
+ guidance_rescale=0.0,
+ **kwargs):
+
+ # Check condition bs
+ if conditioning is not None:
+ if isinstance(conditioning, dict):
+ try:
+ cbs = conditioning[list(conditioning.keys())[0]].shape[0]
+ except:
+ cbs = conditioning[list(
+ conditioning.keys())[0]][0].shape[0]
+
+ if cbs != batch_size:
+ print(
+ f"Warning: Got {cbs} conditionings but batch-size is {batch_size}"
+ )
+ else:
+ if conditioning.shape[0] != batch_size:
+ print(
+ f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}"
+ )
+
+ self.make_schedule(ddim_num_steps=S,
+ ddim_discretize=timestep_spacing,
+ ddim_eta=eta,
+ verbose=schedule_verbose)
+
+ # Make shape
+ if len(shape) == 3:
+ C, H, W = shape
+ size = (batch_size, C, H, W)
+ elif len(shape) == 4:
+ C, T, H, W = shape
+ size = (batch_size, C, T, H, W)
+
+ samples, actions, states, intermediates = self.ddim_sampling(
+ conditioning,
+ size,
+ callback=callback,
+ img_callback=img_callback,
+ quantize_denoised=quantize_x0,
+ mask=mask,
+ x0=x0,
+ ddim_use_original_steps=False,
+ noise_dropout=noise_dropout,
+ temperature=temperature,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ x_T=x_T,
+ log_every_t=log_every_t,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ verbose=verbose,
+ precision=precision,
+ fs=fs,
+ guidance_rescale=guidance_rescale,
+ **kwargs)
+ return samples, actions, states, intermediates
+
+ @torch.no_grad()
+ def ddim_sampling(self,
+ cond,
+ shape,
+ x_T=None,
+ ddim_use_original_steps=False,
+ callback=None,
+ timesteps=None,
+ quantize_denoised=False,
+ mask=None,
+ x0=None,
+ img_callback=None,
+ log_every_t=100,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ verbose=True,
+ precision=None,
+ fs=None,
+ guidance_rescale=0.0,
+ **kwargs):
+ device = self.model.betas.device
+ dp_ddim_scheduler_action = self.model.dp_noise_scheduler_action
+ dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
+
+ b = shape[0]
+ if x_T is None:
+ img = torch.randn(shape, device=device)
+ action = torch.randn((b, 16, self.model.agent_action_dim),
+ device=device)
+ state = torch.randn((b, 16, self.model.agent_state_dim),
+ device=device)
+ else:
+ img = x_T
+ action = torch.randn((b, 16, self.model.agent_action_dim),
+ device=device)
+ state = torch.randn((b, 16, self.model.agent_state_dim),
+ device=device)
+
+ if precision is not None:
+ if precision == 16:
+ img = img.to(dtype=torch.float16)
+ action = action.to(dtype=torch.float16)
+ state = state.to(dtype=torch.float16)
+
+ if timesteps is None:
+ timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
+ elif timesteps is not None and not ddim_use_original_steps:
+ subset_end = int(
+ min(timesteps / self.ddim_timesteps.shape[0], 1) *
+ self.ddim_timesteps.shape[0]) - 1
+ timesteps = self.ddim_timesteps[:subset_end]
+
+ intermediates = {
+ 'x_inter': [img],
+ 'pred_x0': [img],
+ 'x_inter_action': [action],
+ 'pred_x0_action': [action],
+ 'x_inter_state': [state],
+ 'pred_x0_state': [state],
+ }
+ time_range = reversed(range(
+ 0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
+ total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
+ 0]
+ if verbose:
+ iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
+ else:
+ iterator = time_range
+
+ clean_cond = kwargs.pop("clean_cond", False)
+
+ dp_ddim_scheduler_action.set_timesteps(len(timesteps))
+ dp_ddim_scheduler_state.set_timesteps(len(timesteps))
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((b, ), step, device=device, dtype=torch.long)
+
+ # Use mask to blend noised original latent (img_orig) & new sampled latent (img)
+ if mask is not None:
+ assert x0 is not None
+ if clean_cond:
+ img_orig = x0
+ else:
+ img_orig = self.model.q_sample(x0, ts)
+ img = img_orig * mask + (1. - mask) * img
+
+ outs = self.p_sample_ddim(
+ img,
+ action,
+ state,
+ cond,
+ ts,
+ index=index,
+ use_original_steps=ddim_use_original_steps,
+ quantize_denoised=quantize_denoised,
+ temperature=temperature,
+ noise_dropout=noise_dropout,
+ score_corrector=score_corrector,
+ corrector_kwargs=corrector_kwargs,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning,
+ mask=mask,
+ x0=x0,
+ fs=fs,
+ guidance_rescale=guidance_rescale,
+ **kwargs)
+
+ img, pred_x0, model_output_action, model_output_state = outs
+
+ action = dp_ddim_scheduler_action.step(
+ model_output_action,
+ step,
+ action,
+ generator=None,
+ ).prev_sample
+ state = dp_ddim_scheduler_state.step(
+ model_output_state,
+ step,
+ state,
+ generator=None,
+ ).prev_sample
+
+ if callback: callback(i)
+ if img_callback: img_callback(pred_x0, i)
+
+ if index % log_every_t == 0 or index == total_steps - 1:
+ intermediates['x_inter'].append(img)
+ intermediates['pred_x0'].append(pred_x0)
+ intermediates['x_inter_action'].append(action)
+ intermediates['x_inter_state'].append(state)
+
+ return img, action, state, intermediates
+
+ @torch.no_grad()
+ def p_sample_ddim(self,
+ x,
+ x_action,
+ x_state,
+ c,
+ t,
+ index,
+ repeat_noise=False,
+ use_original_steps=False,
+ quantize_denoised=False,
+ temperature=1.,
+ noise_dropout=0.,
+ score_corrector=None,
+ corrector_kwargs=None,
+ unconditional_guidance_scale=1.,
+ unconditional_conditioning=None,
+ uc_type=None,
+ conditional_guidance_scale_temporal=None,
+ mask=None,
+ x0=None,
+ guidance_rescale=0.0,
+ **kwargs):
+ b, *_, device = *x.shape, x.device
+ if x.dim() == 5:
+ is_video = True
+ else:
+ is_video = False
+
+ if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
+ model_output, model_output_action, model_output_state = self.model.apply_model(
+ x, x_action, x_state, t, c, **kwargs) # unet denoiser
+ else:
+ # do_classifier_free_guidance
+ if isinstance(c, torch.Tensor) or isinstance(c, dict):
+ e_t_cond, e_t_cond_action, e_t_cond_state = self.model.apply_model(
+ x, x_action, x_state, t, c, **kwargs)
+ e_t_uncond, e_t_uncond_action, e_t_uncond_state = self.model.apply_model(
+ x, x_action, x_state, t, unconditional_conditioning,
+ **kwargs)
+ else:
+ raise NotImplementedError
+ model_output = e_t_uncond + unconditional_guidance_scale * (
+ e_t_cond - e_t_uncond)
+ model_output_action = e_t_uncond_action + unconditional_guidance_scale * (
+ e_t_cond_action - e_t_uncond_action)
+ model_output_state = e_t_uncond_state + unconditional_guidance_scale * (
+ e_t_cond_state - e_t_uncond_state)
+
+ if guidance_rescale > 0.0:
+ model_output = rescale_noise_cfg(
+ model_output, e_t_cond, guidance_rescale=guidance_rescale)
+ model_output_action = rescale_noise_cfg(
+ model_output_action,
+ e_t_cond_action,
+ guidance_rescale=guidance_rescale)
+ model_output_state = rescale_noise_cfg(
+ model_output_state,
+ e_t_cond_state,
+ guidance_rescale=guidance_rescale)
+
+ if self.model.parameterization == "v":
+ e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
+ else:
+ e_t = model_output
+
+ if score_corrector is not None:
+ assert self.model.parameterization == "eps", 'not implemented'
+ e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
+ **corrector_kwargs)
+
+ alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
+ alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
+ sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
+ sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
+
+ if is_video:
+ size = (b, 1, 1, 1, 1)
+ else:
+ size = (b, 1, 1, 1)
+
+ a_t = torch.full(size, alphas[index], device=device)
+ a_prev = torch.full(size, alphas_prev[index], device=device)
+ sigma_t = torch.full(size, sigmas[index], device=device)
+ sqrt_one_minus_at = torch.full(size,
+ sqrt_one_minus_alphas[index],
+ device=device)
+
+ if self.model.parameterization != "v":
+ pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
+ else:
+ pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
+
+ if self.model.use_dynamic_rescale:
+ scale_t = torch.full(size,
+ self.ddim_scale_arr[index],
+ device=device)
+ prev_scale_t = torch.full(size,
+ self.ddim_scale_arr_prev[index],
+ device=device)
+ rescale = (prev_scale_t / scale_t)
+ pred_x0 *= rescale
+
+ if quantize_denoised:
+ pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
+
+ dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
+
+ noise = sigma_t * noise_like(x.shape, device,
+ repeat_noise) * temperature
+ if noise_dropout > 0.:
+ noise = torch.nn.functional.dropout(noise, p=noise_dropout)
+
+ x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
+
+ return x_prev, pred_x0, model_output_action, model_output_state
+
+ @torch.no_grad()
+ def decode(self,
+ x_latent,
+ cond,
+ t_start,
+ unconditional_guidance_scale=1.0,
+ unconditional_conditioning=None,
+ use_original_steps=False,
+ callback=None):
+
+ timesteps = np.arange(self.ddpm_num_timesteps
+ ) if use_original_steps else self.ddim_timesteps
+ timesteps = timesteps[:t_start]
+
+ time_range = np.flip(timesteps)
+ total_steps = timesteps.shape[0]
+ print(f"Running DDIM Sampling with {total_steps} timesteps")
+
+ iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
+ x_dec = x_latent
+ for i, step in enumerate(iterator):
+ index = total_steps - i - 1
+ ts = torch.full((x_latent.shape[0], ),
+ step,
+ device=x_latent.device,
+ dtype=torch.long)
+ x_dec, _ = self.p_sample_ddim(
+ x_dec,
+ cond,
+ ts,
+ index=index,
+ use_original_steps=use_original_steps,
+ unconditional_guidance_scale=unconditional_guidance_scale,
+ unconditional_conditioning=unconditional_conditioning)
+ if callback: callback(i)
+ return x_dec
+
+ @torch.no_grad()
+ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
+ # fast, but does not allow for exact reconstruction
+ if use_original_steps:
+ sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
+ sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
+ else:
+ sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
+ sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
+
+ if noise is None:
+ noise = torch.randn_like(x0)
+ return (
+ extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 +
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) *
+ noise)
diff --git a/src/unifolm_wma/modules/__init__.py b/src/unifolm_wma/modules/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/unifolm_wma/modules/attention.py b/src/unifolm_wma/modules/attention.py
new file mode 100644
index 0000000..7b21317
--- /dev/null
+++ b/src/unifolm_wma/modules/attention.py
@@ -0,0 +1,806 @@
+import torch
+import torch.nn.functional as F
+
+from torch import nn, einsum
+from einops import rearrange, repeat
+from functools import partial
+
+try:
+ import xformers
+ import xformers.ops
+ XFORMERS_IS_AVAILBLE = True
+except:
+ XFORMERS_IS_AVAILBLE = False
+
+from unifolm_wma.utils.common import (
+ checkpoint,
+ exists,
+ default,
+)
+from unifolm_wma.utils.basics import zero_module
+
+
+class RelativePosition(nn.Module):
+ """ https://github.com/evelinehong/Transformer_Relative_Position_PyTorch/blob/master/relative_position.py """
+
+ def __init__(self, num_units, max_relative_position):
+ super().__init__()
+ self.num_units = num_units
+ self.max_relative_position = max_relative_position
+ self.embeddings_table = nn.Parameter(
+ torch.Tensor(max_relative_position * 2 + 1, num_units))
+ nn.init.xavier_uniform_(self.embeddings_table)
+
+ def forward(self, length_q, length_k):
+ device = self.embeddings_table.device
+ range_vec_q = torch.arange(length_q, device=device)
+ range_vec_k = torch.arange(length_k, device=device)
+ distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
+ distance_mat_clipped = torch.clamp(distance_mat,
+ -self.max_relative_position,
+ self.max_relative_position)
+ final_mat = distance_mat_clipped + self.max_relative_position
+ final_mat = final_mat.long()
+ embeddings = self.embeddings_table[final_mat]
+ return embeddings
+
+
+class CrossAttention(nn.Module):
+
+ def __init__(self,
+ query_dim,
+ context_dim=None,
+ heads=8,
+ dim_head=64,
+ dropout=0.,
+ relative_position=False,
+ temporal_length=None,
+ video_length=None,
+ agent_state_context_len=2,
+ agent_action_context_len=16,
+ image_cross_attention=False,
+ image_cross_attention_scale=1.0,
+ agent_state_cross_attention_scale=1.0,
+ agent_action_cross_attention_scale=1.0,
+ cross_attention_scale_learnable=False,
+ text_context_len=77):
+ super().__init__()
+ inner_dim = dim_head * heads
+ context_dim = default(context_dim, query_dim)
+
+ self.scale = dim_head**-0.5
+ self.heads = heads
+ self.dim_head = dim_head
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
+
+ self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim),
+ nn.Dropout(dropout))
+
+ self.relative_position = relative_position
+ if self.relative_position:
+ assert (temporal_length is not None)
+ self.relative_position_k = RelativePosition(
+ num_units=dim_head, max_relative_position=temporal_length)
+ self.relative_position_v = RelativePosition(
+ num_units=dim_head, max_relative_position=temporal_length)
+ else:
+ ## only used for spatial attention, while NOT for temporal attention
+ if XFORMERS_IS_AVAILBLE and temporal_length is None:
+ self.forward = self.efficient_forward
+
+ self.video_length = video_length
+ self.image_cross_attention = image_cross_attention
+ self.image_cross_attention_scale = image_cross_attention_scale
+ self.agent_state_cross_attention_scale = agent_state_cross_attention_scale
+ self.agent_action_cross_attention_scale = agent_action_cross_attention_scale
+ self.text_context_len = text_context_len
+ self.agent_state_context_len = agent_state_context_len
+ self.agent_action_context_len = agent_action_context_len
+ self.cross_attention_scale_learnable = cross_attention_scale_learnable
+ if self.image_cross_attention:
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_k_as = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v_as = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_k_aa = nn.Linear(context_dim, inner_dim, bias=False)
+ self.to_v_aa = nn.Linear(context_dim, inner_dim, bias=False)
+ if cross_attention_scale_learnable:
+ self.register_parameter('alpha_ctx',
+ nn.Parameter(torch.tensor(0.)))
+ self.register_parameter('alpha_cas',
+ nn.Parameter(torch.tensor(0.)))
+ self.register_parameter('alpha_caa',
+ nn.Parameter(torch.tensor(0.)))
+
+ def forward(self, x, context=None, mask=None):
+ spatial_self_attn = (context is None)
+ k_ip, v_ip, out_ip = None, None, None
+ k_as, v_as, out_as = None, None, None
+ k_aa, v_aa, out_aa = None, None, None
+
+ h = self.heads
+ q = self.to_q(x)
+ context = default(context, x)
+
+ if self.image_cross_attention and not spatial_self_attn:
+ assert 1 > 2, ">>> ERROR: should setup xformers and use efficient_forward ..."
+ context_agent_state = context[:, :self.agent_state_context_len, :]
+ context_agent_action = context[:,
+ self.agent_state_context_len:self.
+ agent_state_context_len +
+ self.agent_action_context_len, :]
+ context_ins = context[:, self.agent_state_context_len +
+ self.agent_action_context_len:self.
+ agent_state_context_len +
+ self.agent_action_context_len +
+ self.text_context_len, :]
+ context_image = context[:, self.agent_state_context_len +
+ self.agent_action_context_len +
+ self.text_context_len:, :]
+
+ k = self.to_k(context_ins)
+ v = self.to_v(context_ins)
+ k_ip = self.to_k_ip(context_image)
+ v_ip = self.to_v_ip(context_image)
+ k_as = self.to_k_as(context_agent_state)
+ v_as = self.to_v_as(context_agent_state)
+ k_aa = self.to_k_aa(context_agent_action)
+ v_aa = self.to_v_aa(context_agent_action)
+ else:
+ if not spatial_self_attn:
+ context = context[:, :self.text_context_len, :]
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
+ (q, k, v))
+
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
+ if self.relative_position:
+ len_q, len_k, len_v = q.shape[1], k.shape[1], v.shape[1]
+ k2 = self.relative_position_k(len_q, len_k)
+ sim2 = einsum('b t d, t s d -> b t s', q,
+ k2) * self.scale # TODO check
+ sim += sim2
+ del k
+
+ if exists(mask):
+ ## feasible for causal attention mask only
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b i j -> (b h) i j', h=h)
+ sim.masked_fill_(~(mask > 0.5), max_neg_value)
+
+ # attention, what we cannot get enough of
+ sim = sim.softmax(dim=-1)
+
+ out = torch.einsum('b i j, b j d -> b i d', sim, v)
+ if self.relative_position:
+ v2 = self.relative_position_v(len_q, len_v)
+ out2 = einsum('b t s, t s d -> b t d', sim, v2) # TODO check
+ out += out2
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+
+ if k_ip is not None and k_as is not None and k_aa is not None:
+ ## for image cross-attention
+ k_ip, v_ip = map(
+ lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
+ (k_ip, v_ip))
+ sim_ip = torch.einsum('b i d, b j d -> b i j', q,
+ k_ip) * self.scale
+ del k_ip
+ sim_ip = sim_ip.softmax(dim=-1)
+ out_ip = torch.einsum('b i j, b j d -> b i d', sim_ip, v_ip)
+ out_ip = rearrange(out_ip, '(b h) n d -> b n (h d)', h=h)
+
+ ## for agent state cross-attention
+ k_as, v_as = map(
+ lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
+ (k_as, v_as))
+ sim_as = torch.einsum('b i d, b j d -> b i j', q,
+ k_as) * self.scale
+ del k_as
+ sim_as = sim_as.softmax(dim=-1)
+ out_as = torch.einsum('b i j, b j d -> b i d', sim_as, v_as)
+ out_as = rearrange(out_as, '(b h) n d -> b n (h d)', h=h)
+
+ ## for agent action cross-attention
+ k_aa, v_aa = map(
+ lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
+ (k_aa, v_aa))
+ sim_aa = torch.einsum('b i d, b j d -> b i j', q,
+ k_aa) * self.scale
+ del k_aa
+ sim_aa = sim_aa.softmax(dim=-1)
+ out_aa = torch.einsum('b i j, b j d -> b i d', sim_aa, v_aa)
+ out_aa = rearrange(out_aa, '(b h) n d -> b n (h d)', h=h)
+
+ if out_ip is not None and out_as is not None and out_aa is not None:
+ if self.cross_attention_scale_learnable:
+ out = out + \
+ self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
+ self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
+ self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
+ else:
+ out = out + \
+ self.image_cross_attention_scale * out_ip + \
+ self.agent_state_cross_attention_scale * out_as + \
+ self.agent_action_cross_attention_scale * out_aa
+
+ return self.to_out(out)
+
+ def efficient_forward(self, x, context=None, mask=None):
+ spatial_self_attn = (context is None)
+ k, v, out = None, None, None
+ k_ip, v_ip, out_ip = None, None, None
+ k_as, v_as, out_as = None, None, None
+ k_aa, v_aa, out_aa = None, None, None
+
+ q = self.to_q(x)
+ context = default(context, x)
+
+ if self.image_cross_attention and not spatial_self_attn:
+ if context.shape[1] == self.text_context_len + self.video_length:
+ context_ins, context_image = context[:, :self.text_context_len, :], context[:,self.text_context_len:, :]
+ k = self.to_k(context)
+ v = self.to_v(context)
+ k_ip = self.to_k_ip(context_image)
+ v_ip = self.to_v_ip(context_image)
+ elif context.shape[1] == self.agent_state_context_len + self.text_context_len + self.video_length:
+ context_agent_state = context[:, :self.agent_state_context_len, :]
+ context_ins = context[:, self.agent_state_context_len:self.agent_state_context_len+self.text_context_len, :]
+ context_image = context[:, self.agent_state_context_len+self.text_context_len:, :]
+ k = self.to_k(context_ins)
+ v = self.to_v(context_ins)
+ k_ip = self.to_k_ip(context_image)
+ v_ip = self.to_v_ip(context_image)
+ k_as = self.to_k_as(context_agent_state)
+ v_as = self.to_v_as(context_agent_state)
+ else:
+ context_agent_state = context[:, :self.agent_state_context_len, :]
+ context_agent_action = context[:, self.agent_state_context_len:self.agent_state_context_len+self.agent_action_context_len, :]
+ context_ins = context[:, self.agent_state_context_len+self.agent_action_context_len:self.agent_state_context_len+self.agent_action_context_len+self.text_context_len, :]
+ context_image = context[:, self.agent_state_context_len+self.agent_action_context_len+self.text_context_len:, :]
+
+ k = self.to_k(context_ins)
+ v = self.to_v(context_ins)
+ k_ip = self.to_k_ip(context_image)
+ v_ip = self.to_v_ip(context_image)
+ k_as = self.to_k_as(context_agent_state)
+ v_as = self.to_v_as(context_agent_state)
+ k_aa = self.to_k_aa(context_agent_action)
+ v_aa = self.to_v_aa(context_agent_action)
+
+ attn_mask_aa = self._get_attn_mask_aa(x.shape[0],
+ q.shape[1],
+ k_aa.shape[1],
+ block_size=16).to(k_aa.device)
+ else:
+ if not spatial_self_attn:
+ assert 1 > 2, ">>> ERROR: you should never go into here ..."
+ context = context[:, :self.text_context_len, :]
+ k = self.to_k(context)
+ v = self.to_v(context)
+
+ b, _, _ = q.shape
+ q = q.unsqueeze(3).reshape(b, q.shape[1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(b * self.heads, q.shape[1], self.dim_head).contiguous()
+ if k is not None:
+ k, v = map(
+ lambda t: t.unsqueeze(3).reshape(b, t.shape[
+ 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
+ b * self.heads, t.shape[1], self.dim_head).contiguous(),
+ (k, v),
+ )
+ out = xformers.ops.memory_efficient_attention(q,
+ k,
+ v,
+ attn_bias=None,
+ op=None)
+ out = (out.unsqueeze(0).reshape(
+ b, self.heads, out.shape[1],
+ self.dim_head).permute(0, 2, 1,
+ 3).reshape(b, out.shape[1],
+ self.heads * self.dim_head))
+
+ if k_ip is not None:
+ # For image cross-attention
+ k_ip, v_ip = map(
+ lambda t: t.unsqueeze(3).reshape(b, t.shape[
+ 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
+ b * self.heads, t.shape[1], self.dim_head).contiguous(
+ ),
+ (k_ip, v_ip),
+ )
+ out_ip = xformers.ops.memory_efficient_attention(q,
+ k_ip,
+ v_ip,
+ attn_bias=None,
+ op=None)
+ out_ip = (out_ip.unsqueeze(0).reshape(
+ b, self.heads, out_ip.shape[1],
+ self.dim_head).permute(0, 2, 1,
+ 3).reshape(b, out_ip.shape[1],
+ self.heads * self.dim_head))
+
+ if k_as is not None:
+ # For agent state cross-attention
+ k_as, v_as = map(
+ lambda t: t.unsqueeze(3).reshape(b, t.shape[
+ 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
+ b * self.heads, t.shape[1], self.dim_head).contiguous(
+ ),
+ (k_as, v_as),
+ )
+ out_as = xformers.ops.memory_efficient_attention(q,
+ k_as,
+ v_as,
+ attn_bias=None,
+ op=None)
+ out_as = (out_as.unsqueeze(0).reshape(
+ b, self.heads, out_as.shape[1],
+ self.dim_head).permute(0, 2, 1,
+ 3).reshape(b, out_as.shape[1],
+ self.heads * self.dim_head))
+ if k_aa is not None:
+ # For agent action cross-attention
+ k_aa, v_aa = map(
+ lambda t: t.unsqueeze(3).reshape(b, t.shape[
+ 1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
+ b * self.heads, t.shape[1], self.dim_head).contiguous(
+ ),
+ (k_aa, v_aa),
+ )
+
+ attn_mask_aa = attn_mask_aa.unsqueeze(1).repeat(1,self.heads,1,1).reshape(
+ b * self.heads, attn_mask_aa.shape[1], attn_mask_aa.shape[2])
+ attn_mask_aa = attn_mask_aa.to(q.dtype)
+
+ out_aa = xformers.ops.memory_efficient_attention(
+ q, k_aa, v_aa, attn_bias=attn_mask_aa, op=None)
+
+ out_aa = (out_aa.unsqueeze(0).reshape(
+ b, self.heads, out_aa.shape[1],
+ self.dim_head).permute(0, 2, 1,
+ 3).reshape(b, out_aa.shape[1],
+ self.heads * self.dim_head))
+ if exists(mask):
+ raise NotImplementedError
+
+ out = 0.0 if out is None else out
+ out_ip = 0.0 if out_ip is None else out_ip
+ out_as = 0.0 if out_as is None else out_as
+ out_aa = 0.0 if out_aa is None else out_aa
+
+ if self.cross_attention_scale_learnable:
+ out = out + \
+ self.image_cross_attention_scale * out_ip * (torch.tanh(self.alpha_ctx) + 1) + \
+ self.agent_state_cross_attention_scale * out_as * (torch.tanh(self.alpha_cas) + 1) + \
+ self.agent_action_cross_attention_scale * out_aa * (torch.tanh(self.alpha_caa) + 1)
+
+ else:
+ out = out + \
+ self.image_cross_attention_scale * out_ip + \
+ self.agent_state_cross_attention_scale * out_as + \
+ self.agent_action_cross_attention_scale * out_aa
+
+ return self.to_out(out)
+
+ def _get_attn_mask_aa(self, b, l1, l2, block_size=16):
+ num_token = l2 // block_size
+ start_positions = ((torch.arange(b) % block_size) + 1) * num_token
+ col_indices = torch.arange(l2)
+ mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1)
+ mask = mask_2d.unsqueeze(1).expand(b, l1, l2)
+ attn_mask = torch.zeros_like(mask, dtype=torch.float)
+ attn_mask[mask] = float('-inf')
+ return attn_mask
+
+
+class BasicTransformerBlock(nn.Module):
+
+ def __init__(self,
+ dim,
+ n_heads,
+ d_head,
+ dropout=0.,
+ context_dim=None,
+ gated_ff=True,
+ checkpoint=True,
+ disable_self_attn=False,
+ attention_cls=None,
+ video_length=None,
+ agent_state_context_len=2,
+ agent_action_context_len=16,
+ image_cross_attention=False,
+ image_cross_attention_scale=1.0,
+ cross_attention_scale_learnable=False,
+ text_context_len=77):
+ super().__init__()
+ attn_cls = CrossAttention if attention_cls is None else attention_cls
+ self.disable_self_attn = disable_self_attn
+ self.attn1 = attn_cls(
+ query_dim=dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ context_dim=context_dim if self.disable_self_attn else None)
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
+ self.attn2 = attn_cls(
+ query_dim=dim,
+ context_dim=context_dim,
+ heads=n_heads,
+ dim_head=d_head,
+ dropout=dropout,
+ video_length=video_length,
+ agent_state_context_len=agent_state_context_len,
+ agent_action_context_len=agent_action_context_len,
+ image_cross_attention=image_cross_attention,
+ image_cross_attention_scale=image_cross_attention_scale,
+ cross_attention_scale_learnable=cross_attention_scale_learnable,
+ text_context_len=text_context_len)
+ self.image_cross_attention = image_cross_attention
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+ self.norm3 = nn.LayerNorm(dim)
+ self.checkpoint = checkpoint
+
+ def forward(self, x, context=None, mask=None, **kwargs):
+ # implementation tricks: because checkpointing doesn't support non-tensor (e.g. None or scalar) arguments
+ input_tuple = (
+ x,
+ ) # should not be (x), otherwise *input_tuple will decouple x into multiple arguments
+ if context is not None:
+ input_tuple = (x, context)
+ if mask is not None:
+ forward_mask = partial(self._forward, mask=mask)
+ return checkpoint(forward_mask, (x, ), self.parameters(),
+ self.checkpoint)
+ return checkpoint(self._forward, input_tuple, self.parameters(),
+ self.checkpoint)
+
+ def _forward(self, x, context=None, mask=None):
+ x = self.attn1(self.norm1(x),
+ context=context if self.disable_self_attn else None,
+ mask=mask) + x
+ x = self.attn2(self.norm2(x), context=context, mask=mask) + x
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data in spatial axis.
+ First, project the input (aka embedding)
+ and reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ NEW: use_linear for more efficiency instead of the 1x1 convs
+ """
+
+ def __init__(self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.,
+ context_dim=None,
+ use_checkpoint=True,
+ disable_self_attn=False,
+ use_linear=False,
+ video_length=None,
+ agent_state_context_len=2,
+ agent_action_context_len=16,
+ image_cross_attention=False,
+ cross_attention_scale_learnable=False):
+ super().__init__()
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = torch.nn.GroupNorm(num_groups=32,
+ num_channels=in_channels,
+ eps=1e-6,
+ affine=True)
+ if not use_linear:
+ self.proj_in = nn.Conv2d(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ attention_cls = None
+ self.transformer_blocks = nn.ModuleList([
+ BasicTransformerBlock(
+ inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ context_dim=context_dim,
+ disable_self_attn=disable_self_attn,
+ checkpoint=use_checkpoint,
+ attention_cls=attention_cls,
+ video_length=video_length,
+ agent_state_context_len=agent_state_context_len,
+ agent_action_context_len=agent_action_context_len,
+ image_cross_attention=image_cross_attention,
+ cross_attention_scale_learnable=cross_attention_scale_learnable,
+ ) for d in range(depth)
+ ])
+ if not use_linear:
+ self.proj_out = zero_module(
+ nn.Conv2d(inner_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+ else:
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None, **kwargs):
+ b, c, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, context=context, **kwargs)
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = self.proj_out(x)
+ return x + x_in
+
+
+class TemporalTransformer(nn.Module):
+ """
+ Transformer block for image-like data in temporal axis.
+ First, reshape to b, t, d.
+ Then apply standard transformer action.
+ Finally, reshape to image
+ """
+
+ def __init__(self,
+ in_channels,
+ n_heads,
+ d_head,
+ depth=1,
+ dropout=0.,
+ context_dim=None,
+ use_checkpoint=True,
+ use_linear=False,
+ only_self_att=True,
+ causal_attention=False,
+ causal_block_size=1,
+ relative_position=False,
+ temporal_length=None):
+ super().__init__()
+ self.only_self_att = only_self_att
+ self.relative_position = relative_position
+ self.causal_attention = causal_attention
+ self.causal_block_size = causal_block_size
+
+ self.in_channels = in_channels
+ inner_dim = n_heads * d_head
+ self.norm = torch.nn.GroupNorm(num_groups=32,
+ num_channels=in_channels,
+ eps=1e-6,
+ affine=True)
+ self.proj_in = nn.Conv1d(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ if not use_linear:
+ self.proj_in = nn.Conv1d(in_channels,
+ inner_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ else:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ if relative_position:
+ assert (temporal_length is not None)
+ attention_cls = partial(CrossAttention,
+ relative_position=True,
+ temporal_length=temporal_length)
+ else:
+ attention_cls = partial(CrossAttention,
+ temporal_length=temporal_length)
+ if self.causal_attention:
+ assert (temporal_length is not None)
+ self.mask = torch.tril(
+ torch.ones([1, temporal_length, temporal_length]))
+
+ if self.only_self_att:
+ context_dim = None
+ self.transformer_blocks = nn.ModuleList([
+ BasicTransformerBlock(inner_dim,
+ n_heads,
+ d_head,
+ dropout=dropout,
+ context_dim=context_dim,
+ attention_cls=attention_cls,
+ checkpoint=use_checkpoint)
+ for d in range(depth)
+ ])
+ if not use_linear:
+ self.proj_out = zero_module(
+ nn.Conv1d(inner_dim,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0))
+ else:
+ self.proj_out = zero_module(nn.Linear(inner_dim, in_channels))
+ self.use_linear = use_linear
+
+ def forward(self, x, context=None):
+ b, c, t, h, w = x.shape
+ x_in = x
+ x = self.norm(x)
+ x = rearrange(x, 'b c t h w -> (b h w) c t').contiguous()
+ if not self.use_linear:
+ x = self.proj_in(x)
+ x = rearrange(x, 'bhw c t -> bhw t c').contiguous()
+ if self.use_linear:
+ x = self.proj_in(x)
+
+ temp_mask = None
+ if self.causal_attention:
+ # Slice the from mask map
+ temp_mask = self.mask[:, :t, :t].to(x.device)
+
+ if temp_mask is not None:
+ mask = temp_mask.to(x.device)
+ mask = repeat(mask, 'l i j -> (l bhw) i j', bhw=b * h * w)
+ else:
+ mask = None
+
+ if self.only_self_att:
+ # NOTE: if no context is given, cross-attention defaults to self-attention
+ for i, block in enumerate(self.transformer_blocks):
+ x = block(x, mask=mask)
+ x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
+ else:
+ x = rearrange(x, '(b hw) t c -> b hw t c', b=b).contiguous()
+ context = rearrange(context, '(b t) l con -> b t l con',
+ t=t).contiguous()
+ for i, block in enumerate(self.transformer_blocks):
+ # Calculate each batch one by one (since number in shape could not greater then 65,535 for some package)
+ for j in range(b):
+ context_j = repeat(context[j],
+ 't l con -> (t r) l con',
+ r=(h * w) // t,
+ t=t).contiguous()
+ # Note: causal mask will not applied in cross-attention case
+ x[j] = block(x[j], context=context_j)
+
+ if self.use_linear:
+ x = self.proj_out(x)
+ x = rearrange(x, 'b (h w) t c -> b c t h w', h=h, w=w).contiguous()
+ if not self.use_linear:
+ x = rearrange(x, 'b hw t c -> (b hw) c t').contiguous()
+ x = self.proj_out(x)
+ x = rearrange(x, '(b h w) c t -> b c t h w', b=b, h=h,
+ w=w).contiguous()
+
+ return x + x_in
+
+
+class GEGLU(nn.Module):
+
+ def __init__(self, dim_in, dim_out):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def forward(self, x):
+ x, gate = self.proj(x).chunk(2, dim=-1)
+ return x * F.gelu(gate)
+
+
+class FeedForward(nn.Module):
+
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = default(dim_out, dim)
+ project_in = nn.Sequential(nn.Linear(
+ dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
+
+ self.net = nn.Sequential(project_in, nn.Dropout(dropout),
+ nn.Linear(inner_dim, dim_out))
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class LinearAttention(nn.Module):
+
+ def __init__(self, dim, heads=4, dim_head=32):
+ super().__init__()
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ b, c, h, w = x.shape
+ qkv = self.to_qkv(x)
+ q, k, v = rearrange(qkv,
+ 'b (qkv heads c) h w -> qkv b heads c (h w)',
+ heads=self.heads,
+ qkv=3)
+ k = k.softmax(dim=-1)
+ context = torch.einsum('bhdn,bhen->bhde', k, v)
+ out = torch.einsum('bhde,bhdn->bhen', context, q)
+ out = rearrange(out,
+ 'b heads c (h w) -> b (heads c) h w',
+ heads=self.heads,
+ h=h,
+ w=w)
+ return self.to_out(out)
+
+
+class SpatialSelfAttention(nn.Module):
+
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=32,
+ num_channels=in_channels,
+ eps=1e-6,
+ affine=True)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # Compute attention
+ b, c, h, w = q.shape
+ q = rearrange(q, 'b c h w -> b (h w) c')
+ k = rearrange(k, 'b c h w -> b c (h w)')
+ w_ = torch.einsum('bij,bjk->bik', q, k)
+
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # Attend to values
+ v = rearrange(v, 'b c h w -> b c (h w)')
+ w_ = rearrange(w_, 'b i j -> b j i')
+ h_ = torch.einsum('bij,bjk->bik', v, w_)
+ h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
+ h_ = self.proj_out(h_)
+
+ return x + h_
diff --git a/src/unifolm_wma/modules/encoders/condition.py b/src/unifolm_wma/modules/encoders/condition.py
new file mode 100644
index 0000000..44f0fdc
--- /dev/null
+++ b/src/unifolm_wma/modules/encoders/condition.py
@@ -0,0 +1,630 @@
+import torch
+import torch.nn as nn
+import kornia
+import open_clip
+import math
+
+from torch.utils.checkpoint import checkpoint
+from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
+
+from unifolm_wma.utils.common import autocast
+from unifolm_wma.utils.utils import count_params
+from unifolm_wma.modules.encoders.resampler import reshape_tensor
+
+
+class AbstractEncoder(nn.Module):
+
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class IdentityEncoder(AbstractEncoder):
+
+ def encode(self, x):
+ return x
+
+
+class ClassEmbedder(nn.Module):
+
+ def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
+ super().__init__()
+ self.key = key
+ self.embedding = nn.Embedding(n_classes, embed_dim)
+ self.n_classes = n_classes
+ self.ucg_rate = ucg_rate
+
+ def forward(self, batch, key=None, disable_dropout=False):
+ if key is None:
+ key = self.key
+ # this is for use in crossattn
+ c = batch[key][:, None]
+ if self.ucg_rate > 0. and not disable_dropout:
+ mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
+ c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes -
+ 1)
+ c = c.long()
+ c = self.embedding(c)
+ return c
+
+ def get_unconditional_conditioning(self, bs, device="cuda"):
+ uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
+ uc = torch.ones((bs, ), device=device) * uc_class
+ uc = {self.key: uc}
+ return uc
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+class FrozenT5Embedder(AbstractEncoder):
+ """Uses the T5 transformer encoder for text"""
+
+ def __init__(self,
+ version="google/t5-v1_1-xxl",
+ device="cuda",
+ max_length=77,
+ freeze=True
+ ): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
+ super().__init__()
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
+ self.transformer = T5EncoderModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length # TODO: typical value?
+ if freeze:
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ # self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
+ LAYERS = ["last", "pooled", "hidden"]
+
+ def __init__(self,
+ version="openai/clip-vit-large-patch14",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ layer="last",
+ layer_idx=None): # clip-vit-base-patch32
+ super().__init__()
+ assert layer in self.LAYERS
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
+ self.transformer = CLIPTextModel.from_pretrained(version)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ self.layer_idx = layer_idx
+ if layer == "hidden":
+ assert layer_idx is not None
+ assert 0 <= abs(layer_idx) <= 12
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ # self.train = disabled_train
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens,
+ output_hidden_states=self.layer == "hidden")
+ if self.layer == "last":
+ z = outputs.last_hidden_state
+ elif self.layer == "pooled":
+ z = outputs.pooler_output[:, None, :]
+ else:
+ z = outputs.hidden_states[self.layer_idx]
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class ClipImageEmbedder(nn.Module):
+
+ def __init__(self,
+ model,
+ jit=False,
+ device='cuda' if torch.cuda.is_available() else 'cpu',
+ antialias=True,
+ ucg_rate=0.):
+ super().__init__()
+ from clip import load as load_clip
+ self.model, _ = load_clip(name=model, device=device, jit=jit)
+
+ self.antialias = antialias
+
+ self.register_buffer('mean',
+ torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
+ persistent=False)
+ self.register_buffer('std',
+ torch.Tensor([0.26862954, 0.26130258,
+ 0.27577711]),
+ persistent=False)
+ self.ucg_rate = ucg_rate
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = kornia.geometry.resize(x, (224, 224),
+ interpolation='bicubic',
+ align_corners=True,
+ antialias=self.antialias)
+ x = (x + 1.) / 2.
+ # re-normalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def forward(self, x, no_dropout=False):
+ # x is assumed to be in range [-1,1]
+ out = self.model.encode_image(self.preprocess(x))
+ out = out.to(x.dtype)
+ if self.ucg_rate > 0. and not no_dropout:
+ out = torch.bernoulli(
+ (1. - self.ucg_rate) *
+ torch.ones(out.shape[0], device=out.device))[:, None] * out
+ return out
+
+
+class FrozenOpenCLIPEmbedder(AbstractEncoder):
+ """
+ Uses the OpenCLIP transformer encoder for text
+ """
+ LAYERS = [
+ # "pooled",
+ "last",
+ "penultimate"
+ ]
+
+ def __init__(self,
+ arch="ViT-H-14",
+ version="laion2b_s32b_b79k",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ layer="last"):
+ super().__init__()
+ assert layer in self.LAYERS
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch, device=torch.device('cpu'), pretrained=version)
+ del model.visual
+ self.model = model
+
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "last":
+ self.layer_idx = 0
+ elif self.layer == "penultimate":
+ self.layer_idx = 1
+ else:
+ raise NotImplementedError()
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ tokens = open_clip.tokenize(
+ text) ## all clip models use 77 as context length
+ z = self.encode_with_transformer(tokens.to(self.device))
+ return z
+
+ def encode_with_transformer(self, text):
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
+ x = x + self.model.positional_embedding
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.model.ln_final(x)
+ return x
+
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
+ for i, r in enumerate(self.model.transformer.resblocks):
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
+ break
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
+ ):
+ x = checkpoint(r, x, attn_mask)
+ else:
+ x = r(x, attn_mask=attn_mask)
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
+ """
+ Uses the OpenCLIP vision transformer encoder for images
+ """
+
+ def __init__(self,
+ arch="ViT-H-14",
+ version="laion2b_s32b_b79k",
+ device="cuda",
+ max_length=77,
+ freeze=True,
+ layer="pooled",
+ antialias=True,
+ ucg_rate=0.):
+ super().__init__()
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch,
+ device=torch.device('cpu'),
+ pretrained=version,
+ )
+ del model.transformer
+ self.model = model
+ # self.mapper = torch.nn.Linear(1280, 1024)
+ self.device = device
+ self.max_length = max_length
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "penultimate":
+ raise NotImplementedError()
+ self.layer_idx = 1
+
+ self.antialias = antialias
+
+ self.register_buffer('mean',
+ torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
+ persistent=False)
+ self.register_buffer('std',
+ torch.Tensor([0.26862954, 0.26130258,
+ 0.27577711]),
+ persistent=False)
+ self.ucg_rate = ucg_rate
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = kornia.geometry.resize(x, (224, 224),
+ interpolation='bicubic',
+ align_corners=True,
+ antialias=self.antialias)
+ x = (x + 1.) / 2.
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ @autocast
+ def forward(self, image, no_dropout=False):
+ z = self.encode_with_vision_transformer(image)
+ if self.ucg_rate > 0. and not no_dropout:
+ z = torch.bernoulli(
+ (1. - self.ucg_rate) *
+ torch.ones(z.shape[0], device=z.device))[:, None] * z
+ return z
+
+ def encode_with_vision_transformer(self, img):
+ img = self.preprocess(img)
+ x = self.model.visual(img)
+ return x
+
+ def encode(self, text):
+ return self(text)
+
+
+class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
+ """
+ Uses the OpenCLIP vision transformer encoder for images
+ """
+
+ def __init__(self,
+ arch="ViT-H-14",
+ version="laion2b_s32b_b79k",
+ device="cuda",
+ freeze=True,
+ layer="pooled",
+ antialias=True):
+ super().__init__()
+ model, _, _ = open_clip.create_model_and_transforms(
+ arch,
+ device=torch.device('cpu'),
+ pretrained=version,
+ )
+ del model.transformer
+ self.model = model
+ self.device = device
+
+ if freeze:
+ self.freeze()
+ self.layer = layer
+ if self.layer == "penultimate":
+ raise NotImplementedError()
+ self.layer_idx = 1
+
+ self.antialias = antialias
+
+ self.register_buffer('mean',
+ torch.Tensor([0.48145466, 0.4578275, 0.40821073]),
+ persistent=False)
+ self.register_buffer('std',
+ torch.Tensor([0.26862954, 0.26130258,
+ 0.27577711]),
+ persistent=False)
+
+ def preprocess(self, x):
+ # normalize to [0,1]
+ x = kornia.geometry.resize(x, (224, 224),
+ interpolation='bicubic',
+ align_corners=True,
+ antialias=self.antialias)
+ x = (x + 1.) / 2.
+ # renormalize according to clip
+ x = kornia.enhance.normalize(x, self.mean, self.std)
+ return x
+
+ def freeze(self):
+ self.model = self.model.eval()
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ def forward(self, image, no_dropout=False):
+ ## image: b c h w
+ z = self.encode_with_vision_transformer(image)
+ return z
+
+ def encode_with_vision_transformer(self, x):
+ x = self.preprocess(x)
+
+ # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
+ if self.model.visual.input_patchnorm:
+ # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
+ x = x.reshape(x.shape[0], x.shape[1],
+ self.model.visual.grid_size[0],
+ self.model.visual.patch_size[0],
+ self.model.visual.grid_size[1],
+ self.model.visual.patch_size[1])
+ x = x.permute(0, 2, 4, 1, 3, 5)
+ x = x.reshape(
+ x.shape[0], self.model.visual.grid_size[0] *
+ self.model.visual.grid_size[1], -1)
+ x = self.model.visual.patchnorm_pre_ln(x)
+ x = self.model.visual.conv1(x)
+ else:
+ x = self.model.visual.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1],
+ -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+
+ # class embeddings and positional embeddings
+ x = torch.cat([
+ self.model.visual.class_embedding.to(x.dtype) + torch.zeros(
+ x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
+ ],
+ dim=1) # shape = [*, grid ** 2 + 1, width]
+ x = x + self.model.visual.positional_embedding.to(x.dtype)
+
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
+ x = self.model.visual.patch_dropout(x)
+ x = self.model.visual.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.model.visual.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+
+ return x
+
+
+class FrozenCLIPT5Encoder(AbstractEncoder):
+
+ def __init__(self,
+ clip_version="openai/clip-vit-large-patch14",
+ t5_version="google/t5-v1_1-xl",
+ device="cuda",
+ clip_max_length=77,
+ t5_max_length=77):
+ super().__init__()
+ self.clip_encoder = FrozenCLIPEmbedder(clip_version,
+ device,
+ max_length=clip_max_length)
+ self.t5_encoder = FrozenT5Embedder(t5_version,
+ device,
+ max_length=t5_max_length)
+ print(
+ f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
+ )
+
+ def encode(self, text):
+ return self(text)
+
+ def forward(self, text):
+ clip_z = self.clip_encoder.encode(text)
+ t5_z = self.t5_encoder.encode(text)
+ return [clip_z, t5_z]
+
+
+class LinearProjector(nn.Module):
+
+ def __init__(self, input_dim: int, output_dim: int) -> None:
+ super().__init__()
+ self.projector = nn.Linear(input_dim, output_dim, bias=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.projector(x)
+
+
+class MLPProjector(nn.Module):
+
+ def __init__(self,
+ input_dim: int,
+ output_dim: int,
+ mlp_type: str = "gelu-mlp") -> None:
+ super().__init__()
+ if mlp_type == "gelu-mlp":
+ self.projector = nn.Sequential(
+ nn.Linear(input_dim, output_dim, bias=True),
+ nn.GELU(approximate='tanh'),
+ nn.Linear(output_dim, output_dim, bias=True),
+ )
+ elif mlp_type == "silu-mlp":
+ self.projector = nn.Sequential(
+ nn.Linear(input_dim, output_dim, bias=True),
+ nn.SiLU(),
+ nn.Linear(output_dim, output_dim, bias=True),
+ )
+ else:
+ raise ValueError(
+ f"Projector with `{mlp_type = }` is not supported!")
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.projector(x)
+
+
+class PerceiverAttention(nn.Module):
+
+ def __init__(self, *, dim, dim_head=64, heads=8):
+ super().__init__()
+ self.scale = dim_head**-0.5
+ self.dim_head = dim_head
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+ def forward(self, x, latents):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, n1, D)
+ latent (torch.Tensor): latent features
+ shape (b, n2, D)
+ """
+
+ x = self.norm1(x)
+ latents = self.norm2(latents)
+
+ b, l, _ = latents.shape
+
+ q = self.to_q(latents)
+ kv_input = torch.cat((x, latents), dim=-2)
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+
+ q = reshape_tensor(q, self.heads)
+ k = reshape_tensor(k, self.heads)
+ v = reshape_tensor(v, self.heads)
+
+ # attention
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+ weight = (q * scale) @ (k * scale).transpose(
+ -2, -1) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ out = weight @ v
+
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
+
+ return self.to_out(out)
+
+
+def FeedForward(dim, mult=4, ffd_type="gelu-ffd"):
+ inner_dim = int(dim * mult)
+ if ffd_type == "gelu-ffd":
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, inner_dim, bias=False),
+ nn.GELU(approximate='tanh'),
+ nn.Linear(inner_dim, dim, bias=False),
+ )
+ elif ffd_type == "silu-ffd":
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, inner_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(inner_dim, dim, bias=False),
+ )
+ else:
+ raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
+
+
+class SATokenProjector(nn.Module):
+
+ def __init__(self,
+ dim=1024,
+ depth=1,
+ dim_head=64,
+ heads=16,
+ num_queries=16,
+ output_dim=1024,
+ ff_mult=4,
+ chunk_size=None):
+ super().__init__()
+ self.num_queries = num_queries
+ self.chunk_size = chunk_size
+
+ if chunk_size is not None:
+ num_queries = num_queries * chunk_size
+
+ self.latents = nn.Parameter(
+ torch.randn(1, num_queries, dim) / dim**0.5)
+ self.proj_out = nn.Linear(dim, output_dim)
+ self.norm_out = nn.LayerNorm(dim)
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList([
+ PerceiverAttention(dim=dim, dim_head=dim_head,
+ heads=heads),
+ FeedForward(dim=dim, mult=ff_mult),
+ ]))
+
+ def forward(self, x):
+ latents = self.latents.repeat(x.size(0), 1, 1)
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+ latents = self.proj_out(latents)
+ latents = self.norm_out(latents)
+ return latents
diff --git a/src/unifolm_wma/modules/encoders/resampler.py b/src/unifolm_wma/modules/encoders/resampler.py
new file mode 100644
index 0000000..47e7364
--- /dev/null
+++ b/src/unifolm_wma/modules/encoders/resampler.py
@@ -0,0 +1,153 @@
+# modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
+# and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
+# and https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/resampler.py
+import math
+import torch
+import torch.nn as nn
+
+
+class ImageProjModel(nn.Module):
+ """Projection Model"""
+
+ def __init__(self,
+ cross_attention_dim=1024,
+ clip_embeddings_dim=1024,
+ clip_extra_context_tokens=4):
+ super().__init__()
+ self.cross_attention_dim = cross_attention_dim
+ self.clip_extra_context_tokens = clip_extra_context_tokens
+ self.proj = nn.Linear(
+ clip_embeddings_dim,
+ self.clip_extra_context_tokens * cross_attention_dim)
+ self.norm = nn.LayerNorm(cross_attention_dim)
+
+ def forward(self, image_embeds):
+ #embeds = image_embeds
+ embeds = image_embeds.type(list(self.proj.parameters())[0].dtype)
+ clip_extra_context_tokens = self.proj(embeds).reshape(
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim)
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
+ return clip_extra_context_tokens
+
+
+# FFN
+def FeedForward(dim, mult=4):
+ inner_dim = int(dim * mult)
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, inner_dim, bias=False),
+ nn.GELU(),
+ nn.Linear(inner_dim, dim, bias=False),
+ )
+
+
+def reshape_tensor(x, heads):
+ bs, length, width = x.shape
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
+ x = x.view(bs, length, heads, -1)
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
+ x = x.transpose(1, 2)
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
+ x = x.reshape(bs, heads, length, -1)
+ return x
+
+
+class PerceiverAttention(nn.Module):
+
+ def __init__(self, *, dim, dim_head=64, heads=8):
+ super().__init__()
+ self.scale = dim_head**-0.5
+ self.dim_head = dim_head
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+ def forward(self, x, latents):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, n1, D)
+ latent (torch.Tensor): latent features
+ shape (b, n2, D)
+ """
+ x = self.norm1(x)
+ latents = self.norm2(latents)
+
+ b, l, _ = latents.shape
+
+ q = self.to_q(latents)
+ kv_input = torch.cat((x, latents), dim=-2)
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+
+ q = reshape_tensor(q, self.heads)
+ k = reshape_tensor(k, self.heads)
+ v = reshape_tensor(v, self.heads)
+
+ # attention
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+ weight = (q * scale) @ (k * scale).transpose(
+ -2, -1) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ out = weight @ v
+
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
+
+ return self.to_out(out)
+
+
+class Resampler(nn.Module):
+
+ def __init__(
+ self,
+ dim=1024,
+ depth=8,
+ dim_head=64,
+ heads=16,
+ num_queries=8,
+ embedding_dim=768,
+ output_dim=1024,
+ ff_mult=4,
+ video_length=None, # using frame-wise version or not
+ ):
+ super().__init__()
+ ## queries for a single frame / image
+ self.num_queries = num_queries
+ self.video_length = video_length
+
+ ## queries for each frame
+ if video_length is not None:
+ num_queries = num_queries * video_length
+
+ self.latents = nn.Parameter(
+ torch.randn(1, num_queries, dim) / dim**0.5)
+ self.proj_in = nn.Linear(embedding_dim, dim)
+ self.proj_out = nn.Linear(dim, output_dim)
+ self.norm_out = nn.LayerNorm(output_dim)
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList([
+ PerceiverAttention(dim=dim, dim_head=dim_head,
+ heads=heads),
+ FeedForward(dim=dim, mult=ff_mult),
+ ]))
+
+ def forward(self, x):
+ latents = self.latents.repeat(x.size(0), 1, 1)
+ x = self.proj_in(x)
+
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+
+ latents = self.proj_out(latents)
+ latents = self.norm_out(latents)
+
+ return latents
diff --git a/src/unifolm_wma/modules/networks/ae_modules.py b/src/unifolm_wma/modules/networks/ae_modules.py
new file mode 100644
index 0000000..2ec124d
--- /dev/null
+++ b/src/unifolm_wma/modules/networks/ae_modules.py
@@ -0,0 +1,1005 @@
+# pytorch_diffusion + derived encoder decoder
+import math
+import torch
+import numpy as np
+import torch.nn as nn
+
+from einops import rearrange
+from unifolm_wma.modules.attention import LinearAttention
+from unifolm_wma.utils.utils import instantiate_from_config
+
+
+def nonlinearity(x):
+ # swish
+ return x * torch.sigmoid(x)
+
+
+def Normalize(in_channels, num_groups=32):
+ return torch.nn.GroupNorm(num_groups=num_groups,
+ num_channels=in_channels,
+ eps=1e-6,
+ affine=True)
+
+
+class LinAttnBlock(LinearAttention):
+ """to match AttnBlock usage"""
+
+ def __init__(self, in_channels):
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
+
+
+class AttnBlock(nn.Module):
+
+ def __init__(self, in_channels):
+ super().__init__()
+ self.in_channels = in_channels
+
+ self.norm = Normalize(in_channels)
+ self.q = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.k = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.v = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+ self.proj_out = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x):
+ h_ = x
+ h_ = self.norm(h_)
+ q = self.q(h_)
+ k = self.k(h_)
+ v = self.v(h_)
+
+ # compute attention
+ b, c, h, w = q.shape
+ q = q.reshape(b, c, h * w) # bcl
+ q = q.permute(0, 2, 1) # bcl -> blc l=hw
+ k = k.reshape(b, c, h * w) # bcl
+
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
+ w_ = w_ * (int(c)**(-0.5))
+ w_ = torch.nn.functional.softmax(w_, dim=2)
+
+ # attend to values
+ v = v.reshape(b, c, h * w)
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
+ h_ = torch.bmm(
+ v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
+ h_ = h_.reshape(b, c, h, w)
+
+ h_ = self.proj_out(h_)
+
+ return x + h_
+
+
+def make_attn(in_channels, attn_type="vanilla"):
+ assert attn_type in ["vanilla", "linear",
+ "none"], f'attn_type {attn_type} unknown'
+ #print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
+ if attn_type == "vanilla":
+ return AttnBlock(in_channels)
+ elif attn_type == "none":
+ return nn.Identity(in_channels)
+ else:
+ return LinAttnBlock(in_channels)
+
+
+class Downsample(nn.Module):
+
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ self.in_channels = in_channels
+ if self.with_conv:
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=2,
+ padding=0)
+
+ def forward(self, x):
+ if self.with_conv:
+ pad = (0, 1, 0, 1)
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
+ x = self.conv(x)
+ else:
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
+ return x
+
+
+class Upsample(nn.Module):
+
+ def __init__(self, in_channels, with_conv):
+ super().__init__()
+ self.with_conv = with_conv
+ self.in_channels = in_channels
+ if self.with_conv:
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ x = torch.nn.functional.interpolate(x,
+ scale_factor=2.0,
+ mode="nearest")
+ if self.with_conv:
+ x = self.conv(x)
+ return x
+
+
+def get_timestep_embedding(timesteps, embedding_dim):
+ """
+ This matches the implementation in Denoising Diffusion Probabilistic Models:
+ From Fairseq.
+ Build sinusoidal embeddings.
+ This matches the implementation in tensor2tensor, but differs slightly
+ from the description in Section 3.5 of "Attention Is All You Need".
+ """
+ assert len(timesteps.shape) == 1
+
+ half_dim = embedding_dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
+ emb = emb.to(device=timesteps.device)
+ emb = timesteps.float()[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+ return emb
+
+
+class ResnetBlock(nn.Module):
+
+ def __init__(self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout,
+ temb_channels=512):
+ super().__init__()
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+
+ self.norm1 = Normalize(in_channels)
+ self.conv1 = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if temb_channels > 0:
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
+ self.norm2 = Normalize(out_channels)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = torch.nn.Conv2d(out_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ else:
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0)
+
+ def forward(self, x, temb):
+ h = x
+ h = self.norm1(h)
+ h = nonlinearity(h)
+ h = self.conv1(h)
+
+ if temb is not None:
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
+
+ h = self.norm2(h)
+ h = nonlinearity(h)
+ h = self.dropout(h)
+ h = self.conv2(h)
+
+ if self.in_channels != self.out_channels:
+ if self.use_conv_shortcut:
+ x = self.conv_shortcut(x)
+ else:
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class Model(nn.Module):
+
+ def __init__(self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ use_timestep=True,
+ use_linear_attn=False,
+ attn_type="vanilla"):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = self.ch * 4
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ self.use_timestep = use_timestep
+ if self.use_timestep:
+ # timestep embedding
+ self.temb = nn.Module()
+ self.temb.dense = nn.ModuleList([
+ torch.nn.Linear(self.ch, self.temb_ch),
+ torch.nn.Linear(self.temb_ch, self.temb_ch),
+ ])
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1, ) + tuple(ch_mult)
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ skip_in = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ if i_block == self.num_res_blocks:
+ skip_in = ch * in_ch_mult[i_level]
+ block.append(
+ ResnetBlock(in_channels=block_in + skip_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, t=None, context=None):
+ #assert x.shape[2] == x.shape[3] == self.resolution
+ if context is not None:
+ # assume aligned context, cat along channel axis
+ x = torch.cat((x, context), dim=1)
+ if self.use_timestep:
+ # timestep embedding
+ assert t is not None
+ temb = get_timestep_embedding(t, self.ch)
+ temb = self.temb.dense[0](temb)
+ temb = nonlinearity(temb)
+ temb = self.temb.dense[1](temb)
+ else:
+ temb = None
+
+ # downsampling
+ hs = [self.conv_in(x)]
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ hs.append(self.down[i_level].downsample(hs[-1]))
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()],
+ dim=1), temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+ def get_last_layer(self):
+ return self.conv_out.weight
+
+
+class Encoder(nn.Module):
+
+ def __init__(self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ double_z=True,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ **ignore_kwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+
+ # downsampling
+ self.conv_in = torch.nn.Conv2d(in_channels,
+ self.ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ curr_res = resolution
+ in_ch_mult = (1, ) + tuple(ch_mult)
+ self.in_ch_mult = in_ch_mult
+ self.down = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_in = ch * in_ch_mult[i_level]
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks):
+ block.append(
+ ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ down = nn.Module()
+ down.block = block
+ down.attn = attn
+ if i_level != self.num_resolutions - 1:
+ down.downsample = Downsample(block_in, resamp_with_conv)
+ curr_res = curr_res // 2
+ self.down.append(down)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ 2 *
+ z_channels if double_z else z_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # timestep embedding
+ temb = None
+
+ # print(f'encoder-input={x.shape}')
+ # downsampling
+ hs = [self.conv_in(x)]
+ # print(f'encoder-conv in feat={hs[0].shape}')
+ for i_level in range(self.num_resolutions):
+ for i_block in range(self.num_res_blocks):
+ h = self.down[i_level].block[i_block](hs[-1], temb)
+ # print(f'encoder-down feat={h.shape}')
+ if len(self.down[i_level].attn) > 0:
+ h = self.down[i_level].attn[i_block](h)
+ hs.append(h)
+ if i_level != self.num_resolutions - 1:
+ # print(f'encoder-downsample (input)={hs[-1].shape}')
+ hs.append(self.down[i_level].downsample(hs[-1]))
+ # print(f'encoder-downsample (output)={hs[-1].shape}')
+
+ # middle
+ h = hs[-1]
+ h = self.mid.block_1(h, temb)
+ # print(f'encoder-mid1 feat={h.shape}')
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+ # print(f'encoder-mid2 feat={h.shape}')
+
+ # end
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ # print(f'end feat={h.shape}')
+ return h
+
+
+class Decoder(nn.Module):
+
+ def __init__(self,
+ *,
+ ch,
+ out_ch,
+ ch_mult=(1, 2, 4, 8),
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ in_channels,
+ resolution,
+ z_channels,
+ give_pre_end=False,
+ tanh_out=False,
+ use_linear_attn=False,
+ attn_type="vanilla",
+ **ignorekwargs):
+ super().__init__()
+ if use_linear_attn: attn_type = "linear"
+ self.ch = ch
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ self.resolution = resolution
+ self.in_channels = in_channels
+ self.give_pre_end = give_pre_end
+ self.tanh_out = tanh_out
+
+ # compute in_ch_mult, block_in and curr_res at lowest res
+ in_ch_mult = (1, ) + tuple(ch_mult)
+ block_in = ch * ch_mult[self.num_resolutions - 1]
+ curr_res = resolution // 2**(self.num_resolutions - 1)
+ self.z_shape = (1, z_channels, curr_res, curr_res)
+ print("AE working on z of shape {} = {} dimensions.".format(
+ self.z_shape, np.prod(self.z_shape)))
+
+ # z to block_in
+ self.conv_in = torch.nn.Conv2d(z_channels,
+ block_in,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ # middle
+ self.mid = nn.Module()
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+ self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
+ out_channels=block_in,
+ temb_channels=self.temb_ch,
+ dropout=dropout)
+
+ # upsampling
+ self.up = nn.ModuleList()
+ for i_level in reversed(range(self.num_resolutions)):
+ block = nn.ModuleList()
+ attn = nn.ModuleList()
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ block.append(
+ ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ if curr_res in attn_resolutions:
+ attn.append(make_attn(block_in, attn_type=attn_type))
+ up = nn.Module()
+ up.block = block
+ up.attn = attn
+ if i_level != 0:
+ up.upsample = Upsample(block_in, resamp_with_conv)
+ curr_res = curr_res * 2
+ self.up.insert(0, up) # prepend to get consistent order
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_ch,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, z):
+ #assert z.shape[1:] == self.z_shape[1:]
+ self.last_z_shape = z.shape
+
+ # print(f'decoder-input={z.shape}')
+ # timestep embedding
+ temb = None
+
+ # z to block_in
+ h = self.conv_in(z)
+ # print(f'decoder-conv in feat={h.shape}')
+
+ # middle
+ h = self.mid.block_1(h, temb)
+ h = self.mid.attn_1(h)
+ h = self.mid.block_2(h, temb)
+ # print(f'decoder-mid feat={h.shape}')
+
+ # upsampling
+ for i_level in reversed(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.up[i_level].block[i_block](h, temb)
+ if len(self.up[i_level].attn) > 0:
+ h = self.up[i_level].attn[i_block](h)
+ # print(f'decoder-up feat={h.shape}')
+ if i_level != 0:
+ h = self.up[i_level].upsample(h)
+ # print(f'decoder-upsample feat={h.shape}')
+
+ # end
+ if self.give_pre_end:
+ return h
+
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ # print(f'decoder-conv_out feat={h.shape}')
+ if self.tanh_out:
+ h = torch.tanh(h)
+ return h
+
+
+class SimpleDecoder(nn.Module):
+
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__()
+ self.model = nn.ModuleList([
+ nn.Conv2d(in_channels, in_channels, 1),
+ ResnetBlock(in_channels=in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0,
+ dropout=0.0),
+ ResnetBlock(in_channels=2 * in_channels,
+ out_channels=4 * in_channels,
+ temb_channels=0,
+ dropout=0.0),
+ ResnetBlock(in_channels=4 * in_channels,
+ out_channels=2 * in_channels,
+ temb_channels=0,
+ dropout=0.0),
+ nn.Conv2d(2 * in_channels, in_channels, 1),
+ Upsample(in_channels, with_conv=True)
+ ])
+ # end
+ self.norm_out = Normalize(in_channels)
+ self.conv_out = torch.nn.Conv2d(in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ for i, layer in enumerate(self.model):
+ if i in [1, 2, 3]:
+ x = layer(x, None)
+ else:
+ x = layer(x)
+
+ h = self.norm_out(x)
+ h = nonlinearity(h)
+ x = self.conv_out(h)
+ return x
+
+
+class UpsampleDecoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ ch,
+ num_res_blocks,
+ resolution,
+ ch_mult=(2, 2),
+ dropout=0.0):
+ super().__init__()
+ # upsampling
+ self.temb_ch = 0
+ self.num_resolutions = len(ch_mult)
+ self.num_res_blocks = num_res_blocks
+ block_in = in_channels
+ curr_res = resolution // 2**(self.num_resolutions - 1)
+ self.res_blocks = nn.ModuleList()
+ self.upsample_blocks = nn.ModuleList()
+ for i_level in range(self.num_resolutions):
+ res_block = []
+ block_out = ch * ch_mult[i_level]
+ for i_block in range(self.num_res_blocks + 1):
+ res_block.append(
+ ResnetBlock(in_channels=block_in,
+ out_channels=block_out,
+ temb_channels=self.temb_ch,
+ dropout=dropout))
+ block_in = block_out
+ self.res_blocks.append(nn.ModuleList(res_block))
+ if i_level != self.num_resolutions - 1:
+ self.upsample_blocks.append(Upsample(block_in, True))
+ curr_res = curr_res * 2
+
+ # end
+ self.norm_out = Normalize(block_in)
+ self.conv_out = torch.nn.Conv2d(block_in,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x):
+ # upsampling
+ h = x
+ for k, i_level in enumerate(range(self.num_resolutions)):
+ for i_block in range(self.num_res_blocks + 1):
+ h = self.res_blocks[i_level][i_block](h, None)
+ if i_level != self.num_resolutions - 1:
+ h = self.upsample_blocks[k](h)
+ h = self.norm_out(h)
+ h = nonlinearity(h)
+ h = self.conv_out(h)
+ return h
+
+
+class LatentRescaler(nn.Module):
+
+ def __init__(self,
+ factor,
+ in_channels,
+ mid_channels,
+ out_channels,
+ depth=2):
+ super().__init__()
+ # residual block, interpolate, residual block
+ self.factor = factor
+ self.conv_in = nn.Conv2d(in_channels,
+ mid_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+ self.res_block1 = nn.ModuleList([
+ ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)
+ ])
+ self.attn = AttnBlock(mid_channels)
+ self.res_block2 = nn.ModuleList([
+ ResnetBlock(in_channels=mid_channels,
+ out_channels=mid_channels,
+ temb_channels=0,
+ dropout=0.0) for _ in range(depth)
+ ])
+
+ self.conv_out = nn.Conv2d(
+ mid_channels,
+ out_channels,
+ kernel_size=1,
+ )
+
+ def forward(self, x):
+ x = self.conv_in(x)
+ for block in self.res_block1:
+ x = block(x, None)
+ x = torch.nn.functional.interpolate(
+ x,
+ size=(int(round(x.shape[2] * self.factor)),
+ int(round(x.shape[3] * self.factor))))
+ x = self.attn(x)
+ for block in self.res_block2:
+ x = block(x, None)
+ x = self.conv_out(x)
+ return x
+
+
+class MergedRescaleEncoder(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ ch,
+ resolution,
+ out_ch,
+ num_res_blocks,
+ attn_resolutions,
+ dropout=0.0,
+ resamp_with_conv=True,
+ ch_mult=(1, 2, 4, 8),
+ rescale_factor=1.0,
+ rescale_module_depth=1):
+ super().__init__()
+ intermediate_chn = ch * ch_mult[-1]
+ self.encoder = Encoder(in_channels=in_channels,
+ num_res_blocks=num_res_blocks,
+ ch=ch,
+ ch_mult=ch_mult,
+ z_channels=intermediate_chn,
+ double_z=False,
+ resolution=resolution,
+ attn_resolutions=attn_resolutions,
+ dropout=dropout,
+ resamp_with_conv=resamp_with_conv,
+ out_ch=None)
+ self.rescaler = LatentRescaler(factor=rescale_factor,
+ in_channels=intermediate_chn,
+ mid_channels=intermediate_chn,
+ out_channels=out_ch,
+ depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.encoder(x)
+ x = self.rescaler(x)
+ return x
+
+
+class MergedRescaleDecoder(nn.Module):
+
+ def __init__(self,
+ z_channels,
+ out_ch,
+ resolution,
+ num_res_blocks,
+ attn_resolutions,
+ ch,
+ ch_mult=(1, 2, 4, 8),
+ dropout=0.0,
+ resamp_with_conv=True,
+ rescale_factor=1.0,
+ rescale_module_depth=1):
+ super().__init__()
+ tmp_chn = z_channels * ch_mult[-1]
+ self.decoder = Decoder(out_ch=out_ch,
+ z_channels=tmp_chn,
+ attn_resolutions=attn_resolutions,
+ dropout=dropout,
+ resamp_with_conv=resamp_with_conv,
+ in_channels=None,
+ num_res_blocks=num_res_blocks,
+ ch_mult=ch_mult,
+ resolution=resolution,
+ ch=ch)
+ self.rescaler = LatentRescaler(factor=rescale_factor,
+ in_channels=z_channels,
+ mid_channels=tmp_chn,
+ out_channels=tmp_chn,
+ depth=rescale_module_depth)
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Upsampler(nn.Module):
+
+ def __init__(self,
+ in_size,
+ out_size,
+ in_channels,
+ out_channels,
+ ch_mult=2):
+ super().__init__()
+ assert out_size >= in_size
+ num_blocks = int(np.log2(out_size // in_size)) + 1
+ factor_up = 1. + (out_size % in_size)
+ print(
+ f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}"
+ )
+ self.rescaler = LatentRescaler(factor=factor_up,
+ in_channels=in_channels,
+ mid_channels=2 * in_channels,
+ out_channels=in_channels)
+ self.decoder = Decoder(out_ch=out_channels,
+ resolution=out_size,
+ z_channels=in_channels,
+ num_res_blocks=2,
+ attn_resolutions=[],
+ in_channels=None,
+ ch=in_channels,
+ ch_mult=[ch_mult for _ in range(num_blocks)])
+
+ def forward(self, x):
+ x = self.rescaler(x)
+ x = self.decoder(x)
+ return x
+
+
+class Resize(nn.Module):
+
+ def __init__(self, in_channels=None, learned=False, mode="bilinear"):
+ super().__init__()
+ self.with_conv = learned
+ self.mode = mode
+ if self.with_conv:
+ print(
+ f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode"
+ )
+ raise NotImplementedError()
+ assert in_channels is not None
+ # no asymmetric padding in torch conv, must do it ourselves
+ self.conv = torch.nn.Conv2d(in_channels,
+ in_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1)
+
+ def forward(self, x, scale_factor=1.0):
+ if scale_factor == 1.0:
+ return x
+ else:
+ x = torch.nn.functional.interpolate(x,
+ mode=self.mode,
+ align_corners=False,
+ scale_factor=scale_factor)
+ return x
+
+
+class FirstStagePostProcessor(nn.Module):
+
+ def __init__(self,
+ ch_mult: list,
+ in_channels,
+ pretrained_model: nn.Module = None,
+ reshape=False,
+ n_channels=None,
+ dropout=0.,
+ pretrained_config=None):
+ super().__init__()
+ if pretrained_config is None:
+ assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.pretrained_model = pretrained_model
+ else:
+ assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
+ self.instantiate_pretrained(pretrained_config)
+
+ self.do_reshape = reshape
+
+ if n_channels is None:
+ n_channels = self.pretrained_model.encoder.ch
+
+ self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2)
+ self.proj = nn.Conv2d(in_channels,
+ n_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ blocks = []
+ downs = []
+ ch_in = n_channels
+ for m in ch_mult:
+ blocks.append(
+ ResnetBlock(in_channels=ch_in,
+ out_channels=m * n_channels,
+ dropout=dropout))
+ ch_in = m * n_channels
+ downs.append(Downsample(ch_in, with_conv=False))
+
+ self.model = nn.ModuleList(blocks)
+ self.downsampler = nn.ModuleList(downs)
+
+ def instantiate_pretrained(self, config):
+ model = instantiate_from_config(config)
+ self.pretrained_model = model.eval()
+ # self.pretrained_model.train = False
+ for param in self.pretrained_model.parameters():
+ param.requires_grad = False
+
+ @torch.no_grad()
+ def encode_with_pretrained(self, x):
+ c = self.pretrained_model.encode(x)
+ if isinstance(c, DiagonalGaussianDistribution):
+ c = c.mode()
+ return c
+
+ def forward(self, x):
+ z_fs = self.encode_with_pretrained(x)
+ z = self.proj_norm(z_fs)
+ z = self.proj(z)
+ z = nonlinearity(z)
+
+ for submodel, downmodel in zip(self.model, self.downsampler):
+ z = submodel(z, temb=None)
+ z = downmodel(z)
+
+ if self.do_reshape:
+ z = rearrange(z, 'b c h w -> b (h w) c')
+ return z
diff --git a/src/unifolm_wma/modules/networks/wma_model.py b/src/unifolm_wma/modules/networks/wma_model.py
new file mode 100644
index 0000000..e1b4838
--- /dev/null
+++ b/src/unifolm_wma/modules/networks/wma_model.py
@@ -0,0 +1,848 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from torch import Tensor
+from functools import partial
+from abc import abstractmethod
+from einops import rearrange
+from omegaconf import OmegaConf
+from typing import Optional, Sequence, Any, Tuple, Union, List, Dict
+from collections.abc import Mapping, Iterable, Callable
+
+from unifolm_wma.utils.diffusion import timestep_embedding
+from unifolm_wma.utils.common import checkpoint
+from unifolm_wma.utils.basics import (zero_module, conv_nd, linear,
+ avg_pool_nd, normalization)
+from unifolm_wma.modules.attention import SpatialTransformer, TemporalTransformer
+from unifolm_wma.utils.utils import instantiate_from_config
+
+
+class TimestepBlock(nn.Module):
+ """
+ Any module where forward() takes timestep embeddings as a second argument.
+ """
+
+ @abstractmethod
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ """
+ A sequential module that passes timestep embeddings to the children that
+ support it as an extra input.
+ """
+
+ def forward(self, x, emb, context=None, batch_size=None):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb, batch_size=batch_size)
+ elif isinstance(layer, SpatialTransformer):
+ x = layer(x, context)
+ elif isinstance(layer, TemporalTransformer):
+ x = rearrange(x, '(b f) c h w -> b c f h w', b=batch_size)
+ x = layer(x, context)
+ x = rearrange(x, 'b c f h w -> (b f) c h w')
+ else:
+ x = layer(x)
+ return x
+
+
+class Downsample(nn.Module):
+ """
+ A downsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ downsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self,
+ channels,
+ use_conv,
+ dims=2,
+ out_channels=None,
+ padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ stride = 2 if dims != 3 else (1, 2, 2)
+ if use_conv:
+ self.op = conv_nd(dims,
+ self.channels,
+ self.out_channels,
+ 3,
+ stride=stride,
+ padding=padding)
+ else:
+ assert self.channels == self.out_channels
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ return self.op(x)
+
+
+class Upsample(nn.Module):
+ """
+ An upsampling layer with an optional convolution.
+ :param channels: channels in the inputs and outputs.
+ :param use_conv: a bool determining if a convolution is applied.
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
+ upsampling occurs in the inner-two dimensions.
+ """
+
+ def __init__(self,
+ channels,
+ use_conv,
+ dims=2,
+ out_channels=None,
+ padding=1):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.dims = dims
+ if use_conv:
+ self.conv = conv_nd(dims,
+ self.channels,
+ self.out_channels,
+ 3,
+ padding=padding)
+
+ def forward(self, x):
+ assert x.shape[1] == self.channels
+ if self.dims == 3:
+ x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
+ mode='nearest')
+ else:
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class ResBlock(TimestepBlock):
+ """
+ A residual block that can optionally change the number of channels.
+ :param channels: the number of input channels.
+ :param emb_channels: the number of timestep embedding channels.
+ :param dropout: the rate of dropout.
+ :param out_channels: if specified, the number of out channels.
+ :param use_conv: if True and out_channels is specified, use a spatial
+ convolution instead of a smaller 1x1 convolution to change the
+ channels in the skip connection.
+ :param dims: determines if the signal is 1D, 2D, or 3D.
+ :param up: if True, use this block for upsampling.
+ :param down: if True, use this block for downsampling.
+ :param use_temporal_conv: if True, use the temporal convolution.
+ :param use_image_dataset: if True, the temporal parameters will not be optimized.
+ """
+
+ def __init__(self,
+ channels,
+ emb_channels,
+ dropout,
+ out_channels=None,
+ use_scale_shift_norm=False,
+ dims=2,
+ use_checkpoint=False,
+ use_conv=False,
+ up=False,
+ down=False,
+ use_temporal_conv=False,
+ tempspatial_aware=False):
+ super().__init__()
+ self.channels = channels
+ self.emb_channels = emb_channels
+ self.dropout = dropout
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_checkpoint = use_checkpoint
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.use_temporal_conv = use_temporal_conv
+
+ self.in_layers = nn.Sequential(
+ normalization(channels),
+ nn.SiLU(),
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
+ )
+
+ self.updown = up or down
+
+ if up:
+ self.h_upd = Upsample(channels, False, dims)
+ self.x_upd = Upsample(channels, False, dims)
+ elif down:
+ self.h_upd = Downsample(channels, False, dims)
+ self.x_upd = Downsample(channels, False, dims)
+ else:
+ self.h_upd = self.x_upd = nn.Identity()
+
+ self.emb_layers = nn.Sequential(
+ nn.SiLU(),
+ nn.Linear(
+ emb_channels,
+ 2 * self.out_channels
+ if use_scale_shift_norm else self.out_channels,
+ ),
+ )
+ self.out_layers = nn.Sequential(
+ normalization(self.out_channels),
+ nn.SiLU(),
+ nn.Dropout(p=dropout),
+ zero_module(
+ nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)),
+ )
+
+ if self.out_channels == channels:
+ self.skip_connection = nn.Identity()
+ elif use_conv:
+ self.skip_connection = conv_nd(dims,
+ channels,
+ self.out_channels,
+ 3,
+ padding=1)
+ else:
+ self.skip_connection = conv_nd(dims, channels, self.out_channels,
+ 1)
+
+ if self.use_temporal_conv:
+ self.temopral_conv = TemporalConvBlock(
+ self.out_channels,
+ self.out_channels,
+ dropout=0.1,
+ spatial_aware=tempspatial_aware)
+
+ def forward(self, x, emb, batch_size=None):
+ """
+ Apply the block to a Tensor, conditioned on a timestep embedding.
+ :param x: an [N x C x ...] Tensor of features.
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
+ :return: an [N x C x ...] Tensor of outputs.
+ """
+ input_tuple = (x, emb)
+ if batch_size:
+ forward_batchsize = partial(self._forward, batch_size=batch_size)
+ return checkpoint(forward_batchsize, input_tuple,
+ self.parameters(), self.use_checkpoint)
+ return checkpoint(self._forward, input_tuple, self.parameters(),
+ self.use_checkpoint)
+
+ def _forward(self, x, emb, batch_size=None):
+ if self.updown:
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
+ h = in_rest(x)
+ h = self.h_upd(h)
+ x = self.x_upd(x)
+ h = in_conv(h)
+ else:
+ h = self.in_layers(x)
+ emb_out = self.emb_layers(emb).type(h.dtype)
+ while len(emb_out.shape) < len(h.shape):
+ emb_out = emb_out[..., None]
+ if self.use_scale_shift_norm:
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
+ h = out_norm(h) * (1 + scale) + shift
+ h = out_rest(h)
+ else:
+ h = h + emb_out
+ h = self.out_layers(h)
+ h = self.skip_connection(x) + h
+
+ if self.use_temporal_conv and batch_size:
+ h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size)
+ h = self.temopral_conv(h)
+ h = rearrange(h, 'b c t h w -> (b t) c h w')
+ return h
+
+
+class TemporalConvBlock(nn.Module):
+ """
+ Adapted from modelscope: https://github.com/modelscope/modelscope/blob/master/modelscope/models/multi_modal/video_synthesis/unet_sd.py
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels=None,
+ dropout=0.0,
+ spatial_aware=False):
+ super(TemporalConvBlock, self).__init__()
+ if out_channels is None:
+ out_channels = in_channels
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ th_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 3, 1)
+ th_padding_shape = (1, 0, 0) if not spatial_aware else (1, 1, 0)
+ tw_kernel_shape = (3, 1, 1) if not spatial_aware else (3, 1, 3)
+ tw_padding_shape = (1, 0, 0) if not spatial_aware else (1, 0, 1)
+
+ # conv layers
+ self.conv1 = nn.Sequential(
+ nn.GroupNorm(32, in_channels), nn.SiLU(),
+ nn.Conv3d(in_channels,
+ out_channels,
+ th_kernel_shape,
+ padding=th_padding_shape))
+ self.conv2 = nn.Sequential(
+ nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
+ nn.Conv3d(out_channels,
+ in_channels,
+ tw_kernel_shape,
+ padding=tw_padding_shape))
+ self.conv3 = nn.Sequential(
+ nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
+ nn.Conv3d(out_channels,
+ in_channels,
+ th_kernel_shape,
+ padding=th_padding_shape))
+ self.conv4 = nn.Sequential(
+ nn.GroupNorm(32, out_channels), nn.SiLU(), nn.Dropout(dropout),
+ nn.Conv3d(out_channels,
+ in_channels,
+ tw_kernel_shape,
+ padding=tw_padding_shape))
+
+ # Zero out the last layer params,so the conv block is identity
+ nn.init.zeros_(self.conv4[-1].weight)
+ nn.init.zeros_(self.conv4[-1].bias)
+
+ def forward(self, x):
+ identity = x
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ x = self.conv4(x)
+
+ return identity + x
+
+
+class WMAModel(nn.Module):
+ """
+ The full World-Model-Action model.
+ """
+
+ def __init__(self,
+ in_channels: int,
+ model_channels: int,
+ out_channels: int,
+ num_res_blocks: int,
+ attention_resolutions: Sequence[int],
+ dropout: float = 0.0,
+ channel_mult: Sequence[int] = (1, 2, 4, 8),
+ conv_resample: bool = True,
+ dims: int = 2,
+ context_dim: int | None = None,
+ use_scale_shift_norm: bool = False,
+ resblock_updown: bool = False,
+ num_heads: int = -1,
+ num_head_channels: int = -1,
+ transformer_depth: int = 1,
+ use_linear: bool = False,
+ use_checkpoint: bool = False,
+ temporal_conv: bool = False,
+ tempspatial_aware: bool = False,
+ temporal_attention: bool = True,
+ use_relative_position: bool = True,
+ use_causal_attention: bool = False,
+ temporal_length: int | None = None,
+ use_fp16: bool = False,
+ addition_attention: bool = False,
+ temporal_selfatt_only: bool = True,
+ image_cross_attention: bool = False,
+ cross_attention_scale_learnable: bool = False,
+ default_fs: int = 4,
+ fs_condition: bool = False,
+ n_obs_steps: int = 1,
+ num_stem_token: int = 1,
+ unet_head_config: OmegaConf | None = None,
+ stem_process_config: OmegaConf | None = None,
+ base_model_gen_only: bool = False):
+ """
+ Initialize the World-Model-Action network.
+
+ Args:
+ in_channels: Number of input channels to the backbone.
+ model_channels: Base channel width for the UNet/backbone.
+ out_channels: Number of output channels.
+ num_res_blocks: Number of residual blocks per resolution stage.
+ attention_resolutions: Resolutions at which to enable attention.
+ dropout: Dropout probability used inside residual/attention blocks.
+ channel_mult: Multipliers for channels at each resolution level.
+ conv_resample: If True, use convolutional resampling for up/down sampling.
+ dims: Spatial dimensionality of the backbone (1/2/3).
+ context_dim: Optional context embedding dimension (for cross-attention).
+ use_scale_shift_norm: Enable scale-shift (FiLM-style) normalization in blocks.
+ resblock_updown: Use residual blocks for up/down sampling (instead of plain conv).
+ num_heads: Number of attention heads (if >= 0). If -1, derive from num_head_channels.
+ num_head_channels: Channels per attention head (if >= 0). If -1, derive from num_heads.
+ transformer_depth: Number of transformer/attention blocks per stage.
+ use_linear: Use linear attention variants where applicable.
+ use_checkpoint: Enable gradient checkpointing in blocks to save memory.
+ temporal_conv: Include temporal convolution along the time dimension.
+ tempspatial_aware: If True, use time–space aware blocks.
+ temporal_attention: Enable temporal self-attention.
+ use_relative_position: Use relative position encodings in attention.
+ use_causal_attention: Use causal (uni-directional) attention along time.
+ temporal_length: Optional maximum temporal length expected by the model.
+ use_fp16: Enable half-precision layers/normalization where supported.
+ addition_attention: Add auxiliary attention modules.
+ temporal_selfatt_only: Restrict attention to temporal-only (no spatial) if True.
+ image_cross_attention: Enable cross-attention with image embeddings.
+ cross_attention_scale_learnable: Make cross-attention scaling a learnable parameter.
+ default_fs: Default frame-stride / fps.
+ fs_condition: If True, condition on frame-stride/fps features.
+ n_obs_steps: Number of observed steps used in conditioning heads.
+ num_stem_token: Number of stem tokens for action tokenization.
+ unet_head_config: OmegaConf for UNet heads (e.g., action/state heads).
+ stem_process_config: OmegaConf for stem/preprocessor module.
+ base_model_gen_only: Perform the generation using the base model with out action and state outputs.
+ """
+
+ super(WMAModel, self).__init__()
+ if num_heads == -1:
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
+ if num_head_channels == -1:
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
+
+ self.in_channels = in_channels
+ self.model_channels = model_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_resolutions = attention_resolutions
+ self.dropout = dropout
+ self.channel_mult = channel_mult
+ self.conv_resample = conv_resample
+ self.temporal_attention = temporal_attention
+ time_embed_dim = model_channels * 4
+ self.use_checkpoint = use_checkpoint
+ self.dtype = torch.float16 if use_fp16 else torch.float32
+ temporal_self_att_only = True
+ self.addition_attention = addition_attention
+ self.temporal_length = temporal_length
+ self.image_cross_attention = image_cross_attention
+ self.cross_attention_scale_learnable = cross_attention_scale_learnable
+ self.default_fs = default_fs
+ self.fs_condition = fs_condition
+ self.n_obs_steps = n_obs_steps
+ self.num_stem_token = num_stem_token
+ self.base_model_gen_only = base_model_gen_only
+
+ # Time embedding blocks
+ self.time_embed = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ if fs_condition:
+ self.fps_embedding = nn.Sequential(
+ linear(model_channels, time_embed_dim),
+ nn.SiLU(),
+ linear(time_embed_dim, time_embed_dim),
+ )
+ nn.init.zeros_(self.fps_embedding[-1].weight)
+ nn.init.zeros_(self.fps_embedding[-1].bias)
+ # Input Block
+ self.input_blocks = nn.ModuleList([
+ TimestepEmbedSequential(
+ conv_nd(dims, in_channels, model_channels, 3, padding=1))
+ ])
+ if self.addition_attention:
+ self.init_attn = TimestepEmbedSequential(
+ TemporalTransformer(model_channels,
+ n_heads=8,
+ d_head=num_head_channels,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ use_checkpoint=use_checkpoint,
+ only_self_att=temporal_selfatt_only,
+ causal_attention=False,
+ relative_position=use_relative_position,
+ temporal_length=temporal_length))
+
+ input_block_chans = [model_channels]
+ ch = model_channels
+ ds = 1
+ for level, mult in enumerate(channel_mult):
+ for _ in range(num_res_blocks):
+ layers = [
+ ResBlock(ch,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ tempspatial_aware=tempspatial_aware,
+ use_temporal_conv=temporal_conv)
+ ]
+ ch = mult * model_channels
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+
+ layers.append(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ use_linear=use_linear,
+ use_checkpoint=use_checkpoint,
+ disable_self_attn=False,
+ video_length=temporal_length,
+ agent_state_context_len=self.n_obs_steps,
+ agent_action_context_len=self.temporal_length *
+ num_stem_token,
+ image_cross_attention=self.image_cross_attention,
+ cross_attention_scale_learnable=self.
+ cross_attention_scale_learnable,
+ ))
+ if self.temporal_attention:
+ layers.append(
+ TemporalTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ use_linear=use_linear,
+ use_checkpoint=use_checkpoint,
+ only_self_att=temporal_self_att_only,
+ causal_attention=use_causal_attention,
+ relative_position=use_relative_position,
+ temporal_length=temporal_length))
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
+ input_block_chans.append(ch)
+ if level != len(channel_mult) - 1:
+ out_ch = ch
+ self.input_blocks.append(
+ TimestepEmbedSequential(
+ ResBlock(ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ down=True)
+ if resblock_updown else Downsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch))
+ )
+ ch = out_ch
+ input_block_chans.append(ch)
+ ds *= 2
+
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ layers = [
+ ResBlock(ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ tempspatial_aware=tempspatial_aware,
+ use_temporal_conv=temporal_conv),
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ use_linear=use_linear,
+ use_checkpoint=use_checkpoint,
+ disable_self_attn=False,
+ video_length=temporal_length,
+ agent_state_context_len=self.n_obs_steps,
+ agent_action_context_len=self.temporal_length * num_stem_token,
+ image_cross_attention=self.image_cross_attention,
+ cross_attention_scale_learnable=self.
+ cross_attention_scale_learnable)
+ ]
+ if self.temporal_attention:
+ layers.append(
+ TemporalTransformer(ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ use_linear=use_linear,
+ use_checkpoint=use_checkpoint,
+ only_self_att=temporal_self_att_only,
+ causal_attention=use_causal_attention,
+ relative_position=use_relative_position,
+ temporal_length=temporal_length))
+ layers.append(
+ ResBlock(ch,
+ time_embed_dim,
+ dropout,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ tempspatial_aware=tempspatial_aware,
+ use_temporal_conv=temporal_conv))
+
+ # Middle Block
+ self.middle_block = TimestepEmbedSequential(*layers)
+
+ # Output Block
+ self.output_blocks = nn.ModuleList([])
+ for level, mult in list(enumerate(channel_mult))[::-1]:
+ for i in range(num_res_blocks + 1):
+ ich = input_block_chans.pop()
+ layers = [
+ ResBlock(ch + ich,
+ time_embed_dim,
+ dropout,
+ out_channels=mult * model_channels,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ tempspatial_aware=tempspatial_aware,
+ use_temporal_conv=temporal_conv)
+ ]
+ ch = model_channels * mult
+ if ds in attention_resolutions:
+ if num_head_channels == -1:
+ dim_head = ch // num_heads
+ else:
+ num_heads = ch // num_head_channels
+ dim_head = num_head_channels
+ layers.append(
+ SpatialTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ use_linear=use_linear,
+ use_checkpoint=use_checkpoint,
+ disable_self_attn=False,
+ video_length=temporal_length,
+ agent_state_context_len=self.n_obs_steps,
+ image_cross_attention=self.image_cross_attention,
+ cross_attention_scale_learnable=self.
+ cross_attention_scale_learnable))
+ if self.temporal_attention:
+ layers.append(
+ TemporalTransformer(
+ ch,
+ num_heads,
+ dim_head,
+ depth=transformer_depth,
+ context_dim=context_dim,
+ use_linear=use_linear,
+ use_checkpoint=use_checkpoint,
+ only_self_att=temporal_self_att_only,
+ causal_attention=use_causal_attention,
+ relative_position=use_relative_position,
+ temporal_length=temporal_length))
+ if level and i == num_res_blocks:
+ out_ch = ch
+ layers.append(
+ ResBlock(ch,
+ time_embed_dim,
+ dropout,
+ out_channels=out_ch,
+ dims=dims,
+ use_checkpoint=use_checkpoint,
+ use_scale_shift_norm=use_scale_shift_norm,
+ up=True)
+ if resblock_updown else Upsample(
+ ch, conv_resample, dims=dims, out_channels=out_ch))
+ ds //= 2
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
+
+ self.out = nn.Sequential(
+ normalization(ch),
+ nn.SiLU(),
+ zero_module(
+ conv_nd(dims, model_channels, out_channels, 3, padding=1)),
+ )
+
+ # Action and state prediction unet
+ unet_head_config['params']['context_dims'] = [
+ mult * model_channels for mult in channel_mult
+ ]
+ self.action_unet = instantiate_from_config(unet_head_config)
+ self.state_unet = instantiate_from_config(unet_head_config)
+
+ # Initialize action token_projector
+ self.action_token_projector = instantiate_from_config(
+ stem_process_config)
+
+ def forward(self,
+ x: Tensor,
+ x_action: Tensor,
+ x_state: Tensor,
+ timesteps: Tensor,
+ context: Tensor | None = None,
+ context_action: Tensor | None = None,
+ features_adapter: Any = None,
+ fs: Tensor | None = None,
+ **kwargs) -> Tensor | tuple[Tensor, ...]:
+
+ """
+ Forward pass of the World-Model-Action backbone.
+
+ Args:
+ x: Input tensor (latent video), shape (B, C,...).
+ x_action: action stream input.
+ x_state: state stream input.
+ timesteps: Diffusion timesteps, shape (B,) or scalar Tensor.
+ context: conditioning context for cross-attention.
+ context_action: conditioning context specific to action/state (implementation-specific).
+ features_adapter: module or dict to adapt intermediate features.
+ fs: frame-stride / fps conditioning.
+
+ Returns:
+ Tuple of Tensors for predictions:
+
+ """
+
+ b, _, t, _, _ = x.shape
+ t_emb = timestep_embedding(timesteps,
+ self.model_channels,
+ repeat_only=False).type(x.dtype)
+ emb = self.time_embed(t_emb)
+
+ bt, l_context, _ = context.shape
+ if self.base_model_gen_only:
+ assert l_context == 77 + self.n_obs_steps * 16, ">>> ERROR Context dim 1 ..." ## NOTE HANDCODE
+ else:
+ if l_context == self.n_obs_steps + 77 + t * 16:
+ context_agent_state = context[:, :self.n_obs_steps]
+ context_text = context[:, self.n_obs_steps:self.n_obs_steps +
+ 77, :]
+ context_img = context[:, self.n_obs_steps + 77:, :]
+ context_agent_state = context_agent_state.repeat_interleave(
+ repeats=t, dim=0)
+ context_text = context_text.repeat_interleave(repeats=t, dim=0)
+ context_img = rearrange(context_img,
+ 'b (t l) c -> (b t) l c',
+ t=t)
+ context = torch.cat(
+ [context_agent_state, context_text, context_img], dim=1)
+ elif l_context == self.n_obs_steps + 16 + 77 + t * 16:
+ context_agent_state = context[:, :self.n_obs_steps]
+ context_agent_action = context[:, self.
+ n_obs_steps:self.n_obs_steps +
+ 16, :]
+ context_agent_action = rearrange(
+ context_agent_action.unsqueeze(2), 'b t l d -> (b t) l d')
+ context_agent_action = self.action_token_projector(
+ context_agent_action)
+ context_agent_action = rearrange(context_agent_action,
+ '(b o) l d -> b o l d',
+ o=t)
+ context_agent_action = rearrange(context_agent_action,
+ 'b o (t l) d -> b o t l d',
+ t=t)
+ context_agent_action = context_agent_action.permute(
+ 0, 2, 1, 3, 4)
+ context_agent_action = rearrange(context_agent_action,
+ 'b t o l d -> (b t) (o l) d')
+
+ context_text = context[:, self.n_obs_steps +
+ 16:self.n_obs_steps + 16 + 77, :]
+ context_text = context_text.repeat_interleave(repeats=t, dim=0)
+
+ context_img = context[:, self.n_obs_steps + 16 + 77:, :]
+ context_img = rearrange(context_img,
+ 'b (t l) c -> (b t) l c',
+ t=t)
+ context_agent_state = context_agent_state.repeat_interleave(
+ repeats=t, dim=0)
+ context = torch.cat([
+ context_agent_state, context_agent_action, context_text,
+ context_img
+ ],
+ dim=1)
+
+ emb = emb.repeat_interleave(repeats=t, dim=0)
+
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
+
+ # Combine emb
+ if self.fs_condition:
+ if fs is None:
+ fs = torch.tensor([self.default_fs] * b,
+ dtype=torch.long,
+ device=x.device)
+ fs_emb = timestep_embedding(fs,
+ self.model_channels,
+ repeat_only=False).type(x.dtype)
+
+ fs_embed = self.fps_embedding(fs_emb)
+ fs_embed = fs_embed.repeat_interleave(repeats=t, dim=0)
+ emb = emb + fs_embed
+
+ h = x.type(self.dtype)
+ adapter_idx = 0
+ hs = []
+ hs_a = []
+ for id, module in enumerate(self.input_blocks):
+ h = module(h, emb, context=context, batch_size=b)
+ if id == 0 and self.addition_attention:
+ h = self.init_attn(h, emb, context=context, batch_size=b)
+ # plug-in adapter features
+ if ((id + 1) % 3 == 0) and features_adapter is not None:
+ h = h + features_adapter[adapter_idx]
+ adapter_idx += 1
+ if id != 0:
+ if isinstance(module[0], Downsample):
+ hs_a.append(
+ rearrange(hs[-1], '(b t) c h w -> b t c h w', t=t))
+ hs.append(h)
+ hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
+
+ if features_adapter is not None:
+ assert len(
+ features_adapter) == adapter_idx, 'Wrong features_adapter'
+ h = self.middle_block(h, emb, context=context, batch_size=b)
+ hs_a.append(rearrange(h, '(b t) c h w -> b t c h w', t=t))
+
+ hs_out = []
+ for module in self.output_blocks:
+ h = torch.cat([h, hs.pop()], dim=1)
+ h = module(h, emb, context=context, batch_size=b)
+ if isinstance(module[-1], Upsample):
+ hs_a.append(
+ rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
+ hs_out.append(h)
+ h = h.type(x.dtype)
+ hs_a.append(rearrange(hs_out[-1], '(b t) c h w -> b t c h w', t=t))
+
+ y = self.out(h)
+ y = rearrange(y, '(b t) c h w -> b c t h w', b=b)
+
+ if not self.base_model_gen_only:
+ ba, _, _ = x_action.shape
+ a_y = self.action_unet(x_action, timesteps[:ba], hs_a,
+ context_action[:2], **kwargs)
+ # Predict state
+ if b > 1:
+ s_y = self.state_unet(x_state, timesteps[:ba], hs_a,
+ context_action[:2], **kwargs)
+ else:
+ s_y = self.state_unet(x_state, timesteps, hs_a,
+ context_action[:2], **kwargs)
+ else:
+ a_y = torch.zeros_like(x_action)
+ s_y = torch.zeros_like(x_state)
+
+ return y, a_y, s_y
diff --git a/src/unifolm_wma/modules/vision/__init__.py b/src/unifolm_wma/modules/vision/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/unifolm_wma/modules/vision/base_vision.py b/src/unifolm_wma/modules/vision/base_vision.py
new file mode 100644
index 0000000..d6f09ef
--- /dev/null
+++ b/src/unifolm_wma/modules/vision/base_vision.py
@@ -0,0 +1,244 @@
+"""
+base_vision.py
+
+Abstract class definition of a Vision Backbone (Visual Featurizer), with full annotations of class methods, utility
+functions, and initialization logic.
+
+We also define the generic TimmViTBackbone class here, providing a default interface for loading any TIMM Vision
+Transformer model for feature extraction.
+"""
+
+import timm
+import torch
+import torch.nn as nn
+import torchvision.transforms.functional as TVF
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from functools import partial
+from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union
+
+from PIL.Image import Image
+from timm.models.vision_transformer import Block, VisionTransformer
+from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
+from torchvision.transforms import Compose, Resize
+
+
+# === Utility Functions for Monkey-Patching ===
+def unpack_tuple(fn: Callable[[Any], Tuple[Any]]) -> Callable[[Any], Any]:
+
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
+ result = fn(*args, **kwargs)
+ return result[0] if isinstance(result, tuple) else result
+
+ return wrapper
+
+
+# === Interface for an Image Transform ===
+class ImageTransform(Protocol):
+
+ def __call__(
+ self, img: Image,
+ **kwargs: str) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
+ ...
+
+
+# === Custom Torchvision Image Transforms ===
+@dataclass
+class LetterboxPad:
+ padding_fill_value: Tuple[int, int, int]
+
+ def __call__(self, image: Image) -> Image:
+ """Given a PIL.Image, pad to square by adding a symmetric border around the height/width."""
+ (w, h), max_wh = image.size, max(image.size)
+ horizontal_pad, vertical_pad = int((max_wh - w) / 2), int(
+ (max_wh - h) / 2)
+ padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad)
+ return TVF.pad(image,
+ padding,
+ fill=self.padding_fill_value,
+ padding_mode="constant")
+
+
+# === Abstract Base Class for arbitrary Vision Backbones ===
+class VisionBackbone(nn.Module, ABC):
+
+ def __init__(self,
+ vision_backbone_id: str,
+ image_resize_strategy: str,
+ default_image_size: int = 224) -> None:
+ super().__init__()
+ self.identifier: str = vision_backbone_id
+ self.image_resize_strategy: str = image_resize_strategy
+ self.default_image_size: int = default_image_size
+
+ # Instance attributes for a Vision Backbone
+ self.featurizer: nn.Module = None
+ self.image_transform: ImageTransform = None
+
+ def get_image_transform(self) -> ImageTransform:
+ return self.image_transform
+
+ @abstractmethod
+ def get_fsdp_wrapping_policy(self) -> Callable:
+ ...
+
+ @abstractmethod
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
+ """Run a forward pass through the featurizer given a set of processed images, returning patch/grid features."""
+ raise NotImplementedError
+
+ @property
+ @abstractmethod
+ def default_image_resolution(self) -> Tuple[int, int, int]:
+ ...
+
+ @property
+ @abstractmethod
+ def embed_dim(self) -> int:
+ ...
+
+ @property
+ @abstractmethod
+ def num_patches(self) -> int:
+ ...
+
+ @property
+ @abstractmethod
+ def half_precision_dtype(self) -> torch.dtype:
+ ...
+
+
+# === Abstract Base Class for Arbitrary TIMM Vision Transformer Backbones ===
+class TimmViTBackbone(VisionBackbone, ABC):
+
+ def __init__(
+ self,
+ vision_backbone_id: str,
+ timm_path_or_url: str,
+ image_resize_strategy: str,
+ default_image_size: int = 224,
+ override_act_layer: Optional[str] = None,
+ ) -> None:
+ super().__init__(vision_backbone_id,
+ image_resize_strategy,
+ default_image_size=default_image_size)
+ self.timm_path_or_url = timm_path_or_url
+ self.override_act_layer = override_act_layer
+ self.dtype = torch.bfloat16
+
+ # Initialize Featurizer (ViT) by downloading from HF / TIMM Hub if necessary
+ if self.override_act_layer is None:
+ self.featurizer: VisionTransformer = timm.create_model(
+ self.timm_path_or_url,
+ pretrained=True,
+ num_classes=0,
+ img_size=self.default_image_size)
+ else:
+ self.featurizer: VisionTransformer = timm.create_model(
+ self.timm_path_or_url,
+ pretrained=True,
+ num_classes=0,
+ img_size=self.default_image_size,
+ act_layer=self.override_act_layer,
+ )
+ self.featurizer.eval()
+
+ # Monkey-Patch the `forward()` function of the featurizer to ensure FSDP-compatibility
+ # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
+ # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
+ self.featurizer.forward = unpack_tuple(
+ partial(self.featurizer.get_intermediate_layers,
+ n={len(self.featurizer.blocks) - 2}))
+
+ # Validation =>> for now, this class *only* supports TIMM Vision Transformers (but can be extended!)
+ assert isinstance(self.featurizer, VisionTransformer), (
+ "Featurizer is not a TIMM VisionTransformer; if you would like to support a new visual representation, "
+ "file an issue or implement the requisite logic (see `prismatic/models/backbones/vision/base_vision.py`)!"
+ )
+
+ # Get Config =>> Note :: Override default image size to ensure correct image transform
+ self.data_cfg = timm.data.resolve_model_data_config(self.featurizer)
+ self.data_cfg["input_size"] = (3, self.default_image_size,
+ self.default_image_size)
+
+ # Initialize Default Image Transform --> Modified by `self.image_resize_strategy`
+ default_image_transform = timm.data.create_transform(**self.data_cfg,
+ is_training=False)
+
+ # Fix =>> SigLIP & IN1K default transforms resize to *larger* than `self.default_image_size` (crops image)!
+ if "siglip" in self.timm_path_or_url or "in1k" in self.timm_path_or_url:
+ assert isinstance(default_image_transform,
+ Compose), "Unexpected `default_image_transform`!"
+ assert isinstance(default_image_transform.transforms[0], Resize)
+ default_image_transform = Compose([
+ Resize(self.default_image_size,
+ interpolation=default_image_transform.transforms[0].
+ interpolation),
+ *default_image_transform.transforms[1:],
+ ])
+
+ # Switch on `image_resize_strategy`
+ if self.image_resize_strategy == "resize-naive":
+ assert isinstance(default_image_transform,
+ Compose), "Unexpected `default_image_transform`!"
+ assert isinstance(default_image_transform.transforms[0], Resize)
+
+ target_size = (self.default_image_size, self.default_image_size)
+ self.image_transform = Compose([
+ Resize(target_size,
+ interpolation=default_image_transform.transforms[0].
+ interpolation),
+ *default_image_transform.transforms[1:],
+ ])
+
+ elif self.image_resize_strategy == "resize-crop":
+ self.image_transform = default_image_transform
+
+ elif self.image_resize_strategy == "letterbox":
+ assert isinstance(default_image_transform,
+ Compose), "Unexpected `default_image_transform`!"
+ assert "mean" in self.data_cfg, "TIMM `data_cfg` missing image normalization mean!"
+
+ # Compute Padding Fill Value (rescaled normalization mean if applicable)
+ fill = tuple([int(x * 255) for x in self.data_cfg["mean"]])
+
+ # Build New Transform
+ self.image_transform = Compose(
+ [LetterboxPad(fill), *default_image_transform.transforms])
+
+ else:
+ raise ValueError(
+ f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!"
+ )
+
+ def get_fsdp_wrapping_policy(self) -> Callable:
+ """Return a simple FSDP policy that wraps each ViT block and then the _entire_ featurizer."""
+ vit_wrap_policy = partial(_module_wrap_policy,
+ module_classes={VisionTransformer})
+ transformer_block_policy = partial(transformer_auto_wrap_policy,
+ transformer_layer_cls={Block})
+ return partial(_or_policy,
+ policies=[vit_wrap_policy, transformer_block_policy])
+
+ def forward(
+ self, pixel_values: Union[torch.Tensor, Dict[str, torch.Tensor]]
+ ) -> torch.Tensor:
+ """Runs transformed image/pixel tensor through vision backbone, returning _all_ patch features."""
+ return self.featurizer(pixel_values)
+
+ @property
+ def default_image_resolution(self) -> Tuple[int, int, int]:
+ return self.data_cfg["input_size"]
+
+ @property
+ def embed_dim(self) -> int:
+ return self.featurizer.embed_dim
+
+ @property
+ def num_patches(self) -> int:
+ return self.featurizer.patch_embed.num_patches
+
+ @property
+ def half_precision_dtype(self) -> torch.dtype:
+ return self.dtype
diff --git a/src/unifolm_wma/modules/vision/dinosiglip_vit.py b/src/unifolm_wma/modules/vision/dinosiglip_vit.py
new file mode 100644
index 0000000..c688678
--- /dev/null
+++ b/src/unifolm_wma/modules/vision/dinosiglip_vit.py
@@ -0,0 +1,273 @@
+"""
+dinosiglip_vit.py
+
+Vision backbone that returns concatenated features from both DINOv2 and SigLIP.
+"""
+
+import timm
+import torch
+import torchvision.transforms as transforms
+
+from dataclasses import dataclass
+from functools import partial
+from typing import Callable, Dict, Tuple
+from PIL import Image
+from timm.models.vision_transformer import Block, VisionTransformer
+from torch.distributed.fsdp.wrap import _module_wrap_policy, _or_policy, transformer_auto_wrap_policy
+from torchvision.transforms import Compose, Resize, Normalize
+
+from unifolm_wma.modules.vision.base_vision import ImageTransform, LetterboxPad, VisionBackbone, unpack_tuple
+from unifolm_wma.utils.nn_utils import FusedMLPProjector, LinearProjector, MLPProjector
+
+# Registry =>> Supported DinoSigLIP Pairs (as TIMM identifiers)
+DINOSigLIP_VISION_BACKBONES = {
+ "dinosiglip-vit-so-224px": {
+ "dino": "vit_large_patch14_reg4_dinov2.lvd142m",
+ "siglip": "vit_so400m_patch14_siglip_224",
+ },
+ "dinosiglip-vit-so-384px": {
+ "dino": "vit_large_patch14_reg4_dinov2.lvd142m",
+ "siglip": "vit_so400m_patch14_siglip_384",
+ },
+}
+
+
+@dataclass
+class DinoSigLIPImageTransform:
+ dino_image_transform: ImageTransform
+ siglip_image_transform: ImageTransform
+ is_prismatic: bool = True
+
+ def __call__(self, img: Image, **kwargs: str) -> Dict[str, torch.Tensor]:
+ return {
+ "dino": self.dino_image_transform(img, **kwargs),
+ "siglip": self.siglip_image_transform(img, **kwargs)
+ }
+
+
+class DinoSigLIPViTBackbone(VisionBackbone):
+
+ def __init__(self,
+ vision_backbone_id: str,
+ image_resize_strategy: str,
+ arch_specifier: str,
+ output_dim: int,
+ pretrained_checkpoint=None,
+ freeze=True,
+ default_image_size: int = 224) -> None:
+ super().__init__(vision_backbone_id,
+ image_resize_strategy,
+ default_image_size=default_image_size)
+ self.dino_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[
+ vision_backbone_id]["dino"]
+ self.siglip_timm_path_or_url = DINOSigLIP_VISION_BACKBONES[
+ vision_backbone_id]["siglip"]
+
+ # Initialize both Featurizers (ViTs) by downloading from HF / TIMM Hub if necessary
+ self.dino_featurizer: VisionTransformer = timm.create_model(
+ self.dino_timm_path_or_url,
+ pretrained=True,
+ num_classes=0,
+ img_size=self.default_image_size)
+ if pretrained_checkpoint:
+ ckpt = pretrained_checkpoint + '/openvla_dino.pt'
+ self.dino_featurizer.load_state_dict(
+ torch.load(ckpt, weights_only=True))
+ print('>>> load dino weights')
+ if freeze:
+ self.dino_featurizer.eval()
+ for param in self.dino_featurizer.parameters():
+ param.requires_grad = False
+
+ self.siglip_featurizer: VisionTransformer = timm.create_model(
+ self.siglip_timm_path_or_url,
+ pretrained=True,
+ num_classes=0,
+ img_size=self.default_image_size)
+ if pretrained_checkpoint:
+ ckpt = pretrained_checkpoint + '/openvla_siglip.pt'
+ self.siglip_featurizer.load_state_dict(
+ torch.load(ckpt, weights_only=True))
+ print('>>> load siglip weights')
+ if freeze:
+ self.siglip_featurizer.eval()
+ for param in self.siglip_featurizer.parameters():
+ param.requires_grad = False
+
+ # Monkey-Patch the `forward()` function of the featurizers to ensure FSDP-compatibility
+ # => Note: By default set `get_intermediate_layers` to return the *SECOND-TO-LAST* layer patches!
+ # => TODO (siddk) Remove after resolution of https://github.com/pytorch/pytorch/issues/109385
+ self.dino_featurizer.forward = unpack_tuple(
+ partial(self.dino_featurizer.get_intermediate_layers,
+ n={len(self.dino_featurizer.blocks) - 2}))
+ self.siglip_featurizer.forward = unpack_tuple(
+ partial(self.siglip_featurizer.get_intermediate_layers,
+ n={len(self.siglip_featurizer.blocks) - 2}))
+
+ # Get Configs for _both_ Featurizers =>> Note :: Override default image size for larger resolution models
+ self.dino_data_cfg = timm.data.resolve_model_data_config(
+ self.dino_featurizer)
+ self.dino_data_cfg["input_size"] = (3, self.default_image_size,
+ self.default_image_size)
+
+ self.siglip_data_cfg = timm.data.resolve_model_data_config(
+ self.siglip_featurizer)
+ self.siglip_data_cfg["input_size"] = (3, self.default_image_size,
+ self.default_image_size)
+
+ # Initialize *both* Transforms
+ self.default_dino_transform = timm.data.create_transform(
+ **self.dino_data_cfg, is_training=False)
+ self.default_siglip_transform = timm.data.create_transform(
+ **self.siglip_data_cfg, is_training=False)
+
+ # Fix =>> SigLIP default transform resizes to *larger* than `self.default_image_size` (crops image)!!
+ assert isinstance(self.default_siglip_transform,
+ Compose), "Unexpected `default_image_transform`!"
+ assert isinstance(self.default_siglip_transform.transforms[0], Resize)
+ self.default_siglip_transform = Compose([
+ Resize(self.default_image_size,
+ interpolation=self.default_siglip_transform.transforms[0].
+ interpolation),
+ *self.default_siglip_transform.transforms[1:],
+ ])
+
+ if self.image_resize_strategy == "resize-naive":
+ assert isinstance(
+ self.default_dino_transform,
+ Compose), "Unexpected `default_dino_image_transform`!"
+ assert isinstance(
+ self.default_siglip_transform,
+ Compose), "Unexpected `default_siglip_image_transform`!"
+ assert isinstance(self.default_dino_transform.transforms[0],
+ Resize)
+ assert isinstance(self.default_siglip_transform.transforms[0],
+ Resize)
+
+ self.target_size = (self.default_image_size,
+ self.default_image_size)
+ dino_transform = Compose([
+ Resize(self.target_size,
+ interpolation=self.default_dino_transform.transforms[0].
+ interpolation),
+ *self.default_dino_transform.transforms[1:],
+ ])
+ siglip_transform = Compose([
+ Resize(self.target_size,
+ interpolation=self.default_siglip_transform.
+ transforms[0].interpolation),
+ *self.default_siglip_transform.transforms[1:],
+ ])
+
+ self.image_transform = DinoSigLIPImageTransform(
+ dino_transform, siglip_transform)
+
+ elif self.image_resize_strategy == "resize-crop":
+ self.image_transform = DinoSigLIPImageTransform(
+ self.default_dino_transform, self.default_siglip_transform)
+
+ elif self.image_resize_strategy == "letterbox":
+ assert isinstance(self.default_dino_transform,
+ Compose), "Unexpected `default_dino_transform`!"
+ assert isinstance(
+ self.default_siglip_transform,
+ Compose), "Unexpected `default_siglip_transform`!"
+ assert ("mean" in self.dino_data_cfg
+ and "mean" in self.siglip_data_cfg
+ ), "DinoSigLIP `data_cfg` missing `mean`!"
+
+ # Compute Padding Fill Value(s) (rescaled normalization mean if applicable)
+ dino_fill = tuple(
+ [int(x * 255) for x in self.dino_data_cfg["mean"]])
+ siglip_fill = tuple(
+ [int(x * 255) for x in self.siglip_data_cfg["mean"]])
+
+ # Build New Transform
+ self.image_transform = DinoSigLIPImageTransform(
+ Compose([
+ LetterboxPad(dino_fill),
+ *self.default_dino_transform.transforms
+ ]),
+ Compose([
+ LetterboxPad(siglip_fill),
+ *self.default_siglip_transform.transforms
+ ]),
+ )
+
+ else:
+ raise ValueError(
+ f"Image Resize Strategy `{self.image_resize_strategy}` is not supported!"
+ )
+
+ self.arch_specifier = arch_specifier
+ if arch_specifier == "linear":
+ self.projector = LinearProjector(self.embed_dim, output_dim)
+ elif arch_specifier.endswith("fused-gelu-mlp"):
+ self.projector = FusedMLPProjector(self.embed_dim, output_dim)
+ elif arch_specifier.endswith("gelu-mlp"):
+ self.projector = MLPProjector(self.embed_dim, output_dim)
+ else:
+ raise ValueError(
+ f"PrismaticVLM with `{arch_specifier = }` is not supported!")
+
+ self.on_gpu = False
+
+ def get_fsdp_wrapping_policy(self) -> Callable:
+ """Return a simple FSDP policy that wraps each ViT block and then both of the _entire_ featurizers."""
+ vit_wrap_policy = partial(_module_wrap_policy,
+ module_classes={VisionTransformer})
+ transformer_block_policy = partial(transformer_auto_wrap_policy,
+ transformer_layer_cls={Block})
+ return partial(_or_policy,
+ policies=[vit_wrap_policy, transformer_block_policy])
+
+ def forward(self, img) -> torch.Tensor:
+ img = torch.clamp(img.float(), -1., 1.)
+ img = (img + 1.0) / 2.0
+ img = img * 255
+
+ resize = transforms.Resize(min(self.target_size),
+ interpolation=self.default_dino_transform.
+ transforms[0].interpolation,
+ max_size=None,
+ antialias=True)
+ center_crop = transforms.CenterCrop(self.target_size)
+ img = center_crop(resize(img))
+
+ dino_normalizer = Normalize(mean=torch.tensor([0.4850, 0.4560,
+ 0.4060]),
+ std=torch.tensor([0.2290, 0.2240, 0.2250]))
+ siglip_normalizer = Normalize(
+ mean=torch.tensor([0.5000, 0.5000, 0.5000]),
+ std=torch.tensor([0.5000, 0.5000, 0.5000]))
+ pixel_values = {
+ 'dino': dino_normalizer(img),
+ 'siglip': siglip_normalizer(img)
+ }
+
+ if self.on_gpu:
+ pixel_values = {k: v.cuda() for k, v in pixel_values.items()}
+ elif next(self.dino_featurizer.parameters()).device.type != 'cpu':
+ self.on_gpu = True
+ """Runs the transformed image/pixel tensors through each vision backbone, returning concatenated patches."""
+ dino_patches = self.dino_featurizer(pixel_values["dino"])
+ siglip_patches = self.siglip_featurizer(pixel_values["siglip"])
+
+ return self.projector(torch.cat([dino_patches, siglip_patches], dim=2))
+
+ @property
+ def default_image_resolution(self) -> Tuple[int, int, int]:
+ return self.dino_data_cfg["input_size"]
+
+ @property
+ def embed_dim(self) -> int:
+ return self.dino_featurizer.embed_dim + self.siglip_featurizer.embed_dim
+
+ @property
+ def num_patches(self) -> int:
+ assert self.dino_featurizer.patch_embed.num_patches == self.siglip_featurizer.patch_embed.num_patches
+ return self.dino_featurizer.patch_embed.num_patches
+
+ @property
+ def half_precision_dtype(self) -> torch.dtype:
+ return torch.bfloat16
diff --git a/src/unifolm_wma/utils/basics.py b/src/unifolm_wma/utils/basics.py
new file mode 100644
index 0000000..088298b
--- /dev/null
+++ b/src/unifolm_wma/utils/basics.py
@@ -0,0 +1,104 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+import torch.nn as nn
+from unifolm_wma.utils.utils import instantiate_from_config
+
+
+def disabled_train(self, mode=True):
+ """Overwrite model.train with this function to make sure train/eval mode
+ does not change anymore."""
+ return self
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def nonlinearity(type='silu'):
+ if type == 'silu':
+ return nn.SiLU()
+ elif type == 'leaky_relu':
+ return nn.LeakyReLU()
+
+
+class GroupNormSpecific(nn.GroupNorm):
+
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def normalization(channels, num_groups=32):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNormSpecific(num_groups, channels)
+
+
+class HybridConditioner(nn.Module):
+
+ def __init__(self, c_concat_config, c_crossattn_config):
+ super().__init__()
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
+ self.crossattn_conditioner = instantiate_from_config(
+ c_crossattn_config)
+
+ def forward(self, c_concat, c_crossattn):
+ c_concat = self.concat_conditioner(c_concat)
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
diff --git a/src/unifolm_wma/utils/callbacks.py b/src/unifolm_wma/utils/callbacks.py
new file mode 100644
index 0000000..0458394
--- /dev/null
+++ b/src/unifolm_wma/utils/callbacks.py
@@ -0,0 +1,226 @@
+import os
+import time
+import logging
+import json
+
+mainlogger = logging.getLogger('mainlogger')
+
+import torch
+import torchvision
+import pytorch_lightning as pl
+import matplotlib.pyplot as plt
+
+from pytorch_lightning.callbacks import Callback
+from pytorch_lightning.utilities import rank_zero_only
+from pytorch_lightning.utilities import rank_zero_info
+
+from unifolm_wma.utils.save_video import log_local, prepare_to_log
+
+STAT_DIR = '~/'
+
+
+class ImageLogger(Callback):
+
+ def __init__(self, batch_frequency, max_images=8, clamp=True, rescale=True, save_dir=None, \
+ to_local=False, log_images_kwargs=None):
+ super().__init__()
+ self.rescale = rescale
+ self.batch_freq = batch_frequency
+ self.max_images = max_images
+ self.to_local = to_local
+ self.clamp = clamp
+ self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
+ self.save_stat_dir = os.path.join(save_dir, "stat")
+ os.makedirs(self.save_stat_dir, exist_ok=True)
+ self.fps_stat = {}
+ self.fs_stat = {}
+ if self.to_local:
+ self.save_dir = os.path.join(save_dir, "images")
+ os.makedirs(os.path.join(self.save_dir, "train"), exist_ok=True)
+ os.makedirs(os.path.join(self.save_dir, "val"), exist_ok=True)
+ self.count_data = 0
+
+ def log_to_tensorboard(self,
+ pl_module,
+ batch_logs,
+ filename,
+ split,
+ save_fps=8):
+ """ log images and videos to tensorboard """
+ global_step = pl_module.global_step
+ for key in batch_logs:
+ value = batch_logs[key]
+ tag = "gs%d-%s/%s||%s||%s||%s" % (
+ global_step, split, key,
+ batch_logs['condition'][0].split('_')[0],
+ batch_logs['condition'][0].split('_')[1],
+ batch_logs['video_idx'])
+ if isinstance(value, list) and isinstance(value[0], str):
+ captions = ' |------| '.join(value)
+ pl_module.logger.experiment.add_text(tag,
+ captions,
+ global_step=global_step)
+ elif isinstance(value, torch.Tensor) and value.dim() == 5:
+ video = value
+ n = video.shape[0]
+ video = video.permute(2, 0, 1, 3, 4)
+ frame_grids = [
+ torchvision.utils.make_grid(framesheet,
+ nrow=int(n),
+ padding=0)
+ for framesheet in video
+ ]
+ grid = torch.stack(
+ frame_grids, dim=0)
+ grid = (grid + 1.0) / 2.0
+ grid = grid.unsqueeze(dim=0)
+ pl_module.logger.experiment.add_video(tag,
+ grid,
+ fps=save_fps,
+ global_step=global_step)
+ elif isinstance(value, torch.Tensor) and value.dim() == 4:
+ img = value
+ grid = torchvision.utils.make_grid(img, nrow=int(n), padding=0)
+ grid = (grid + 1.0) / 2.0
+ pl_module.logger.experiment.add_image(tag,
+ grid,
+ global_step=global_step)
+ elif isinstance(value, torch.Tensor) and value.dim() == 3:
+ b, _, _ = value.shape
+ value1 = value[:b // 2, ...]
+ value2 = value[b // 2:, ...]
+ _, num_points, d = value1.shape
+ for i in range(d):
+ data1 = value1[0, :, i].cpu().detach().numpy()
+ data2 = value2[0, :, i].cpu().detach().numpy()
+ fig, ax = plt.subplots()
+ ax.plot(data1, label='Target 1')
+ ax.plot(data2, label='Sample 1')
+ ax.set_title(f'Comparison at dimension {i} for {key}')
+ ax.legend()
+ pl_module.logger.experiment.add_figure(
+ tag + f"| {key}_dim_{i}", fig, global_step=global_step)
+ plt.close(fig)
+ else:
+ pass
+
+ @rank_zero_only
+ def log_batch_imgs(self, pl_module, batch, batch_idx, split="train"):
+ """ generate images, then save and log to tensorboard """
+ # Update fps and fs statistics
+ batch_fps = batch['fps'].tolist()
+ batch_fs = batch['frame_stride'].tolist()
+ for num in batch_fps:
+ self.fps_stat[num] = self.fps_stat.get(num, 0) + 1
+ for num in batch_fs:
+ self.fs_stat[num] = self.fs_stat.get(num, 0) + 1
+ skip_freq = self.batch_freq if split == "train" else 5
+ ## NOTE HAND CODE
+ self.count_data += 12.5 * 2
+ if self.count_data >= skip_freq:
+ self.count_data = 0
+
+ is_train = pl_module.training
+ if is_train:
+ pl_module.eval()
+ torch.cuda.empty_cache()
+ with torch.no_grad():
+ log_func = pl_module.log_images
+ batch_logs = log_func(batch,
+ split=split,
+ **self.log_images_kwargs)
+ # Log fps and fs statistics
+ with open(self.save_stat_dir + '/fps_fs_stat.json',
+ 'w') as file:
+ json.dump({
+ 'fps': self.fps_stat,
+ 'fs': self.fs_stat
+ },
+ file,
+ indent=4)
+
+ batch_logs = prepare_to_log(batch_logs, self.max_images,
+ self.clamp)
+ torch.cuda.empty_cache()
+
+ filename = "ep{}_idx{}_rank{}".format(pl_module.current_epoch,
+ batch_idx,
+ pl_module.global_rank)
+ if self.to_local:
+ mainlogger.info("Log [%s] batch <%s> to local ..." %
+ (split, filename))
+ filename = "gs{}_".format(pl_module.global_step) + filename
+ log_local(batch_logs,
+ os.path.join(self.save_dir, split),
+ filename,
+ save_fps=10)
+ else:
+ mainlogger.info("Log [%s] batch <%s> to tensorboard ..." %
+ (split, filename))
+ self.log_to_tensorboard(pl_module,
+ batch_logs,
+ filename,
+ split,
+ save_fps=10)
+ mainlogger.info('Finish!')
+
+ if is_train:
+ pl_module.train()
+
+ def on_train_batch_end(self,
+ trainer,
+ pl_module,
+ outputs,
+ batch,
+ batch_idx,
+ dataloader_idx=None):
+ if self.batch_freq != -1 and pl_module.logdir:
+ self.log_batch_imgs(pl_module, batch, batch_idx, split="train")
+
+ def on_validation_batch_end(self,
+ trainer,
+ pl_module,
+ outputs,
+ batch,
+ batch_idx,
+ dataloader_idx=None):
+ #Different with validation_step() that saving the whole validation set and only keep the latest,
+ #It records the performance of every validation (without overwritten) by only keep a subset
+ if self.batch_freq != -1 and pl_module.logdir:
+ self.log_batch_imgs(pl_module, batch, batch_idx, split="val")
+ if hasattr(pl_module, 'calibrate_grad_norm'):
+ if (pl_module.calibrate_grad_norm
+ and batch_idx % 25 == 0) and batch_idx > 0:
+ self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
+
+
+class CUDACallback(Callback):
+ # See https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
+ def on_train_epoch_start(self, trainer, pl_module):
+ # Reset the memory use counter
+ # Lightning update
+ if int((pl.__version__).split('.')[1]) >= 7:
+ gpu_index = trainer.strategy.root_device.index
+ else:
+ gpu_index = trainer.root_gpu
+ torch.cuda.reset_peak_memory_stats(gpu_index)
+ torch.cuda.synchronize(gpu_index)
+ self.start_time = time.time()
+
+ def on_train_epoch_end(self, trainer, pl_module):
+ if int((pl.__version__).split('.')[1]) >= 7:
+ gpu_index = trainer.strategy.root_device.index
+ else:
+ gpu_index = trainer.root_gpu
+ torch.cuda.synchronize(gpu_index)
+ max_memory = torch.cuda.max_memory_allocated(gpu_index) / 2**20
+ epoch_time = time.time() - self.start_time
+
+ try:
+ max_memory = trainer.training_type_plugin.reduce(max_memory)
+ epoch_time = trainer.training_type_plugin.reduce(epoch_time)
+
+ rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
+ rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
+ except AttributeError:
+ pass
diff --git a/src/unifolm_wma/utils/common.py b/src/unifolm_wma/utils/common.py
new file mode 100644
index 0000000..f5b8a03
--- /dev/null
+++ b/src/unifolm_wma/utils/common.py
@@ -0,0 +1,111 @@
+import math
+from inspect import isfunction
+import torch
+from torch import Tensor, nn
+import torch.distributed as dist
+
+
+def gather_data(data, return_np=True):
+ ''' gather data from multiple processes to one list '''
+ data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
+ dist.all_gather(data_list, data) # gather not supported with NCCL
+ if return_np:
+ data_list = [data.cpu().numpy() for data in data_list]
+ return data_list
+
+
+def autocast(f):
+
+ def do_autocast(*args, **kwargs):
+ with torch.cuda.amp.autocast(
+ enabled=True,
+ dtype=torch.get_autocast_gpu_dtype(),
+ cache_enabled=torch.is_autocast_cache_enabled()):
+ return f(*args, **kwargs)
+
+ return do_autocast
+
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
+ shape[0], *((1, ) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def exists(val):
+ return val is not None
+
+
+def identity(*args, **kwargs):
+ return nn.Identity()
+
+
+def uniq(arr):
+ return {el: True for el in arr}.keys()
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def ismap(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
+
+
+def isimage(x):
+ if not isinstance(x, torch.Tensor):
+ return False
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
+
+
+def max_neg_value(t):
+ return -torch.finfo(t.dtype).max
+
+
+def shape_to_str(x):
+ shape_str = "x".join([str(x) for x in x.shape])
+ return shape_str
+
+
+def init_(tensor):
+ dim = tensor.shape[-1]
+ std = 1 / math.sqrt(dim)
+ tensor.uniform_(-std, std)
+ return tensor
+
+
+ckpt = torch.utils.checkpoint.checkpoint
+
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ return ckpt(func, *inputs, use_reentrant=False)
+ else:
+ return func(*inputs)
diff --git a/src/unifolm_wma/utils/data.py b/src/unifolm_wma/utils/data.py
new file mode 100644
index 0000000..589da5c
--- /dev/null
+++ b/src/unifolm_wma/utils/data.py
@@ -0,0 +1,242 @@
+import os, sys
+import numpy as np
+import torch
+import pytorch_lightning as pl
+
+from functools import partial
+from torch.utils.data import (DataLoader, Dataset, ConcatDataset,
+ WeightedRandomSampler)
+from unifolm_wma.data.base import Txt2ImgIterableBaseDataset
+from unifolm_wma.utils.utils import instantiate_from_config
+
+
+def worker_init_fn(_):
+ worker_info = torch.utils.data.get_worker_info()
+
+ dataset = worker_info.dataset
+ worker_id = worker_info.id
+
+ if isinstance(dataset, Txt2ImgIterableBaseDataset):
+ split_size = dataset.num_records // worker_info.num_workers
+ # Reset num_records to the true number to retain reliable length information
+ dataset.sample_ids = dataset.valid_ids[worker_id *
+ split_size:(worker_id + 1) *
+ split_size]
+ current_id = np.random.choice(len(np.random.get_state()[1]), 1)
+ return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
+ else:
+ return np.random.seed(np.random.get_state()[1][0] + worker_id)
+
+
+class WrappedDataset(Dataset):
+ """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
+
+ def __init__(self, dataset):
+ self.data = dataset
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ return self.data[idx]
+
+
+class DataModuleFromConfig(pl.LightningDataModule):
+
+ def __init__(self,
+ batch_size,
+ train=None,
+ validation=None,
+ test=None,
+ predict=None,
+ wrap=False,
+ num_workers=None,
+ shuffle_test_loader=False,
+ use_worker_init_fn=False,
+ shuffle_val_dataloader=True,
+ train_img=None,
+ dataset_and_weights=None):
+ super().__init__()
+ self.batch_size = batch_size
+ self.dataset_configs = dict()
+ self.num_workers = num_workers if num_workers is not None else batch_size * 2
+ self.use_worker_init_fn = use_worker_init_fn
+ if train is not None:
+ self.dataset_configs["train"] = train
+ self.train_dataloader = self._train_dataloader
+ if validation is not None:
+ self.dataset_configs["validation"] = validation
+ self.val_dataloader = partial(self._val_dataloader,
+ shuffle=shuffle_val_dataloader)
+ if test is not None:
+ self.dataset_configs["test"] = test
+ self.test_dataloader = partial(self._test_dataloader,
+ shuffle=shuffle_test_loader)
+ if predict is not None:
+ self.dataset_configs["predict"] = predict
+ self.predict_dataloader = self._predict_dataloader
+
+ self.img_loader = None
+ self.wrap = wrap
+ self.collate_fn = None
+ self.dataset_weights = dataset_and_weights
+ assert round(sum(self.dataset_weights.values()),
+ 2) == 1.0, "The sum of dataset weights != 1.0"
+
+ def prepare_data(self):
+ pass
+
+ def setup(self, stage=None):
+ if 'train' in self.dataset_configs:
+ self.train_datasets = dict()
+ for dataname in self.dataset_weights:
+ data_dir = self.dataset_configs['train']['params']['data_dir']
+ transition_dir = '/'.join([data_dir, 'transitions'])
+ csv_file = f'{dataname}.csv'
+ meta_path = '/'.join([data_dir, csv_file])
+ self.dataset_configs['train']['params'][
+ 'meta_path'] = meta_path
+ self.dataset_configs['train']['params'][
+ 'transition_dir'] = transition_dir
+ self.dataset_configs['train']['params'][
+ 'dataset_name'] = dataname
+ self.train_datasets[dataname] = instantiate_from_config(
+ self.dataset_configs['train'])
+
+ # Setup validation dataset
+ if 'validation' in self.dataset_configs:
+ self.val_datasets = dict()
+ for dataname in self.dataset_weights:
+ data_dir = self.dataset_configs['validation']['params'][
+ 'data_dir']
+ transition_dir = '/'.join([data_dir, 'transitions'])
+ csv_file = f'{dataname}.csv'
+ meta_path = '/'.join([data_dir, csv_file])
+ self.dataset_configs['validation']['params'][
+ 'meta_path'] = meta_path
+ self.dataset_configs['validation']['params'][
+ 'transition_dir'] = transition_dir
+ self.dataset_configs['validation']['params'][
+ 'dataset_name'] = dataname
+ self.val_datasets[dataname] = instantiate_from_config(
+ self.dataset_configs['validation'])
+
+ # Setup test dataset
+ if 'test' in self.dataset_configs:
+ self.test_datasets = dict()
+ for dataname in self.dataset_weights:
+ data_dir = self.dataset_configs['test']['params']['data_dir']
+ transition_dir = '/'.join([data_dir, 'transitions'])
+ csv_file = f'{dataname}.csv'
+ meta_path = '/'.join([data_dir, csv_file])
+ self.dataset_configs['test']['params']['meta_path'] = meta_path
+ self.dataset_configs['test']['params'][
+ 'transition_dir'] = transition_dir
+ self.dataset_configs['test']['params'][
+ 'dataset_name'] = dataname
+ self.test_datasets[dataname] = instantiate_from_config(
+ self.dataset_configs['test'])
+
+ if self.wrap:
+ for k in self.datasets:
+ self.datasets[k] = WrappedDataset(self.datasets[k])
+
+ def _train_dataloader(self):
+ is_iterable_dataset = False # NOTE Hand Code
+ if is_iterable_dataset or self.use_worker_init_fn:
+ init_fn = worker_init_fn
+ else:
+ init_fn = None
+ combined_dataset = []
+ sample_weights = []
+ for dataname, dataset in self.train_datasets.items():
+ combined_dataset.append(dataset)
+ sample_weights.append(
+ torch.full((len(dataset), ),
+ self.dataset_weights[dataname] / len(dataset)))
+ combined_dataset = ConcatDataset(combined_dataset)
+ sample_weights = torch.cat(sample_weights)
+ sampler = WeightedRandomSampler(sample_weights,
+ num_samples=len(combined_dataset),
+ replacement=True)
+ loader = DataLoader(combined_dataset,
+ sampler=sampler,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ worker_init_fn=init_fn,
+ collate_fn=self.collate_fn,
+ drop_last=True
+ )
+ return loader
+
+ def _val_dataloader(self, shuffle=False):
+ is_iterable_dataset = False # NOTE Hand Code
+ if is_iterable_dataset or self.use_worker_init_fn:
+ init_fn = worker_init_fn
+ else:
+ init_fn = None
+ combined_dataset = []
+ sample_weights = []
+ for dataname, dataset in self.val_datasets.items():
+ combined_dataset.append(dataset)
+ sample_weights.append(
+ torch.full((len(dataset), ),
+ self.dataset_weights[dataname] / len(dataset)))
+ combined_dataset = ConcatDataset(combined_dataset)
+ sample_weights = torch.cat(sample_weights)
+ sampler = WeightedRandomSampler(sample_weights,
+ num_samples=len(combined_dataset),
+ replacement=True)
+ loader = DataLoader(combined_dataset,
+ sampler=sampler,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ worker_init_fn=init_fn,
+ collate_fn=self.collate_fn)
+ return loader
+
+ def _test_dataloader(self, shuffle=False):
+ is_iterable_dataset = False # NOTE Hand Code
+ if is_iterable_dataset or self.use_worker_init_fn:
+ init_fn = worker_init_fn
+ else:
+ init_fn = None
+ combined_dataset = []
+ sample_weights = []
+ for dataname, dataset in self.test_datasets.items():
+ combined_dataset.append(dataset)
+ sample_weights.append(
+ torch.full((len(dataset), ),
+ self.dataset_weights[dataname] / len(dataset)))
+ combined_dataset = ConcatDataset(combined_dataset)
+ sample_weights = torch.cat(sample_weights)
+ sampler = WeightedRandomSampler(sample_weights,
+ num_samples=len(combined_dataset),
+ replacement=True)
+ loader = DataLoader(combined_dataset,
+ sampler=sampler,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ worker_init_fn=init_fn,
+ collate_fn=self.collate_fn)
+ return loader
+
+ def _predict_dataloader(self, shuffle=False):
+ if isinstance(self.datasets['predict'],
+ Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
+ init_fn = worker_init_fn
+ else:
+ init_fn = None
+ return DataLoader(
+ self.datasets["predict"],
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ worker_init_fn=init_fn,
+ collate_fn=self.collate_fn,
+ )
+
+ def __len__(self):
+ count = 0
+ for _, values in self.train_datasets.items():
+ count += len(values)
+ return count
diff --git a/src/unifolm_wma/utils/diffusion.py b/src/unifolm_wma/utils/diffusion.py
new file mode 100644
index 0000000..ee92eab
--- /dev/null
+++ b/src/unifolm_wma/utils/diffusion.py
@@ -0,0 +1,191 @@
+import math
+import numpy as np
+import torch
+import torch.nn.functional as F
+from einops import repeat
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) *
+ torch.arange(start=0, end=half, dtype=torch.float32) /
+ half).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat(
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim)
+ return embedding
+
+
+def make_beta_schedule(schedule,
+ n_timestep,
+ linear_start=1e-4,
+ linear_end=2e-2,
+ cosine_s=8e-3):
+ if schedule == "linear":
+ betas = (torch.linspace(linear_start**0.5,
+ linear_end**0.5,
+ n_timestep,
+ dtype=torch.float64)**2)
+
+ elif schedule == "cosine":
+ timesteps = (
+ torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep +
+ cosine_s)
+ alphas = timesteps / (1 + cosine_s) * np.pi / 2
+ alphas = torch.cos(alphas).pow(2)
+ alphas = alphas / alphas[0]
+ betas = 1 - alphas[1:] / alphas[:-1]
+ betas = np.clip(betas, a_min=0, a_max=0.999)
+
+ elif schedule == "sqrt_linear":
+ betas = torch.linspace(linear_start,
+ linear_end,
+ n_timestep,
+ dtype=torch.float64)
+ elif schedule == "sqrt":
+ betas = torch.linspace(linear_start,
+ linear_end,
+ n_timestep,
+ dtype=torch.float64)**0.5
+ else:
+ raise ValueError(f"schedule '{schedule}' unknown.")
+ return betas.numpy()
+
+
+def make_ddim_timesteps(ddim_discr_method,
+ num_ddim_timesteps,
+ num_ddpm_timesteps,
+ verbose=True):
+ if ddim_discr_method == 'uniform':
+ c = num_ddpm_timesteps // num_ddim_timesteps
+ ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
+ steps_out = ddim_timesteps + 1
+ elif ddim_discr_method == 'uniform_trailing':
+ c = num_ddpm_timesteps / num_ddim_timesteps
+ ddim_timesteps = np.flip(np.round(np.arange(num_ddpm_timesteps, 0,
+ -c))).astype(np.int64)
+ steps_out = ddim_timesteps - 1
+ elif ddim_discr_method == 'quad':
+ ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8),
+ num_ddim_timesteps))**2).astype(int)
+ steps_out = ddim_timesteps + 1
+ else:
+ raise NotImplementedError(
+ f'There is no ddim discretization method called "{ddim_discr_method}"'
+ )
+
+ # assert ddim_timesteps.shape[0] == num_ddim_timesteps
+ # add one to get the final alpha values right (the ones from first scale to data during sampling)
+ # steps_out = ddim_timesteps + 1
+ if verbose:
+ print(f'Selected timesteps for ddim sampler: {steps_out}')
+ return steps_out
+
+
+def make_ddim_sampling_parameters(alphacums,
+ ddim_timesteps,
+ eta,
+ verbose=True):
+ # select alphas for computing the variance schedule
+ # print(f'ddim_timesteps={ddim_timesteps}, len_alphacums={len(alphacums)}')
+ alphas = alphacums[ddim_timesteps]
+ alphas_prev = np.asarray([alphacums[0]] +
+ alphacums[ddim_timesteps[:-1]].tolist())
+
+ # according the formula provided in https://arxiv.org/abs/2010.02502
+ sigmas = eta * np.sqrt(
+ (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
+ if verbose:
+ print(
+ f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
+ )
+ print(
+ f'For the chosen value of eta, which is {eta}, '
+ f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
+ )
+ return sigmas, alphas, alphas_prev
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+def rescale_zero_terminal_snr(betas):
+ """
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
+
+ Args:
+ betas (`numpy.ndarray`):
+ the betas that the scheduler is being initialized with.
+
+ Returns:
+ `numpy.ndarray`: rescaled betas with zero terminal SNR
+ """
+ # Convert betas to alphas_bar_sqrt
+ alphas = 1.0 - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_bar_sqrt = np.sqrt(alphas_cumprod)
+
+ # Store old values.
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy()
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy()
+
+ # Shift so the last timestep is zero.
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
+
+ # Scale so the first timestep is back to the old value.
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 -
+ alphas_bar_sqrt_T)
+
+ # Convert alphas_bar_sqrt to betas
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
+ alphas = np.concatenate([alphas_bar[0:1], alphas])
+ betas = 1 - alphas
+
+ return betas
+
+
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+ """
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+ """
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)),
+ keepdim=True)
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+ # Rescale the results from guidance (fixes overexposure)
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+ # Mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (
+ 1 - guidance_rescale) * noise_cfg
+ return noise_cfg
diff --git a/src/unifolm_wma/utils/distributions.py b/src/unifolm_wma/utils/distributions.py
new file mode 100644
index 0000000..8b04a67
--- /dev/null
+++ b/src/unifolm_wma/utils/distributions.py
@@ -0,0 +1,94 @@
+import torch
+import numpy as np
+
+
+class AbstractDistribution:
+
+ def sample(self):
+ raise NotImplementedError()
+
+ def mode(self):
+ raise NotImplementedError()
+
+
+class DiracDistribution(AbstractDistribution):
+
+ def __init__(self, value):
+ self.value = value
+
+ def sample(self):
+ return self.value
+
+ def mode(self):
+ return self.value
+
+
+class DiagonalGaussianDistribution(object):
+
+ def __init__(self, parameters, deterministic=False):
+ self.parameters = parameters
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
+ self.deterministic = deterministic
+ self.std = torch.exp(0.5 * self.logvar)
+ self.var = torch.exp(self.logvar)
+ if self.deterministic:
+ self.var = self.std = torch.zeros_like(
+ self.mean).to(device=self.parameters.device)
+
+ def sample(self, noise=None):
+ if noise is None:
+ noise = torch.randn(self.mean.shape)
+
+ x = self.mean + self.std * noise.to(device=self.parameters.device)
+ return x
+
+ def kl(self, other=None):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ else:
+ if other is None:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
+ dim=[1, 2, 3])
+ else:
+ return 0.5 * torch.sum(
+ torch.pow(self.mean - other.mean, 2) / other.var +
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
+ dim=[1, 2, 3])
+
+ def nll(self, sample, dims=[1, 2, 3]):
+ if self.deterministic:
+ return torch.Tensor([0.])
+ logtwopi = np.log(2.0 * np.pi)
+ return 0.5 * torch.sum(logtwopi + self.logvar +
+ torch.pow(sample - self.mean, 2) / self.var,
+ dim=dims)
+
+ def mode(self):
+ return self.mean
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, torch.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for torch.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) +
+ ((mean1 - mean2)**2) * torch.exp(-logvar2))
diff --git a/src/unifolm_wma/utils/ema.py b/src/unifolm_wma/utils/ema.py
new file mode 100644
index 0000000..d3898a8
--- /dev/null
+++ b/src/unifolm_wma/utils/ema.py
@@ -0,0 +1,84 @@
+import torch
+from torch import nn
+
+
+class LitEma(nn.Module):
+
+ def __init__(self, model, decay=0.9999, use_num_upates=True):
+ super().__init__()
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError('Decay must be between 0 and 1')
+
+ self.m_name2s_name = {}
+ self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
+ self.register_buffer(
+ 'num_updates',
+ torch.tensor(0, dtype=torch.int)
+ if use_num_upates else torch.tensor(-1, dtype=torch.int))
+
+ for name, p in model.named_parameters():
+ if p.requires_grad:
+ #Remove as '.'-character is not allowed in buffers
+ s_name = name.replace('.', '')
+ self.m_name2s_name.update({name: s_name})
+ self.register_buffer(s_name, p.clone().detach().data)
+
+ self.collected_params = []
+
+ def forward(self, model):
+ decay = self.decay
+
+ if self.num_updates >= 0:
+ self.num_updates += 1
+ decay = min(self.decay,
+ (1 + self.num_updates) / (10 + self.num_updates))
+
+ one_minus_decay = 1.0 - decay
+
+ with torch.no_grad():
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+
+ for key in m_param:
+ if m_param[key].requires_grad:
+ sname = self.m_name2s_name[key]
+ shadow_params[sname] = shadow_params[sname].type_as(
+ m_param[key])
+ shadow_params[sname].sub_(
+ one_minus_decay *
+ (shadow_params[sname] - m_param[key]))
+ else:
+ assert not key in self.m_name2s_name
+
+ def copy_to(self, model):
+ m_param = dict(model.named_parameters())
+ shadow_params = dict(self.named_buffers())
+ for key in m_param:
+ if m_param[key].requires_grad:
+ m_param[key].data.copy_(
+ shadow_params[self.m_name2s_name[key]].data)
+ else:
+ assert not key in self.m_name2s_name
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
diff --git a/src/unifolm_wma/utils/nn_utils.py b/src/unifolm_wma/utils/nn_utils.py
new file mode 100644
index 0000000..15dabd3
--- /dev/null
+++ b/src/unifolm_wma/utils/nn_utils.py
@@ -0,0 +1,66 @@
+"""
+nn_utils.py
+
+Utility functions and PyTorch submodule definitions.
+"""
+
+import torch
+import torch.nn as nn
+
+
+# === Definitions for Various Projection Modules, with Signature :: [..., in_dim] --> [..., out_dim] ===
+class LinearProjector(nn.Module):
+
+ def __init__(self, vision_dim: int, llm_dim: int) -> None:
+ super().__init__()
+ self.projector = nn.Linear(vision_dim, llm_dim, bias=True)
+
+ def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
+ return self.projector(img_patches)
+
+
+class MLPProjector(nn.Module):
+
+ def __init__(self,
+ vision_dim: int,
+ llm_dim: int,
+ mlp_type: str = "gelu-mlp") -> None:
+ super().__init__()
+ if mlp_type == "gelu-mlp":
+ self.projector = nn.Sequential(
+ nn.Linear(vision_dim, llm_dim, bias=True),
+ nn.GELU(),
+ nn.Linear(llm_dim, llm_dim, bias=True),
+ )
+ else:
+ raise ValueError(
+ f"Projector with `{mlp_type = }` is not supported!")
+
+ def forward(self, img_patches: torch.Tensor) -> torch.Tensor:
+ return self.projector(img_patches)
+
+
+class FusedMLPProjector(nn.Module):
+
+ def __init__(self,
+ fused_vision_dim: int,
+ llm_dim: int,
+ mlp_type: str = "fused-gelu-mlp") -> None:
+ super().__init__()
+ self.initial_projection_dim = fused_vision_dim * 4
+ if mlp_type == "fused-gelu-mlp":
+ self.projector = nn.Sequential(
+ nn.Linear(fused_vision_dim,
+ self.initial_projection_dim,
+ bias=True),
+ nn.GELU(),
+ nn.Linear(self.initial_projection_dim, llm_dim, bias=True),
+ nn.GELU(),
+ nn.Linear(llm_dim, llm_dim, bias=True),
+ )
+ else:
+ raise ValueError(
+ f"Fused Projector with `{mlp_type = }` is not supported!")
+
+ def forward(self, fused_img_patches: torch.Tensor) -> torch.Tensor:
+ return self.projector(fused_img_patches)
diff --git a/src/unifolm_wma/utils/projector.py b/src/unifolm_wma/utils/projector.py
new file mode 100644
index 0000000..579b7c2
--- /dev/null
+++ b/src/unifolm_wma/utils/projector.py
@@ -0,0 +1,147 @@
+import torch
+import torch.nn as nn
+
+
+class LinearProjector(nn.Module):
+ def __init__(self, input_dim: int, output_dim: int) -> None:
+ super().__init__()
+ self.projector = nn.Linear(input_dim, output_dim, bias=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.projector(x)
+
+
+class MLPProjector(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ mlp_type: str = "gelu-mlp") -> None:
+ super().__init__()
+ if mlp_type == "gelu-mlp":
+ self.projector = nn.Sequential(
+ nn.Linear(vision_dim, llm_dim, bias=True),
+ nn.GELU(approximate='tanh'),
+ nn.Linear(llm_dim, llm_dim, bias=True),
+ )
+ elif mlp_type == "silu-mlp":
+ self.projector = nn.Sequential(
+ nn.Linear(vision_dim, llm_dim, bias=True),
+ nn.SiLU(),
+ nn.Linear(llm_dim, llm_dim, bias=True),
+ )
+ else:
+ raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.projector(x)
+
+
+class PerceiverAttention(nn.Module):
+ def __init__(self, *, dim, dim_head=64, heads=8):
+ super().__init__()
+ self.scale = dim_head**-0.5
+ self.dim_head = dim_head
+ self.heads = heads
+ inner_dim = dim_head * heads
+
+ self.norm1 = nn.LayerNorm(dim)
+ self.norm2 = nn.LayerNorm(dim)
+
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
+
+
+ def forward(self, x, latents):
+ """
+ Args:
+ x (torch.Tensor): image features
+ shape (b, n1, D)
+ latent (torch.Tensor): latent features
+ shape (b, n2, D)
+ """
+
+ x = self.norm1(x)
+ latents = self.norm2(latents)
+
+ b, l, _ = latents.shape
+
+ q = self.to_q(latents)
+ kv_input = torch.cat((x, latents), dim=-2)
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
+
+ q = reshape_tensor(q, self.heads)
+ k = reshape_tensor(k, self.heads)
+ v = reshape_tensor(v, self.heads)
+
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ out = weight @ v
+
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
+
+ return self.to_out(out)
+
+
+def FeedForward(dim, mult=4, ffd_type="gelu-ffd"):
+ inner_dim = int(dim * mult)
+ if ffd_type = "gelu-ffd":
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, inner_dim, bias=False),
+ nn.GELU(approximate='tanh'),
+ nn.Linear(inner_dim, dim, bias=False),
+ )
+ elif ffd_type = "silu-ffd":
+ return nn.Sequential(
+ nn.LayerNorm(dim),
+ nn.Linear(dim, inner_dim, bias=False),
+ nn.SiLU(),
+ nn.Linear(inner_dim, dim, bias=False),
+ )
+ else:
+ raise ValueError(f"Projector with `{mlp_type = }` is not supported!")
+
+
+class TokenProjector(nn.Module):
+ def __init__(
+ self,
+ dim=1024,
+ depth=1,
+ dim_head=64,
+ heads=16,
+ num_queries=16,
+ output_dim=1024,
+ ff_mult=4,
+ chunck_size=None,
+ ):
+ super().__init__()
+ self.num_queries = num_queries
+ self.chunck_size = chunck_size
+ if chunck_size is not None:
+ num_queries = num_queries * chunck_size
+
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
+ self.proj_out = nn.Linear(dim, output_dim)
+ self.norm_out = nn.LayerNorm(dim)
+
+ self.layers = nn.ModuleList([])
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
+ FeedForward(dim=dim, mult=ff_mult),
+ ]
+ )
+ )
+
+ def forward(self, x)
+ latents = self.latents.repeat(x.size(0), 1, 1)
+ for attn, ff in self.layers:
+ latents = attn(x, latents) + latents
+ latents = ff(latents) + latents
+ latents = self.proj_out(latents)
+ latents = self.norm_out(latents)
diff --git a/src/unifolm_wma/utils/save_video.py b/src/unifolm_wma/utils/save_video.py
new file mode 100644
index 0000000..83f475b
--- /dev/null
+++ b/src/unifolm_wma/utils/save_video.py
@@ -0,0 +1,258 @@
+import os
+import torch
+import numpy as np
+import torchvision
+
+from tqdm import tqdm
+from PIL import Image
+from einops import rearrange
+from torch import Tensor
+from torchvision.utils import make_grid
+from torchvision.transforms.functional import to_tensor
+
+
+def frames_to_mp4(frame_dir, output_path, fps):
+
+ def read_first_n_frames(d: os.PathLike, num_frames: int):
+ if num_frames:
+ images = [
+ Image.open(os.path.join(d, f))
+ for f in sorted(os.listdir(d))[:num_frames]
+ ]
+ else:
+ images = [
+ Image.open(os.path.join(d, f)) for f in sorted(os.listdir(d))
+ ]
+ images = [to_tensor(x) for x in images]
+ return torch.stack(images)
+
+ videos = read_first_n_frames(frame_dir, num_frames=None)
+ videos = videos.mul(255).to(torch.uint8).permute(0, 2, 3, 1)
+ torchvision.io.write_video(output_path,
+ videos,
+ fps=fps,
+ video_codec='h264',
+ options={'crf': '10'})
+
+
+def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None):
+ """
+ video: torch.Tensor, b,c,t,h,w, 0-1
+ if -1~1, enable rescale=True
+ """
+ n = video.shape[0]
+ video = video.permute(2, 0, 1, 3, 4)
+ nrow = int(np.sqrt(n)) if nrow is None else nrow
+ frame_grids = [
+ torchvision.utils.make_grid(framesheet, nrow=nrow, padding=0)
+ for framesheet in video
+ ]
+ grid = torch.stack(frame_grids,
+ dim=0)
+ grid = torch.clamp(grid.float(), -1., 1.)
+ if rescale:
+ grid = (grid + 1.0) / 2.0
+ grid = (grid * 255).to(torch.uint8).permute(
+ 0, 2, 3, 1)
+ torchvision.io.write_video(savepath,
+ grid,
+ fps=fps,
+ video_codec='h264',
+ options={'crf': '10'})
+
+
+def tensor2videogrids(video, root, filename, fps, rescale=True, clamp=True):
+ assert (video.dim() == 5)
+ assert (isinstance(video, torch.Tensor))
+
+ video = video.detach().cpu()
+ if clamp:
+ video = torch.clamp(video, -1., 1.)
+ n = video.shape[0]
+ video = video.permute(2, 0, 1, 3, 4)
+ frame_grids = [
+ torchvision.utils.make_grid(framesheet, nrow=int(np.sqrt(n)))
+ for framesheet in video
+ ]
+ grid = torch.stack(frame_grids,
+ dim=0)
+ if rescale:
+ grid = (grid + 1.0) / 2.0
+ grid = (grid * 255).to(torch.uint8).permute(
+ 0, 2, 3, 1)
+ path = os.path.join(root, filename)
+ torchvision.io.write_video(path,
+ grid,
+ fps=fps,
+ video_codec='h264',
+ options={'crf': '10'})
+
+
+def log_local(batch_logs, save_dir, filename, save_fps=10, rescale=True):
+ if batch_logs is None:
+ return None
+ """ save images and videos from images dict """
+
+ def save_img_grid(grid, path, rescale):
+ if rescale:
+ grid = (grid + 1.0) / 2.0
+ grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
+ grid = grid.numpy()
+ grid = (grid * 255).astype(np.uint8)
+ os.makedirs(os.path.split(path)[0], exist_ok=True)
+ Image.fromarray(grid).save(path)
+
+ for key in batch_logs:
+ value = batch_logs[key]
+ if isinstance(value, list) and isinstance(value[0], str):
+ # A batch of captions
+ path = os.path.join(save_dir, "%s-%s.txt" % (key, filename))
+ with open(path, 'w') as f:
+ for i, txt in enumerate(value):
+ f.write(f'idx={i}, txt={txt}\n')
+ f.close()
+ elif isinstance(value, torch.Tensor) and value.dim() == 5:
+ # Save video grids
+ video = value
+ # Only save grayscale or rgb mode
+ if video.shape[1] != 1 and video.shape[1] != 3:
+ continue
+ n = video.shape[0]
+ video = video.permute(2, 0, 1, 3, 4)
+ frame_grids = [
+ torchvision.utils.make_grid(framesheet, nrow=int(1), padding=0)
+ for framesheet in video
+ ]
+ grid = torch.stack(frame_grids,
+ dim=0)
+ if rescale:
+ grid = (grid + 1.0) / 2.0
+ grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
+ path = os.path.join(save_dir, "%s-%s.mp4" % (key, filename))
+ torchvision.io.write_video(path,
+ grid,
+ fps=save_fps,
+ video_codec='h264',
+ options={'crf': '10'})
+
+ # Save frame sheet
+ img = value
+ video_frames = rearrange(img, 'b c t h w -> (b t) c h w')
+ t = img.shape[2]
+ grid = torchvision.utils.make_grid(video_frames, nrow=t, padding=0)
+ path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename))
+ # Save_img_grid(grid, path, rescale)
+ elif isinstance(value, torch.Tensor) and value.dim() == 4:
+ # Save image grids
+ img = value
+ # Only save grayscale or rgb mode
+ if img.shape[1] != 1 and img.shape[1] != 3:
+ continue
+ n = img.shape[0]
+ grid = torchvision.utils.make_grid(img, nrow=1, padding=0)
+ path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename))
+ save_img_grid(grid, path, rescale)
+ else:
+ pass
+
+
+def prepare_to_log(batch_logs, max_images=100000, clamp=True):
+ if batch_logs is None:
+ return None
+ for key in batch_logs:
+ N = batch_logs[key].shape[0] if hasattr(
+ batch_logs[key], 'shape') else len(batch_logs[key])
+ N = min(N, max_images)
+ batch_logs[key] = batch_logs[key][:N]
+ # In batch_logs: images & instruction
+ if isinstance(batch_logs[key], torch.Tensor):
+ batch_logs[key] = batch_logs[key].detach().cpu()
+ if clamp:
+ try:
+ batch_logs[key] = torch.clamp(batch_logs[key].float(), -1.,
+ 1.)
+ except RuntimeError:
+ print("clamp_scalar_cpu not implemented for Half")
+ return batch_logs
+
+
+# ----------------------------------------------------------------------------------------------
+
+
+def fill_with_black_squares(video, desired_len: int) -> Tensor:
+ if len(video) >= desired_len:
+ return video
+
+ return torch.cat([
+ video,
+ torch.zeros_like(video[0]).unsqueeze(0).repeat(
+ desired_len - len(video), 1, 1, 1),
+ ],
+ dim=0)
+
+
+# ----------------------------------------------------------------------------------------------
+def load_num_videos(data_path, num_videos):
+ # First argument can be either data_path of np array
+ if isinstance(data_path, str):
+ videos = np.load(data_path)['arr_0'] # NTHWC
+ elif isinstance(data_path, np.ndarray):
+ videos = data_path
+ else:
+ raise Exception
+
+ if num_videos is not None:
+ videos = videos[:num_videos, :, :, :, :]
+ return videos
+
+
+def npz_to_video_grid(data_path,
+ out_path,
+ num_frames,
+ fps,
+ num_videos=None,
+ nrow=None,
+ verbose=True):
+ if isinstance(data_path, str):
+ videos = load_num_videos(data_path, num_videos)
+ elif isinstance(data_path, np.ndarray):
+ videos = data_path
+ else:
+ raise Exception
+ n, t, h, w, c = videos.shape
+ videos_th = []
+ for i in range(n):
+ video = videos[i, :, :, :, :]
+ images = [video[j, :, :, :] for j in range(t)]
+ images = [to_tensor(img) for img in images]
+ video = torch.stack(images)
+ videos_th.append(video)
+ if verbose:
+ videos = [
+ fill_with_black_squares(v, num_frames)
+ for v in tqdm(videos_th, desc='Adding empty frames')
+ ]
+ else:
+ videos = [fill_with_black_squares(v, num_frames)
+ for v in videos_th] # NTCHW
+
+ frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4)
+ if nrow is None:
+ nrow = int(np.ceil(np.sqrt(n)))
+ if verbose:
+ frame_grids = [
+ make_grid(fs, nrow=nrow)
+ for fs in tqdm(frame_grids, desc='Making grids')
+ ]
+ else:
+ frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids]
+
+ if os.path.dirname(out_path) != "":
+ os.makedirs(os.path.dirname(out_path), exist_ok=True)
+ frame_grids = (torch.stack(frame_grids) * 255).to(torch.uint8).permute(
+ 0, 2, 3, 1)
+ torchvision.io.write_video(out_path,
+ frame_grids,
+ fps=fps,
+ video_codec='h264',
+ options={'crf': '10'})
diff --git a/src/unifolm_wma/utils/train.py b/src/unifolm_wma/utils/train.py
new file mode 100644
index 0000000..225fbf2
--- /dev/null
+++ b/src/unifolm_wma/utils/train.py
@@ -0,0 +1,231 @@
+import os
+import logging
+
+mainlogger = logging.getLogger('mainlogger')
+
+import torch
+import pandas as pd
+
+from omegaconf import OmegaConf
+from collections import OrderedDict
+
+
+def init_workspace(name, logdir, model_config, lightning_config, rank=0):
+ workdir = os.path.join(logdir, name)
+ ckptdir = os.path.join(workdir, "checkpoints")
+ cfgdir = os.path.join(workdir, "configs")
+ loginfo = os.path.join(workdir, "loginfo")
+
+ # Create logdirs and save configs (all ranks will do to avoid missing directory error if rank:0 is slower)
+ os.makedirs(workdir, exist_ok=True)
+ os.makedirs(ckptdir, exist_ok=True)
+ os.makedirs(cfgdir, exist_ok=True)
+ os.makedirs(loginfo, exist_ok=True)
+
+ if rank == 0:
+ if "callbacks" in lightning_config and 'metrics_over_trainsteps_checkpoint' in lightning_config.callbacks:
+ os.makedirs(os.path.join(ckptdir, 'trainstep_checkpoints'),
+ exist_ok=True)
+ OmegaConf.save(model_config, os.path.join(cfgdir, "model.yaml"))
+ OmegaConf.save(OmegaConf.create({"lightning": lightning_config}),
+ os.path.join(cfgdir, "lightning.yaml"))
+ return workdir, ckptdir, cfgdir, loginfo
+
+
+def check_config_attribute(config, name):
+ if name in config:
+ value = getattr(config, name)
+ return value
+ else:
+ return None
+
+
+def get_trainer_callbacks(lightning_config, config, logdir, ckptdir, logger):
+ default_callbacks_cfg = {
+ "model_checkpoint": {
+ "target": "pytorch_lightning.callbacks.ModelCheckpoint",
+ "params": {
+ "dirpath": ckptdir,
+ "filename": "{epoch}",
+ "verbose": True,
+ "save_last": False,
+ }
+ },
+ "batch_logger": {
+ "target": "unifolm_wma.utils.callbacks.ImageLogger",
+ "params": {
+ "save_dir": logdir,
+ "batch_frequency": 1000,
+ "max_images": 4,
+ "clamp": True,
+ }
+ },
+ "learning_rate_logger": {
+ "target": "pytorch_lightning.callbacks.LearningRateMonitor",
+ "params": {
+ "logging_interval": "step",
+ "log_momentum": False
+ }
+ },
+ "cuda_callback": {
+ "target": "unifolm_wma.utils.callbacks.CUDACallback",
+ },
+ }
+
+ # Optional setting for saving checkpoints
+ monitor_metric = check_config_attribute(config.model.params, "monitor")
+ if monitor_metric is not None:
+ mainlogger.info(f"Monitoring {monitor_metric} as checkpoint metric.")
+ default_callbacks_cfg["model_checkpoint"]["params"][
+ "monitor"] = monitor_metric
+ default_callbacks_cfg["model_checkpoint"]["params"]["save_top_k"] = 3
+ default_callbacks_cfg["model_checkpoint"]["params"]["mode"] = "min"
+
+ if 'metrics_over_trainsteps_checkpoint' in lightning_config.callbacks:
+ mainlogger.info(
+ 'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.'
+ )
+ default_metrics_over_trainsteps_ckpt_dict = {
+ 'metrics_over_trainsteps_checkpoint': {
+ "target": 'pytorch_lightning.callbacks.ModelCheckpoint',
+ 'params': {
+ "dirpath": os.path.join(ckptdir, 'trainstep_checkpoints'),
+ "filename": "{epoch}-{step}",
+ "verbose": True,
+ 'save_top_k': -1,
+ 'every_n_train_steps': 10000,
+ 'save_weights_only': True
+ }
+ }
+ }
+ default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
+
+ if "callbacks" in lightning_config:
+ callbacks_cfg = lightning_config.callbacks
+ else:
+ callbacks_cfg = OmegaConf.create()
+ callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
+
+ return callbacks_cfg
+
+
+def get_trainer_logger(lightning_config, logdir, on_debug):
+ default_logger_cfgs = {
+ "tensorboard": {
+ "target": "pytorch_lightning.loggers.TensorBoardLogger",
+ "params": {
+ "save_dir": logdir,
+ "name": "tensorboard",
+ }
+ },
+ "testtube": {
+ "target": "pytorch_lightning.loggers.CSVLogger",
+ "params": {
+ "name": "testtube",
+ "save_dir": logdir,
+ }
+ },
+ }
+ os.makedirs(os.path.join(logdir, "tensorboard"), exist_ok=True)
+ default_logger_cfg = default_logger_cfgs["tensorboard"]
+ if "logger" in lightning_config:
+ logger_cfg = lightning_config.logger
+ else:
+ logger_cfg = OmegaConf.create()
+ logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
+ return logger_cfg
+
+
+def get_trainer_strategy(lightning_config):
+ default_strategy_dict = {
+ "target": "pytorch_lightning.strategies.DDPShardedStrategy"
+ }
+ if "strategy" in lightning_config:
+ strategy_cfg = lightning_config.strategy
+ return strategy_cfg
+ else:
+ strategy_cfg = OmegaConf.create()
+
+ strategy_cfg = OmegaConf.merge(default_strategy_dict, strategy_cfg)
+ return strategy_cfg
+
+
+def load_checkpoints(model, model_cfg):
+ if check_config_attribute(model_cfg, "pretrained_checkpoint"):
+ pretrained_ckpt = model_cfg.pretrained_checkpoint
+ assert os.path.exists(
+ pretrained_ckpt
+ ), "Error: Pre-trained checkpoint NOT found at:%s" % pretrained_ckpt
+ mainlogger.info(">>> Load weights from pretrained checkpoint")
+
+ pl_sd = torch.load(pretrained_ckpt, map_location="cpu")
+ try:
+ if 'state_dict' in pl_sd.keys():
+ model.load_state_dict(pl_sd["state_dict"], strict=False)
+ mainlogger.info(
+ ">>> Loaded weights from pretrained checkpoint: %s" %
+ pretrained_ckpt)
+ else:
+ # deepspeed
+ new_pl_sd = OrderedDict()
+ for key in pl_sd['module'].keys():
+ new_pl_sd[key[16:]] = pl_sd['module'][key]
+ model.load_state_dict(new_pl_sd, strict=False)
+ except:
+ model.load_state_dict(pl_sd)
+ else:
+ mainlogger.info(">>> Start training from scratch")
+
+ return model
+
+
+def set_logger(logfile, name='mainlogger'):
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.INFO)
+ fh = logging.FileHandler(logfile, mode='w')
+ fh.setLevel(logging.INFO)
+ ch = logging.StreamHandler()
+ ch.setLevel(logging.DEBUG)
+ fh.setFormatter(
+ logging.Formatter("%(asctime)s-%(levelname)s: %(message)s"))
+ ch.setFormatter(logging.Formatter("%(message)s"))
+ logger.addHandler(fh)
+ logger.addHandler(ch)
+ return logger
+
+
+def count_parameters(model):
+ return sum(p.numel() for p in model.parameters())
+
+
+def count_trainable_parameters(model):
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
+
+
+def get_num_parameters(model):
+ models = [('World Model', model.model.diffusion_model),
+ ('Action Head', model.model.diffusion_model.action_unet),
+ ('State Head', model.model.diffusion_model.state_unet),
+ ('Total Trainable', model),
+ ('Total', model)]
+
+ data = []
+ for index, (name, model) in enumerate(models):
+ if name == "Total Trainable":
+ total_params = count_trainable_parameters(model)
+ else:
+ total_params = count_parameters(model)
+ if total_params < 0.1e9:
+ total_params_value = round(total_params / 1e6, 2)
+ unit = 'M'
+ else:
+ total_params_value = round(total_params / 1e9, 2)
+ unit = 'B'
+
+ data.append({
+ 'Model Name': name,
+ 'Params': f"{total_params_value} {unit}"
+ })
+
+ df = pd.DataFrame(data)
+ print(df)
diff --git a/src/unifolm_wma/utils/utils.py b/src/unifolm_wma/utils/utils.py
new file mode 100644
index 0000000..177d1bf
--- /dev/null
+++ b/src/unifolm_wma/utils/utils.py
@@ -0,0 +1,81 @@
+import importlib
+import numpy as np
+import cv2
+import torch
+import torch.distributed as dist
+
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(
+ f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params."
+ )
+ return total_params
+
+
+def check_istarget(name, para_list):
+ """
+ name: full name of source para
+ para_list: partial name of target para
+ """
+ istarget = False
+ for para in para_list:
+ if para in name:
+ return True
+ return istarget
+
+
+def instantiate_from_config(config):
+ if not "target" in config:
+ if config == '__is_first_stage__':
+ return None
+ elif config == "__is_unconditional__":
+ return None
+ raise KeyError("Expected key `target` to instantiate.")
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
+
+
+def get_obj_from_str(string, reload=False):
+ module, cls = string.rsplit(".", 1)
+ if reload:
+ module_imp = importlib.import_module(module)
+ importlib.reload(module_imp)
+ return getattr(importlib.import_module(module, package=None), cls)
+
+
+def load_npz_from_dir(data_dir):
+ data = [
+ np.load(os.path.join(data_dir, data_name))['arr_0']
+ for data_name in os.listdir(data_dir)
+ ]
+ data = np.concatenate(data, axis=0)
+ return data
+
+
+def load_npz_from_paths(data_paths):
+ data = [np.load(data_path)['arr_0'] for data_path in data_paths]
+ data = np.concatenate(data, axis=0)
+ return data
+
+
+def resize_numpy_image(image,
+ max_resolution=512 * 512,
+ resize_short_edge=None):
+ h, w = image.shape[:2]
+ if resize_short_edge is not None:
+ k = resize_short_edge / min(h, w)
+ else:
+ k = max_resolution / (h * w)
+ k = k**0.5
+ h = int(np.round(h * k / 64)) * 64
+ w = int(np.round(w * k / 64)) * 64
+ image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
+ return image
+
+
+def setup_dist(args):
+ if dist.is_initialized():
+ return
+ torch.cuda.set_device(args.local_rank)
+ torch.distributed.init_process_group('nccl', init_method='env://')
diff --git a/unitree_deploy/README.md b/unitree_deploy/README.md
new file mode 100644
index 0000000..c0d1ba6
--- /dev/null
+++ b/unitree_deploy/README.md
@@ -0,0 +1,223 @@
+# Unitree Deploy
+
+